Preetham22 commited on
Commit
b136189
·
1 Parent(s): 9218201

formatting changes

Browse files
experiments/csv_file_generator_iterations/generate_emr_csv_final.py CHANGED
@@ -49,12 +49,14 @@ neutral_noise = [
49
  "No medications currently administered.",
50
  ]
51
 
 
52
  def random_token():
53
  prefix = "ID"
54
  letters = ''.join(random.choices(string.ascii_uppercase, k=2))
55
  digits = ''.join(random.choices(string.digits, k=2))
56
  return f"{prefix}-{letters}{digits}"
57
 
 
58
  def get_oxygen(label):
59
  # Soft blur across classes
60
  if label == "NORMAL":
@@ -64,18 +66,22 @@ def get_oxygen(label):
64
  else:
65
  return random.randint(87, 94)
66
 
 
67
  def get_temp(label):
68
  if label == "NORMAL":
69
  return round(random.uniform(97.5, 99.0), 1)
70
  else:
71
  return round(random.uniform(98.8, 102.5), 1)
72
 
 
73
  def get_age():
74
  return random.randint(18, 85)
75
 
 
76
  def get_days():
77
  return random.randint(1, 10)
78
 
 
79
  def build_emr(label, i):
80
  pid = random_token()
81
  age = f"{get_age()}-year-old"
@@ -112,6 +118,7 @@ def build_emr(label, i):
112
  random.shuffle(body[1:]) # Keep intro in position 0
113
  return " ".join(body)
114
 
 
115
  # Generate records
116
  records = []
117
  for label, img_dir in categories.items():
@@ -120,7 +127,7 @@ for label, img_dir in categories.items():
120
  image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
121
  text = build_emr(label, i)
122
  triage = triage_map[label]
123
- records.append([f"{label}-{i+1}", image_path, text, triage])
124
 
125
  # Shuffle + write
126
  random.shuffle(records)
 
49
  "No medications currently administered.",
50
  ]
51
 
52
+
53
  def random_token():
54
  prefix = "ID"
55
  letters = ''.join(random.choices(string.ascii_uppercase, k=2))
56
  digits = ''.join(random.choices(string.digits, k=2))
57
  return f"{prefix}-{letters}{digits}"
58
 
59
+
60
  def get_oxygen(label):
61
  # Soft blur across classes
62
  if label == "NORMAL":
 
66
  else:
67
  return random.randint(87, 94)
68
 
69
+
70
  def get_temp(label):
71
  if label == "NORMAL":
72
  return round(random.uniform(97.5, 99.0), 1)
73
  else:
74
  return round(random.uniform(98.8, 102.5), 1)
75
 
76
+
77
  def get_age():
78
  return random.randint(18, 85)
79
 
80
+
81
  def get_days():
82
  return random.randint(1, 10)
83
 
84
+
85
  def build_emr(label, i):
86
  pid = random_token()
87
  age = f"{get_age()}-year-old"
 
118
  random.shuffle(body[1:]) # Keep intro in position 0
119
  return " ".join(body)
120
 
121
+
122
  # Generate records
123
  records = []
124
  for label, img_dir in categories.items():
 
127
  image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
128
  text = build_emr(label, i)
129
  triage = triage_map[label]
130
+ records.append([f"{label}-{i + 1}", image_path, text, triage])
131
 
132
  # Shuffle + write
133
  random.shuffle(records)
experiments/csv_file_generator_iterations/generate_emr_csv_v1.py CHANGED
@@ -8,7 +8,7 @@ IMAGES_DIR = CURRENT_DIR.parent / "data" / "images"
8
  OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records_extended.csv"
9
 
10
  # Sample size
11
- SAMPLES_PER_CLASS = 300 # 300 * 3 = 900 total
12
 
13
  # Categories and labels
14
  categories = {
@@ -56,6 +56,8 @@ ambiguous_templates = [
56
  ]
57
 
58
  # --- Vitals & Symptoms ---
 
 
59
  def get_oxygen(label):
60
  base_ranges = {
61
  "COVID": (85, 94),
@@ -67,6 +69,7 @@ def get_oxygen(label):
67
  oxygen = random.randint(base_min - 1, base_max + 1)
68
  return min(100, max(80, oxygen))
69
 
 
70
  def get_temp(label):
71
  if label == "NORMAL":
72
  base_min, base_max = 97.0, 98.6
@@ -76,21 +79,23 @@ def get_temp(label):
76
  # Apply + or - 0.5°F blur and clamp between 95-105°F
77
  temp = random.uniform(base_min - 0.5, base_max + 0.5)
78
  return round(min(105.0, max(95.0, temp)), 1)
79
-
 
80
  def get_days():
81
  return random.randint(1, 14)
82
 
 
83
  def get_age():
84
  return random.randint(18, 80)
85
 
 
86
  # --- Templates ---
87
  def build_emr(label, i):
88
- name = f"Patient-{label}-{i+1}"
89
  age = f"{get_age()}-year-old"
90
  days = get_days()
91
  temp = get_temp(label)
92
  oxygen = get_oxygen(label)
93
-
94
  # Symptoms Pool
95
  symptoms = {
96
  "COVID": [
@@ -138,12 +143,12 @@ def build_emr(label, i):
138
 
139
  # adding noise to 90% of cases
140
  if random.random() < 0.9:
141
- for _ in range(random.randint(1,2)):
142
  body.insert(random.randint(0, len(body)), random.choice(noise_sentences))
143
-
144
  random.shuffle(body)
145
  return " ".join(body)
146
 
 
147
  # Generate dataset
148
  records = []
149
  for label, img_dir in categories.items():
@@ -152,7 +157,7 @@ for label, img_dir in categories.items():
152
  [f for f in img_dir.glob("*") if f.suffix.lower() in valid_exts]
153
  )
154
  for i in range(SAMPLES_PER_CLASS):
155
- patient_id = f"{label}-{i+1}"
156
  image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
157
  emr_text = build_emr(label, i)
158
  triage_level = triage_map[label]
 
8
  OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records_extended.csv"
9
 
10
  # Sample size
11
+ SAMPLES_PER_CLASS = 300 # 300 * 3 = 900 total
12
 
13
  # Categories and labels
14
  categories = {
 
56
  ]
57
 
58
  # --- Vitals & Symptoms ---
59
+
60
+
61
  def get_oxygen(label):
62
  base_ranges = {
63
  "COVID": (85, 94),
 
69
  oxygen = random.randint(base_min - 1, base_max + 1)
70
  return min(100, max(80, oxygen))
71
 
72
+
73
  def get_temp(label):
74
  if label == "NORMAL":
75
  base_min, base_max = 97.0, 98.6
 
79
  # Apply + or - 0.5°F blur and clamp between 95-105°F
80
  temp = random.uniform(base_min - 0.5, base_max + 0.5)
81
  return round(min(105.0, max(95.0, temp)), 1)
82
+
83
+
84
  def get_days():
85
  return random.randint(1, 14)
86
 
87
+
88
  def get_age():
89
  return random.randint(18, 80)
90
 
91
+
92
  # --- Templates ---
93
  def build_emr(label, i):
94
+ name = f"Patient-{label}-{i + 1}"
95
  age = f"{get_age()}-year-old"
96
  days = get_days()
97
  temp = get_temp(label)
98
  oxygen = get_oxygen(label)
 
99
  # Symptoms Pool
100
  symptoms = {
101
  "COVID": [
 
143
 
144
  # adding noise to 90% of cases
145
  if random.random() < 0.9:
146
+ for _ in range(random.randint(1, 2)):
147
  body.insert(random.randint(0, len(body)), random.choice(noise_sentences))
 
148
  random.shuffle(body)
149
  return " ".join(body)
150
 
151
+
152
  # Generate dataset
153
  records = []
154
  for label, img_dir in categories.items():
 
157
  [f for f in img_dir.glob("*") if f.suffix.lower() in valid_exts]
158
  )
159
  for i in range(SAMPLES_PER_CLASS):
160
+ patient_id = f"{label}-{i + 1}"
161
  image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
162
  emr_text = build_emr(label, i)
163
  triage_level = triage_map[label]
experiments/csv_file_generator_iterations/generate_emr_csv_v2.py CHANGED
@@ -39,6 +39,7 @@ neutral_noise = [
39
  "Patient expresses concern about possible flu.",
40
  ]
41
 
 
42
  # ---Patient random token genrator ---
43
  def random_token():
44
  prefix = "ID"
@@ -46,11 +47,13 @@ def random_token():
46
  digits = ''.join(random.choices(string.digits, k=2))
47
  return f"{prefix}-{letters}{digits}"
48
 
 
49
  # Vitals (blurred)
50
  def get_oxygen(label):
51
  base = {"COVID": (85, 94), "VIRAL PNEUMONIA": (89, 96), "NORMAL": (96, 99)}
52
  min_, max_ = base[label]
53
- return min(100, max(80, random.randint(min_-1, max_+1)))
 
54
 
55
  def get_temp(label):
56
  if label == "NORMAL":
@@ -59,8 +62,14 @@ def get_temp(label):
59
  min_, max_ = 99.0, 103.0
60
  return round(random.uniform(min_ - 0.6, max_ + 0.6), 1)
61
 
62
- def get_age(): return random.randint(18, 85)
63
- def get_days(): return random.randint(1, 10)
 
 
 
 
 
 
64
 
65
  # EMR generator
66
  def build_emr(label, i):
@@ -102,6 +111,7 @@ def build_emr(label, i):
102
  random.shuffle(body[1:])
103
  return " ".join(body)
104
 
 
105
  # Generate records
106
  records = []
107
  for label, img_dir in categories.items():
@@ -110,7 +120,7 @@ for label, img_dir in categories.items():
110
  image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
111
  text = build_emr(label, i)
112
  triage = triage_map[label]
113
- records.append([f"{label}-{i+1}", image_path, text, triage])
114
 
115
  # Shuffle + Write
116
  random.shuffle(records)
 
39
  "Patient expresses concern about possible flu.",
40
  ]
41
 
42
+
43
  # ---Patient random token genrator ---
44
  def random_token():
45
  prefix = "ID"
 
47
  digits = ''.join(random.choices(string.digits, k=2))
48
  return f"{prefix}-{letters}{digits}"
49
 
50
+
51
  # Vitals (blurred)
52
  def get_oxygen(label):
53
  base = {"COVID": (85, 94), "VIRAL PNEUMONIA": (89, 96), "NORMAL": (96, 99)}
54
  min_, max_ = base[label]
55
+ return min(100, max(80, random.randint(min_ - 1, max_ + 1)))
56
+
57
 
58
  def get_temp(label):
59
  if label == "NORMAL":
 
62
  min_, max_ = 99.0, 103.0
63
  return round(random.uniform(min_ - 0.6, max_ + 0.6), 1)
64
 
65
+
66
+ def get_age():
67
+ return random.randint(18, 85)
68
+
69
+
70
+ def get_days():
71
+ return random.randint(1, 10)
72
+
73
 
74
  # EMR generator
75
  def build_emr(label, i):
 
111
  random.shuffle(body[1:])
112
  return " ".join(body)
113
 
114
+
115
  # Generate records
116
  records = []
117
  for label, img_dir in categories.items():
 
120
  image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
121
  text = build_emr(label, i)
122
  triage = triage_map[label]
123
+ records.append([f"{label}-{i + 1}", image_path, text, triage])
124
 
125
  # Shuffle + Write
126
  random.shuffle(records)
experiments/train_optuna.py CHANGED
@@ -1,18 +1,18 @@
1
  import os
2
  import sys
3
- import torch
4
  import optuna
5
  import yaml
6
  import json
7
- import wandb
8
  import argparse
9
- import matplotlib.pyplot as plt
10
- import seaborn as sns
11
 
12
  from tqdm import tqdm
13
- from torch.utils.data import DataLoader, Subset
14
  from torch.nn import CrossEntropyLoss
15
- from torch.optim import Adam
16
  from sklearn.model_selection import StratifiedShuffleSplit
17
  from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
18
 
@@ -30,13 +30,14 @@ from src.multimodal_model import MediLLMModel
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
 
33
  def stratified_split(dataset, val_ratio=0.2, seed=42, label_column="triage_level"):
34
- label_map = {"low": 0, "medium": 1, "high": 2}
35
  labels = [dataset.df.iloc[i][label_column] for i in range(len(dataset))]
36
  sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
37
  train_idx, val_idx = next(sss.split(range(len(dataset)), labels))
38
  return Subset(dataset, train_idx), Subset(dataset, val_idx)
39
 
 
40
  def objective(trial, mode):
41
  wandb.init(
42
  project=f"mediLLM-tune-{mode}",
@@ -69,7 +70,7 @@ def objective(trial, mode):
69
 
70
  for epoch in range(2):
71
  model.train()
72
- loop = tqdm(train_loader, desc=f"[{mode}] Epoch {epoch+1}/2", leave=False)
73
  for batch in loop:
74
  input_ids = batch.get("input_ids", None)
75
  attention_mask = batch.get("attention_mask", None)
@@ -136,18 +137,19 @@ def objective(trial, mode):
136
  plt.ylabel("True")
137
  wandb.log({f"{mode}_confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
138
  plt.close()
139
-
140
  return f1
141
 
 
142
  def get_args():
143
  parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
144
  parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
145
  parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
146
  return parser.parse_args()
147
 
148
- if __name__=="__main__":
 
149
  args = get_args()
150
- mode = args.mode
151
 
152
  study = optuna.create_study(
153
  study_name=f"mediLLM_{mode}_optuna",
@@ -160,7 +162,6 @@ if __name__=="__main__":
160
  finally:
161
  wandb.finish()
162
  pbar.update(1)
163
-
164
  study.optimize(wrapped_objective, n_trials=args.n_trials)
165
 
166
  print(f"✅ Best F1 score for {mode}: {study.best_value}")
@@ -196,7 +197,6 @@ if __name__=="__main__":
196
 
197
  # Export to config.yaml
198
  config_path = os.path.join(base_dir, "config", "config.yaml")
199
-
200
  # Make sure config directory exists in the root
201
  os.makedirs(os.path.dirname(config_path), exist_ok=True)
202
 
@@ -218,6 +218,3 @@ if __name__=="__main__":
218
  yaml.dump(config, f, sort_keys=False)
219
 
220
  print(f"✅ Best hyperparameters for [{mode}] saved in config.yaml")
221
-
222
-
223
-
 
1
  import os
2
  import sys
3
+ import torch
4
  import optuna
5
  import yaml
6
  import json
7
+ import wandb
8
  import argparse
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
 
12
  from tqdm import tqdm
13
+ from torch.utils.data import DataLoader, Subset
14
  from torch.nn import CrossEntropyLoss
15
+ from torch.optim import Adam
16
  from sklearn.model_selection import StratifiedShuffleSplit
17
  from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
18
 
 
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
+
34
  def stratified_split(dataset, val_ratio=0.2, seed=42, label_column="triage_level"):
 
35
  labels = [dataset.df.iloc[i][label_column] for i in range(len(dataset))]
36
  sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
37
  train_idx, val_idx = next(sss.split(range(len(dataset)), labels))
38
  return Subset(dataset, train_idx), Subset(dataset, val_idx)
39
 
40
+
41
  def objective(trial, mode):
42
  wandb.init(
43
  project=f"mediLLM-tune-{mode}",
 
70
 
71
  for epoch in range(2):
72
  model.train()
73
+ loop = tqdm(train_loader, desc=f"[{mode}] Epoch {epoch + 1}/2", leave=False)
74
  for batch in loop:
75
  input_ids = batch.get("input_ids", None)
76
  attention_mask = batch.get("attention_mask", None)
 
137
  plt.ylabel("True")
138
  wandb.log({f"{mode}_confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
139
  plt.close()
 
140
  return f1
141
 
142
+
143
  def get_args():
144
  parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
145
  parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
146
  parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
147
  return parser.parse_args()
148
 
149
+
150
+ if __name__ == "__main__":
151
  args = get_args()
152
+ mode = args.mode
153
 
154
  study = optuna.create_study(
155
  study_name=f"mediLLM_{mode}_optuna",
 
162
  finally:
163
  wandb.finish()
164
  pbar.update(1)
 
165
  study.optimize(wrapped_objective, n_trials=args.n_trials)
166
 
167
  print(f"✅ Best F1 score for {mode}: {study.best_value}")
 
197
 
198
  # Export to config.yaml
199
  config_path = os.path.join(base_dir, "config", "config.yaml")
 
200
  # Make sure config directory exists in the root
201
  os.makedirs(os.path.dirname(config_path), exist_ok=True)
202
 
 
218
  yaml.dump(config, f, sort_keys=False)
219
 
220
  print(f"✅ Best hyperparameters for [{mode}] saved in config.yaml")
 
 
 
src/generate_emr_csv.py CHANGED
@@ -137,7 +137,7 @@ def generate_dataset():
137
  )
138
  text = build_emr(label, i)
139
  triage = triage_map[label]
140
- records.append([f"{label}-{i+1}", image_path, text, triage])
141
 
142
  # Shuffle + write
143
  random.shuffle(records)
 
137
  )
138
  text = build_emr(label, i)
139
  triage = triage_map[label]
140
+ records.append([f"{label}-{i + 1}", image_path, text, triage])
141
 
142
  # Shuffle + write
143
  random.shuffle(records)
src/train.py CHANGED
@@ -123,7 +123,7 @@ def train_model(mode="multimodal"):
123
  all_preds, all_labels = [], []
124
 
125
  for batch in tqdm(
126
- train_loader, desc=f"[{mode}] Epoch {epoch+1}"
127
  ): # Load a batch of text, images, and labels to GPU or CPU
128
  input_ids = batch.get("input_ids", None)
129
  attention_mask = batch.get("attention_mask", None)
 
123
  all_preds, all_labels = [], []
124
 
125
  for batch in tqdm(
126
+ train_loader, desc=f"[{mode}] Epoch {epoch + 1}"
127
  ): # Load a batch of text, images, and labels to GPU or CPU
128
  input_ids = batch.get("input_ids", None)
129
  attention_mask = batch.get("attention_mask", None)