Preetham22 commited on
Commit
c8b8d35
·
1 Parent(s): b6f867a

Made changes to optimize hyperparams for ablation study

Browse files
Files changed (2) hide show
  1. experiments/train_optuna.py +107 -80
  2. src/train.py +56 -37
experiments/train_optuna.py CHANGED
@@ -1,26 +1,29 @@
1
  import os
2
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
4
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
  # Automatically add Project root to python import path
6
  base_dir = os.path.dirname(os.path.dirname(__file__))
7
  if base_dir not in sys.path:
8
  sys.path.append(base_dir)
9
 
10
- import torch
11
- import optuna
12
- from torch.utils.data import DataLoader, Subset
13
- from torch.nn import CrossEntropyLoss
14
- from torch.optim import Adam
15
- from sklearn.model_selection import StratifiedShuffleSplit
16
- from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
17
- from tqdm import tqdm
18
- import matplotlib.pyplot as plt
19
- import seaborn as sns
20
- import wandb
21
- import json
22
- import yaml
23
- import argparse
24
 
25
  from src.triage_dataset import TriageDataset
26
  from src.multimodal_model import MediLLMModel
@@ -36,12 +39,13 @@ def stratified_split(dataset, val_ratio=0.2, seed=42, label_column="triage_level
36
 
37
  def objective(trial):
38
  wandb.init(
39
- project="mediLLM-v2",
40
- name=f"trial-{trial.number}-v4-{wandb.util.generate_id()}",
41
  group="SoftLabelTrials",
42
  config={
43
  "dataset_version": "softlabels",
44
- "dataset_size": 900
 
45
  }
46
  )
47
 
@@ -51,7 +55,7 @@ def objective(trial):
51
  hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256, 512])
52
  batch_size = trial.suggest_categorical("bs", [4, 8, 16])
53
 
54
- model = MediLLMModel(dropout=dropout, hidden_dim=hidden_dim).to(device)
55
  wandb.watch(model)
56
 
57
  dataset = TriageDataset(os.path.join(base_dir, "data", "emr_records.csv"))
@@ -65,132 +69,155 @@ def objective(trial):
65
 
66
  for epoch in range(2):
67
  model.train()
68
- loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/2", leave=False)
69
  for batch in loop:
70
- input_ids = batch["input_ids"].to(device)
71
- attention_mask = batch["attention_mask"].to(device)
72
- images = batch["image"].to(device)
73
  labels = batch["label"].to(device)
74
 
 
 
 
 
 
 
 
75
  optimizer.zero_grad()
76
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
77
  loss = criterion(outputs, labels)
78
  loss.backward()
79
  optimizer.step()
80
-
81
  loop.set_postfix(loss=loss.item())
82
 
83
  # Validation
84
  model.eval()
85
  all_preds, all_labels = [], []
86
  with torch.no_grad():
87
- for batch in tqdm(val_loader, desc="Validating", leave=False):
88
- input_ids = batch["input_ids"].to(device)
89
- attention_mask = batch["attention_mask"].to(device)
90
- images = batch["image"].to(device)
91
  labels = batch["label"].to(device)
92
 
 
 
 
 
 
 
 
93
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
94
  preds = torch.argmax(outputs, dim=1).cpu().numpy()
95
  all_preds.extend(preds)
96
  all_labels.extend(labels.cpu().numpy())
97
 
98
  f1 = f1_score(all_labels, all_preds, average="weighted")
99
- print(f"\n[Trial {trial.number}] Classification Report:")
100
- print(classification_report(all_labels, all_preds, target_names=["low", "medium", "high"]))
101
 
 
 
 
 
 
 
 
 
 
 
 
102
  cm = confusion_matrix(all_labels, all_preds)
103
  plt.figure(figsize=(6, 5))
104
  sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
105
  xticklabels=["low", "medium", "high"],
106
  yticklabels=["low", "medium", "high"])
107
- plt.title(f"Confusion Matrix - Trial {trial.number}")
108
  plt.xlabel("Predicted")
109
  plt.ylabel("True")
110
- wandb.log({f"confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
111
  plt.close()
112
 
113
- # Log to W&B and Optuna
114
- wandb.log({
115
- "f1_score": f1,
116
- "accuracy": accuracy_score(all_labels, all_preds),
117
- "lr": lr,
118
- "dropout": dropout,
119
- "hidden_dim": hidden_dim,
120
- "batch_size": batch_size
121
- })
122
  return f1
123
 
124
  def get_args():
125
  parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
126
  parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
 
127
  return parser.parse_args()
128
 
129
  if __name__=="__main__":
130
  args = get_args()
 
131
 
132
  study = optuna.create_study(
133
- study_name="mediLLM_v2",
134
  direction="maximize"
135
  )
136
- with tqdm(total=args.n_trials, desc="Optuna Trials") as pbar:
137
  def wrapped_objective(trial):
138
  try:
139
- result = objective(trial)
140
- return result
141
  finally:
142
  wandb.finish()
143
  pbar.update(1)
144
 
145
  study.optimize(wrapped_objective, n_trials=args.n_trials)
146
 
147
- print("Best F1 score achieved:", study.best_value)
148
- print("Best hyperparameters:", study.best_params)
 
 
 
 
149
 
150
- # Save as JSON
151
- assets_dir = os.path.join(base_dir, "assets")
 
 
 
 
 
152
 
153
- # Make sure assets directory exists in the root
154
- os.makedirs(assets_dir, exist_ok=True)
 
 
155
 
156
- # Save the best hyperparameters
157
- with open(os.path.join(assets_dir, "best_hyperparams.json"), "w") as f:
158
- json.dump(study.best_params, f, indent=4)
 
 
 
 
 
 
 
159
 
160
  # Export to config.yaml
161
- config_dir = os.path.join(base_dir, "config")
162
- config_path = os.path.join(config_dir, "config.yaml")
163
 
164
  # Make sure config directory exists in the root
165
- os.makedirs(config_dir, exist_ok=True)
166
-
167
- # If the config file doesn't exist, create a default one
168
- if not os.path.exists(config_path):
169
- with open(config_path, "w") as f:
170
- f.write(
171
- "model:\n"
172
- " dropout: 0.3\n"
173
- " hidden_dim: 256\n\n"
174
- "train:\n"
175
- " lr: 2e-5\n"
176
- " batch_size: 8\n"
177
- " epochs: 5\n\n"
178
- "wandb:\n"
179
- " project: medi-llm-final\n"
180
- )
181
 
182
- # Export to config.yaml
183
- with open(config_path, "r") as f:
184
- cfg = yaml.safe_load(f)
 
185
 
186
- cfg["model"]["dropout"] = float(study.best_params["dropout"])
187
- cfg["model"]["hidden_dim"] = int(study.best_params["hidden_dim"])
188
- cfg["train"]["lr"] = float(study.best_params["lr"])
189
- cfg["train"]["batch_size"] = int(study.best_params["bs"])
 
 
 
190
 
191
- # Save updated config
192
  with open(config_path, "w") as f:
193
- yaml.dump(cfg, f, default_flow_style=False)
 
 
194
 
195
 
196
 
 
1
  import os
2
  import sys
3
+ import torch
4
+ import optuna
5
+ import yaml
6
+ import json
7
+ import wandb
8
+ import argparse
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+
12
+ from tqdm import tqdm
13
+ from torch.utils.data import DataLoader, Subset
14
+ from torch.nn import CrossEntropyLoss
15
+ from torch.optim import Adam
16
+ from sklearn.model_selection import StratifiedShuffleSplit
17
+ from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
18
 
19
+
20
+ # Setup base path
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
  # Automatically add Project root to python import path
23
  base_dir = os.path.dirname(os.path.dirname(__file__))
24
  if base_dir not in sys.path:
25
  sys.path.append(base_dir)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  from src.triage_dataset import TriageDataset
29
  from src.multimodal_model import MediLLMModel
 
39
 
40
  def objective(trial):
41
  wandb.init(
42
+ project=f"mediLLM-tune-{mode}",
43
+ name=f"{mode}-trial-{trial.number}-v5-{wandb.util.generate_id()}",
44
  group="SoftLabelTrials",
45
  config={
46
  "dataset_version": "softlabels",
47
+ "dataset_size": 900,
48
+ "mode": mode
49
  }
50
  )
51
 
 
55
  hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256, 512])
56
  batch_size = trial.suggest_categorical("bs", [4, 8, 16])
57
 
58
+ model = MediLLMModel(dropout=dropout, hidden_dim=hidden_dim, mode=mode).to(device)
59
  wandb.watch(model)
60
 
61
  dataset = TriageDataset(os.path.join(base_dir, "data", "emr_records.csv"))
 
69
 
70
  for epoch in range(2):
71
  model.train()
72
+ loop = tqdm(train_loader, desc=f"[{mode}] Epoch {epoch+1}/2", leave=False)
73
  for batch in loop:
74
+ input_ids = batch.get("input_ids", None)
75
+ attention_mask = batch.get("attention_mask", None)
76
+ images = batch.get("image", None)
77
  labels = batch["label"].to(device)
78
 
79
+ if input_ids is not None:
80
+ input_ids = input_ids.to(device)
81
+ if attention_mask is not None:
82
+ attention_mask = attention_mask.to(device)
83
+ if images is not None:
84
+ images = images.to(device)
85
+
86
  optimizer.zero_grad()
87
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
88
  loss = criterion(outputs, labels)
89
  loss.backward()
90
  optimizer.step()
 
91
  loop.set_postfix(loss=loss.item())
92
 
93
  # Validation
94
  model.eval()
95
  all_preds, all_labels = [], []
96
  with torch.no_grad():
97
+ for batch in tqdm(val_loader, desc=f"[{mode}] Validating", leave=False):
98
+ input_ids = batch.get("input_ids", None)
99
+ attention_mask = batch.get("attention_mask", None)
100
+ images = batch.get("image", None)
101
  labels = batch["label"].to(device)
102
 
103
+ if input_ids is not None:
104
+ input_ids = input_ids.to(device)
105
+ if attention_mask is not None:
106
+ attention_mask = attention_mask.to(device)
107
+ if images is not None:
108
+ images = images.to(device)
109
+
110
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
111
  preds = torch.argmax(outputs, dim=1).cpu().numpy()
112
  all_preds.extend(preds)
113
  all_labels.extend(labels.cpu().numpy())
114
 
115
  f1 = f1_score(all_labels, all_preds, average="weighted")
116
+ acc = accuracy_score(all_labels, all_preds)
 
117
 
118
+ # Log to W&B and Optuna
119
+ wandb.log({
120
+ "val_f1_score": f1,
121
+ "val_accuracy": acc,
122
+ "lr": lr,
123
+ "dropout": dropout,
124
+ "hidden_dim": hidden_dim,
125
+ "batch_size": batch_size
126
+ })
127
+
128
+ # Confusion Matrix
129
  cm = confusion_matrix(all_labels, all_preds)
130
  plt.figure(figsize=(6, 5))
131
  sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
132
  xticklabels=["low", "medium", "high"],
133
  yticklabels=["low", "medium", "high"])
134
+ plt.title(f"Confusion Matrix - {mode} Trial {trial.number}")
135
  plt.xlabel("Predicted")
136
  plt.ylabel("True")
137
+ wandb.log({f"{mode}_confusion_matrix/trial_{trial.number}": wandb.Image(plt)})
138
  plt.close()
139
 
 
 
 
 
 
 
 
 
 
140
  return f1
141
 
142
  def get_args():
143
  parser = argparse.ArgumentParser(description="Run Optuna hyperparameter search")
144
  parser.add_argument("--n_trials", type=int, default=10, help="Number of Optuna trials to run")
145
+ parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], required=True, help="Input mode")
146
  return parser.parse_args()
147
 
148
  if __name__=="__main__":
149
  args = get_args()
150
+ mode = args.mode
151
 
152
  study = optuna.create_study(
153
+ study_name=f"mediLLM_{mode}_optuna",
154
  direction="maximize"
155
  )
156
+ with tqdm(total=args.n_trials, desc=f"Optuna Trials [{mode}]") as pbar:
157
  def wrapped_objective(trial):
158
  try:
159
+ return objective(trial, mode)
 
160
  finally:
161
  wandb.finish()
162
  pbar.update(1)
163
 
164
  study.optimize(wrapped_objective, n_trials=args.n_trials)
165
 
166
+ print(f"Best F1 score for {mode}: {study.best_value}")
167
+ print(f"Best hyperparameters: {study.best_params}")
168
+
169
+ # Save best hyperparameters to JSON per mode
170
+ json_path = os.path.join(base_dir, "assets", "best_hyperparams.json")
171
+ os.makedirs(os.path.dirname(json_path), exist_ok=True)
172
 
173
+ best_params_entry = {
174
+ "lr": float(study.best_params["lr"]),
175
+ "dropout": float(study.best_params["dropout"]),
176
+ "hidden_dim": int(study.best_params["hidden_dim"]),
177
+ "batch_size": int(study.best_params["bs"]),
178
+ "epochs": 5
179
+ }
180
 
181
+ # Load existing or start new
182
+ if os.path.exists(json_path):
183
+ with open(json_path, "r") as f:
184
+ best_params_all = json.load(f)
185
 
186
+ else:
187
+ best_params_all = {}
188
+
189
+ best_params_all[mode] = best_params_entry
190
+
191
+ # Write back
192
+ with open(json_path, "w") as f:
193
+ json.dump(best_params_all, f, indent=4)
194
+
195
+ print(f"✅ Saved best hyperparameters for [{mode}] to best_hyperparams.json")
196
 
197
  # Export to config.yaml
198
+ config_path = os.path.join(base_dir, "config", "config.yaml")
 
199
 
200
  # Make sure config directory exists in the root
201
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ config = {}
204
+ if os.path.exists(config_path):
205
+ with open(config_path, "r") as f:
206
+ config = yaml.safe_load(f) or {}
207
 
208
+ config[mode] = {
209
+ "lr": float(study.best_params["lr"]),
210
+ "dropout": float(study.best_params["dropout"]),
211
+ "hidden_dim": int(study.best_params["hidden_dim"]),
212
+ "batch_size": int(study.best_params["bs"]),
213
+ "epochs": 5
214
+ }
215
 
216
+ # Export to config.yaml
217
  with open(config_path, "w") as f:
218
+ yaml.dump(config, f, sort_keys=False)
219
+
220
+ print(f"✅ Best hyperparameters for [{mode}] saved in config.yaml")
221
 
222
 
223
 
src/train.py CHANGED
@@ -1,50 +1,69 @@
1
  import torch # PyTorch core utility for model training
2
  import os
3
  import sys
 
 
 
 
 
 
 
 
 
 
 
4
 
 
5
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
  # Automatically add Project root to python import path
7
  base_dir = os.path.dirname(os.path.dirname(__file__))
8
  if base_dir not in sys.path:
9
  sys.path.append(base_dir)
10
 
11
- import argparse
12
- import yaml
13
- from torch.utils.data import DataLoader, Subset # Dataloader to batch and feed data to model, random split to split dataset into train and validation sets
14
- from torch.nn import CrossEntropyLoss # PyTorch core utility for model training
15
- from torch.optim import Adam # PyTorch core utility for model training, Adam is the Optimizer a gradient descent model
16
- from sklearn.metrics import accuracy_score, f1_score # Evaluation metrics
17
- from sklearn.model_selection import StratifiedShuffleSplit
18
- from tqdm import tqdm # loading bar for loops
19
- import matplotlib.pyplot as plt # for plotting
20
  from src.triage_dataset import TriageDataset # Dataset Class
21
  from src.multimodal_model import MediLLMModel # Mutlimodal Model
22
 
23
- def load_config():
24
- config_dir = os.path.join(base_dir, "config")
25
- config_path = os.path.join(config_dir, "config.yaml")
26
 
27
- # Make sure config directory exists in the root
28
- os.makedirs(config_dir, exist_ok=True)
29
-
30
- # If the config file doesn't exist, create a default one
31
  if not os.path.exists(config_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  with open(config_path, "w") as f:
33
- f.write(
34
- "model:\n"
35
- " dropout: 0.3\n"
36
- " hidden_dim: 256\n\n"
37
- "train:\n"
38
- " lr: 2e-5\n"
39
- " batch_size: 8\n"
40
- " epochs: 5\n\n"
41
- "wandb:\n"
42
- " project: medi-llm-final\n"
43
- )
44
 
45
  # otherwise export to yaml
46
  with open(config_path, "r") as f:
47
- return yaml.safe_load(f)
 
 
 
 
 
48
 
49
  def stratified_split(dataset, val_ratio=0.2, seed=42):
50
  labels = [dataset.df.iloc[i]["triage_level"] for i in range(len(dataset))]
@@ -53,33 +72,33 @@ def stratified_split(dataset, val_ratio=0.2, seed=42):
53
  return Subset(dataset, tran_idx), Subset(dataset, val_idx)
54
 
55
  def train_model(mode="multimodal"): # Function to instantiate model and data, train, validate, plot results and save the model
56
- config = load_config()
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available or else use CPU
58
 
59
- dataset_dir = os.path.join(base_dir, "data", "emr_records_softlabels.csv")
60
  dataset = TriageDataset(
61
  csv_file=dataset_dir,
62
  mode=mode
63
  )
64
 
65
  model = MediLLMModel(
66
- dropout=config["model"]["dropout"],
67
- hidden_dim=config["model"]["hidden_dim"],
68
  mode = mode
69
  ).to(device) # moves the model to selected device
70
 
71
  train_set, val_set = stratified_split(dataset)
72
- batch_size = config["train"]["batch_size"]
73
 
74
  train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) # Create data in batches to the model
75
  val_loader = DataLoader(val_set, batch_size=batch_size)
76
 
77
  criterion = CrossEntropyLoss() # Calculate difference between model prediction and true labels
78
- optimizer = Adam(model.parameters(), lr=config["train"]["lr"]) # Adaptive learning rate optimizer for fast-converging
79
 
80
  train_acc, val_acc = [], [] # Lists to store accuracy per epoch for plotting
81
 
82
- for epoch in range(config["train"]["epochs"]):
83
  model.train() # Activate training the model, enable dropout
84
  all_preds, all_labels = [], []
85
 
@@ -152,11 +171,11 @@ def train_model(mode="multimodal"): # Function to instantiate model and data, tr
152
  print(f"Val Accuracy: {val_acc_epoch:.4f}, F1 Score: {val_f1:.4f}")
153
 
154
  # Save model
155
- model_path = os.path.join(base_dir, f"medi_llm_model_softlabels{mode}.pth")
156
  torch.save(model.state_dict(), model_path) # Saves the model weights only not total architecture to reuse later
157
 
158
  # Plot accuracy
159
- plot_path = os.path.join(base_dir, "assets", f"model_training_curve_softlabels{mode}.png")
160
  plt.plot(train_acc, label="Train Acc")
161
  plt.plot(val_acc, label="Val Acc")
162
  plt.legend()
 
1
  import torch # PyTorch core utility for model training
2
  import os
3
  import sys
4
+ import yaml
5
+ import argparse
6
+ import matplotlib.pyplot as plt # for plotting
7
+
8
+ from tqdm import tqdm # loading bar for loops
9
+ from torch.utils.data import DataLoader, Subset # Dataloader to batch and feed data to model, random split to split dataset into train and validation sets
10
+ from torch.nn import CrossEntropyLoss # PyTorch core utility for model training
11
+ from torch.optim import Adam # PyTorch core utility for model training, Adam is the Optimizer a gradient descent model
12
+ from sklearn.metrics import accuracy_score, f1_score # Evaluation metrics
13
+ from sklearn.model_selection import StratifiedShuffleSplit
14
+
15
 
16
+ # Setup base path
17
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
  # Automatically add Project root to python import path
19
  base_dir = os.path.dirname(os.path.dirname(__file__))
20
  if base_dir not in sys.path:
21
  sys.path.append(base_dir)
22
 
23
+
 
 
 
 
 
 
 
 
24
  from src.triage_dataset import TriageDataset # Dataset Class
25
  from src.multimodal_model import MediLLMModel # Mutlimodal Model
26
 
27
+ def load_config(mode):
28
+ config_path = os.path.join(base_dir, "config", "config.yaml")
29
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
30
 
31
+ # If the config file doesn't exist, create it defaults for all modes
 
 
 
32
  if not os.path.exists(config_path):
33
+ default_config = {
34
+ "text": {
35
+ "lr": 2e-5,
36
+ "dropout": 0.3,
37
+ "hidden_dim": 256,
38
+ "batch_size": 8,
39
+ "epochs": 5
40
+ },
41
+ "image": {
42
+ "lr": 2e-5,
43
+ "dropout": 0.3,
44
+ "hidden_dim": 256,
45
+ "batch_size": 8,
46
+ "epochs": 5
47
+ },
48
+ "multimodal": {
49
+ "lr": 2e-5,
50
+ "dropout": 0.3,
51
+ "hidden_dim": 256,
52
+ "batch_size": 8,
53
+ "epochs": 5
54
+ }
55
+ }
56
  with open(config_path, "w") as f:
57
+ yaml.dump(default_config, f)
 
 
 
 
 
 
 
 
 
 
58
 
59
  # otherwise export to yaml
60
  with open(config_path, "r") as f:
61
+ config = yaml.safe_load(f)
62
+
63
+ if mode not in config:
64
+ raise ValueError(f"No config found for mode '{mode}' in config.yaml")
65
+
66
+ return config[mode]
67
 
68
  def stratified_split(dataset, val_ratio=0.2, seed=42):
69
  labels = [dataset.df.iloc[i]["triage_level"] for i in range(len(dataset))]
 
72
  return Subset(dataset, tran_idx), Subset(dataset, val_idx)
73
 
74
  def train_model(mode="multimodal"): # Function to instantiate model and data, train, validate, plot results and save the model
75
+ cfg = load_config(mode)
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available or else use CPU
77
 
78
+ dataset_dir = os.path.join(base_dir, "data", "emr_records.csv")
79
  dataset = TriageDataset(
80
  csv_file=dataset_dir,
81
  mode=mode
82
  )
83
 
84
  model = MediLLMModel(
85
+ dropout=cfg["dropout"],
86
+ hidden_dim=cfg["hidden_dim"],
87
  mode = mode
88
  ).to(device) # moves the model to selected device
89
 
90
  train_set, val_set = stratified_split(dataset)
91
+ batch_size = cfg["batch_size"]
92
 
93
  train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) # Create data in batches to the model
94
  val_loader = DataLoader(val_set, batch_size=batch_size)
95
 
96
  criterion = CrossEntropyLoss() # Calculate difference between model prediction and true labels
97
+ optimizer = Adam(model.parameters(), lr=cfg["lr"]) # Adaptive learning rate optimizer for fast-converging
98
 
99
  train_acc, val_acc = [], [] # Lists to store accuracy per epoch for plotting
100
 
101
+ for epoch in range(cfg["epochs"]):
102
  model.train() # Activate training the model, enable dropout
103
  all_preds, all_labels = [], []
104
 
 
171
  print(f"Val Accuracy: {val_acc_epoch:.4f}, F1 Score: {val_f1:.4f}")
172
 
173
  # Save model
174
+ model_path = os.path.join(base_dir, f"medi_llm_model_{mode}.pth")
175
  torch.save(model.state_dict(), model_path) # Saves the model weights only not total architecture to reuse later
176
 
177
  # Plot accuracy
178
+ plot_path = os.path.join(base_dir, "assets", f"model_training_curve_{mode}.png")
179
  plt.plot(train_acc, label="Train Acc")
180
  plt.plot(val_acc, label="Val Acc")
181
  plt.legend()