coco-demo / interpretability_demo.py
evanec's picture
Upload 11 files
21f2675 verified
import torch
from PIL import Image
from src.interpretability import cross_attention_to_image
import numpy as np
import matplotlib.cm as cm
def resize_for_display(pil_img, max_dim=5000):
w, h = pil_img.size
if max(w, h) <= max_dim:
return pil_img
scale = max_dim / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
return pil_img.resize((new_w, new_h), Image.LANCZOS)
@torch.no_grad()
def generate_rollout_for_demo(model, tokenizer, img, preprocess,
device="cuda", max_new_tokens=32, alpha=0.45):
model.eval()
img_tensor = preprocess(img).unsqueeze(0).to(device)
vision_out = model.vision_encoder(img_tensor)
img_embeds = vision_out["image_embeds"]
if img_embeds.dim() == 2:
img_embeds = img_embeds.unsqueeze(1)
projected = model.projector(img_embeds)
decoder_input_ids = torch.tensor(
[[model.t5.config.decoder_start_token_id]], device=device
)
generated_ids = []
avg_frames = []
labels = []
per_head_frames = []
num_heads = None
# Decode token-by-token
for step in range(max_new_tokens):
outputs = model.t5(
encoder_outputs=(projected,),
decoder_input_ids=decoder_input_ids,
output_attentions=True,
return_dict=True,
)
# Cross-attention from last decoder layer
last_cross = outputs.cross_attentions[-1][0] # (heads, tgt, src)
num_heads = last_cross.size(0)
# average over heads (tgt, src)
attn_avg = last_cross.mean(dim=0)
# Get attention for the last generated token (tgt index = -1)
attn_vec = attn_avg[-1] # shape: (src_len,)
heat_avg = cross_attention_to_image(attn_vec)
if isinstance(heat_avg, tuple):
heat_avg = heat_avg[0]
if isinstance(heat_avg, np.ndarray):
heat_avg = Image.fromarray((heat_avg * 255).astype("uint8"))
avg_frames.append(
overlay_attention_for_demo(img_tensor, heat_avg, alpha=alpha)
)
head_overlays = []
for h in range(num_heads):
attn_vec_h = last_cross[h][-1] # (src_len,)
hmap = cross_attention_to_image(attn_vec_h)
if isinstance(hmap, tuple):
hmap = hmap[0]
if isinstance(hmap, np.ndarray):
hmap = Image.fromarray((hmap * 255).astype("uint8"))
head_overlays.append(
overlay_attention_for_demo(img_tensor, hmap, alpha=alpha)
)
per_head_frames.append(head_overlays)
# Decode next token
next_token = outputs.logits[:, -1, :].argmax(-1)
token_str = tokenizer.decode(next_token, skip_special_tokens=True)
labels.append(f"Token #{step}: \"{token_str}\"")
generated_ids.append(int(next_token))
if next_token.item() == tokenizer.eos_token_id:
break
decoder_input_ids = torch.cat(
[decoder_input_ids, next_token.unsqueeze(0)], dim=1
)
# Caption
caption = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Return structured dict for Gradio
return {
"caption": caption,
"avg": {
"frames": avg_frames,
"labels": labels
},
"heads": {
"frames": per_head_frames, # list[step][head] = PIL image
"labels": labels,
"num_heads": num_heads
}
}
def overlay_attention_for_demo(image_tensor, heatmap, alpha=0.45):
img = image_tensor[0].detach().cpu().permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min()) # normalize
img_uint8 = (img * 255).astype("uint8")
heatmap = heatmap.resize((img_uint8.shape[1], img_uint8.shape[0]), Image.BILINEAR)
heat_np = np.asarray(heatmap).astype("float32") / 255.0
base = Image.fromarray(img_uint8).convert("RGBA")
colored = cm.inferno(heat_np) # returns RGBA float array
colored_uint8 = (colored * 255).astype("uint8")
heat = Image.fromarray(colored_uint8).convert("RGBA")
heat.putalpha(int(alpha * 255))
blended = Image.alpha_composite(base, heat)
blended = blended.convert("RGB")
return blended #resize_for_display(blended, max_dim=500)