medi-llm / inference.py
Preetham22's picture
update inference, inference web UI with gradio
43dc105
import argparse
import pandas as pd
import torch
import yaml
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, classification_report
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from torchvision import transforms
from src.triage_dataset import TriageDataset
from src.multimodal_model import MediLLMModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor()
])
def load_config(config_path="config/config.yaml", mode="multimodal"):
with open(config_path) as f:
config = yaml.safe_load(f)
return config[mode]
def predict(model, dataloader, device):
model.eval()
all_preds, all_truths = [], []
all_texts, all_paths, all_ids = [], [], []
inv_map = {0: "low", 1: "medium", 2: "high"}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Running inference"):
input_ids = batch.get("input_ids", None)
attention_mask = batch.get("attention_mask", None)
images = batch.get("image", None)
if input_ids is not None:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
if images is not None:
images = images.to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
preds = torch.argmax(outputs, dim=1).cpu().tolist()
all_preds.extend(preds)
batch_size = len(preds)
# Patient ID (original from dataset)
all_ids.extend(batch["patient_id"])
# EMR Text
all_texts.extend(batch.get("emr_text", [""] * batch_size))
# Image Path
all_paths.extend(batch.get("image_path", [""] * batch_size))
# True Labels
if "label" in batch:
all_truths.extend([inv_map.get(label.item(), "") for label in batch["label"]])
else:
all_truths.extend([""] * batch_size)
return all_preds, all_truths, all_texts, all_paths, all_ids
def inverse_label_map():
return {0: "low", 1: "medium", 2: "high"}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, default="test_samples.csv", help="Path to test file")
parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], default="multimodal", help="mode of data")
parser.add_argument("--model_path", type=str, required=True, help="path to the model")
parser.add_argument("--config", type=str, default="config/config.yaml", help="Path to config file")
parser.add_argument("--image_dir", type=str, default="data/images", help="path to images folder")
parser.add_argument("--output_csv", type=str, help="Optional custome path to output file")
parser.add_argument("--batch_size", type=int, help="Optional override for batch size")
parser.add_argument("--save_misclassified_only", action="store_true", help="Save only misclassified samples")
args = parser.parse_args()
# Checks
if not Path(args.csv_path).exists():
raise FileNotFoundError(f"❌ CSV file not found at {args.csv_path}")
if not Path(args.model_path).exists():
raise FileNotFoundError(f"❌ Model checkpoint not found at: {args.model_path}")
if not Path(args.config).exists():
raise FileNotFoundError(f"❌ Config file not found at: {args.config}")
if args.mode in ["image", "multimodal"] and not Path(args.image_dir).exists():
raise FileNotFoundError(f"❌ Image directory not found at: {args.image_dir}")
# Always generate mode-specific output file if not provided
if not args.output_csv:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.output_csv = f"predictions_{args.mode}_{timestamp}.csv"
config = load_config(config_path=args.config, mode=args.mode)
batch_size = args.batch_size or config["batch_size"]
dataset = TriageDataset(
csv_file=args.csv_path,
mode=args.mode,
image_base_dir=args.image_dir,
transform=inference_transform # avoid random augmentations during inference
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
model = MediLLMModel(
mode=args.mode,
dropout=config["dropout"],
hidden_dim=config["hidden_dim"]
)
model.load_state_dict(torch.load(args.model_path, map_location=DEVICE))
model.to(DEVICE)
preds, truths, texts, paths, ids = predict(model, dataloader, DEVICE)
label_inv_map = inverse_label_map()
pred_labels = [label_inv_map[p] for p in preds]
df = pd.DataFrame({
"patient_id": ids,
"predicted": pred_labels,
"truth_label": truths,
"emr_text": texts,
"image_path": paths,
})
# Filter misclassified rows if needed
if args.save_misclassified_only:
df = df[df["predicted"] != df["truth_label"]]
print(df[["patient_id", "predicted", "truth_label"]])
# Ensure output directory exists
output_path = Path(args.output_csv)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save predictions
df.to_csv(args.output_csv, index=False)
print(f"✅ Saved predictions to {output_path}")
print(f"\n🔎 Processed {len(preds)} samples ({'missclassified only' if args.save_misclassified_only else 'all'}).")
# print classification report + metrics if labels exist
if all(label in ["low", "medium", "high"] for label in truths):
print("\n📊 Classification Report:")
print(classification_report(truths, pred_labels))
acc = accuracy_score(truths, pred_labels)
f1 = f1_score(truths, pred_labels, average="weighted")
print(f"\n🎯 Accuracy: {acc:.4f}")
print(f"\n🎯 Weighted F1 Score: {f1:.4f}")
if __name__ == "__main__":
main()