Preetham22 commited on
Commit
ca6de2f
·
1 Parent(s): 68e9545

added a notebook to test the model

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. notebooks/test_model.ipynb +415 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  # Ignore Data files for tracking
2
  data/
3
  checkpoints/
 
 
 
1
  # Ignore Data files for tracking
2
  data/
3
  checkpoints/
4
+ __pycache__/
5
+ *.py[cod]
notebooks/test_model.ipynb ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "966675dc",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os, sys\n",
11
+ "\n",
12
+ "# Automatically adds project root to Python's import path\n",
13
+ "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
14
+ "if project_root not in sys.path:\n",
15
+ " sys.path.append(project_root)\n",
16
+ "\n",
17
+ "import torch\n",
18
+ "from transformers import AutoTokenizer\n",
19
+ "from PIL import Image\n",
20
+ "from torchvision import transforms\n",
21
+ "from src.multimodal_model import MediLLMModel\n"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "id": "0714cb83",
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "data": {
32
+ "text/plain": [
33
+ "MediLLMModel(\n",
34
+ " (text_encoder): BertModel(\n",
35
+ " (embeddings): BertEmbeddings(\n",
36
+ " (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
37
+ " (position_embeddings): Embedding(512, 768)\n",
38
+ " (token_type_embeddings): Embedding(2, 768)\n",
39
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
40
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
41
+ " )\n",
42
+ " (encoder): BertEncoder(\n",
43
+ " (layer): ModuleList(\n",
44
+ " (0-11): 12 x BertLayer(\n",
45
+ " (attention): BertAttention(\n",
46
+ " (self): BertSdpaSelfAttention(\n",
47
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
48
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
49
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
50
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
51
+ " )\n",
52
+ " (output): BertSelfOutput(\n",
53
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
54
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
55
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
56
+ " )\n",
57
+ " )\n",
58
+ " (intermediate): BertIntermediate(\n",
59
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
60
+ " (intermediate_act_fn): GELUActivation()\n",
61
+ " )\n",
62
+ " (output): BertOutput(\n",
63
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
64
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
65
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
66
+ " )\n",
67
+ " )\n",
68
+ " )\n",
69
+ " )\n",
70
+ " (pooler): BertPooler(\n",
71
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
72
+ " (activation): Tanh()\n",
73
+ " )\n",
74
+ " )\n",
75
+ " (image_encoder): ResNet(\n",
76
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
77
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
78
+ " (act1): ReLU(inplace=True)\n",
79
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
80
+ " (layer1): Sequential(\n",
81
+ " (0): Bottleneck(\n",
82
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
83
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
84
+ " (act1): ReLU(inplace=True)\n",
85
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
86
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
87
+ " (drop_block): Identity()\n",
88
+ " (act2): ReLU(inplace=True)\n",
89
+ " (aa): Identity()\n",
90
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
91
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
92
+ " (act3): ReLU(inplace=True)\n",
93
+ " (downsample): Sequential(\n",
94
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
95
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
96
+ " )\n",
97
+ " )\n",
98
+ " (1): Bottleneck(\n",
99
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
100
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
101
+ " (act1): ReLU(inplace=True)\n",
102
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
103
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
104
+ " (drop_block): Identity()\n",
105
+ " (act2): ReLU(inplace=True)\n",
106
+ " (aa): Identity()\n",
107
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
108
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
109
+ " (act3): ReLU(inplace=True)\n",
110
+ " )\n",
111
+ " (2): Bottleneck(\n",
112
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
113
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
114
+ " (act1): ReLU(inplace=True)\n",
115
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
116
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
117
+ " (drop_block): Identity()\n",
118
+ " (act2): ReLU(inplace=True)\n",
119
+ " (aa): Identity()\n",
120
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
121
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
122
+ " (act3): ReLU(inplace=True)\n",
123
+ " )\n",
124
+ " )\n",
125
+ " (layer2): Sequential(\n",
126
+ " (0): Bottleneck(\n",
127
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
128
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
129
+ " (act1): ReLU(inplace=True)\n",
130
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
131
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
132
+ " (drop_block): Identity()\n",
133
+ " (act2): ReLU(inplace=True)\n",
134
+ " (aa): Identity()\n",
135
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
136
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
137
+ " (act3): ReLU(inplace=True)\n",
138
+ " (downsample): Sequential(\n",
139
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
140
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
141
+ " )\n",
142
+ " )\n",
143
+ " (1): Bottleneck(\n",
144
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
145
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
146
+ " (act1): ReLU(inplace=True)\n",
147
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
148
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
149
+ " (drop_block): Identity()\n",
150
+ " (act2): ReLU(inplace=True)\n",
151
+ " (aa): Identity()\n",
152
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
153
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
154
+ " (act3): ReLU(inplace=True)\n",
155
+ " )\n",
156
+ " (2): Bottleneck(\n",
157
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
158
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
159
+ " (act1): ReLU(inplace=True)\n",
160
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
161
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
162
+ " (drop_block): Identity()\n",
163
+ " (act2): ReLU(inplace=True)\n",
164
+ " (aa): Identity()\n",
165
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
166
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
167
+ " (act3): ReLU(inplace=True)\n",
168
+ " )\n",
169
+ " (3): Bottleneck(\n",
170
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
171
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
172
+ " (act1): ReLU(inplace=True)\n",
173
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
174
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
175
+ " (drop_block): Identity()\n",
176
+ " (act2): ReLU(inplace=True)\n",
177
+ " (aa): Identity()\n",
178
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
179
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
180
+ " (act3): ReLU(inplace=True)\n",
181
+ " )\n",
182
+ " )\n",
183
+ " (layer3): Sequential(\n",
184
+ " (0): Bottleneck(\n",
185
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
186
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
187
+ " (act1): ReLU(inplace=True)\n",
188
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
189
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
190
+ " (drop_block): Identity()\n",
191
+ " (act2): ReLU(inplace=True)\n",
192
+ " (aa): Identity()\n",
193
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
194
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
195
+ " (act3): ReLU(inplace=True)\n",
196
+ " (downsample): Sequential(\n",
197
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
198
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
199
+ " )\n",
200
+ " )\n",
201
+ " (1): Bottleneck(\n",
202
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
203
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204
+ " (act1): ReLU(inplace=True)\n",
205
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
206
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
207
+ " (drop_block): Identity()\n",
208
+ " (act2): ReLU(inplace=True)\n",
209
+ " (aa): Identity()\n",
210
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
211
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
212
+ " (act3): ReLU(inplace=True)\n",
213
+ " )\n",
214
+ " (2): Bottleneck(\n",
215
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
216
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
217
+ " (act1): ReLU(inplace=True)\n",
218
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
219
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
220
+ " (drop_block): Identity()\n",
221
+ " (act2): ReLU(inplace=True)\n",
222
+ " (aa): Identity()\n",
223
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
224
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
225
+ " (act3): ReLU(inplace=True)\n",
226
+ " )\n",
227
+ " (3): Bottleneck(\n",
228
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
229
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
230
+ " (act1): ReLU(inplace=True)\n",
231
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
232
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
233
+ " (drop_block): Identity()\n",
234
+ " (act2): ReLU(inplace=True)\n",
235
+ " (aa): Identity()\n",
236
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
237
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
238
+ " (act3): ReLU(inplace=True)\n",
239
+ " )\n",
240
+ " (4): Bottleneck(\n",
241
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
242
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
243
+ " (act1): ReLU(inplace=True)\n",
244
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
245
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
246
+ " (drop_block): Identity()\n",
247
+ " (act2): ReLU(inplace=True)\n",
248
+ " (aa): Identity()\n",
249
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
250
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
251
+ " (act3): ReLU(inplace=True)\n",
252
+ " )\n",
253
+ " (5): Bottleneck(\n",
254
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
255
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
256
+ " (act1): ReLU(inplace=True)\n",
257
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
258
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
259
+ " (drop_block): Identity()\n",
260
+ " (act2): ReLU(inplace=True)\n",
261
+ " (aa): Identity()\n",
262
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
263
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
264
+ " (act3): ReLU(inplace=True)\n",
265
+ " )\n",
266
+ " )\n",
267
+ " (layer4): Sequential(\n",
268
+ " (0): Bottleneck(\n",
269
+ " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
270
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
271
+ " (act1): ReLU(inplace=True)\n",
272
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
273
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
274
+ " (drop_block): Identity()\n",
275
+ " (act2): ReLU(inplace=True)\n",
276
+ " (aa): Identity()\n",
277
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
278
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
279
+ " (act3): ReLU(inplace=True)\n",
280
+ " (downsample): Sequential(\n",
281
+ " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
282
+ " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
283
+ " )\n",
284
+ " )\n",
285
+ " (1): Bottleneck(\n",
286
+ " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
287
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
288
+ " (act1): ReLU(inplace=True)\n",
289
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
290
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
291
+ " (drop_block): Identity()\n",
292
+ " (act2): ReLU(inplace=True)\n",
293
+ " (aa): Identity()\n",
294
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
295
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
296
+ " (act3): ReLU(inplace=True)\n",
297
+ " )\n",
298
+ " (2): Bottleneck(\n",
299
+ " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
300
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
301
+ " (act1): ReLU(inplace=True)\n",
302
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
303
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
304
+ " (drop_block): Identity()\n",
305
+ " (act2): ReLU(inplace=True)\n",
306
+ " (aa): Identity()\n",
307
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
308
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
309
+ " (act3): ReLU(inplace=True)\n",
310
+ " )\n",
311
+ " )\n",
312
+ " (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
313
+ " (fc): Identity()\n",
314
+ " )\n",
315
+ " (classifier): Sequential(\n",
316
+ " (0): Linear(in_features=2816, out_features=256, bias=True)\n",
317
+ " (1): ReLU()\n",
318
+ " (2): Dropout(p=0.3, inplace=False)\n",
319
+ " (3): Linear(in_features=256, out_features=3, bias=True)\n",
320
+ " )\n",
321
+ ")"
322
+ ]
323
+ },
324
+ "execution_count": 2,
325
+ "metadata": {},
326
+ "output_type": "execute_result"
327
+ }
328
+ ],
329
+ "source": [
330
+ "# Load model\n",
331
+ "model = MediLLMModel()\n",
332
+ "model.eval()"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": 3,
338
+ "id": "630b0c4e",
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "# Dummy text\n",
343
+ "tokenizer = AutoTokenizer.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")\n",
344
+ "text = \"Patient reports mild chest pain and fatigue for 3 days.\"\n",
345
+ "tokens = tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": 7,
351
+ "id": "0c51794c",
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "# Dummy image\n",
356
+ "img_path = os.path.join(project_root, \"data\", \"images\", \"NORMAL\", \"NORMAL-1.png\")\n",
357
+ "if not os.path.exists(img_path):\n",
358
+ " raise FileNotFoundError(f\"Image not found at {img_path}\")\n",
359
+ "else:\n",
360
+ " img = Image.open(img_path).convert(\"RGB\")\n",
361
+ " \n",
362
+ "transform = transforms.Compose([\n",
363
+ " transforms.Resize((224, 224)),\n",
364
+ " transforms.ToTensor(),\n",
365
+ "])\n",
366
+ "\n",
367
+ "img_tensor = transform(img).unsqueeze(0) # Adds a another dimension at position 0, i.e. batch number as deep learning models expects batch as input also [batch, channels, height, width]\n"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 8,
373
+ "id": "f56f6bf0",
374
+ "metadata": {},
375
+ "outputs": [
376
+ {
377
+ "name": "stdout",
378
+ "output_type": "stream",
379
+ "text": [
380
+ "Prediction probabilities: tensor([[0.3228, 0.3539, 0.3233]])\n"
381
+ ]
382
+ }
383
+ ],
384
+ "source": [
385
+ "# Run model\n",
386
+ "with torch.no_grad():\n",
387
+ " out = model(tokens['input_ids'], tokens['attention_mask'], img_tensor)\n",
388
+ " probs = torch.softmax(out, dim=1)\n",
389
+ "\n",
390
+ "print(\"Prediction probabilities:\", probs)"
391
+ ]
392
+ }
393
+ ],
394
+ "metadata": {
395
+ "kernelspec": {
396
+ "display_name": "medi-llm",
397
+ "language": "python",
398
+ "name": "python3"
399
+ },
400
+ "language_info": {
401
+ "codemirror_mode": {
402
+ "name": "ipython",
403
+ "version": 3
404
+ },
405
+ "file_extension": ".py",
406
+ "mimetype": "text/x-python",
407
+ "name": "python",
408
+ "nbconvert_exporter": "python",
409
+ "pygments_lexer": "ipython3",
410
+ "version": "3.10.18"
411
+ }
412
+ },
413
+ "nbformat": 4,
414
+ "nbformat_minor": 5
415
+ }