import os import sys import torch import yaml import numpy as np from pathlib import Path from transformers import AutoTokenizer from torchvision import transforms from huggingface_hub import hf_hub_download ROOT_DIR = Path(__file__).resolve().parent.parent.parent sys.path.append(str(ROOT_DIR)) from src.multimodal_model import MediLLMModel from app.utils.gradcam_utils import register_hooks, generate_gradcam # -------------------- # Runtime / Hub config # -------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Map modes -> filenames in HF model repo HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "Preetham22/medi-llm-weights") _raw_rev = os.getenv("HF_WEIGHTS_REV", None) HF_WEIGHTS_REV = _raw_rev if (_raw_rev and _raw_rev.strip()) else None # optional (commit/tag/branch), can be None FILENAMES = { "text": "medi_llm_state_dict_text.pth", "image": "medi_llm_state_dict_image.pth", "multimodal": "medi_llm_state_dict_multimodal.pth", } def have_internet(): try: import socket socket.create_connection(("huggingface.co", 443), timeout=3).close() return True except Exception: return False def resolve_weights_path(mode: str) -> str: """Download (or reuse cached) weights for the given mode from HF Hub.""" if mode not in FILENAMES: raise ValueError(f"Unknown mode '{mode}'. Expected one of {list(FILENAMES)}.") filename = FILENAMES[mode] # 1) Prefer a file already present in Space rep local_path = ROOT_DIR / filename if local_path.exists(): return str(local_path) # 2) If no local file and no internet, bail early if not have_internet(): raise RuntimeError( f"โŒ Internet is disabled and weights are not present locally.\n" f" Upload '{filename}' to this Space or enable Network access." ) # 3) Otherwise, download from Hub try: return hf_hub_download( repo_id=HF_MODEL_REPO, filename=filename, revision=HF_WEIGHTS_REV, # can be None -> default branch repo_type="model", # change to "dataset" if needed local_dir=str(ROOT_DIR), # Keep a copy in repo dir local_dir_use_symlinks=False, # avoid symlink weirdness token=None, # For public repo ) except Exception as e: raise RuntimeError( f"Failed to fetch weights '{filename}' from repo '{HF_MODEL_REPO}'. " f"Either enable Network access for this Space or commit the file locally. " f"Original error: {e}" ) # ---------------------- # Labels / preprocessing # ---------------------- inv_map = {0: "low", 1: "medium", 2: "high"} # Tokenizer and image transform tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") image_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor() ]) # ---------------------- # Model load # ---------------------- def _safe_torch_load(path: str, map_location: torch.device): """ Prefer weights_only=True (newer Pytorch), but fall back if not supported. """ try: return torch.load(path, map_location=map_location, weights_only=True) # PyTorch >= 2.2/2.3 except TypeError: return torch.load(path, map_location=map_location) def load_model(mode: str, config_path: str = str(Path("config/config.yaml").resolve())): """ Load MediLLMModel for the given mode and populate weights from HF Hub. Expects config/config.yaml with keys per mode (dropout, hidden_dim). """ with open(config_path, "r") as f: cfg_all = yaml.safe_load(f) if mode not in cfg_all: raise KeyError(f"Mode '{mode}' not found in {config_path}. Keys: {list(cfg_all.keys())}") config = cfg_all[mode] # Build model model = MediLLMModel( mode=mode, dropout=config["dropout"], hidden_dim=config["hidden_dim"] ) # Download weights & load weights_path = resolve_weights_path(mode) state = _safe_torch_load(weights_path, map_location=DEVICE) # Sometimes checkpoints save as {'state_dict': ...} if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] try: model.load_state_dict(state) # strict by default except RuntimeError as e: # allow non-strict if minor mismatches (buffer names) try: model.load_state_dict(state, strict=False) print(f"โš ๏ธ Loaded with strict=False due to: {e}") except Exception: raise model.to(DEVICE) model.eval() return model # ----------------------- # Attention rollout utils # ----------------------- def attention_rollout(attentions, last_k=4, residual_alpha=0.5): """ attentions_tuple: tuple/list of layer attentions; each is (B,H,S,S) last_k: only roll back through the last k layers (keeps contrast) residual_alpha: how much identity to add before normalizing (preserve token self-info) returns: [B, S, S] rollout matrix, or None if input is invalid """ if attentions is None: return None if isinstance(attentions, (list, tuple)) and len(attentions) == 0: return None first = attentions[0] if first is None or first.ndim != 4: return None # expect [B, H, S, S] B, H, S, _ = first.shape eye = torch.eye(S, device=first.device).unsqueeze(0).expand(B, S, S) # [B, S, S] L = len(attentions) if last_k is None: last_k = L if last_k <= 0: # No layers selected -> return identity (no propagation) return eye.clone() start = max(0, L - last_k) A = None for layer in range(start, L): a = attentions[layer] if a is None or a.ndim != 4 or a.shape[0] != B or a.shape[-1] != S: # Skip malformed layer continue a = a.mean(dim=1) # [B, S, S] (avg heads) a = a + float(residual_alpha) * eye a = a / (a.sum(dim=-1, keepdim=True) + 1e-12) # row-normalize A = a if A is None else torch.bmm(A, a) # if we never multiplied like when all layers skipped, fall back to identity return A if A is not None else eye.clone() # [B,S,S] def merge_wordpieces(tokens, scores): merged_tokens, merged_scores = [], [] cur_tok, cur_scores = "", [] for t, s in zip(tokens, scores): if t.startswith("##"): cur_tok += t[2:] cur_scores.append(s) else: if cur_tok: merged_tokens.append(cur_tok) merged_scores.append(sum(cur_scores) / max(1, len(cur_scores))) cur_tok, cur_scores = t, [s] if cur_tok: merged_tokens.append(cur_tok) merged_scores.append(sum(cur_scores) / max(1, len(cur_scores))) return merged_tokens, merged_scores def _normalize_for_display_wordlevel(attn_scores, normalize_mode="visual", temperature=0.30): """ Convert raw *word-level* token scores into: - probabilistic mode: probabilities that sum to 1.0 (100%), with labels like "0.237 | 23.7% (contrib)" - visual mode: min-max + gamma scaling (contrast, not sum-to-100), with labels like "0.68 | visual score" Returns: attn_final: np.ndarray of floats in [0, 1] for color scale labels: list[str] per token (tooltip text; first number stays up front for your color_map bucketing) """ attn_array = np.array(attn_scores, dtype=float) if normalize_mode == "probabilistic": # ---- percentage view that sums up to 100% ---- attn_array = np.maximum(attn_array, 0.0) if attn_array.max() > 0: attn_array = attn_array / (attn_array.max() + 1e-12) # scale to [0, 1] for stability # sharpen (lower temp => peakier) attn_array = np.power(attn_array + 1e-12, 1.0 / max(1e-6, float(temperature))) prob = attn_array / (attn_array.sum() + 1e-12) percent = prob * 100.0 # keep prob (0..1) for color scale; label with % contrib labels = [f"{prob[i]:.3f} | {percent[i]:.1f}% (contrib)" for i in range(len(prob))] return prob, labels else: # ---- visual: min-max + gamma (contrast, not sum-to-100) --- if attn_array.max() > attn_array.min(): attn_array0 = (attn_array - attn_array.min()) / (attn_array.max() - attn_array.min() + 1e-8) attn_array0 = np.clip(np.power(attn_array0, 0.75), 0.1, 1.0) else: attn_array0 = np.zeros_like(attn_array) labels = [f"{attn_array0[i]:.2f} | visual score" for i in range(len(attn_array0))] return attn_array0, labels # ------------------ # Prediction # ------------------ def predict( model, mode, emr_text=None, image=None, normalize_mode="visual", need_token_vis=False, use_rollout=False ): """ normalize_mode: "visual" (min-max + gamma boost) or "probabilistic" (softmax) need_token_vis: request/compute token-level attentions (Doctor mode + text/multimodal) use_rollout: use attention rollout across layers """ input_ids = attention_mask = img_tensor = None cam_image = None highlighted_tokens = None top5 = [] if mode in ["text", "multimodal"] and emr_text: text_tokens = tokenizer( emr_text, return_tensors="pt", truncation=True, padding="max_length", max_length=128, ) input_ids = text_tokens["input_ids"].to(DEVICE) attention_mask = text_tokens["attention_mask"].to(DEVICE) if mode in ["image", "multimodal"] and image: img_tensor = image_transform(image).unsqueeze(0).to(DEVICE) # Only Register hooks for Grad-CAM if needed if mode in ["image", "multimodal"]: activations, gradients, fwd_handle, bwd_handle = register_hooks(model) model.zero_grad() # === Forward === # Only enable attentions when planning to visualize them outputs = model( input_ids=input_ids, attention_mask=attention_mask, image=img_tensor, output_attentions=bool(need_token_vis and (mode in ["text", "multimodal"])), return_raw_attentions=bool(use_rollout and need_token_vis) ) logits = outputs["logits"] if logits.numel() == 0: raise ValueError("Model returned empty logits. Check input format.") probs = torch.softmax(logits, dim=1) pred = torch.argmax(probs, dim=1).item() confidence = probs.squeeze()[pred].item() # === Grad-CAM === if mode in ["image", "multimodal"]: # Enable gradients only for Grad-CAM logits[0, pred].backward(retain_graph=True) cam_image = generate_gradcam(image, activations, gradients) fwd_handle.remove() bwd_handle.remove() # === Token-level attention === if need_token_vis and (mode in ["text", "multimodal"]): token_attn_scores = None if use_rollout and outputs.get("raw_attentions") is not None: # partial rollout # roll: [B, S, S]; roll[b, 0, :] is CLS-to-all tokens for that batch item roll = attention_rollout(outputs["raw_attentions"], last_k=4, residual_alpha=0.5) # [B,S,S] # (S, S) if roll is not None: # roll: [B, S, S]; pick CLS row (index 0) cls_to_tokens = roll[0, 0].detach().cpu().numpy().tolist() # CLS row token_attn_scores = cls_to_tokens elif outputs.get("token_attentions") is not None: token_attn_scores = outputs["token_attentions"].squeeze().tolist() if token_attn_scores is not None: # Filter out specials/pad + aligh to wordpieces ids = input_ids[0].tolist() amask = attention_mask[0].tolist() if attention_mask is not None else [1] * len(ids) wp_all = tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False) special_ids = set(tokenizer.all_special_ids) keep_idx = [i for i, (tid, m) in enumerate(zip(ids, amask)) if (tid not in special_ids) and (m == 1)] wp_tokens = [wp_all[i] for i in keep_idx] wp_scores = [token_attn_scores[i] if i < len(token_attn_scores) else 0.0 for i in keep_idx] # Merge wordpieces into words word_tokens, attn_scores = merge_wordpieces(wp_tokens, wp_scores) # Build Top-5 (probabilistic normalization for ranking) _probs_for_rank, _ = _normalize_for_display_wordlevel( attn_scores, normalize_mode="probabilistic", temperature=0.30 ) pairs = list(zip(word_tokens, _probs_for_rank)) pairs.sort(key=lambda x: x[1], reverse=True) top5 = [(tok, float(p * 100.0)) for tok, p in pairs[:5]] # Final display (probabilistic or visual) attn_final, labels = _normalize_for_display_wordlevel( attn_scores, normalize_mode=normalize_mode, temperature=0.30, ) highlighted_tokens = [(tok, labels[i]) for i, tok in enumerate(word_tokens)] print("๐Ÿงช Normalization Mode Received:", normalize_mode) if highlighted_tokens: print("๐ŸŸฃ Highlighted tokens sample:", highlighted_tokens[:5]) else: print("๐ŸŸฃ No highlighted tokens (no text or attentions unavailable).") return inv_map[pred], cam_image, highlighted_tokens, confidence, probs.tolist(), top5