File size: 7,206 Bytes
f93c46a
af86b36
 
 
f93c46a
4e592d4
 
 
 
 
 
 
 
 
 
f93c46a
 
54b189d
af86b36
562137e
af86b36
 
 
 
54b189d
af86b36
54b189d
af86b36
 
 
 
 
 
54b189d
f93c46a
 
af86b36
 
 
 
562137e
af86b36
 
f93c46a
 
 
 
 
 
 
 
 
 
 
 
 
 
54b189d
af86b36
 
 
 
 
 
 
 
 
 
 
f93c46a
54b189d
 
af86b36
54b189d
af86b36
54b189d
af86b36
 
54b189d
af86b36
 
 
 
562137e
af86b36
 
f93c46a
 
42e56c5
af86b36
 
 
 
42e56c5
af86b36
42e56c5
 
 
af86b36
 
 
 
 
 
42e56c5
af86b36
42e56c5
f93c46a
42e56c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f93c46a
42e56c5
 
 
 
 
af86b36
 
 
42e56c5
af86b36
 
 
4e592d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f93c46a
af86b36
42e56c5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch  # Deep learning framework
import torch.nn as nn
import timm  # For Image models like ResNet
from transformers import AutoModel  # Pretrained text encoders

# ================================================
# NOTE:
# Future upgrades can include:
#   - MLP Fusion
#   - Bilinear Pooling
#   - Attention-based Fusion
#   - Cross-modal Transformers
# These require structural changes to the forward() logic.
# ================================================


class MediLLMModel(nn.Module):
    def __init__(
        self,
        text_model_name=("emilyalsentzer/Bio_ClinicalBERT"),
        # Bio_ClinicalBERT is a pretrained model on clinical notes,
        # output to 3 classes i.e triage levels
        num_classes=3,
        dropout=0.3,
        hidden_dim=256,
        mode="multimodal",
    ):
        super(MediLLMModel, self).__init__()  # Use constructor of nn.Module
        assert mode in [
            "text",
            "image",
            "multimodal",
        ], "Mode must be one of: 'text', 'image', or 'multimodal'"
        self.mode = mode.lower()

        # Text encoder: Bio_ClinicalBERT
        self.text_encoder = AutoModel.from_pretrained(
            text_model_name
        )  # Automodel returns base model without a classification head,
        # just embeddings
        self.text_hidden_size = self.text_encoder.config.hidden_size
        # Dimensionality of hidden states i.e embedding vector size returned by
        # the text_encoder for each token, 768 for Bert models

        # Image encoder: ResNet-50 via TIMM
        """
        Bottle neck block used in ResNet-50
        Input x
          ↓
        Conv1x1 β†’ BN β†’ ReLU   β†’ (reduce size)
        Conv3x3 β†’ BN β†’ ReLU   β†’ (main conv)
        Conv1x1 β†’ BN          β†’ (restore size)
          ↓
        + Add skip (possibly Conv-adjusted x)
          ↓
        ReLU
        """
        # Image encoder
        self.image_encoder = timm.create_model(
            "resnet50", pretrained=True, num_classes=0
        )  # Model for images ResNet-50 which means 50 layers.
        # Initial Conv + pooling layer (size reduction),
        # 16 residual blocks * 3 layers per block = 48 layers,
        # Final Fully-connected layer.
        # Each block has 3 conv layers + skip connection,
        # i.e. input is fed into the next block with the output (F(x) + x)
        self.image_hidden_size = (
            self.image_encoder.num_features
        )  # Size of ResNet output ---> 2048 for ResNet-50

        # --- Determine fusion dimension based on the mode ---
        if self.mode == "text":
            fusion_dim = self.text_hidden_size
        elif self.mode == "image":
            fusion_dim = self.image_hidden_size
        else:
            fusion_dim = self.text_hidden_size + self.image_hidden_size

        # --- Classification head ---
        # pass the combined features into the classifier
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, hidden_dim),  # Dense layer
            nn.ReLU(),  # Non-linear activation function
            nn.Dropout(dropout),  # randomly zeroes 30 percent of neuron outputs
            # to prevent over-fitting
            nn.Linear(hidden_dim, num_classes),  # Final Classification output
        )

    def forward(self, input_ids=None, attention_mask=None, image=None, output_attentions=False, return_raw_attentions=False):
        # input_ids shape: [batch, seq_length]
        # attention_mask: mask to ignore padding, same shape as input_ids
        # image: [batch, 3, 224, 224]
        # Text features
        if self.mode in ["text", "multimodal"]:
            text_outputs = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
            )
            # feed tokenized text into the BERT Model which returns a
            # dictionary with last_hidden_state: [batch_size, seq_len,
            # hidden_size], pooler_output: [batch_size, hidden_size]
            # (CLS embeddings), hidden_states: List of tensors,
            # attentions(weights): List of Tensors
            last_hidden = text_outputs.last_hidden_state  # CLS token, return CLS tokens from all batches, position 0,
            # a batch of 3 sentences has 3 CLS tokens
            cls_embedding = last_hidden[:, 0, :]  # CLS tokens of all batches [batch, hidden_dim]

            # Real token attention using last-layer CLS attention weights
            # attentions = List[12 tensors] -> each [batch, heads, seq_len, seq_len]
            token_attn_scores = None
            raw_attentions = None
            if output_attentions:
                attention_maps = text_outputs.attentions
                last_layer_attn = attention_maps[-1]  # [batch, heads, seq_len, seq_len]
                avg_attn = last_layer_attn.mean(dim=1)  # Average across heads -> [batch, seq_len, seq_len]
                token_attn_scores = avg_attn[:, 0, :]  # CLS attends to all tokens -> [batch, seq_len]
                if return_raw_attentions:
                    raw_attentions = attention_maps
        else:
            cls_embedding = None
            token_attn_scores = None
            raw_attentions = None

        # Image features
        if self.mode == "image":
            features = self.image_encoder(image)  # pass the image through ResNet, returns a [batch, 2048] tensor
        elif self.mode == "text":  # text
            features = cls_embedding
        else:  # multimodal
            image_feat = self.image_encoder(image)
            features = torch.cat(
                (cls_embedding, image_feat), dim=1
            )  # Concatenates text and image features along feature dimension
            # [CLS vector from BERT] + [ResNet image vector]
            # -> [batch_size, 2816]
            # === Placeholder: Advanced Fusion Methods ===
            # Option 1: Bilinear Fusion
            # fused = torch.bmm(text_feat.unsqueeze(2), img_feat.unsqueeze(1)).view(batch_size, -1)
            # fused = self.bilinear_fc(fused)

            # Option 2: Cross-Modal Attention
            # - Use attention mechanism where one modality attends to another
            # - E.g., compute attention weights over image using text as query
            # - Requires custom attention modules or transformers

            # Option 3: Cross-modal Transformer Encoder
            # - Concatenate image and text features as tokens
            # - Feed into transformer encoder with positional embeddings

            # Option 4: Fusion Logic
            # fused = torch.cat([text_feat, img_feat], dim=1)
            # fused = self.dropout(torch.relu(self.fusion_fc1(fused)))
            # fused = self.dropout(torch.relu(self.fusion_fc2(fused)))
            # return self.classifier(fused)

        # Return logits for each class, later apply softmax during evaluation
        logits = self.classifier(features)
        return {
            "logits": logits,
            "token_attentions": token_attn_scores,  # [batch, seq_len] or None
            "raw_attentions": raw_attentions if return_raw_attentions else None,
        }