{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "966675dc", "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "\n", "# Automatically adds project root to Python's import path\n", "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", "if project_root not in sys.path:\n", " sys.path.append(project_root)\n", "\n", "import torch\n", "from transformers import AutoTokenizer\n", "from PIL import Image\n", "from torchvision import transforms\n", "from src.multimodal_model import MediLLMModel\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "0714cb83", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MediLLMModel(\n", " (text_encoder): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(28996, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSdpaSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", " )\n", " (image_encoder): ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (4): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (5): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (drop_block): Identity()\n", " (act2): ReLU(inplace=True)\n", " (aa): Identity()\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act3): ReLU(inplace=True)\n", " )\n", " )\n", " (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n", " (fc): Identity()\n", " )\n", " (classifier): Sequential(\n", " (0): Linear(in_features=2816, out_features=256, bias=True)\n", " (1): ReLU()\n", " (2): Dropout(p=0.3, inplace=False)\n", " (3): Linear(in_features=256, out_features=3, bias=True)\n", " )\n", ")" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load model\n", "model = MediLLMModel()\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 3, "id": "630b0c4e", "metadata": {}, "outputs": [], "source": [ "# Dummy text\n", "tokenizer = AutoTokenizer.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")\n", "text = \"Patient reports mild chest pain and fatigue for 3 days.\"\n", "tokens = tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)" ] }, { "cell_type": "code", "execution_count": 7, "id": "0c51794c", "metadata": {}, "outputs": [], "source": [ "# Dummy image\n", "img_path = os.path.join(project_root, \"data\", \"images\", \"NORMAL\", \"NORMAL-1.png\")\n", "if not os.path.exists(img_path):\n", " raise FileNotFoundError(f\"Image not found at {img_path}\")\n", "else:\n", " img = Image.open(img_path).convert(\"RGB\")\n", " \n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", "])\n", "\n", "img_tensor = transform(img).unsqueeze(0) # Adds a another dimension at position 0, i.e. batch number as deep learning models expects batch as input also [batch, channels, height, width]\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "f56f6bf0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction probabilities: tensor([[0.3228, 0.3539, 0.3233]])\n" ] } ], "source": [ "# Run model\n", "with torch.no_grad():\n", " out = model(tokens['input_ids'], tokens['attention_mask'], img_tensor)\n", " probs = torch.softmax(out, dim=1)\n", "\n", "print(\"Prediction probabilities:\", probs)" ] } ], "metadata": { "kernelspec": { "display_name": "medi-llm", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }