medi-llm / src /train.py
Preetham22's picture
fix model paths, retrain, perform inference
4e592d4
import torch # PyTorch core utility for model training
import os
import sys
import yaml
import json
import argparse
import matplotlib.pyplot as plt # for plotting
import random
import numpy as np
from tqdm import tqdm # loading bar for loops
from torch.utils.data import (
DataLoader,
Subset,
) # Dataloader to batch and feed data to model,
# random split to split dataset into train and validation sets
from torch.nn import (
CrossEntropyLoss,
) # PyTorch core utility for model training
from torch.optim import Adam # PyTorch core utility for model training,
# Adam is the Optimizer, a gradient descent model
from sklearn.metrics import accuracy_score, f1_score, classification_report # Evaluation metrics
from sklearn.model_selection import StratifiedShuffleSplit
# Setup base path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Automatically add Project root to python import path
base_dir = os.path.dirname(os.path.dirname(__file__))
if base_dir not in sys.path:
sys.path.append(base_dir)
from src.triage_dataset import TriageDataset # Dataset Class
from src.multimodal_model import MediLLMModel # Mutlimodal Model
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_config(mode):
config_path = os.path.join(base_dir, "config", "config.yaml")
os.makedirs(os.path.dirname(config_path), exist_ok=True)
# If the config file doesn't exist, create it defaults for all modes
if not os.path.exists(config_path):
default_config = {
"text": {
"lr": 2e-5,
"dropout": 0.3,
"hidden_dim": 256,
"batch_size": 8,
"epochs": 5,
},
"image": {
"lr": 2e-5,
"dropout": 0.3,
"hidden_dim": 256,
"batch_size": 8,
"epochs": 5,
},
"multimodal": {
"lr": 2e-5,
"dropout": 0.3,
"hidden_dim": 256,
"batch_size": 8,
"epochs": 5,
},
}
with open(config_path, "w") as f:
yaml.dump(default_config, f)
# otherwise export to yaml
with open(config_path, "r") as f:
config = yaml.safe_load(f)
if mode not in config:
raise ValueError(f"No config found for mode '{mode}' in config.yaml")
return config[mode]
def stratified_split(dataset, val_ratio=0.2, seed=42):
labels = [dataset.df.iloc[i]["triage_level"] for i in range(len(dataset))]
sss = StratifiedShuffleSplit(
n_splits=1,
test_size=val_ratio,
random_state=seed,
)
tran_idx, val_idx = next(sss.split(range(len(dataset)), labels))
return Subset(dataset, tran_idx), Subset(dataset, val_idx)
# Function to instantiate model and data, train, validate, plot results
# and save the model
def train_model(mode="multimodal", use_wandb=False):
set_seed(42)
if use_wandb:
import wandb
cfg = load_config(mode)
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
) # Use GPU if available or else use CPU
dataset_dir = os.path.join(base_dir, "data", "emr_records.csv")
dataset_kwargs = {"csv_file": dataset_dir, "mode": mode}
if mode in ["image", "multimodal"]:
image_dir = os.path.join(base_dir, "data", "images")
dataset_kwargs["image_base_dir"] = image_dir
dataset = TriageDataset(**dataset_kwargs)
model = MediLLMModel(
dropout=cfg["dropout"], hidden_dim=cfg["hidden_dim"], mode=mode
).to(
device
) # moves the model to selected device
if use_wandb:
# Initialize Weights & Biases
wandb.init(
project="MediLLM_Final_v2",
name=f"train_{mode}",
config=cfg
)
wandb.config.update({"mode": mode})
train_set, val_set = stratified_split(dataset)
batch_size = cfg["batch_size"]
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True
) # Create data in batches to the model
val_loader = DataLoader(val_set, batch_size=batch_size)
# Calculate difference between model prediction and true labels
criterion = CrossEntropyLoss()
optimizer = Adam(
model.parameters(), lr=cfg["lr"]
) # Adaptive learning rate optimizer for fast-converging
# Lists to store accuracy per epoch for plotting
train_acc, val_acc = [], []
train_f1s, val_f1s = [], []
for epoch in range(cfg["epochs"]):
model.train() # Activate training the model, enable dropout
all_preds, all_labels = [], []
for batch in tqdm(
train_loader, desc=f"[{mode}] Epoch {epoch + 1}"
): # Load a batch of text, images, and labels to GPU or CPU
input_ids = batch.get("input_ids", None)
attention_mask = batch.get("attention_mask", None)
images = batch.get("image", None)
labels = batch["label"].to(device)
if input_ids is not None:
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if images is not None:
images = images.to(device)
"""
Each batch looks like this
{
"input_ids": torch.Size([8, 128]),
"attention_mask": torch.Size([8, 128]),
"image": torch.Size([8, 3, 224, 224]),
"label": torch.Size([8])
}
"""
optimizer.zero_grad() # Zero out gradients from previous batch
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
image=images,
) # Forward pass through the model
loss = criterion(outputs, labels) # Compute loss value
loss.backward() # Back propagation to compute gradients
optimizer.step() # Adjust the weights using gradients
preds = torch.argmax(outputs, dim=1).cpu().numpy()
# Get the predicted class per sample and convert to CPU & Numpy for
# easier comparison
all_preds.extend(preds)
# Save predictions for metric computation.
# extend() appends each element of preds to the list
all_labels.extend(
labels.cpu().numpy()
) # Save labels for metric computation
# Calculating classification metrics (Accuracy and F1)
acc = accuracy_score(all_labels, all_preds) # Evaluate full-epoch performance
f1 = f1_score(all_labels, all_preds, average="weighted")
# 1) binary: Binary Classification(F1 score of +ve class only)
# 2) macro: Computes F1 for each class independently, then averages,
# treats all classes equally
# 3) micro: Flattens all true and predicted labels and then computes
# global TP, FP, FN and gets F1 from that, works well with
# imbalanced data, equal to accuracy in binary classification and
# different in multi-class/multi-label
# 4) weighted: calculates F1 for each class, then averages them using
# number of samples, avoids bias, real-world and imbalanced classes,
# per-class performance
# 5) samples: used for multi-label classification, computes F1 for each
# instance, then averages across all samples, row-wise,
# not class-wise
train_acc.append(acc) # Append to a list for plotting
train_f1s.append(f1)
print(f"Train Accuracy: {acc:.4f}, F1 Score: {f1:.4f}")
# Validation loop
model.eval() # Deactivates dropnot and batchnorm for inference
val_preds, val_labels = [], []
with torch.no_grad(): # Disables autograd to save memory
for batch in val_loader:
# Load batch of validation data text, images, labels
# to GPU or CPU
input_ids = batch.get("input_ids", None)
attention_mask = batch.get("attention_mask", None)
images = batch.get("image", None)
labels = batch["label"].to(device)
if input_ids is not None:
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if images is not None:
images = images.to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
image=images,
)
preds = torch.argmax(outputs, dim=1).cpu().numpy()
val_preds.extend(preds)
val_labels.extend(labels.cpu().numpy())
val_acc_epoch = accuracy_score(
val_labels, val_preds
) # Validation metrics, accuracy: how many items did your model get
# right out of total items
val_f1 = f1_score(val_labels, val_preds, average="weighted")
# F1 score weights in both precision and recall, uses harmonic mean to
# punish imbalance. If one of the two is low, it drags the accuracy
# score down.
# Precision: How careful is the model when classifying an item
# (TP / (TP + FP)).
# Recall: How many real items did it actually spot
# (TP / (TP + FN)).
val_acc.append(val_acc_epoch)
val_f1s.append(val_f1)
print(f"Val Accuracy: {val_acc_epoch:.4f}, F1 Score: {val_f1:.4f}")
# Log to Weights & Biases
if use_wandb:
wandb.log({
"epoch": epoch + 1,
"train/accuracy": acc,
"train/f1": f1,
"val/accuracy": val_acc_epoch,
"val/f1": val_f1
})
# Save model
model_path = os.path.join(base_dir, f"medi_llm_state_dict_{mode}.pth")
torch.save(
model.state_dict(), model_path
) # Saves only model weights and biases
print(f"💾 Saved model weights and biases to {model_path}")
# Save to Weights & Biases
if use_wandb:
wandb.save(model_path)
# Plot accuracy
plot_path = os.path.join(base_dir, "assets", f"model_training_curve_{mode}.png")
plt.plot(train_acc, label="Train Acc")
plt.plot(val_acc, label="Val Acc")
plt.legend()
plt.title(f"Accuracy: Train vs Validation ({mode})")
plt.savefig(plot_path)
print(f"✅ Saved training curve to {plot_path}")
if use_wandb:
wandb.log({"training_curve": wandb.Image(plot_path)})
# Save training metrics to JSON
results = {
"train_acc": train_acc,
"val_acc": val_acc,
"train_f1": train_f1s,
"val_f1": val_f1s,
"final_train_acc": train_acc[-1],
"final_val_acc": val_acc[-1],
"final_train_f1": train_f1s[-1],
"final_val_f1": val_f1s[-1]
}
results_dir = os.path.join(base_dir, "results")
os.makedirs(results_dir, exist_ok=True)
results_path = os.path.join(results_dir, f"metrics_{mode}.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f"📊 Saved training metrics to {results_path}")
# Classification Report
class_report = classification_report(val_labels, val_preds, output_dict=True, zero_division=0, target_names=["low", "medium", "high"])
print("\n🗓️ Classification Report (Per Class on Validation Set):")
for cls, metrics in class_report.items():
if cls in ["low", "medium", "high"]:
print(f"{cls:>9} -> Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {metrics['f1-score']:.3f}")
class_report_path = os.path.join(results_dir, f"classification_report_{mode}.json")
with open(class_report_path, "w") as f:
json.dump(class_report, f, indent=2)
print(f"📊 Saved per-class metrics to {class_report_path}")
if use_wandb:
for cls in ["low", "medium", "high"]:
wandb.log({
f"classwise/{cls}_precision": class_report[cls]["precision"],
f"classwise/{cls}_recall": class_report[cls]["recall"],
f"classwise/{cls}_f1": class_report[cls]["f1-score"],
})
wandb.finish()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode", choices=["text", "image", "multimodal"], default="multimodal"
)
parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
train_model(
mode=args.mode,
use_wandb=args.wandb
) # Only runs if file is run directly not when it is imported