import torch # PyTorch core utility for model training import os import sys import yaml import json import argparse import matplotlib.pyplot as plt # for plotting import random import numpy as np from tqdm import tqdm # loading bar for loops from torch.utils.data import ( DataLoader, Subset, ) # Dataloader to batch and feed data to model, # random split to split dataset into train and validation sets from torch.nn import ( CrossEntropyLoss, ) # PyTorch core utility for model training from torch.optim import Adam # PyTorch core utility for model training, # Adam is the Optimizer, a gradient descent model from sklearn.metrics import accuracy_score, f1_score, classification_report # Evaluation metrics from sklearn.model_selection import StratifiedShuffleSplit # Setup base path os.environ["TOKENIZERS_PARALLELISM"] = "false" # Automatically add Project root to python import path base_dir = os.path.dirname(os.path.dirname(__file__)) if base_dir not in sys.path: sys.path.append(base_dir) from src.triage_dataset import TriageDataset # Dataset Class from src.multimodal_model import MediLLMModel # Mutlimodal Model def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_config(mode): config_path = os.path.join(base_dir, "config", "config.yaml") os.makedirs(os.path.dirname(config_path), exist_ok=True) # If the config file doesn't exist, create it defaults for all modes if not os.path.exists(config_path): default_config = { "text": { "lr": 2e-5, "dropout": 0.3, "hidden_dim": 256, "batch_size": 8, "epochs": 5, }, "image": { "lr": 2e-5, "dropout": 0.3, "hidden_dim": 256, "batch_size": 8, "epochs": 5, }, "multimodal": { "lr": 2e-5, "dropout": 0.3, "hidden_dim": 256, "batch_size": 8, "epochs": 5, }, } with open(config_path, "w") as f: yaml.dump(default_config, f) # otherwise export to yaml with open(config_path, "r") as f: config = yaml.safe_load(f) if mode not in config: raise ValueError(f"No config found for mode '{mode}' in config.yaml") return config[mode] def stratified_split(dataset, val_ratio=0.2, seed=42): labels = [dataset.df.iloc[i]["triage_level"] for i in range(len(dataset))] sss = StratifiedShuffleSplit( n_splits=1, test_size=val_ratio, random_state=seed, ) tran_idx, val_idx = next(sss.split(range(len(dataset)), labels)) return Subset(dataset, tran_idx), Subset(dataset, val_idx) # Function to instantiate model and data, train, validate, plot results # and save the model def train_model(mode="multimodal", use_wandb=False): set_seed(42) if use_wandb: import wandb cfg = load_config(mode) device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Use GPU if available or else use CPU dataset_dir = os.path.join(base_dir, "data", "emr_records.csv") dataset_kwargs = {"csv_file": dataset_dir, "mode": mode} if mode in ["image", "multimodal"]: image_dir = os.path.join(base_dir, "data", "images") dataset_kwargs["image_base_dir"] = image_dir dataset = TriageDataset(**dataset_kwargs) model = MediLLMModel( dropout=cfg["dropout"], hidden_dim=cfg["hidden_dim"], mode=mode ).to( device ) # moves the model to selected device if use_wandb: # Initialize Weights & Biases wandb.init( project="MediLLM_Final_v2", name=f"train_{mode}", config=cfg ) wandb.config.update({"mode": mode}) train_set, val_set = stratified_split(dataset) batch_size = cfg["batch_size"] train_loader = DataLoader( train_set, batch_size=batch_size, shuffle=True ) # Create data in batches to the model val_loader = DataLoader(val_set, batch_size=batch_size) # Calculate difference between model prediction and true labels criterion = CrossEntropyLoss() optimizer = Adam( model.parameters(), lr=cfg["lr"] ) # Adaptive learning rate optimizer for fast-converging # Lists to store accuracy per epoch for plotting train_acc, val_acc = [], [] train_f1s, val_f1s = [], [] for epoch in range(cfg["epochs"]): model.train() # Activate training the model, enable dropout all_preds, all_labels = [], [] for batch in tqdm( train_loader, desc=f"[{mode}] Epoch {epoch + 1}" ): # Load a batch of text, images, and labels to GPU or CPU input_ids = batch.get("input_ids", None) attention_mask = batch.get("attention_mask", None) images = batch.get("image", None) labels = batch["label"].to(device) if input_ids is not None: input_ids = input_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if images is not None: images = images.to(device) """ Each batch looks like this { "input_ids": torch.Size([8, 128]), "attention_mask": torch.Size([8, 128]), "image": torch.Size([8, 3, 224, 224]), "label": torch.Size([8]) } """ optimizer.zero_grad() # Zero out gradients from previous batch outputs = model( input_ids=input_ids, attention_mask=attention_mask, image=images, ) # Forward pass through the model loss = criterion(outputs, labels) # Compute loss value loss.backward() # Back propagation to compute gradients optimizer.step() # Adjust the weights using gradients preds = torch.argmax(outputs, dim=1).cpu().numpy() # Get the predicted class per sample and convert to CPU & Numpy for # easier comparison all_preds.extend(preds) # Save predictions for metric computation. # extend() appends each element of preds to the list all_labels.extend( labels.cpu().numpy() ) # Save labels for metric computation # Calculating classification metrics (Accuracy and F1) acc = accuracy_score(all_labels, all_preds) # Evaluate full-epoch performance f1 = f1_score(all_labels, all_preds, average="weighted") # 1) binary: Binary Classification(F1 score of +ve class only) # 2) macro: Computes F1 for each class independently, then averages, # treats all classes equally # 3) micro: Flattens all true and predicted labels and then computes # global TP, FP, FN and gets F1 from that, works well with # imbalanced data, equal to accuracy in binary classification and # different in multi-class/multi-label # 4) weighted: calculates F1 for each class, then averages them using # number of samples, avoids bias, real-world and imbalanced classes, # per-class performance # 5) samples: used for multi-label classification, computes F1 for each # instance, then averages across all samples, row-wise, # not class-wise train_acc.append(acc) # Append to a list for plotting train_f1s.append(f1) print(f"Train Accuracy: {acc:.4f}, F1 Score: {f1:.4f}") # Validation loop model.eval() # Deactivates dropnot and batchnorm for inference val_preds, val_labels = [], [] with torch.no_grad(): # Disables autograd to save memory for batch in val_loader: # Load batch of validation data text, images, labels # to GPU or CPU input_ids = batch.get("input_ids", None) attention_mask = batch.get("attention_mask", None) images = batch.get("image", None) labels = batch["label"].to(device) if input_ids is not None: input_ids = input_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if images is not None: images = images.to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, image=images, ) preds = torch.argmax(outputs, dim=1).cpu().numpy() val_preds.extend(preds) val_labels.extend(labels.cpu().numpy()) val_acc_epoch = accuracy_score( val_labels, val_preds ) # Validation metrics, accuracy: how many items did your model get # right out of total items val_f1 = f1_score(val_labels, val_preds, average="weighted") # F1 score weights in both precision and recall, uses harmonic mean to # punish imbalance. If one of the two is low, it drags the accuracy # score down. # Precision: How careful is the model when classifying an item # (TP / (TP + FP)). # Recall: How many real items did it actually spot # (TP / (TP + FN)). val_acc.append(val_acc_epoch) val_f1s.append(val_f1) print(f"Val Accuracy: {val_acc_epoch:.4f}, F1 Score: {val_f1:.4f}") # Log to Weights & Biases if use_wandb: wandb.log({ "epoch": epoch + 1, "train/accuracy": acc, "train/f1": f1, "val/accuracy": val_acc_epoch, "val/f1": val_f1 }) # Save model model_path = os.path.join(base_dir, f"medi_llm_state_dict_{mode}.pth") torch.save( model.state_dict(), model_path ) # Saves only model weights and biases print(f"šŸ’¾ Saved model weights and biases to {model_path}") # Save to Weights & Biases if use_wandb: wandb.save(model_path) # Plot accuracy plot_path = os.path.join(base_dir, "assets", f"model_training_curve_{mode}.png") plt.plot(train_acc, label="Train Acc") plt.plot(val_acc, label="Val Acc") plt.legend() plt.title(f"Accuracy: Train vs Validation ({mode})") plt.savefig(plot_path) print(f"āœ… Saved training curve to {plot_path}") if use_wandb: wandb.log({"training_curve": wandb.Image(plot_path)}) # Save training metrics to JSON results = { "train_acc": train_acc, "val_acc": val_acc, "train_f1": train_f1s, "val_f1": val_f1s, "final_train_acc": train_acc[-1], "final_val_acc": val_acc[-1], "final_train_f1": train_f1s[-1], "final_val_f1": val_f1s[-1] } results_dir = os.path.join(base_dir, "results") os.makedirs(results_dir, exist_ok=True) results_path = os.path.join(results_dir, f"metrics_{mode}.json") with open(results_path, "w") as f: json.dump(results, f, indent=2) print(f"šŸ“Š Saved training metrics to {results_path}") # Classification Report class_report = classification_report(val_labels, val_preds, output_dict=True, zero_division=0, target_names=["low", "medium", "high"]) print("\nšŸ—“ļø Classification Report (Per Class on Validation Set):") for cls, metrics in class_report.items(): if cls in ["low", "medium", "high"]: print(f"{cls:>9} -> Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {metrics['f1-score']:.3f}") class_report_path = os.path.join(results_dir, f"classification_report_{mode}.json") with open(class_report_path, "w") as f: json.dump(class_report, f, indent=2) print(f"šŸ“Š Saved per-class metrics to {class_report_path}") if use_wandb: for cls in ["low", "medium", "high"]: wandb.log({ f"classwise/{cls}_precision": class_report[cls]["precision"], f"classwise/{cls}_recall": class_report[cls]["recall"], f"classwise/{cls}_f1": class_report[cls]["f1-score"], }) wandb.finish() def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--mode", choices=["text", "image", "multimodal"], default="multimodal" ) parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging") return parser.parse_args() if __name__ == "__main__": args = parse_args() train_model( mode=args.mode, use_wandb=args.wandb ) # Only runs if file is run directly not when it is imported