File size: 21,874 Bytes
ca6de2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "966675dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "\n",
    "# Automatically adds project root to Python's import path\n",
    "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
    "if project_root not in sys.path:\n",
    "    sys.path.append(project_root)\n",
    "\n",
    "import torch\n",
    "from  transformers import AutoTokenizer\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "from src.multimodal_model import MediLLMModel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0714cb83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MediLLMModel(\n",
       "  (text_encoder): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSdpaSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (image_encoder): ResNet(\n",
       "    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act1): ReLU(inplace=True)\n",
       "    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (layer1): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer2): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer3): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (4): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (5): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer4): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act1): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (drop_block): Identity()\n",
       "        (act2): ReLU(inplace=True)\n",
       "        (aa): Identity()\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act3): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
       "    (fc): Identity()\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=2816, out_features=256, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Dropout(p=0.3, inplace=False)\n",
       "    (3): Linear(in_features=256, out_features=3, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load model\n",
    "model = MediLLMModel()\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "630b0c4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dummy text\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")\n",
    "text = \"Patient reports mild chest pain and fatigue for 3 days.\"\n",
    "tokens = tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0c51794c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dummy image\n",
    "img_path = os.path.join(project_root, \"data\", \"images\", \"NORMAL\", \"NORMAL-1.png\")\n",
    "if not os.path.exists(img_path):\n",
    "    raise FileNotFoundError(f\"Image not found at {img_path}\")\n",
    "else:\n",
    "    img = Image.open(img_path).convert(\"RGB\")\n",
    "    \n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "img_tensor = transform(img).unsqueeze(0) # Adds a another dimension at position 0, i.e. batch number as deep learning models expects batch as input also [batch, channels, height, width]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f56f6bf0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction probabilities: tensor([[0.3228, 0.3539, 0.3233]])\n"
     ]
    }
   ],
   "source": [
    "# Run model\n",
    "with torch.no_grad():\n",
    "    out = model(tokens['input_ids'], tokens['attention_mask'], img_tensor)\n",
    "    probs = torch.softmax(out, dim=1)\n",
    "\n",
    "print(\"Prediction probabilities:\", probs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "medi-llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}