Preetham22 commited on
Commit
562137e
·
1 Parent(s): b136189

Auto-format code with Black

Browse files
experiments/csv_file_generator_iterations/generate_emr_csv_final.py CHANGED
@@ -16,7 +16,7 @@ SAMPLES_PER_CLASS = 300
16
  categories = {
17
  "COVID": IMAGES_DIR / "COVID",
18
  "NORMAL": IMAGES_DIR / "NORMAL",
19
- "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA"
20
  }
21
 
22
  # Shared ambiguous templates
@@ -36,7 +36,7 @@ shared_diagnosis = [
36
  "Further tests required to confirm diagnosis.",
37
  "Findings are borderline; clinical judgment advised.",
38
  "Observation warranted due to overlapping signs.",
39
- "Initial assessment inconclusive."
40
  ]
41
 
42
  # Noise sentences
@@ -52,8 +52,8 @@ neutral_noise = [
52
 
53
  def random_token():
54
  prefix = "ID"
55
- letters = ''.join(random.choices(string.ascii_uppercase, k=2))
56
- digits = ''.join(random.choices(string.digits, k=2))
57
  return f"{prefix}-{letters}{digits}"
58
 
59
 
@@ -97,7 +97,7 @@ def build_emr(label, i):
97
  intro,
98
  random.choice(shared_symptoms),
99
  vitals,
100
- random.choice(shared_diagnosis)
101
  ]
102
 
103
  # Optionally inject a mild class-specific clue (with low probability)
@@ -122,9 +122,13 @@ def build_emr(label, i):
122
  # Generate records
123
  records = []
124
  for label, img_dir in categories.items():
125
- image_files = sorted([f for f in img_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"]])
 
 
126
  for i in range(SAMPLES_PER_CLASS):
127
- image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
 
 
128
  text = build_emr(label, i)
129
  triage = triage_map[label]
130
  records.append([f"{label}-{i + 1}", image_path, text, triage])
 
16
  categories = {
17
  "COVID": IMAGES_DIR / "COVID",
18
  "NORMAL": IMAGES_DIR / "NORMAL",
19
+ "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA",
20
  }
21
 
22
  # Shared ambiguous templates
 
36
  "Further tests required to confirm diagnosis.",
37
  "Findings are borderline; clinical judgment advised.",
38
  "Observation warranted due to overlapping signs.",
39
+ "Initial assessment inconclusive.",
40
  ]
41
 
42
  # Noise sentences
 
52
 
53
  def random_token():
54
  prefix = "ID"
55
+ letters = "".join(random.choices(string.ascii_uppercase, k=2))
56
+ digits = "".join(random.choices(string.digits, k=2))
57
  return f"{prefix}-{letters}{digits}"
58
 
59
 
 
97
  intro,
98
  random.choice(shared_symptoms),
99
  vitals,
100
+ random.choice(shared_diagnosis),
101
  ]
102
 
103
  # Optionally inject a mild class-specific clue (with low probability)
 
122
  # Generate records
123
  records = []
124
  for label, img_dir in categories.items():
125
+ image_files = sorted(
126
+ [f for f in img_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"]]
127
+ )
128
  for i in range(SAMPLES_PER_CLASS):
129
+ image_path = str(
130
+ random.choice(image_files).relative_to(IMAGES_DIR.parent.parent)
131
+ )
132
  text = build_emr(label, i)
133
  triage = triage_map[label]
134
  records.append([f"{label}-{i + 1}", image_path, text, triage])
experiments/csv_file_generator_iterations/generate_emr_csv_v1.py CHANGED
@@ -14,15 +14,11 @@ SAMPLES_PER_CLASS = 300 # 300 * 3 = 900 total
14
  categories = {
15
  "COVID": IMAGES_DIR / "COVID",
16
  "NORMAL": IMAGES_DIR / "NORMAL",
17
- "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA"
18
  }
19
 
20
  # Triage mapping
21
- triage_map = {
22
- "COVID": "high",
23
- "NORMAL": "low",
24
- "VIRAL PNEUMONIA": "medium"
25
- }
26
 
27
  # --- Noise Sentences ---
28
  noise_sentences = [
@@ -43,7 +39,7 @@ noise_sentences = [
43
  "Patient remains alert and cooperative.",
44
  "No medication administered at this stage.",
45
  "Doctor recommends home resr and observation.",
46
- "Evaluation ongoing for possible infection."
47
  ]
48
 
49
  # --- ambiguity sentences ---
@@ -52,18 +48,14 @@ ambiguous_templates = [
52
  "Normal oxygen levels observed. Slight wheeze on auscultation.",
53
  "Patient reports chest discomfort but vitals are stable.",
54
  "No known exposure. Minor throat irritation present.",
55
- "Slight fatigue without other systemic symptoms."
56
  ]
57
 
58
  # --- Vitals & Symptoms ---
59
 
60
 
61
  def get_oxygen(label):
62
- base_ranges = {
63
- "COVID": (85, 94),
64
- "VIRAL PNEUMONIA": (88, 95),
65
- "NORMAL": (96, 99)
66
- }
67
  base_min, base_max = base_ranges[label]
68
  # Apply + or - 1 blur, clamping between 80 and 100
69
  oxygen = random.randint(base_min - 1, base_max + 1)
@@ -112,7 +104,7 @@ def build_emr(label, i):
112
  f"{name} ({age}) complains of dry cough for {days} days.",
113
  f"{name} experiencing low-grade fever and SPO2 at {oxygen}%.",
114
  f"{name} reports breathlessness. X-ray indicates mild infiltrates.",
115
- ]
116
  }
117
 
118
  # Diagnosis Observations
@@ -120,18 +112,18 @@ def build_emr(label, i):
120
  "COVID": [
121
  "Findings suggest viral respiratory infection.",
122
  "Signs consistent with COVID-19 infection.",
123
- "Clinical features align with COVID diagnosis."
124
  ],
125
  "NORMAL": [
126
  "No signs of respiratory infection.",
127
  "No abnormal findings detected.",
128
- "Checkup results within normal limits."
129
  ],
130
  "VIRAL PNEUMONIA": [
131
  "X-ray shows patchy infiltrates.",
132
  "Suspected viral origin of symptoms.",
133
- "Clinical signs indicate viral pneumonia."
134
- ]
135
  }
136
 
137
  # Construct sentence pool
@@ -158,7 +150,9 @@ for label, img_dir in categories.items():
158
  )
159
  for i in range(SAMPLES_PER_CLASS):
160
  patient_id = f"{label}-{i + 1}"
161
- image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
 
 
162
  emr_text = build_emr(label, i)
163
  triage_level = triage_map[label]
164
  records.append([patient_id, image_path, emr_text, triage_level])
 
14
  categories = {
15
  "COVID": IMAGES_DIR / "COVID",
16
  "NORMAL": IMAGES_DIR / "NORMAL",
17
+ "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA",
18
  }
19
 
20
  # Triage mapping
21
+ triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"}
 
 
 
 
22
 
23
  # --- Noise Sentences ---
24
  noise_sentences = [
 
39
  "Patient remains alert and cooperative.",
40
  "No medication administered at this stage.",
41
  "Doctor recommends home resr and observation.",
42
+ "Evaluation ongoing for possible infection.",
43
  ]
44
 
45
  # --- ambiguity sentences ---
 
48
  "Normal oxygen levels observed. Slight wheeze on auscultation.",
49
  "Patient reports chest discomfort but vitals are stable.",
50
  "No known exposure. Minor throat irritation present.",
51
+ "Slight fatigue without other systemic symptoms.",
52
  ]
53
 
54
  # --- Vitals & Symptoms ---
55
 
56
 
57
  def get_oxygen(label):
58
+ base_ranges = {"COVID": (85, 94), "VIRAL PNEUMONIA": (88, 95), "NORMAL": (96, 99)}
 
 
 
 
59
  base_min, base_max = base_ranges[label]
60
  # Apply + or - 1 blur, clamping between 80 and 100
61
  oxygen = random.randint(base_min - 1, base_max + 1)
 
104
  f"{name} ({age}) complains of dry cough for {days} days.",
105
  f"{name} experiencing low-grade fever and SPO2 at {oxygen}%.",
106
  f"{name} reports breathlessness. X-ray indicates mild infiltrates.",
107
+ ],
108
  }
109
 
110
  # Diagnosis Observations
 
112
  "COVID": [
113
  "Findings suggest viral respiratory infection.",
114
  "Signs consistent with COVID-19 infection.",
115
+ "Clinical features align with COVID diagnosis.",
116
  ],
117
  "NORMAL": [
118
  "No signs of respiratory infection.",
119
  "No abnormal findings detected.",
120
+ "Checkup results within normal limits.",
121
  ],
122
  "VIRAL PNEUMONIA": [
123
  "X-ray shows patchy infiltrates.",
124
  "Suspected viral origin of symptoms.",
125
+ "Clinical signs indicate viral pneumonia.",
126
+ ],
127
  }
128
 
129
  # Construct sentence pool
 
150
  )
151
  for i in range(SAMPLES_PER_CLASS):
152
  patient_id = f"{label}-{i + 1}"
153
+ image_path = str(
154
+ random.choice(image_files).relative_to(IMAGES_DIR.parent.parent)
155
+ )
156
  emr_text = build_emr(label, i)
157
  triage_level = triage_map[label]
158
  records.append([patient_id, image_path, emr_text, triage_level])
experiments/csv_file_generator_iterations/generate_emr_csv_v2.py CHANGED
@@ -16,7 +16,7 @@ SAMPLES_PER_CLASS = 300
16
  categories = {
17
  "COVID": IMAGES_DIR / "COVID",
18
  "NORMAL": IMAGES_DIR / "NORMAL",
19
- "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA"
20
  }
21
 
22
  # Shared ambiguous templates
@@ -43,8 +43,8 @@ neutral_noise = [
43
  # ---Patient random token genrator ---
44
  def random_token():
45
  prefix = "ID"
46
- letters = ''.join(random.choices(string.ascii_uppercase, k=2))
47
- digits = ''.join(random.choices(string.digits, k=2))
48
  return f"{prefix}-{letters}{digits}"
49
 
50
 
@@ -79,27 +79,41 @@ def build_emr(label, i):
79
  temp = get_temp(label)
80
  days = get_days()
81
 
82
- general_intro = f"Patient {patient_id}, a {age}, presents with symptoms for {days} days."
 
 
83
  vitals = f"Temperature recorded at {temp}°F, SPO2 levels at {oxygen}%."
84
 
85
  # Label-specific (but fuzzy) symptoms
86
  symptoms = {
87
- "COVID": ["Complains of fatigue and shortness of breath.", "Dry cough with mild fever noted."],
88
- "NORMAL": ["No major complaints; here for general checkup.", "Reports good health, no active issues."],
89
- "VIRAL PNEUMONIA": ["Persistent cough and mild fever observed.", "Slight wheezing with chest tightness."]
 
 
 
 
 
 
 
 
 
90
  }
91
 
92
  diagnosis = {
93
  "COVID": ["Viral etiology suspected.", "COVID infection not ruled out."],
94
  "NORMAL": ["Unlikely presence of infection.", "Clinical impression is benign."],
95
- "VIRAL PNEUMONIA": ["Signs may indicate atypical pneumonia.", "Possible viral infection of lower tract."]
 
 
 
96
  }
97
 
98
  body = [
99
  general_intro,
100
  random.choice(symptoms[label]),
101
  vitals,
102
- random.choice(diagnosis[label])
103
  ]
104
 
105
  # Inject 1–2 ambiguous or neutral sentences
@@ -115,9 +129,13 @@ def build_emr(label, i):
115
  # Generate records
116
  records = []
117
  for label, img_dir in categories.items():
118
- image_files = sorted([f for f in img_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"]])
 
 
119
  for i in range(SAMPLES_PER_CLASS):
120
- image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
 
 
121
  text = build_emr(label, i)
122
  triage = triage_map[label]
123
  records.append([f"{label}-{i + 1}", image_path, text, triage])
 
16
  categories = {
17
  "COVID": IMAGES_DIR / "COVID",
18
  "NORMAL": IMAGES_DIR / "NORMAL",
19
+ "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA",
20
  }
21
 
22
  # Shared ambiguous templates
 
43
  # ---Patient random token genrator ---
44
  def random_token():
45
  prefix = "ID"
46
+ letters = "".join(random.choices(string.ascii_uppercase, k=2))
47
+ digits = "".join(random.choices(string.digits, k=2))
48
  return f"{prefix}-{letters}{digits}"
49
 
50
 
 
79
  temp = get_temp(label)
80
  days = get_days()
81
 
82
+ general_intro = (
83
+ f"Patient {patient_id}, a {age}, presents with symptoms for {days} days."
84
+ )
85
  vitals = f"Temperature recorded at {temp}°F, SPO2 levels at {oxygen}%."
86
 
87
  # Label-specific (but fuzzy) symptoms
88
  symptoms = {
89
+ "COVID": [
90
+ "Complains of fatigue and shortness of breath.",
91
+ "Dry cough with mild fever noted.",
92
+ ],
93
+ "NORMAL": [
94
+ "No major complaints; here for general checkup.",
95
+ "Reports good health, no active issues.",
96
+ ],
97
+ "VIRAL PNEUMONIA": [
98
+ "Persistent cough and mild fever observed.",
99
+ "Slight wheezing with chest tightness.",
100
+ ],
101
  }
102
 
103
  diagnosis = {
104
  "COVID": ["Viral etiology suspected.", "COVID infection not ruled out."],
105
  "NORMAL": ["Unlikely presence of infection.", "Clinical impression is benign."],
106
+ "VIRAL PNEUMONIA": [
107
+ "Signs may indicate atypical pneumonia.",
108
+ "Possible viral infection of lower tract.",
109
+ ],
110
  }
111
 
112
  body = [
113
  general_intro,
114
  random.choice(symptoms[label]),
115
  vitals,
116
+ random.choice(diagnosis[label]),
117
  ]
118
 
119
  # Inject 1–2 ambiguous or neutral sentences
 
129
  # Generate records
130
  records = []
131
  for label, img_dir in categories.items():
132
+ image_files = sorted(
133
+ [f for f in img_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"]]
134
+ )
135
  for i in range(SAMPLES_PER_CLASS):
136
+ image_path = str(
137
+ random.choice(image_files).relative_to(IMAGES_DIR.parent.parent)
138
+ )
139
  text = build_emr(label, i)
140
  triage = triage_map[label]
141
  records.append([f"{label}-{i + 1}", image_path, text, triage])
experiments/train_optuna.py CHANGED
@@ -43,11 +43,7 @@ def objective(trial, mode):
43
  project=f"mediLLM-tune-{mode}",
44
  name=f"{mode}-trial-{trial.number}-v5-{wandb.util.generate_id()}",
45
  group="SoftLabelTrials",
46
- config={
47
- "dataset_version": "softlabels",
48
- "dataset_size": 900,
49
- "mode": mode
50
- }
51
  )
52
 
53
  # --- Hyperparameters ---
@@ -85,7 +81,9 @@ def objective(trial, mode):
85
  images = images.to(device)
86
 
87
  optimizer.zero_grad()
88
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
 
 
89
  loss = criterion(outputs, labels)
90
  loss.backward()
91
  optimizer.step()
@@ -108,7 +106,9 @@ def objective(trial, mode):
108
  if images is not None:
109
  images = images.to(device)
110
 
111
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
 
 
112
  preds = torch.argmax(outputs, dim=1).cpu().numpy()
113
  all_preds.extend(preds)
114
  all_labels.extend(labels.cpu().numpy())
@@ -117,21 +117,28 @@ def objective(trial, mode):
117
  acc = accuracy_score(all_labels, all_preds)
118
 
119
  # Log to W&B and Optuna
120
- wandb.log({
121
- "val_f1_score": f1,
122
- "val_accuracy": acc,
123
- "lr": lr,
124
- "dropout": dropout,
125
- "hidden_dim": hidden_dim,
126
- "batch_size": batch_size
127
- })
 
 
128
 
129
  # Confusion Matrix
130
  cm = confusion_matrix(all_labels, all_preds)
131
  plt.figure(figsize=(6, 5))
132
- sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
133
- xticklabels=["low", "medium", "high"],
134
- yticklabels=["low", "medium", "high"])
 
 
 
 
 
135
  plt.title(f"Confusion Matrix - {mode} Trial {trial.number}")
136
  plt.xlabel("Predicted")
137
  plt.ylabel("True")
@@ -142,8 +149,16 @@ def objective(trial, mode):
142
 
143
  def get_args():
144
  parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
145
- parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
146
- parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
 
 
 
 
 
 
 
 
147
  return parser.parse_args()
148
 
149
 
@@ -152,16 +167,17 @@ if __name__ == "__main__":
152
  mode = args.mode
153
 
154
  study = optuna.create_study(
155
- study_name=f"mediLLM_{mode}_optuna",
156
- direction="maximize"
157
  )
158
  with tqdm(total=args.n_trials, desc=f"Optuna Trials [{mode}]") as pbar:
 
159
  def wrapped_objective(trial):
160
  try:
161
  return objective(trial, mode)
162
  finally:
163
  wandb.finish()
164
  pbar.update(1)
 
165
  study.optimize(wrapped_objective, n_trials=args.n_trials)
166
 
167
  print(f"✅ Best F1 score for {mode}: {study.best_value}")
@@ -176,7 +192,7 @@ if __name__ == "__main__":
176
  "dropout": float(study.best_params["dropout"]),
177
  "hidden_dim": int(study.best_params["hidden_dim"]),
178
  "batch_size": int(study.best_params["bs"]),
179
- "epochs": 5
180
  }
181
 
182
  # Load existing or start new
@@ -210,7 +226,7 @@ if __name__ == "__main__":
210
  "dropout": float(study.best_params["dropout"]),
211
  "hidden_dim": int(study.best_params["hidden_dim"]),
212
  "batch_size": int(study.best_params["bs"]),
213
- "epochs": 5
214
  }
215
 
216
  # Export to config.yaml
 
43
  project=f"mediLLM-tune-{mode}",
44
  name=f"{mode}-trial-{trial.number}-v5-{wandb.util.generate_id()}",
45
  group="SoftLabelTrials",
46
+ config={"dataset_version": "softlabels", "dataset_size": 900, "mode": mode},
 
 
 
 
47
  )
48
 
49
  # --- Hyperparameters ---
 
81
  images = images.to(device)
82
 
83
  optimizer.zero_grad()
84
+ outputs = model(
85
+ input_ids=input_ids, attention_mask=attention_mask, image=images
86
+ )
87
  loss = criterion(outputs, labels)
88
  loss.backward()
89
  optimizer.step()
 
106
  if images is not None:
107
  images = images.to(device)
108
 
109
+ outputs = model(
110
+ input_ids=input_ids, attention_mask=attention_mask, image=images
111
+ )
112
  preds = torch.argmax(outputs, dim=1).cpu().numpy()
113
  all_preds.extend(preds)
114
  all_labels.extend(labels.cpu().numpy())
 
117
  acc = accuracy_score(all_labels, all_preds)
118
 
119
  # Log to W&B and Optuna
120
+ wandb.log(
121
+ {
122
+ "val_f1_score": f1,
123
+ "val_accuracy": acc,
124
+ "lr": lr,
125
+ "dropout": dropout,
126
+ "hidden_dim": hidden_dim,
127
+ "batch_size": batch_size,
128
+ }
129
+ )
130
 
131
  # Confusion Matrix
132
  cm = confusion_matrix(all_labels, all_preds)
133
  plt.figure(figsize=(6, 5))
134
+ sns.heatmap(
135
+ cm,
136
+ annot=True,
137
+ fmt="d",
138
+ cmap="Blues",
139
+ xticklabels=["low", "medium", "high"],
140
+ yticklabels=["low", "medium", "high"],
141
+ )
142
  plt.title(f"Confusion Matrix - {mode} Trial {trial.number}")
143
  plt.xlabel("Predicted")
144
  plt.ylabel("True")
 
149
 
150
  def get_args():
151
  parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
152
+ parser.add_argument(
153
+ "--n_trials", type=int, default=10, help="Number of Optuna trials to run"
154
+ )
155
+ parser.add_argument(
156
+ "--mode",
157
+ type=str,
158
+ choices=["text", "image", "multimodal"],
159
+ required=True,
160
+ help="Input mode",
161
+ )
162
  return parser.parse_args()
163
 
164
 
 
167
  mode = args.mode
168
 
169
  study = optuna.create_study(
170
+ study_name=f"mediLLM_{mode}_optuna", direction="maximize"
 
171
  )
172
  with tqdm(total=args.n_trials, desc=f"Optuna Trials [{mode}]") as pbar:
173
+
174
  def wrapped_objective(trial):
175
  try:
176
  return objective(trial, mode)
177
  finally:
178
  wandb.finish()
179
  pbar.update(1)
180
+
181
  study.optimize(wrapped_objective, n_trials=args.n_trials)
182
 
183
  print(f"✅ Best F1 score for {mode}: {study.best_value}")
 
192
  "dropout": float(study.best_params["dropout"]),
193
  "hidden_dim": int(study.best_params["hidden_dim"]),
194
  "batch_size": int(study.best_params["bs"]),
195
+ "epochs": 5,
196
  }
197
 
198
  # Load existing or start new
 
226
  "dropout": float(study.best_params["dropout"]),
227
  "hidden_dim": int(study.best_params["hidden_dim"]),
228
  "batch_size": int(study.best_params["bs"]),
229
+ "epochs": 5,
230
  }
231
 
232
  # Export to config.yaml
src/data_preprocessing.py CHANGED
@@ -33,11 +33,7 @@ def preprocess_text(text):
33
  if not isinstance(text, str):
34
  raise ValueError("Input text must be a string.")
35
  return tokenizer(
36
- text,
37
- truncation=True,
38
- padding="max_length",
39
- max_length=128,
40
- return_tensors="pt"
41
  )
42
 
43
 
@@ -60,9 +56,7 @@ if __name__ == "__main__":
60
  print("Data loaded successfully.")
61
  # apply function applies to each row in the 'image_path' column and joins
62
  # the base directory with the relative path
63
- df["image_path"] = df["image_path"].apply(
64
- lambda p: os.path.join(base_dir, p)
65
- )
66
  print("Sample record:")
67
  print(df.iloc[0])
68
 
 
33
  if not isinstance(text, str):
34
  raise ValueError("Input text must be a string.")
35
  return tokenizer(
36
+ text, truncation=True, padding="max_length", max_length=128, return_tensors="pt"
 
 
 
 
37
  )
38
 
39
 
 
56
  print("Data loaded successfully.")
57
  # apply function applies to each row in the 'image_path' column and joins
58
  # the base directory with the relative path
59
+ df["image_path"] = df["image_path"].apply(lambda p: os.path.join(base_dir, p))
 
 
60
  print("Sample record:")
61
  print(df.iloc[0])
62
 
src/generate_emr_csv.py CHANGED
@@ -132,8 +132,7 @@ def generate_dataset():
132
  )
133
  for i in range(SAMPLES_PER_CLASS):
134
  image_path = str(
135
- random.choice(image_files)
136
- .relative_to(IMAGES_DIR.parent.parent)
137
  )
138
  text = build_emr(label, i)
139
  triage = triage_map[label]
@@ -143,12 +142,7 @@ def generate_dataset():
143
  random.shuffle(records)
144
  with open(OUTPUT_FILE, "w", newline="") as f:
145
  writer = csv.writer(f)
146
- writer.writerow([
147
- "patient_id",
148
- "image_path",
149
- "emr_text",
150
- "triage_level"
151
- ])
152
  writer.writerows(records)
153
 
154
  print(f"✅ Softlabel EMR dataset generated at {OUTPUT_FILE}")
 
132
  )
133
  for i in range(SAMPLES_PER_CLASS):
134
  image_path = str(
135
+ random.choice(image_files).relative_to(IMAGES_DIR.parent.parent)
 
136
  )
137
  text = build_emr(label, i)
138
  triage = triage_map[label]
 
142
  random.shuffle(records)
143
  with open(OUTPUT_FILE, "w", newline="") as f:
144
  writer = csv.writer(f)
145
+ writer.writerow(["patient_id", "image_path", "emr_text", "triage_level"])
 
 
 
 
 
146
  writer.writerows(records)
147
 
148
  print(f"✅ Softlabel EMR dataset generated at {OUTPUT_FILE}")
src/multimodal_model.py CHANGED
@@ -7,9 +7,7 @@ from transformers import AutoModel # Pretrained text encoders
7
  class MediLLMModel(nn.Module):
8
  def __init__(
9
  self,
10
- text_model_name=(
11
- "emilyalsentzer/Bio_ClinicalBERT"
12
- ),
13
  # Bio_ClinicalBERT is a pretrained model on clinical notes,
14
  # output to 3 classes i.e triage levels
15
  num_classes=3,
@@ -30,9 +28,7 @@ class MediLLMModel(nn.Module):
30
  text_model_name
31
  ) # Automodel returns base model without a classification head,
32
  # just embeddings
33
- self.text_hidden_size = (
34
- self.text_encoder.config.hidden_size
35
- )
36
  # Dimensionality of hidden states i.e embedding vector size returned by
37
  # the text_encoder for each token, 768 for Bert models
38
 
@@ -75,16 +71,12 @@ class MediLLMModel(nn.Module):
75
  self.classifier = nn.Sequential(
76
  nn.Linear(fusion_dim, hidden_dim), # Dense layer
77
  nn.ReLU(), # Non-linear activation function
78
- nn.Dropout(
79
- dropout
80
- ), # randomly zeroes 30 percent of neuron outputs
81
  # to prevent over-fitting
82
  nn.Linear(hidden_dim, num_classes), # Final Classification output
83
  )
84
 
85
- def forward(
86
- self, input_ids=None, attention_mask=None, image=None
87
- ):
88
  # input_ids shape: [batch, seq_length]
89
  # attention_mask: mask to ignore padding, same shape as input_ids
90
  # image: [batch, 3, 224, 224]
@@ -122,6 +114,4 @@ class MediLLMModel(nn.Module):
122
  # -> [batch_size, 2816]
123
 
124
  # Return logits for each class, later apply softmax during evaluation
125
- return self.classifier(
126
- features
127
- )
 
7
  class MediLLMModel(nn.Module):
8
  def __init__(
9
  self,
10
+ text_model_name=("emilyalsentzer/Bio_ClinicalBERT"),
 
 
11
  # Bio_ClinicalBERT is a pretrained model on clinical notes,
12
  # output to 3 classes i.e triage levels
13
  num_classes=3,
 
28
  text_model_name
29
  ) # Automodel returns base model without a classification head,
30
  # just embeddings
31
+ self.text_hidden_size = self.text_encoder.config.hidden_size
 
 
32
  # Dimensionality of hidden states i.e embedding vector size returned by
33
  # the text_encoder for each token, 768 for Bert models
34
 
 
71
  self.classifier = nn.Sequential(
72
  nn.Linear(fusion_dim, hidden_dim), # Dense layer
73
  nn.ReLU(), # Non-linear activation function
74
+ nn.Dropout(dropout), # randomly zeroes 30 percent of neuron outputs
 
 
75
  # to prevent over-fitting
76
  nn.Linear(hidden_dim, num_classes), # Final Classification output
77
  )
78
 
79
+ def forward(self, input_ids=None, attention_mask=None, image=None):
 
 
80
  # input_ids shape: [batch, seq_length]
81
  # attention_mask: mask to ignore padding, same shape as input_ids
82
  # image: [batch, 3, 224, 224]
 
114
  # -> [batch_size, 2816]
115
 
116
  # Return logits for each class, later apply softmax during evaluation
117
+ return self.classifier(features)
 
 
src/train.py CHANGED
@@ -166,9 +166,7 @@ def train_model(mode="multimodal"):
166
  ) # Save labels for metric computation
167
 
168
  # Calculating classification metrics (Accuracy and F1)
169
- acc = accuracy_score(
170
- all_labels, all_preds
171
- ) # Evaluate full-epoch performance
172
  f1 = f1_score(all_labels, all_preds, average="weighted")
173
  # 1) binary: Binary Classification(F1 score of +ve class only)
174
  # 2) macro: Computes F1 for each class independently, then averages,
@@ -240,9 +238,7 @@ def train_model(mode="multimodal"):
240
  ) # Saves the model weights only not total architecture to reuse later
241
 
242
  # Plot accuracy
243
- plot_path = os.path.join(
244
- base_dir, "assets", f"model_training_curve_{mode}.png"
245
- )
246
  plt.plot(train_acc, label="Train Acc")
247
  plt.plot(val_acc, label="Val Acc")
248
  plt.legend()
 
166
  ) # Save labels for metric computation
167
 
168
  # Calculating classification metrics (Accuracy and F1)
169
+ acc = accuracy_score(all_labels, all_preds) # Evaluate full-epoch performance
 
 
170
  f1 = f1_score(all_labels, all_preds, average="weighted")
171
  # 1) binary: Binary Classification(F1 score of +ve class only)
172
  # 2) macro: Computes F1 for each class independently, then averages,
 
238
  ) # Saves the model weights only not total architecture to reuse later
239
 
240
  # Plot accuracy
241
+ plot_path = os.path.join(base_dir, "assets", f"model_training_curve_{mode}.png")
 
 
242
  plt.plot(train_acc, label="Train Acc")
243
  plt.plot(val_acc, label="Val Acc")
244
  plt.legend()
src/triage_dataset.py CHANGED
@@ -39,9 +39,7 @@ class TriageDataset(Dataset):
39
  scale=(0.9, 1.0),
40
  interpolation=InterpolationMode.BILINEAR,
41
  ), # Slight zoom-in/out
42
- transforms.RandomRotation(
43
- degrees=10
44
- ), # + or - 10° rotation
45
  transforms.ColorJitter(
46
  brightness=0.3, contrast=0.3
47
  ), # simulate slight imaging variations
 
39
  scale=(0.9, 1.0),
40
  interpolation=InterpolationMode.BILINEAR,
41
  ), # Slight zoom-in/out
42
+ transforms.RandomRotation(degrees=10), # + or - 10° rotation
 
 
43
  transforms.ColorJitter(
44
  brightness=0.3, contrast=0.3
45
  ), # simulate slight imaging variations
tests/test_generate_emr_csv.py CHANGED
@@ -71,9 +71,7 @@ def test_total_and_per_class_counts(load_emr_csv):
71
  assert len(load_emr_csv) == 900, "Total records should be 900"
72
  counts = Counter(row["triage_level"] for row in load_emr_csv)
73
  for cls in EXPECTED_CLASSES:
74
- assert counts[cls] == EXPECTED_SAMPLES_PER_CLASS, (
75
- f"{cls} count mismatch"
76
- )
77
 
78
 
79
  def test_patient_id_format_and_uniqueness(load_emr_csv):
@@ -94,9 +92,7 @@ def test_emr_text_quality(load_emr_csv):
94
  def test_image_path_format(load_emr_csv):
95
  for row in load_emr_csv:
96
  path = row["image_path"]
97
- assert path.endswith((".jpg", ".jpeg", ".png")), (
98
- f"Invalid image path: {path}"
99
- )
100
 
101
 
102
  def test_ambiguous_and_noise_injection(load_emr_csv):
 
71
  assert len(load_emr_csv) == 900, "Total records should be 900"
72
  counts = Counter(row["triage_level"] for row in load_emr_csv)
73
  for cls in EXPECTED_CLASSES:
74
+ assert counts[cls] == EXPECTED_SAMPLES_PER_CLASS, f"{cls} count mismatch"
 
 
75
 
76
 
77
  def test_patient_id_format_and_uniqueness(load_emr_csv):
 
92
  def test_image_path_format(load_emr_csv):
93
  for row in load_emr_csv:
94
  path = row["image_path"]
95
+ assert path.endswith((".jpg", ".jpeg", ".png")), f"Invalid image path: {path}"
 
 
96
 
97
 
98
  def test_ambiguous_and_noise_injection(load_emr_csv):
tests/test_multimodal_model.py CHANGED
@@ -44,9 +44,7 @@ def test_text_only(dummy_inputs):
44
  input_ids=dummy_inputs["input_ids"],
45
  attention_mask=dummy_inputs["attention_mask"],
46
  )
47
- assert outputs.shape == (BATCH_SIZE, 3), (
48
- "Incorrect output shape for text-only mode"
49
- )
50
 
51
 
52
  def test_image_only(dummy_inputs):
 
44
  input_ids=dummy_inputs["input_ids"],
45
  attention_mask=dummy_inputs["attention_mask"],
46
  )
47
+ assert outputs.shape == (BATCH_SIZE, 3), "Incorrect output shape for text-only mode"
 
 
48
 
49
 
50
  def test_image_only(dummy_inputs):
tests/test_triage_dataset.py CHANGED
@@ -24,9 +24,7 @@ def test_dataset_loading(mode):
24
  sample = dataset[0]
25
 
26
  if mode in ["text", "multimodal"]:
27
- assert "input_ids" in sample, (
28
- "Missing input_ids in text/multimodal mode"
29
- )
30
  assert (
31
  "attention_mask" in sample
32
  ), "Missing attention_mask in text/multimodal mode"
 
24
  sample = dataset[0]
25
 
26
  if mode in ["text", "multimodal"]:
27
+ assert "input_ids" in sample, "Missing input_ids in text/multimodal mode"
 
 
28
  assert (
29
  "attention_mask" in sample
30
  ), "Missing attention_mask in text/multimodal mode"