Spaces:
Sleeping
Sleeping
| import argparse | |
| import pandas as pd | |
| import torch | |
| import yaml | |
| from torch.utils.data import DataLoader | |
| from sklearn.metrics import accuracy_score, f1_score, classification_report | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from datetime import datetime | |
| from torchvision import transforms | |
| from src.triage_dataset import TriageDataset | |
| from src.multimodal_model import MediLLMModel | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor() | |
| ]) | |
| def load_config(config_path="config/config.yaml", mode="multimodal"): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| return config[mode] | |
| def predict(model, dataloader, device): | |
| model.eval() | |
| all_preds, all_truths = [], [] | |
| all_texts, all_paths, all_ids = [], [], [] | |
| inv_map = {0: "low", 1: "medium", 2: "high"} | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc="Running inference"): | |
| input_ids = batch.get("input_ids", None) | |
| attention_mask = batch.get("attention_mask", None) | |
| images = batch.get("image", None) | |
| if input_ids is not None: | |
| input_ids = input_ids.to(device) | |
| attention_mask = attention_mask.to(device) | |
| if images is not None: | |
| images = images.to(device) | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images) | |
| preds = torch.argmax(outputs, dim=1).cpu().tolist() | |
| all_preds.extend(preds) | |
| batch_size = len(preds) | |
| # Patient ID (original from dataset) | |
| all_ids.extend(batch["patient_id"]) | |
| # EMR Text | |
| all_texts.extend(batch.get("emr_text", [""] * batch_size)) | |
| # Image Path | |
| all_paths.extend(batch.get("image_path", [""] * batch_size)) | |
| # True Labels | |
| if "label" in batch: | |
| all_truths.extend([inv_map.get(label.item(), "") for label in batch["label"]]) | |
| else: | |
| all_truths.extend([""] * batch_size) | |
| return all_preds, all_truths, all_texts, all_paths, all_ids | |
| def inverse_label_map(): | |
| return {0: "low", 1: "medium", 2: "high"} | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--csv_path", type=str, default="test_samples.csv", help="Path to test file") | |
| parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], default="multimodal", help="mode of data") | |
| parser.add_argument("--model_path", type=str, required=True, help="path to the model") | |
| parser.add_argument("--config", type=str, default="config/config.yaml", help="Path to config file") | |
| parser.add_argument("--image_dir", type=str, default="data/images", help="path to images folder") | |
| parser.add_argument("--output_csv", type=str, help="Optional custome path to output file") | |
| parser.add_argument("--batch_size", type=int, help="Optional override for batch size") | |
| parser.add_argument("--save_misclassified_only", action="store_true", help="Save only misclassified samples") | |
| args = parser.parse_args() | |
| # Checks | |
| if not Path(args.csv_path).exists(): | |
| raise FileNotFoundError(f"❌ CSV file not found at {args.csv_path}") | |
| if not Path(args.model_path).exists(): | |
| raise FileNotFoundError(f"❌ Model checkpoint not found at: {args.model_path}") | |
| if not Path(args.config).exists(): | |
| raise FileNotFoundError(f"❌ Config file not found at: {args.config}") | |
| if args.mode in ["image", "multimodal"] and not Path(args.image_dir).exists(): | |
| raise FileNotFoundError(f"❌ Image directory not found at: {args.image_dir}") | |
| # Always generate mode-specific output file if not provided | |
| if not args.output_csv: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| args.output_csv = f"predictions_{args.mode}_{timestamp}.csv" | |
| config = load_config(config_path=args.config, mode=args.mode) | |
| batch_size = args.batch_size or config["batch_size"] | |
| dataset = TriageDataset( | |
| csv_file=args.csv_path, | |
| mode=args.mode, | |
| image_base_dir=args.image_dir, | |
| transform=inference_transform # avoid random augmentations during inference | |
| ) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| model = MediLLMModel( | |
| mode=args.mode, | |
| dropout=config["dropout"], | |
| hidden_dim=config["hidden_dim"] | |
| ) | |
| model.load_state_dict(torch.load(args.model_path, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| preds, truths, texts, paths, ids = predict(model, dataloader, DEVICE) | |
| label_inv_map = inverse_label_map() | |
| pred_labels = [label_inv_map[p] for p in preds] | |
| df = pd.DataFrame({ | |
| "patient_id": ids, | |
| "predicted": pred_labels, | |
| "truth_label": truths, | |
| "emr_text": texts, | |
| "image_path": paths, | |
| }) | |
| # Filter misclassified rows if needed | |
| if args.save_misclassified_only: | |
| df = df[df["predicted"] != df["truth_label"]] | |
| print(df[["patient_id", "predicted", "truth_label"]]) | |
| # Ensure output directory exists | |
| output_path = Path(args.output_csv) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Save predictions | |
| df.to_csv(args.output_csv, index=False) | |
| print(f"✅ Saved predictions to {output_path}") | |
| print(f"\n🔎 Processed {len(preds)} samples ({'missclassified only' if args.save_misclassified_only else 'all'}).") | |
| # print classification report + metrics if labels exist | |
| if all(label in ["low", "medium", "high"] for label in truths): | |
| print("\n📊 Classification Report:") | |
| print(classification_report(truths, pred_labels)) | |
| acc = accuracy_score(truths, pred_labels) | |
| f1 = f1_score(truths, pred_labels, average="weighted") | |
| print(f"\n🎯 Accuracy: {acc:.4f}") | |
| print(f"\n🎯 Weighted F1 Score: {f1:.4f}") | |
| if __name__ == "__main__": | |
| main() | |