File size: 13,624 Bytes
1ff0d2d
43dc105
 
 
42e56c5
43dc105
 
 
1ff0d2d
43dc105
 
 
 
 
42e56c5
 
1ff0d2d
 
 
43dc105
 
1ff0d2d
 
131cb13
 
371f394
1ff0d2d
 
 
371f394
1ff0d2d
 
 
371f394
 
 
 
 
 
 
 
 
1ff0d2d
 
 
 
 
371f394
 
 
 
 
 
 
 
 
 
 
 
 
 
1ff0d2d
 
 
 
371f394
 
 
131cb13
 
1ff0d2d
131cb13
1ff0d2d
 
371f394
 
 
1ff0d2d
 
 
 
 
 
43dc105
 
 
 
 
 
 
 
 
 
 
1ff0d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43dc105
1ff0d2d
 
 
 
43dc105
1ff0d2d
43dc105
 
 
 
 
1ff0d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43dc105
 
 
 
 
1ff0d2d
 
 
42e56c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ff0d2d
 
 
42e56c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43dc105
42e56c5
 
 
 
 
43dc105
 
42e56c5
43dc105
42e56c5
 
 
 
 
43dc105
42e56c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import os
import sys
import torch
import yaml
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer
from torchvision import transforms
from huggingface_hub import hf_hub_download

ROOT_DIR = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(ROOT_DIR))

from src.multimodal_model import MediLLMModel
from app.utils.gradcam_utils import register_hooks, generate_gradcam

# --------------------
# Runtime / Hub config
# --------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Map modes -> filenames in  HF model repo
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "Preetham22/medi-llm-weights")
_raw_rev = os.getenv("HF_WEIGHTS_REV", None)
HF_WEIGHTS_REV = _raw_rev if (_raw_rev and _raw_rev.strip()) else None  # optional (commit/tag/branch), can be None

FILENAMES = {
    "text": "medi_llm_state_dict_text.pth",
    "image": "medi_llm_state_dict_image.pth",
    "multimodal": "medi_llm_state_dict_multimodal.pth",
}


def have_internet():
    try:
        import socket
        socket.create_connection(("huggingface.co", 443), timeout=3).close()
        return True
    except Exception:
        return False


def resolve_weights_path(mode: str) -> str:
    """Download (or reuse cached)  weights for the given mode from HF Hub."""
    if mode not in FILENAMES:
        raise ValueError(f"Unknown mode '{mode}'. Expected one of {list(FILENAMES)}.")
    filename = FILENAMES[mode]

    # 1) Prefer a file already present in Space rep
    local_path = ROOT_DIR / filename
    if local_path.exists():
        return str(local_path)

    # 2) If no local file and no internet, bail early
    if not have_internet():
        raise RuntimeError(
            f"❌ Internet is disabled and weights are not present locally.\n"
            f"  Upload '{filename}' to this Space or enable Network access."
        )

    # 3) Otherwise, download from Hub
    try:
        return hf_hub_download(
            repo_id=HF_MODEL_REPO,
            filename=filename,
            revision=HF_WEIGHTS_REV,         # can be None -> default branch
            repo_type="model",               # change to "dataset" if needed
            local_dir=str(ROOT_DIR),         # Keep a copy in repo dir
            local_dir_use_symlinks=False,    # avoid symlink weirdness
            token=None,                      # For public repo
        )

    except Exception as e:
        raise RuntimeError(
            f"Failed to fetch weights '{filename}' from repo '{HF_MODEL_REPO}'. "
            f"Either enable Network access for this Space or commit the file locally. "
            f"Original error: {e}"
        )


# ----------------------
# Labels / preprocessing
# ----------------------
inv_map = {0: "low", 1: "medium", 2: "high"}

# Tokenizer and image transform
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])


# ----------------------
# Model load
# ----------------------
def _safe_torch_load(path: str, map_location: torch.device):
    """
    Prefer weights_only=True (newer Pytorch), but fall back if not supported.
    """
    try:
        return torch.load(path, map_location=map_location, weights_only=True)  # PyTorch >= 2.2/2.3
    except TypeError:
        return torch.load(path, map_location=map_location)


def load_model(mode: str, config_path: str = str(Path("config/config.yaml").resolve())):
    """
    Load MediLLMModel for the given mode and populate weights from HF Hub.
    Expects config/config.yaml with keys per mode (dropout, hidden_dim).
    """
    with open(config_path, "r") as f:
        cfg_all = yaml.safe_load(f)
    if mode not in cfg_all:
        raise KeyError(f"Mode '{mode}' not found in {config_path}. Keys: {list(cfg_all.keys())}")
    config = cfg_all[mode]

    # Build model
    model = MediLLMModel(
        mode=mode,
        dropout=config["dropout"],
        hidden_dim=config["hidden_dim"]
    )

    # Download weights & load
    weights_path = resolve_weights_path(mode)
    state = _safe_torch_load(weights_path, map_location=DEVICE)

    # Sometimes checkpoints save as {'state_dict': ...}
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]

    try:
        model.load_state_dict(state)  # strict by default
    except RuntimeError as e:
        # allow non-strict if minor mismatches (buffer names)
        try:
            model.load_state_dict(state, strict=False)
            print(f"⚠️ Loaded with strict=False due to: {e}")
        except Exception:
            raise

    model.to(DEVICE)
    model.eval()
    return model


# -----------------------
# Attention rollout utils
# -----------------------
def attention_rollout(attentions, last_k=4, residual_alpha=0.5):
    """
    attentions_tuple: tuple/list of layer attentions; each is (B,H,S,S)
    last_k: only roll back through the last k layers (keeps contrast)
    residual_alpha: how much identity to add before normalizing (preserve token self-info)
    returns: [B, S, S] rollout matrix, or None if input is invalid
    """
    if attentions is None:
        return None
    if isinstance(attentions, (list, tuple)) and len(attentions) == 0:
        return None

    first = attentions[0]
    if first is None or first.ndim != 4:
        return None  # expect [B, H, S, S]

    B, H, S, _ = first.shape
    eye = torch.eye(S, device=first.device).unsqueeze(0).expand(B, S, S)  # [B, S, S]

    L = len(attentions)
    if last_k is None:
        last_k = L
    if last_k <= 0:
        # No layers selected -> return identity (no propagation)
        return eye.clone()

    start = max(0, L - last_k)
    A = None
    for layer in range(start, L):
        a = attentions[layer]
        if a is None or a.ndim != 4 or a.shape[0] != B or a.shape[-1] != S:
            # Skip malformed layer
            continue
        a = a.mean(dim=1)  # [B, S, S] (avg heads)
        a = a + float(residual_alpha) * eye
        a = a / (a.sum(dim=-1, keepdim=True) + 1e-12)  # row-normalize
        A = a if A is None else torch.bmm(A, a)

    # if we never multiplied like when all layers skipped, fall back to identity
    return A if A is not None else eye.clone()  # [B,S,S]


def merge_wordpieces(tokens, scores):
    merged_tokens, merged_scores = [], []
    cur_tok, cur_scores = "", []
    for t, s in zip(tokens, scores):
        if t.startswith("##"):
            cur_tok += t[2:]
            cur_scores.append(s)
        else:
            if cur_tok:
                merged_tokens.append(cur_tok)
                merged_scores.append(sum(cur_scores) / max(1, len(cur_scores)))
            cur_tok, cur_scores = t, [s]
    if cur_tok:
        merged_tokens.append(cur_tok)
        merged_scores.append(sum(cur_scores) / max(1, len(cur_scores)))
    return merged_tokens, merged_scores


def _normalize_for_display_wordlevel(attn_scores, normalize_mode="visual", temperature=0.30):
    """
    Convert raw *word-level* token scores into:
      - probabilistic mode: probabilities that sum to 1.0 (100%), with labels like "0.237 | 23.7% (contrib)"
      - visual mode: min-max + gamma scaling (contrast, not sum-to-100), with labels like "0.68 | visual score"

      Returns:
        attn_final: np.ndarray of floats in [0, 1] for color scale
        labels: list[str] per token (tooltip text; first number stays up front for your color_map bucketing)
    """
    attn_array = np.array(attn_scores, dtype=float)

    if normalize_mode == "probabilistic":
        # ---- percentage view that sums up to 100% ----
        attn_array = np.maximum(attn_array, 0.0)
        if attn_array.max() > 0:
            attn_array = attn_array / (attn_array.max() + 1e-12)  # scale to [0, 1] for stability
        # sharpen (lower temp => peakier)
        attn_array = np.power(attn_array + 1e-12, 1.0 / max(1e-6, float(temperature)))
        prob = attn_array / (attn_array.sum() + 1e-12)
        percent = prob * 100.0

        # keep prob (0..1) for color scale; label with % contrib
        labels = [f"{prob[i]:.3f} | {percent[i]:.1f}% (contrib)" for i in range(len(prob))]
        return prob, labels
    else:
        # ---- visual: min-max + gamma (contrast, not sum-to-100) ---
        if attn_array.max() > attn_array.min():
            attn_array0 = (attn_array - attn_array.min()) / (attn_array.max() - attn_array.min() + 1e-8)
            attn_array0 = np.clip(np.power(attn_array0, 0.75), 0.1, 1.0)
        else:
            attn_array0 = np.zeros_like(attn_array)
        labels = [f"{attn_array0[i]:.2f} | visual score" for i in range(len(attn_array0))]
        return attn_array0, labels


# ------------------
# Prediction
# ------------------
def predict(
    model,
    mode,
    emr_text=None,
    image=None,
    normalize_mode="visual",
    need_token_vis=False,
    use_rollout=False
):
    """
    normalize_mode: "visual" (min-max + gamma boost) or "probabilistic" (softmax)
    need_token_vis: request/compute token-level attentions (Doctor mode + text/multimodal)
    use_rollout: use attention rollout across layers
    """
    input_ids = attention_mask = img_tensor = None
    cam_image = None
    highlighted_tokens = None
    top5 = []

    if mode in ["text", "multimodal"] and emr_text:
        text_tokens = tokenizer(
            emr_text,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=128,
        )
        input_ids = text_tokens["input_ids"].to(DEVICE)
        attention_mask = text_tokens["attention_mask"].to(DEVICE)

    if mode in ["image", "multimodal"] and image:
        img_tensor = image_transform(image).unsqueeze(0).to(DEVICE)

    # Only Register hooks for Grad-CAM if needed
    if mode in ["image", "multimodal"]:
        activations, gradients, fwd_handle, bwd_handle = register_hooks(model)
        model.zero_grad()

    # === Forward ===
    # Only enable attentions when planning to visualize them
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        image=img_tensor,
        output_attentions=bool(need_token_vis and (mode in ["text", "multimodal"])),
        return_raw_attentions=bool(use_rollout and need_token_vis)
    )

    logits = outputs["logits"]
    if logits.numel() == 0:
        raise ValueError("Model returned empty logits. Check input format.")

    probs = torch.softmax(logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()
    confidence = probs.squeeze()[pred].item()

    # === Grad-CAM ===
    if mode in ["image", "multimodal"]:
        # Enable gradients only for Grad-CAM
        logits[0, pred].backward(retain_graph=True)
        cam_image = generate_gradcam(image, activations, gradients)
        fwd_handle.remove()
        bwd_handle.remove()

    # === Token-level attention ===
    if need_token_vis and (mode in ["text", "multimodal"]):
        token_attn_scores = None

        if use_rollout and outputs.get("raw_attentions") is not None:
            # partial rollout
            # roll: [B, S, S]; roll[b, 0, :] is CLS-to-all tokens for that batch item
            roll = attention_rollout(outputs["raw_attentions"], last_k=4, residual_alpha=0.5)  # [B,S,S]  # (S, S)
            if roll is not None:
                # roll: [B, S, S]; pick CLS row (index 0)
                cls_to_tokens = roll[0, 0].detach().cpu().numpy().tolist()  # CLS row
                token_attn_scores = cls_to_tokens
        elif outputs.get("token_attentions") is not None:
            token_attn_scores = outputs["token_attentions"].squeeze().tolist()

        if token_attn_scores is not None:
            # Filter out specials/pad + aligh to wordpieces
            ids = input_ids[0].tolist()
            amask = attention_mask[0].tolist() if attention_mask is not None else [1] * len(ids)
            wp_all = tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False)
            special_ids = set(tokenizer.all_special_ids)
            keep_idx = [i for i, (tid, m) in enumerate(zip(ids, amask)) if (tid not in special_ids) and (m == 1)]
            wp_tokens = [wp_all[i] for i in keep_idx]
            wp_scores = [token_attn_scores[i] if i < len(token_attn_scores) else 0.0 for i in keep_idx]

            # Merge wordpieces into words
            word_tokens, attn_scores = merge_wordpieces(wp_tokens, wp_scores)

            # Build Top-5 (probabilistic normalization for ranking)
            _probs_for_rank, _ = _normalize_for_display_wordlevel(
                attn_scores, normalize_mode="probabilistic", temperature=0.30
            )
            pairs = list(zip(word_tokens, _probs_for_rank))
            pairs.sort(key=lambda x: x[1], reverse=True)
            top5 = [(tok, float(p * 100.0)) for tok, p in pairs[:5]]

            # Final display (probabilistic or visual)
            attn_final, labels = _normalize_for_display_wordlevel(
                attn_scores,
                normalize_mode=normalize_mode,
                temperature=0.30,
            )

            highlighted_tokens = [(tok, labels[i]) for i, tok in enumerate(word_tokens)]

        print("🧪 Normalization Mode Received:", normalize_mode)
        if highlighted_tokens:
            print("🟣 Highlighted tokens sample:", highlighted_tokens[:5])
        else:
            print("🟣 No highlighted tokens (no text or attentions unavailable).")

    return inv_map[pred], cam_image, highlighted_tokens, confidence, probs.tolist(), top5