medi-llm / src /multimodal_model.py
Preetham22's picture
Add demo UI, token attention rollout & top5 table; clean ignores
42e56c5
import torch # Deep learning framework
import torch.nn as nn
import timm # For Image models like ResNet
from transformers import AutoModel # Pretrained text encoders
# ================================================
# NOTE:
# Future upgrades can include:
# - MLP Fusion
# - Bilinear Pooling
# - Attention-based Fusion
# - Cross-modal Transformers
# These require structural changes to the forward() logic.
# ================================================
class MediLLMModel(nn.Module):
def __init__(
self,
text_model_name=("emilyalsentzer/Bio_ClinicalBERT"),
# Bio_ClinicalBERT is a pretrained model on clinical notes,
# output to 3 classes i.e triage levels
num_classes=3,
dropout=0.3,
hidden_dim=256,
mode="multimodal",
):
super(MediLLMModel, self).__init__() # Use constructor of nn.Module
assert mode in [
"text",
"image",
"multimodal",
], "Mode must be one of: 'text', 'image', or 'multimodal'"
self.mode = mode.lower()
# Text encoder: Bio_ClinicalBERT
self.text_encoder = AutoModel.from_pretrained(
text_model_name
) # Automodel returns base model without a classification head,
# just embeddings
self.text_hidden_size = self.text_encoder.config.hidden_size
# Dimensionality of hidden states i.e embedding vector size returned by
# the text_encoder for each token, 768 for Bert models
# Image encoder: ResNet-50 via TIMM
"""
Bottle neck block used in ResNet-50
Input x
↓
Conv1x1 β†’ BN β†’ ReLU β†’ (reduce size)
Conv3x3 β†’ BN β†’ ReLU β†’ (main conv)
Conv1x1 β†’ BN β†’ (restore size)
↓
+ Add skip (possibly Conv-adjusted x)
↓
ReLU
"""
# Image encoder
self.image_encoder = timm.create_model(
"resnet50", pretrained=True, num_classes=0
) # Model for images ResNet-50 which means 50 layers.
# Initial Conv + pooling layer (size reduction),
# 16 residual blocks * 3 layers per block = 48 layers,
# Final Fully-connected layer.
# Each block has 3 conv layers + skip connection,
# i.e. input is fed into the next block with the output (F(x) + x)
self.image_hidden_size = (
self.image_encoder.num_features
) # Size of ResNet output ---> 2048 for ResNet-50
# --- Determine fusion dimension based on the mode ---
if self.mode == "text":
fusion_dim = self.text_hidden_size
elif self.mode == "image":
fusion_dim = self.image_hidden_size
else:
fusion_dim = self.text_hidden_size + self.image_hidden_size
# --- Classification head ---
# pass the combined features into the classifier
self.classifier = nn.Sequential(
nn.Linear(fusion_dim, hidden_dim), # Dense layer
nn.ReLU(), # Non-linear activation function
nn.Dropout(dropout), # randomly zeroes 30 percent of neuron outputs
# to prevent over-fitting
nn.Linear(hidden_dim, num_classes), # Final Classification output
)
def forward(self, input_ids=None, attention_mask=None, image=None, output_attentions=False, return_raw_attentions=False):
# input_ids shape: [batch, seq_length]
# attention_mask: mask to ignore padding, same shape as input_ids
# image: [batch, 3, 224, 224]
# Text features
if self.mode in ["text", "multimodal"]:
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
# feed tokenized text into the BERT Model which returns a
# dictionary with last_hidden_state: [batch_size, seq_len,
# hidden_size], pooler_output: [batch_size, hidden_size]
# (CLS embeddings), hidden_states: List of tensors,
# attentions(weights): List of Tensors
last_hidden = text_outputs.last_hidden_state # CLS token, return CLS tokens from all batches, position 0,
# a batch of 3 sentences has 3 CLS tokens
cls_embedding = last_hidden[:, 0, :] # CLS tokens of all batches [batch, hidden_dim]
# Real token attention using last-layer CLS attention weights
# attentions = List[12 tensors] -> each [batch, heads, seq_len, seq_len]
token_attn_scores = None
raw_attentions = None
if output_attentions:
attention_maps = text_outputs.attentions
last_layer_attn = attention_maps[-1] # [batch, heads, seq_len, seq_len]
avg_attn = last_layer_attn.mean(dim=1) # Average across heads -> [batch, seq_len, seq_len]
token_attn_scores = avg_attn[:, 0, :] # CLS attends to all tokens -> [batch, seq_len]
if return_raw_attentions:
raw_attentions = attention_maps
else:
cls_embedding = None
token_attn_scores = None
raw_attentions = None
# Image features
if self.mode == "image":
features = self.image_encoder(image) # pass the image through ResNet, returns a [batch, 2048] tensor
elif self.mode == "text": # text
features = cls_embedding
else: # multimodal
image_feat = self.image_encoder(image)
features = torch.cat(
(cls_embedding, image_feat), dim=1
) # Concatenates text and image features along feature dimension
# [CLS vector from BERT] + [ResNet image vector]
# -> [batch_size, 2816]
# === Placeholder: Advanced Fusion Methods ===
# Option 1: Bilinear Fusion
# fused = torch.bmm(text_feat.unsqueeze(2), img_feat.unsqueeze(1)).view(batch_size, -1)
# fused = self.bilinear_fc(fused)
# Option 2: Cross-Modal Attention
# - Use attention mechanism where one modality attends to another
# - E.g., compute attention weights over image using text as query
# - Requires custom attention modules or transformers
# Option 3: Cross-modal Transformer Encoder
# - Concatenate image and text features as tokens
# - Feed into transformer encoder with positional embeddings
# Option 4: Fusion Logic
# fused = torch.cat([text_feat, img_feat], dim=1)
# fused = self.dropout(torch.relu(self.fusion_fc1(fused)))
# fused = self.dropout(torch.relu(self.fusion_fc2(fused)))
# return self.classifier(fused)
# Return logits for each class, later apply softmax during evaluation
logits = self.classifier(features)
return {
"logits": logits,
"token_attentions": token_attn_scores, # [batch, seq_len] or None
"raw_attentions": raw_attentions if return_raw_attentions else None,
}