Spaces:
Sleeping
Sleeping
Commit
·
c8b8d35
1
Parent(s):
b6f867a
Made changes to optimize hyperparams for ablation study
Browse files- experiments/train_optuna.py +107 -80
- src/train.py +56 -37
experiments/train_optuna.py
CHANGED
|
@@ -1,26 +1,29 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
| 4 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 5 |
# Automatically add Project root to python import path
|
| 6 |
base_dir = os.path.dirname(os.path.dirname(__file__))
|
| 7 |
if base_dir not in sys.path:
|
| 8 |
sys.path.append(base_dir)
|
| 9 |
|
| 10 |
-
import torch
|
| 11 |
-
import optuna
|
| 12 |
-
from torch.utils.data import DataLoader, Subset
|
| 13 |
-
from torch.nn import CrossEntropyLoss
|
| 14 |
-
from torch.optim import Adam
|
| 15 |
-
from sklearn.model_selection import StratifiedShuffleSplit
|
| 16 |
-
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
|
| 17 |
-
from tqdm import tqdm
|
| 18 |
-
import matplotlib.pyplot as plt
|
| 19 |
-
import seaborn as sns
|
| 20 |
-
import wandb
|
| 21 |
-
import json
|
| 22 |
-
import yaml
|
| 23 |
-
import argparse
|
| 24 |
|
| 25 |
from src.triage_dataset import TriageDataset
|
| 26 |
from src.multimodal_model import MediLLMModel
|
|
@@ -36,12 +39,13 @@ def stratified_split(dataset, val_ratio=0.2, seed=42, label_column="triage_level
|
|
| 36 |
|
| 37 |
def objective(trial):
|
| 38 |
wandb.init(
|
| 39 |
-
project="mediLLM-
|
| 40 |
-
name=f"trial-{trial.number}-
|
| 41 |
group="SoftLabelTrials",
|
| 42 |
config={
|
| 43 |
"dataset_version": "softlabels",
|
| 44 |
-
"dataset_size": 900
|
|
|
|
| 45 |
}
|
| 46 |
)
|
| 47 |
|
|
@@ -51,7 +55,7 @@ def objective(trial):
|
|
| 51 |
hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256, 512])
|
| 52 |
batch_size = trial.suggest_categorical("bs", [4, 8, 16])
|
| 53 |
|
| 54 |
-
model = MediLLMModel(dropout=dropout, hidden_dim=hidden_dim).to(device)
|
| 55 |
wandb.watch(model)
|
| 56 |
|
| 57 |
dataset = TriageDataset(os.path.join(base_dir, "data", "emr_records.csv"))
|
|
@@ -65,132 +69,155 @@ def objective(trial):
|
|
| 65 |
|
| 66 |
for epoch in range(2):
|
| 67 |
model.train()
|
| 68 |
-
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/2", leave=False)
|
| 69 |
for batch in loop:
|
| 70 |
-
input_ids = batch
|
| 71 |
-
attention_mask = batch
|
| 72 |
-
images = batch
|
| 73 |
labels = batch["label"].to(device)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
optimizer.zero_grad()
|
| 76 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
|
| 77 |
loss = criterion(outputs, labels)
|
| 78 |
loss.backward()
|
| 79 |
optimizer.step()
|
| 80 |
-
|
| 81 |
loop.set_postfix(loss=loss.item())
|
| 82 |
|
| 83 |
# Validation
|
| 84 |
model.eval()
|
| 85 |
all_preds, all_labels = [], []
|
| 86 |
with torch.no_grad():
|
| 87 |
-
for batch in tqdm(val_loader, desc="Validating", leave=False):
|
| 88 |
-
input_ids = batch
|
| 89 |
-
attention_mask = batch
|
| 90 |
-
images = batch
|
| 91 |
labels = batch["label"].to(device)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
|
| 94 |
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
| 95 |
all_preds.extend(preds)
|
| 96 |
all_labels.extend(labels.cpu().numpy())
|
| 97 |
|
| 98 |
f1 = f1_score(all_labels, all_preds, average="weighted")
|
| 99 |
-
|
| 100 |
-
print(classification_report(all_labels, all_preds, target_names=["low", "medium", "high"]))
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
cm = confusion_matrix(all_labels, all_preds)
|
| 103 |
plt.figure(figsize=(6, 5))
|
| 104 |
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
| 105 |
xticklabels=["low", "medium", "high"],
|
| 106 |
yticklabels=["low", "medium", "high"])
|
| 107 |
-
plt.title(f"Confusion Matrix - Trial {trial.number}")
|
| 108 |
plt.xlabel("Predicted")
|
| 109 |
plt.ylabel("True")
|
| 110 |
-
wandb.log({f"
|
| 111 |
plt.close()
|
| 112 |
|
| 113 |
-
# Log to W&B and Optuna
|
| 114 |
-
wandb.log({
|
| 115 |
-
"f1_score": f1,
|
| 116 |
-
"accuracy": accuracy_score(all_labels, all_preds),
|
| 117 |
-
"lr": lr,
|
| 118 |
-
"dropout": dropout,
|
| 119 |
-
"hidden_dim": hidden_dim,
|
| 120 |
-
"batch_size": batch_size
|
| 121 |
-
})
|
| 122 |
return f1
|
| 123 |
|
| 124 |
def get_args():
|
| 125 |
parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
|
| 126 |
parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
|
|
|
|
| 127 |
return parser.parse_args()
|
| 128 |
|
| 129 |
if __name__=="__main__":
|
| 130 |
args = get_args()
|
|
|
|
| 131 |
|
| 132 |
study = optuna.create_study(
|
| 133 |
-
study_name="
|
| 134 |
direction="maximize"
|
| 135 |
)
|
| 136 |
-
with tqdm(total=args.n_trials, desc="Optuna Trials") as pbar:
|
| 137 |
def wrapped_objective(trial):
|
| 138 |
try:
|
| 139 |
-
|
| 140 |
-
return result
|
| 141 |
finally:
|
| 142 |
wandb.finish()
|
| 143 |
pbar.update(1)
|
| 144 |
|
| 145 |
study.optimize(wrapped_objective, n_trials=args.n_trials)
|
| 146 |
|
| 147 |
-
print("Best F1 score
|
| 148 |
-
print("Best hyperparameters:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
os.
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# Export to config.yaml
|
| 161 |
-
|
| 162 |
-
config_path = os.path.join(config_dir, "config.yaml")
|
| 163 |
|
| 164 |
# Make sure config directory exists in the root
|
| 165 |
-
os.makedirs(
|
| 166 |
-
|
| 167 |
-
# If the config file doesn't exist, create a default one
|
| 168 |
-
if not os.path.exists(config_path):
|
| 169 |
-
with open(config_path, "w") as f:
|
| 170 |
-
f.write(
|
| 171 |
-
"model:\n"
|
| 172 |
-
" dropout: 0.3\n"
|
| 173 |
-
" hidden_dim: 256\n\n"
|
| 174 |
-
"train:\n"
|
| 175 |
-
" lr: 2e-5\n"
|
| 176 |
-
" batch_size: 8\n"
|
| 177 |
-
" epochs: 5\n\n"
|
| 178 |
-
"wandb:\n"
|
| 179 |
-
" project: medi-llm-final\n"
|
| 180 |
-
)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
#
|
| 192 |
with open(config_path, "w") as f:
|
| 193 |
-
yaml.dump(
|
|
|
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import optuna
|
| 5 |
+
import yaml
|
| 6 |
+
import json
|
| 7 |
+
import wandb
|
| 8 |
+
import argparse
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from torch.utils.data import DataLoader, Subset
|
| 14 |
+
from torch.nn import CrossEntropyLoss
|
| 15 |
+
from torch.optim import Adam
|
| 16 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
| 17 |
+
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
|
| 18 |
|
| 19 |
+
|
| 20 |
+
# Setup base path
|
| 21 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
# Automatically add Project root to python import path
|
| 23 |
base_dir = os.path.dirname(os.path.dirname(__file__))
|
| 24 |
if base_dir not in sys.path:
|
| 25 |
sys.path.append(base_dir)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
from src.triage_dataset import TriageDataset
|
| 29 |
from src.multimodal_model import MediLLMModel
|
|
|
|
| 39 |
|
| 40 |
def objective(trial):
|
| 41 |
wandb.init(
|
| 42 |
+
project=f"mediLLM-tune-{mode}",
|
| 43 |
+
name=f"{mode}-trial-{trial.number}-v5-{wandb.util.generate_id()}",
|
| 44 |
group="SoftLabelTrials",
|
| 45 |
config={
|
| 46 |
"dataset_version": "softlabels",
|
| 47 |
+
"dataset_size": 900,
|
| 48 |
+
"mode": mode
|
| 49 |
}
|
| 50 |
)
|
| 51 |
|
|
|
|
| 55 |
hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256, 512])
|
| 56 |
batch_size = trial.suggest_categorical("bs", [4, 8, 16])
|
| 57 |
|
| 58 |
+
model = MediLLMModel(dropout=dropout, hidden_dim=hidden_dim, mode=mode).to(device)
|
| 59 |
wandb.watch(model)
|
| 60 |
|
| 61 |
dataset = TriageDataset(os.path.join(base_dir, "data", "emr_records.csv"))
|
|
|
|
| 69 |
|
| 70 |
for epoch in range(2):
|
| 71 |
model.train()
|
| 72 |
+
loop = tqdm(train_loader, desc=f"[{mode}] Epoch {epoch+1}/2", leave=False)
|
| 73 |
for batch in loop:
|
| 74 |
+
input_ids = batch.get("input_ids", None)
|
| 75 |
+
attention_mask = batch.get("attention_mask", None)
|
| 76 |
+
images = batch.get("image", None)
|
| 77 |
labels = batch["label"].to(device)
|
| 78 |
|
| 79 |
+
if input_ids is not None:
|
| 80 |
+
input_ids = input_ids.to(device)
|
| 81 |
+
if attention_mask is not None:
|
| 82 |
+
attention_mask = attention_mask.to(device)
|
| 83 |
+
if images is not None:
|
| 84 |
+
images = images.to(device)
|
| 85 |
+
|
| 86 |
optimizer.zero_grad()
|
| 87 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
|
| 88 |
loss = criterion(outputs, labels)
|
| 89 |
loss.backward()
|
| 90 |
optimizer.step()
|
|
|
|
| 91 |
loop.set_postfix(loss=loss.item())
|
| 92 |
|
| 93 |
# Validation
|
| 94 |
model.eval()
|
| 95 |
all_preds, all_labels = [], []
|
| 96 |
with torch.no_grad():
|
| 97 |
+
for batch in tqdm(val_loader, desc=f"[{mode}] Validating", leave=False):
|
| 98 |
+
input_ids = batch.get("input_ids", None)
|
| 99 |
+
attention_mask = batch.get("attention_mask", None)
|
| 100 |
+
images = batch.get("image", None)
|
| 101 |
labels = batch["label"].to(device)
|
| 102 |
|
| 103 |
+
if input_ids is not None:
|
| 104 |
+
input_ids = input_ids.to(device)
|
| 105 |
+
if attention_mask is not None:
|
| 106 |
+
attention_mask = attention_mask.to(device)
|
| 107 |
+
if images is not None:
|
| 108 |
+
images = images.to(device)
|
| 109 |
+
|
| 110 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
|
| 111 |
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
| 112 |
all_preds.extend(preds)
|
| 113 |
all_labels.extend(labels.cpu().numpy())
|
| 114 |
|
| 115 |
f1 = f1_score(all_labels, all_preds, average="weighted")
|
| 116 |
+
acc = accuracy_score(all_labels, all_preds)
|
|
|
|
| 117 |
|
| 118 |
+
# Log to W&B and Optuna
|
| 119 |
+
wandb.log({
|
| 120 |
+
"val_f1_score": f1,
|
| 121 |
+
"val_accuracy": acc,
|
| 122 |
+
"lr": lr,
|
| 123 |
+
"dropout": dropout,
|
| 124 |
+
"hidden_dim": hidden_dim,
|
| 125 |
+
"batch_size": batch_size
|
| 126 |
+
})
|
| 127 |
+
|
| 128 |
+
# Confusion Matrix
|
| 129 |
cm = confusion_matrix(all_labels, all_preds)
|
| 130 |
plt.figure(figsize=(6, 5))
|
| 131 |
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
| 132 |
xticklabels=["low", "medium", "high"],
|
| 133 |
yticklabels=["low", "medium", "high"])
|
| 134 |
+
plt.title(f"Confusion Matrix - {mode} Trial {trial.number}")
|
| 135 |
plt.xlabel("Predicted")
|
| 136 |
plt.ylabel("True")
|
| 137 |
+
wandb.log({f"{mode}_confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
|
| 138 |
plt.close()
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
return f1
|
| 141 |
|
| 142 |
def get_args():
|
| 143 |
parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
|
| 144 |
parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
|
| 145 |
+
parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
|
| 146 |
return parser.parse_args()
|
| 147 |
|
| 148 |
if __name__=="__main__":
|
| 149 |
args = get_args()
|
| 150 |
+
mode = args.mode
|
| 151 |
|
| 152 |
study = optuna.create_study(
|
| 153 |
+
study_name=f"mediLLM_{mode}_optuna",
|
| 154 |
direction="maximize"
|
| 155 |
)
|
| 156 |
+
with tqdm(total=args.n_trials, desc=f"Optuna Trials [{mode}]") as pbar:
|
| 157 |
def wrapped_objective(trial):
|
| 158 |
try:
|
| 159 |
+
return objective(trial, mode)
|
|
|
|
| 160 |
finally:
|
| 161 |
wandb.finish()
|
| 162 |
pbar.update(1)
|
| 163 |
|
| 164 |
study.optimize(wrapped_objective, n_trials=args.n_trials)
|
| 165 |
|
| 166 |
+
print(f"✅ Best F1 score for {mode}: {study.best_value}")
|
| 167 |
+
print(f"✅ Best hyperparameters: {study.best_params}")
|
| 168 |
+
|
| 169 |
+
# Save best hyperparameters to JSON per mode
|
| 170 |
+
json_path = os.path.join(base_dir, "assets", "best_hyperparams.json")
|
| 171 |
+
os.makedirs(os.path.dirname(json_path), exist_ok=True)
|
| 172 |
|
| 173 |
+
best_params_entry = {
|
| 174 |
+
"lr": float(study.best_params["lr"]),
|
| 175 |
+
"dropout": float(study.best_params["dropout"]),
|
| 176 |
+
"hidden_dim": int(study.best_params["hidden_dim"]),
|
| 177 |
+
"batch_size": int(study.best_params["bs"]),
|
| 178 |
+
"epochs": 5
|
| 179 |
+
}
|
| 180 |
|
| 181 |
+
# Load existing or start new
|
| 182 |
+
if os.path.exists(json_path):
|
| 183 |
+
with open(json_path, "r") as f:
|
| 184 |
+
best_params_all = json.load(f)
|
| 185 |
|
| 186 |
+
else:
|
| 187 |
+
best_params_all = {}
|
| 188 |
+
|
| 189 |
+
best_params_all[mode] = best_params_entry
|
| 190 |
+
|
| 191 |
+
# Write back
|
| 192 |
+
with open(json_path, "w") as f:
|
| 193 |
+
json.dump(best_params_all, f, indent=4)
|
| 194 |
+
|
| 195 |
+
print(f"✅ Saved best hyperparameters for [{mode}] to best_hyperparams.json")
|
| 196 |
|
| 197 |
# Export to config.yaml
|
| 198 |
+
config_path = os.path.join(base_dir, "config", "config.yaml")
|
|
|
|
| 199 |
|
| 200 |
# Make sure config directory exists in the root
|
| 201 |
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
config = {}
|
| 204 |
+
if os.path.exists(config_path):
|
| 205 |
+
with open(config_path, "r") as f:
|
| 206 |
+
config = yaml.safe_load(f) or {}
|
| 207 |
|
| 208 |
+
config[mode] = {
|
| 209 |
+
"lr": float(study.best_params["lr"]),
|
| 210 |
+
"dropout": float(study.best_params["dropout"]),
|
| 211 |
+
"hidden_dim": int(study.best_params["hidden_dim"]),
|
| 212 |
+
"batch_size": int(study.best_params["bs"]),
|
| 213 |
+
"epochs": 5
|
| 214 |
+
}
|
| 215 |
|
| 216 |
+
# Export to config.yaml
|
| 217 |
with open(config_path, "w") as f:
|
| 218 |
+
yaml.dump(config, f, sort_keys=False)
|
| 219 |
+
|
| 220 |
+
print(f"✅ Best hyperparameters for [{mode}] saved in config.yaml")
|
| 221 |
|
| 222 |
|
| 223 |
|
src/train.py
CHANGED
|
@@ -1,50 +1,69 @@
|
|
| 1 |
import torch # PyTorch core utility for model training
|
| 2 |
import os
|
| 3 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
|
|
|
| 5 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 6 |
# Automatically add Project root to python import path
|
| 7 |
base_dir = os.path.dirname(os.path.dirname(__file__))
|
| 8 |
if base_dir not in sys.path:
|
| 9 |
sys.path.append(base_dir)
|
| 10 |
|
| 11 |
-
|
| 12 |
-
import yaml
|
| 13 |
-
from torch.utils.data import DataLoader, Subset # Dataloader to batch and feed data to model, random split to split dataset into train and validation sets
|
| 14 |
-
from torch.nn import CrossEntropyLoss # PyTorch core utility for model training
|
| 15 |
-
from torch.optim import Adam # PyTorch core utility for model training, Adam is the Optimizer a gradient descent model
|
| 16 |
-
from sklearn.metrics import accuracy_score, f1_score # Evaluation metrics
|
| 17 |
-
from sklearn.model_selection import StratifiedShuffleSplit
|
| 18 |
-
from tqdm import tqdm # loading bar for loops
|
| 19 |
-
import matplotlib.pyplot as plt # for plotting
|
| 20 |
from src.triage_dataset import TriageDataset # Dataset Class
|
| 21 |
from src.multimodal_model import MediLLMModel # Mutlimodal Model
|
| 22 |
|
| 23 |
-
def load_config():
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
os.makedirs(config_dir, exist_ok=True)
|
| 29 |
-
|
| 30 |
-
# If the config file doesn't exist, create a default one
|
| 31 |
if not os.path.exists(config_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
with open(config_path, "w") as f:
|
| 33 |
-
|
| 34 |
-
"model:\n"
|
| 35 |
-
" dropout: 0.3\n"
|
| 36 |
-
" hidden_dim: 256\n\n"
|
| 37 |
-
"train:\n"
|
| 38 |
-
" lr: 2e-5\n"
|
| 39 |
-
" batch_size: 8\n"
|
| 40 |
-
" epochs: 5\n\n"
|
| 41 |
-
"wandb:\n"
|
| 42 |
-
" project: medi-llm-final\n"
|
| 43 |
-
)
|
| 44 |
|
| 45 |
# otherwise export to yaml
|
| 46 |
with open(config_path, "r") as f:
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def stratified_split(dataset, val_ratio=0.2, seed=42):
|
| 50 |
labels = [dataset.df.iloc[i]["triage_level"] for i in range(len(dataset))]
|
|
@@ -53,33 +72,33 @@ def stratified_split(dataset, val_ratio=0.2, seed=42):
|
|
| 53 |
return Subset(dataset, tran_idx), Subset(dataset, val_idx)
|
| 54 |
|
| 55 |
def train_model(mode="multimodal"): # Function to instantiate model and data, train, validate, plot results and save the model
|
| 56 |
-
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available or else use CPU
|
| 58 |
|
| 59 |
-
dataset_dir = os.path.join(base_dir, "data", "
|
| 60 |
dataset = TriageDataset(
|
| 61 |
csv_file=dataset_dir,
|
| 62 |
mode=mode
|
| 63 |
)
|
| 64 |
|
| 65 |
model = MediLLMModel(
|
| 66 |
-
dropout=
|
| 67 |
-
hidden_dim=
|
| 68 |
mode = mode
|
| 69 |
).to(device) # moves the model to selected device
|
| 70 |
|
| 71 |
train_set, val_set = stratified_split(dataset)
|
| 72 |
-
batch_size =
|
| 73 |
|
| 74 |
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) # Create data in batches to the model
|
| 75 |
val_loader = DataLoader(val_set, batch_size=batch_size)
|
| 76 |
|
| 77 |
criterion = CrossEntropyLoss() # Calculate difference between model prediction and true labels
|
| 78 |
-
optimizer = Adam(model.parameters(), lr=
|
| 79 |
|
| 80 |
train_acc, val_acc = [], [] # Lists to store accuracy per epoch for plotting
|
| 81 |
|
| 82 |
-
for epoch in range(
|
| 83 |
model.train() # Activate training the model, enable dropout
|
| 84 |
all_preds, all_labels = [], []
|
| 85 |
|
|
@@ -152,11 +171,11 @@ def train_model(mode="multimodal"): # Function to instantiate model and data, tr
|
|
| 152 |
print(f"Val Accuracy: {val_acc_epoch:.4f}, F1 Score: {val_f1:.4f}")
|
| 153 |
|
| 154 |
# Save model
|
| 155 |
-
model_path = os.path.join(base_dir, f"
|
| 156 |
torch.save(model.state_dict(), model_path) # Saves the model weights only not total architecture to reuse later
|
| 157 |
|
| 158 |
# Plot accuracy
|
| 159 |
-
plot_path = os.path.join(base_dir, "assets", f"
|
| 160 |
plt.plot(train_acc, label="Train Acc")
|
| 161 |
plt.plot(val_acc, label="Val Acc")
|
| 162 |
plt.legend()
|
|
|
|
| 1 |
import torch # PyTorch core utility for model training
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
+
import yaml
|
| 5 |
+
import argparse
|
| 6 |
+
import matplotlib.pyplot as plt # for plotting
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm # loading bar for loops
|
| 9 |
+
from torch.utils.data import DataLoader, Subset # Dataloader to batch and feed data to model, random split to split dataset into train and validation sets
|
| 10 |
+
from torch.nn import CrossEntropyLoss # PyTorch core utility for model training
|
| 11 |
+
from torch.optim import Adam # PyTorch core utility for model training, Adam is the Optimizer a gradient descent model
|
| 12 |
+
from sklearn.metrics import accuracy_score, f1_score # Evaluation metrics
|
| 13 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
| 14 |
+
|
| 15 |
|
| 16 |
+
# Setup base path
|
| 17 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
# Automatically add Project root to python import path
|
| 19 |
base_dir = os.path.dirname(os.path.dirname(__file__))
|
| 20 |
if base_dir not in sys.path:
|
| 21 |
sys.path.append(base_dir)
|
| 22 |
|
| 23 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from src.triage_dataset import TriageDataset # Dataset Class
|
| 25 |
from src.multimodal_model import MediLLMModel # Mutlimodal Model
|
| 26 |
|
| 27 |
+
def load_config(mode):
|
| 28 |
+
config_path = os.path.join(base_dir, "config", "config.yaml")
|
| 29 |
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
| 30 |
|
| 31 |
+
# If the config file doesn't exist, create it defaults for all modes
|
|
|
|
|
|
|
|
|
|
| 32 |
if not os.path.exists(config_path):
|
| 33 |
+
default_config = {
|
| 34 |
+
"text": {
|
| 35 |
+
"lr": 2e-5,
|
| 36 |
+
"dropout": 0.3,
|
| 37 |
+
"hidden_dim": 256,
|
| 38 |
+
"batch_size": 8,
|
| 39 |
+
"epochs": 5
|
| 40 |
+
},
|
| 41 |
+
"image": {
|
| 42 |
+
"lr": 2e-5,
|
| 43 |
+
"dropout": 0.3,
|
| 44 |
+
"hidden_dim": 256,
|
| 45 |
+
"batch_size": 8,
|
| 46 |
+
"epochs": 5
|
| 47 |
+
},
|
| 48 |
+
"multimodal": {
|
| 49 |
+
"lr": 2e-5,
|
| 50 |
+
"dropout": 0.3,
|
| 51 |
+
"hidden_dim": 256,
|
| 52 |
+
"batch_size": 8,
|
| 53 |
+
"epochs": 5
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
with open(config_path, "w") as f:
|
| 57 |
+
yaml.dump(default_config, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# otherwise export to yaml
|
| 60 |
with open(config_path, "r") as f:
|
| 61 |
+
config = yaml.safe_load(f)
|
| 62 |
+
|
| 63 |
+
if mode not in config:
|
| 64 |
+
raise ValueError(f"No config found for mode '{mode}' in config.yaml")
|
| 65 |
+
|
| 66 |
+
return config[mode]
|
| 67 |
|
| 68 |
def stratified_split(dataset, val_ratio=0.2, seed=42):
|
| 69 |
labels = [dataset.df.iloc[i]["triage_level"] for i in range(len(dataset))]
|
|
|
|
| 72 |
return Subset(dataset, tran_idx), Subset(dataset, val_idx)
|
| 73 |
|
| 74 |
def train_model(mode="multimodal"): # Function to instantiate model and data, train, validate, plot results and save the model
|
| 75 |
+
cfg = load_config(mode)
|
| 76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available or else use CPU
|
| 77 |
|
| 78 |
+
dataset_dir = os.path.join(base_dir, "data", "emr_records.csv")
|
| 79 |
dataset = TriageDataset(
|
| 80 |
csv_file=dataset_dir,
|
| 81 |
mode=mode
|
| 82 |
)
|
| 83 |
|
| 84 |
model = MediLLMModel(
|
| 85 |
+
dropout=cfg["dropout"],
|
| 86 |
+
hidden_dim=cfg["hidden_dim"],
|
| 87 |
mode = mode
|
| 88 |
).to(device) # moves the model to selected device
|
| 89 |
|
| 90 |
train_set, val_set = stratified_split(dataset)
|
| 91 |
+
batch_size = cfg["batch_size"]
|
| 92 |
|
| 93 |
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) # Create data in batches to the model
|
| 94 |
val_loader = DataLoader(val_set, batch_size=batch_size)
|
| 95 |
|
| 96 |
criterion = CrossEntropyLoss() # Calculate difference between model prediction and true labels
|
| 97 |
+
optimizer = Adam(model.parameters(), lr=cfg["lr"]) # Adaptive learning rate optimizer for fast-converging
|
| 98 |
|
| 99 |
train_acc, val_acc = [], [] # Lists to store accuracy per epoch for plotting
|
| 100 |
|
| 101 |
+
for epoch in range(cfg["epochs"]):
|
| 102 |
model.train() # Activate training the model, enable dropout
|
| 103 |
all_preds, all_labels = [], []
|
| 104 |
|
|
|
|
| 171 |
print(f"Val Accuracy: {val_acc_epoch:.4f}, F1 Score: {val_f1:.4f}")
|
| 172 |
|
| 173 |
# Save model
|
| 174 |
+
model_path = os.path.join(base_dir, f"medi_llm_model_{mode}.pth")
|
| 175 |
torch.save(model.state_dict(), model_path) # Saves the model weights only not total architecture to reuse later
|
| 176 |
|
| 177 |
# Plot accuracy
|
| 178 |
+
plot_path = os.path.join(base_dir, "assets", f"model_training_curve_{mode}.png")
|
| 179 |
plt.plot(train_acc, label="Train Acc")
|
| 180 |
plt.plot(val_acc, label="Val Acc")
|
| 181 |
plt.legend()
|