In [1]:
import os, sys

# Automatically adds project root to Python's import path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
 sys.path.append(project_root)

import torch
from transformers import AutoTokenizer
from PIL import Image
from torchvision import transforms
from src.multimodal_model import MediLLMModel


In [2]:
# Load model
model = MediLLMModel()
model.eval()

MediLLMModel(
 (text_encoder): BertModel(
 (embeddings): BertEmbeddings(
 (word_embeddings): Embedding(28996, 768, padding_idx=0)
 (position_embeddings): Embedding(512, 768)
 (token_type_embeddings): Embedding(2, 768)
 (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 (dropout): Dropout(p=0.1, inplace=False)
 )
 (encoder): BertEncoder(
 (layer): ModuleList(
 (0-11): 12 x BertLayer(
 (attention): BertAttention(
 (self): BertSdpaSelfAttention(
 (query): Linear(in_features=768, out_features=768, bias=True)
 (key): Linear(in_features=768, out_features=768, bias=True)
 (value): Linear(in_features=768, out_features=768, bias=True)
 (dropout): Dropout(p=0.1, inplace=False)
 )
 (output): BertSelfOutput(
 (dense): Linear(in_features=768, out_features=768, bias=True)
 (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 (dropout): Dropout(p=0.1, inplace=False)
 )
 )
 (intermediate): BertIntermediate(
 (dense): Linear(in_features=768, out_features=3072, bias=True)

In [3]:
# Dummy text
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
text = "Patient reports mild chest pain and fatigue for 3 days."
tokens = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)

In [7]:
# Dummy image
img_path = os.path.join(project_root, "data", "images", "NORMAL", "NORMAL-1.png")
if not os.path.exists(img_path):
 raise FileNotFoundError(f"Image not found at {img_path}")
else:
 img = Image.open(img_path).convert("RGB")
 
transform = transforms.Compose([
 transforms.Resize((224, 224)),
 transforms.ToTensor(),
])

img_tensor = transform(img).unsqueeze(0) # Adds a another dimension at position 0, i.e. batch number as deep learning models expects batch as input also [batch, channels, height, width]


In [8]:
# Run model
with torch.no_grad():
 out = model(tokens['input_ids'], tokens['attention_mask'], img_tensor)
 probs = torch.softmax(out, dim=1)

print("Prediction probabilities:", probs)

Prediction probabilities: tensor([[0.3228, 0.3539, 0.3233]])
