Spaces:
Sleeping
Sleeping
Commit
·
b136189
1
Parent(s):
9218201
formatting changes
Browse files- experiments/csv_file_generator_iterations/generate_emr_csv_final.py +8 -1
- experiments/csv_file_generator_iterations/generate_emr_csv_v1.py +12 -7
- experiments/csv_file_generator_iterations/generate_emr_csv_v2.py +14 -4
- experiments/train_optuna.py +13 -16
- src/generate_emr_csv.py +1 -1
- src/train.py +1 -1
experiments/csv_file_generator_iterations/generate_emr_csv_final.py
CHANGED
|
@@ -49,12 +49,14 @@ neutral_noise = [
|
|
| 49 |
"No medications currently administered.",
|
| 50 |
]
|
| 51 |
|
|
|
|
| 52 |
def random_token():
|
| 53 |
prefix = "ID"
|
| 54 |
letters = ''.join(random.choices(string.ascii_uppercase, k=2))
|
| 55 |
digits = ''.join(random.choices(string.digits, k=2))
|
| 56 |
return f"{prefix}-{letters}{digits}"
|
| 57 |
|
|
|
|
| 58 |
def get_oxygen(label):
|
| 59 |
# Soft blur across classes
|
| 60 |
if label == "NORMAL":
|
|
@@ -64,18 +66,22 @@ def get_oxygen(label):
|
|
| 64 |
else:
|
| 65 |
return random.randint(87, 94)
|
| 66 |
|
|
|
|
| 67 |
def get_temp(label):
|
| 68 |
if label == "NORMAL":
|
| 69 |
return round(random.uniform(97.5, 99.0), 1)
|
| 70 |
else:
|
| 71 |
return round(random.uniform(98.8, 102.5), 1)
|
| 72 |
|
|
|
|
| 73 |
def get_age():
|
| 74 |
return random.randint(18, 85)
|
| 75 |
|
|
|
|
| 76 |
def get_days():
|
| 77 |
return random.randint(1, 10)
|
| 78 |
|
|
|
|
| 79 |
def build_emr(label, i):
|
| 80 |
pid = random_token()
|
| 81 |
age = f"{get_age()}-year-old"
|
|
@@ -112,6 +118,7 @@ def build_emr(label, i):
|
|
| 112 |
random.shuffle(body[1:]) # Keep intro in position 0
|
| 113 |
return " ".join(body)
|
| 114 |
|
|
|
|
| 115 |
# Generate records
|
| 116 |
records = []
|
| 117 |
for label, img_dir in categories.items():
|
|
@@ -120,7 +127,7 @@ for label, img_dir in categories.items():
|
|
| 120 |
image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
|
| 121 |
text = build_emr(label, i)
|
| 122 |
triage = triage_map[label]
|
| 123 |
-
records.append([f"{label}-{i+1}", image_path, text, triage])
|
| 124 |
|
| 125 |
# Shuffle + write
|
| 126 |
random.shuffle(records)
|
|
|
|
| 49 |
"No medications currently administered.",
|
| 50 |
]
|
| 51 |
|
| 52 |
+
|
| 53 |
def random_token():
|
| 54 |
prefix = "ID"
|
| 55 |
letters = ''.join(random.choices(string.ascii_uppercase, k=2))
|
| 56 |
digits = ''.join(random.choices(string.digits, k=2))
|
| 57 |
return f"{prefix}-{letters}{digits}"
|
| 58 |
|
| 59 |
+
|
| 60 |
def get_oxygen(label):
|
| 61 |
# Soft blur across classes
|
| 62 |
if label == "NORMAL":
|
|
|
|
| 66 |
else:
|
| 67 |
return random.randint(87, 94)
|
| 68 |
|
| 69 |
+
|
| 70 |
def get_temp(label):
|
| 71 |
if label == "NORMAL":
|
| 72 |
return round(random.uniform(97.5, 99.0), 1)
|
| 73 |
else:
|
| 74 |
return round(random.uniform(98.8, 102.5), 1)
|
| 75 |
|
| 76 |
+
|
| 77 |
def get_age():
|
| 78 |
return random.randint(18, 85)
|
| 79 |
|
| 80 |
+
|
| 81 |
def get_days():
|
| 82 |
return random.randint(1, 10)
|
| 83 |
|
| 84 |
+
|
| 85 |
def build_emr(label, i):
|
| 86 |
pid = random_token()
|
| 87 |
age = f"{get_age()}-year-old"
|
|
|
|
| 118 |
random.shuffle(body[1:]) # Keep intro in position 0
|
| 119 |
return " ".join(body)
|
| 120 |
|
| 121 |
+
|
| 122 |
# Generate records
|
| 123 |
records = []
|
| 124 |
for label, img_dir in categories.items():
|
|
|
|
| 127 |
image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
|
| 128 |
text = build_emr(label, i)
|
| 129 |
triage = triage_map[label]
|
| 130 |
+
records.append([f"{label}-{i + 1}", image_path, text, triage])
|
| 131 |
|
| 132 |
# Shuffle + write
|
| 133 |
random.shuffle(records)
|
experiments/csv_file_generator_iterations/generate_emr_csv_v1.py
CHANGED
|
@@ -8,7 +8,7 @@ IMAGES_DIR = CURRENT_DIR.parent / "data" / "images"
|
|
| 8 |
OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records_extended.csv"
|
| 9 |
|
| 10 |
# Sample size
|
| 11 |
-
SAMPLES_PER_CLASS = 300
|
| 12 |
|
| 13 |
# Categories and labels
|
| 14 |
categories = {
|
|
@@ -56,6 +56,8 @@ ambiguous_templates = [
|
|
| 56 |
]
|
| 57 |
|
| 58 |
# --- Vitals & Symptoms ---
|
|
|
|
|
|
|
| 59 |
def get_oxygen(label):
|
| 60 |
base_ranges = {
|
| 61 |
"COVID": (85, 94),
|
|
@@ -67,6 +69,7 @@ def get_oxygen(label):
|
|
| 67 |
oxygen = random.randint(base_min - 1, base_max + 1)
|
| 68 |
return min(100, max(80, oxygen))
|
| 69 |
|
|
|
|
| 70 |
def get_temp(label):
|
| 71 |
if label == "NORMAL":
|
| 72 |
base_min, base_max = 97.0, 98.6
|
|
@@ -76,21 +79,23 @@ def get_temp(label):
|
|
| 76 |
# Apply + or - 0.5°F blur and clamp between 95-105°F
|
| 77 |
temp = random.uniform(base_min - 0.5, base_max + 0.5)
|
| 78 |
return round(min(105.0, max(95.0, temp)), 1)
|
| 79 |
-
|
|
|
|
| 80 |
def get_days():
|
| 81 |
return random.randint(1, 14)
|
| 82 |
|
|
|
|
| 83 |
def get_age():
|
| 84 |
return random.randint(18, 80)
|
| 85 |
|
|
|
|
| 86 |
# --- Templates ---
|
| 87 |
def build_emr(label, i):
|
| 88 |
-
name = f"Patient-{label}-{i+1}"
|
| 89 |
age = f"{get_age()}-year-old"
|
| 90 |
days = get_days()
|
| 91 |
temp = get_temp(label)
|
| 92 |
oxygen = get_oxygen(label)
|
| 93 |
-
|
| 94 |
# Symptoms Pool
|
| 95 |
symptoms = {
|
| 96 |
"COVID": [
|
|
@@ -138,12 +143,12 @@ def build_emr(label, i):
|
|
| 138 |
|
| 139 |
# adding noise to 90% of cases
|
| 140 |
if random.random() < 0.9:
|
| 141 |
-
for _ in range(random.randint(1,2)):
|
| 142 |
body.insert(random.randint(0, len(body)), random.choice(noise_sentences))
|
| 143 |
-
|
| 144 |
random.shuffle(body)
|
| 145 |
return " ".join(body)
|
| 146 |
|
|
|
|
| 147 |
# Generate dataset
|
| 148 |
records = []
|
| 149 |
for label, img_dir in categories.items():
|
|
@@ -152,7 +157,7 @@ for label, img_dir in categories.items():
|
|
| 152 |
[f for f in img_dir.glob("*") if f.suffix.lower() in valid_exts]
|
| 153 |
)
|
| 154 |
for i in range(SAMPLES_PER_CLASS):
|
| 155 |
-
patient_id = f"{label}-{i+1}"
|
| 156 |
image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
|
| 157 |
emr_text = build_emr(label, i)
|
| 158 |
triage_level = triage_map[label]
|
|
|
|
| 8 |
OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records_extended.csv"
|
| 9 |
|
| 10 |
# Sample size
|
| 11 |
+
SAMPLES_PER_CLASS = 300 # 300 * 3 = 900 total
|
| 12 |
|
| 13 |
# Categories and labels
|
| 14 |
categories = {
|
|
|
|
| 56 |
]
|
| 57 |
|
| 58 |
# --- Vitals & Symptoms ---
|
| 59 |
+
|
| 60 |
+
|
| 61 |
def get_oxygen(label):
|
| 62 |
base_ranges = {
|
| 63 |
"COVID": (85, 94),
|
|
|
|
| 69 |
oxygen = random.randint(base_min - 1, base_max + 1)
|
| 70 |
return min(100, max(80, oxygen))
|
| 71 |
|
| 72 |
+
|
| 73 |
def get_temp(label):
|
| 74 |
if label == "NORMAL":
|
| 75 |
base_min, base_max = 97.0, 98.6
|
|
|
|
| 79 |
# Apply + or - 0.5°F blur and clamp between 95-105°F
|
| 80 |
temp = random.uniform(base_min - 0.5, base_max + 0.5)
|
| 81 |
return round(min(105.0, max(95.0, temp)), 1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
def get_days():
|
| 85 |
return random.randint(1, 14)
|
| 86 |
|
| 87 |
+
|
| 88 |
def get_age():
|
| 89 |
return random.randint(18, 80)
|
| 90 |
|
| 91 |
+
|
| 92 |
# --- Templates ---
|
| 93 |
def build_emr(label, i):
|
| 94 |
+
name = f"Patient-{label}-{i + 1}"
|
| 95 |
age = f"{get_age()}-year-old"
|
| 96 |
days = get_days()
|
| 97 |
temp = get_temp(label)
|
| 98 |
oxygen = get_oxygen(label)
|
|
|
|
| 99 |
# Symptoms Pool
|
| 100 |
symptoms = {
|
| 101 |
"COVID": [
|
|
|
|
| 143 |
|
| 144 |
# adding noise to 90% of cases
|
| 145 |
if random.random() < 0.9:
|
| 146 |
+
for _ in range(random.randint(1, 2)):
|
| 147 |
body.insert(random.randint(0, len(body)), random.choice(noise_sentences))
|
|
|
|
| 148 |
random.shuffle(body)
|
| 149 |
return " ".join(body)
|
| 150 |
|
| 151 |
+
|
| 152 |
# Generate dataset
|
| 153 |
records = []
|
| 154 |
for label, img_dir in categories.items():
|
|
|
|
| 157 |
[f for f in img_dir.glob("*") if f.suffix.lower() in valid_exts]
|
| 158 |
)
|
| 159 |
for i in range(SAMPLES_PER_CLASS):
|
| 160 |
+
patient_id = f"{label}-{i + 1}"
|
| 161 |
image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
|
| 162 |
emr_text = build_emr(label, i)
|
| 163 |
triage_level = triage_map[label]
|
experiments/csv_file_generator_iterations/generate_emr_csv_v2.py
CHANGED
|
@@ -39,6 +39,7 @@ neutral_noise = [
|
|
| 39 |
"Patient expresses concern about possible flu.",
|
| 40 |
]
|
| 41 |
|
|
|
|
| 42 |
# ---Patient random token genrator ---
|
| 43 |
def random_token():
|
| 44 |
prefix = "ID"
|
|
@@ -46,11 +47,13 @@ def random_token():
|
|
| 46 |
digits = ''.join(random.choices(string.digits, k=2))
|
| 47 |
return f"{prefix}-{letters}{digits}"
|
| 48 |
|
|
|
|
| 49 |
# Vitals (blurred)
|
| 50 |
def get_oxygen(label):
|
| 51 |
base = {"COVID": (85, 94), "VIRAL PNEUMONIA": (89, 96), "NORMAL": (96, 99)}
|
| 52 |
min_, max_ = base[label]
|
| 53 |
-
return min(100, max(80, random.randint(min_-1, max_+1)))
|
|
|
|
| 54 |
|
| 55 |
def get_temp(label):
|
| 56 |
if label == "NORMAL":
|
|
@@ -59,8 +62,14 @@ def get_temp(label):
|
|
| 59 |
min_, max_ = 99.0, 103.0
|
| 60 |
return round(random.uniform(min_ - 0.6, max_ + 0.6), 1)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# EMR generator
|
| 66 |
def build_emr(label, i):
|
|
@@ -102,6 +111,7 @@ def build_emr(label, i):
|
|
| 102 |
random.shuffle(body[1:])
|
| 103 |
return " ".join(body)
|
| 104 |
|
|
|
|
| 105 |
# Generate records
|
| 106 |
records = []
|
| 107 |
for label, img_dir in categories.items():
|
|
@@ -110,7 +120,7 @@ for label, img_dir in categories.items():
|
|
| 110 |
image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
|
| 111 |
text = build_emr(label, i)
|
| 112 |
triage = triage_map[label]
|
| 113 |
-
records.append([f"{label}-{i+1}", image_path, text, triage])
|
| 114 |
|
| 115 |
# Shuffle + Write
|
| 116 |
random.shuffle(records)
|
|
|
|
| 39 |
"Patient expresses concern about possible flu.",
|
| 40 |
]
|
| 41 |
|
| 42 |
+
|
| 43 |
# ---Patient random token genrator ---
|
| 44 |
def random_token():
|
| 45 |
prefix = "ID"
|
|
|
|
| 47 |
digits = ''.join(random.choices(string.digits, k=2))
|
| 48 |
return f"{prefix}-{letters}{digits}"
|
| 49 |
|
| 50 |
+
|
| 51 |
# Vitals (blurred)
|
| 52 |
def get_oxygen(label):
|
| 53 |
base = {"COVID": (85, 94), "VIRAL PNEUMONIA": (89, 96), "NORMAL": (96, 99)}
|
| 54 |
min_, max_ = base[label]
|
| 55 |
+
return min(100, max(80, random.randint(min_ - 1, max_ + 1)))
|
| 56 |
+
|
| 57 |
|
| 58 |
def get_temp(label):
|
| 59 |
if label == "NORMAL":
|
|
|
|
| 62 |
min_, max_ = 99.0, 103.0
|
| 63 |
return round(random.uniform(min_ - 0.6, max_ + 0.6), 1)
|
| 64 |
|
| 65 |
+
|
| 66 |
+
def get_age():
|
| 67 |
+
return random.randint(18, 85)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_days():
|
| 71 |
+
return random.randint(1, 10)
|
| 72 |
+
|
| 73 |
|
| 74 |
# EMR generator
|
| 75 |
def build_emr(label, i):
|
|
|
|
| 111 |
random.shuffle(body[1:])
|
| 112 |
return " ".join(body)
|
| 113 |
|
| 114 |
+
|
| 115 |
# Generate records
|
| 116 |
records = []
|
| 117 |
for label, img_dir in categories.items():
|
|
|
|
| 120 |
image_path = str(random.choice(image_files).relative_to(IMAGES_DIR.parent.parent))
|
| 121 |
text = build_emr(label, i)
|
| 122 |
triage = triage_map[label]
|
| 123 |
+
records.append([f"{label}-{i + 1}", image_path, text, triage])
|
| 124 |
|
| 125 |
# Shuffle + Write
|
| 126 |
random.shuffle(records)
|
experiments/train_optuna.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
import torch
|
| 4 |
import optuna
|
| 5 |
import yaml
|
| 6 |
import json
|
| 7 |
-
import wandb
|
| 8 |
import argparse
|
| 9 |
-
import matplotlib.pyplot as plt
|
| 10 |
-
import seaborn as sns
|
| 11 |
|
| 12 |
from tqdm import tqdm
|
| 13 |
-
from torch.utils.data import DataLoader, Subset
|
| 14 |
from torch.nn import CrossEntropyLoss
|
| 15 |
-
from torch.optim import Adam
|
| 16 |
from sklearn.model_selection import StratifiedShuffleSplit
|
| 17 |
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
|
| 18 |
|
|
@@ -30,13 +30,14 @@ from src.multimodal_model import MediLLMModel
|
|
| 30 |
|
| 31 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
|
|
|
|
| 33 |
def stratified_split(dataset, val_ratio=0.2, seed=42, label_column="triage_level"):
|
| 34 |
-
label_map = {"low": 0, "medium": 1, "high": 2}
|
| 35 |
labels = [dataset.df.iloc[i][label_column] for i in range(len(dataset))]
|
| 36 |
sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
|
| 37 |
train_idx, val_idx = next(sss.split(range(len(dataset)), labels))
|
| 38 |
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
| 39 |
|
|
|
|
| 40 |
def objective(trial, mode):
|
| 41 |
wandb.init(
|
| 42 |
project=f"mediLLM-tune-{mode}",
|
|
@@ -69,7 +70,7 @@ def objective(trial, mode):
|
|
| 69 |
|
| 70 |
for epoch in range(2):
|
| 71 |
model.train()
|
| 72 |
-
loop = tqdm(train_loader, desc=f"[{mode}] Epoch {epoch+1}/2", leave=False)
|
| 73 |
for batch in loop:
|
| 74 |
input_ids = batch.get("input_ids", None)
|
| 75 |
attention_mask = batch.get("attention_mask", None)
|
|
@@ -136,18 +137,19 @@ def objective(trial, mode):
|
|
| 136 |
plt.ylabel("True")
|
| 137 |
wandb.log({f"{mode}_confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
|
| 138 |
plt.close()
|
| 139 |
-
|
| 140 |
return f1
|
| 141 |
|
|
|
|
| 142 |
def get_args():
|
| 143 |
parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
|
| 144 |
parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
|
| 145 |
parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
|
| 146 |
return parser.parse_args()
|
| 147 |
|
| 148 |
-
|
|
|
|
| 149 |
args = get_args()
|
| 150 |
-
mode = args.mode
|
| 151 |
|
| 152 |
study = optuna.create_study(
|
| 153 |
study_name=f"mediLLM_{mode}_optuna",
|
|
@@ -160,7 +162,6 @@ if __name__=="__main__":
|
|
| 160 |
finally:
|
| 161 |
wandb.finish()
|
| 162 |
pbar.update(1)
|
| 163 |
-
|
| 164 |
study.optimize(wrapped_objective, n_trials=args.n_trials)
|
| 165 |
|
| 166 |
print(f"✅ Best F1 score for {mode}: {study.best_value}")
|
|
@@ -196,7 +197,6 @@ if __name__=="__main__":
|
|
| 196 |
|
| 197 |
# Export to config.yaml
|
| 198 |
config_path = os.path.join(base_dir, "config", "config.yaml")
|
| 199 |
-
|
| 200 |
# Make sure config directory exists in the root
|
| 201 |
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
| 202 |
|
|
@@ -218,6 +218,3 @@ if __name__=="__main__":
|
|
| 218 |
yaml.dump(config, f, sort_keys=False)
|
| 219 |
|
| 220 |
print(f"✅ Best hyperparameters for [{mode}] saved in config.yaml")
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import torch
|
| 4 |
import optuna
|
| 5 |
import yaml
|
| 6 |
import json
|
| 7 |
+
import wandb
|
| 8 |
import argparse
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
|
| 12 |
from tqdm import tqdm
|
| 13 |
+
from torch.utils.data import DataLoader, Subset
|
| 14 |
from torch.nn import CrossEntropyLoss
|
| 15 |
+
from torch.optim import Adam
|
| 16 |
from sklearn.model_selection import StratifiedShuffleSplit
|
| 17 |
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
|
| 18 |
|
|
|
|
| 30 |
|
| 31 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
|
| 33 |
+
|
| 34 |
def stratified_split(dataset, val_ratio=0.2, seed=42, label_column="triage_level"):
|
|
|
|
| 35 |
labels = [dataset.df.iloc[i][label_column] for i in range(len(dataset))]
|
| 36 |
sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
|
| 37 |
train_idx, val_idx = next(sss.split(range(len(dataset)), labels))
|
| 38 |
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
| 39 |
|
| 40 |
+
|
| 41 |
def objective(trial, mode):
|
| 42 |
wandb.init(
|
| 43 |
project=f"mediLLM-tune-{mode}",
|
|
|
|
| 70 |
|
| 71 |
for epoch in range(2):
|
| 72 |
model.train()
|
| 73 |
+
loop = tqdm(train_loader, desc=f"[{mode}] Epoch {epoch + 1}/2", leave=False)
|
| 74 |
for batch in loop:
|
| 75 |
input_ids = batch.get("input_ids", None)
|
| 76 |
attention_mask = batch.get("attention_mask", None)
|
|
|
|
| 137 |
plt.ylabel("True")
|
| 138 |
wandb.log({f"{mode}_confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
|
| 139 |
plt.close()
|
|
|
|
| 140 |
return f1
|
| 141 |
|
| 142 |
+
|
| 143 |
def get_args():
|
| 144 |
parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
|
| 145 |
parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
|
| 146 |
parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
|
| 147 |
return parser.parse_args()
|
| 148 |
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
args = get_args()
|
| 152 |
+
mode = args.mode
|
| 153 |
|
| 154 |
study = optuna.create_study(
|
| 155 |
study_name=f"mediLLM_{mode}_optuna",
|
|
|
|
| 162 |
finally:
|
| 163 |
wandb.finish()
|
| 164 |
pbar.update(1)
|
|
|
|
| 165 |
study.optimize(wrapped_objective, n_trials=args.n_trials)
|
| 166 |
|
| 167 |
print(f"✅ Best F1 score for {mode}: {study.best_value}")
|
|
|
|
| 197 |
|
| 198 |
# Export to config.yaml
|
| 199 |
config_path = os.path.join(base_dir, "config", "config.yaml")
|
|
|
|
| 200 |
# Make sure config directory exists in the root
|
| 201 |
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
| 202 |
|
|
|
|
| 218 |
yaml.dump(config, f, sort_keys=False)
|
| 219 |
|
| 220 |
print(f"✅ Best hyperparameters for [{mode}] saved in config.yaml")
|
|
|
|
|
|
|
|
|
src/generate_emr_csv.py
CHANGED
|
@@ -137,7 +137,7 @@ def generate_dataset():
|
|
| 137 |
)
|
| 138 |
text = build_emr(label, i)
|
| 139 |
triage = triage_map[label]
|
| 140 |
-
records.append([f"{label}-{i+1}", image_path, text, triage])
|
| 141 |
|
| 142 |
# Shuffle + write
|
| 143 |
random.shuffle(records)
|
|
|
|
| 137 |
)
|
| 138 |
text = build_emr(label, i)
|
| 139 |
triage = triage_map[label]
|
| 140 |
+
records.append([f"{label}-{i + 1}", image_path, text, triage])
|
| 141 |
|
| 142 |
# Shuffle + write
|
| 143 |
random.shuffle(records)
|
src/train.py
CHANGED
|
@@ -123,7 +123,7 @@ def train_model(mode="multimodal"):
|
|
| 123 |
all_preds, all_labels = [], []
|
| 124 |
|
| 125 |
for batch in tqdm(
|
| 126 |
-
train_loader, desc=f"[{mode}] Epoch {epoch+1}"
|
| 127 |
): # Load a batch of text, images, and labels to GPU or CPU
|
| 128 |
input_ids = batch.get("input_ids", None)
|
| 129 |
attention_mask = batch.get("attention_mask", None)
|
|
|
|
| 123 |
all_preds, all_labels = [], []
|
| 124 |
|
| 125 |
for batch in tqdm(
|
| 126 |
+
train_loader, desc=f"[{mode}] Epoch {epoch + 1}"
|
| 127 |
): # Load a batch of text, images, and labels to GPU or CPU
|
| 128 |
input_ids = batch.get("input_ids", None)
|
| 129 |
attention_mask = batch.get("attention_mask", None)
|