Spaces:
Sleeping
Sleeping
File size: 6,056 Bytes
c27b2d2 b473579 43dc105 c27b2d2 43dc105 c27b2d2 4e592d4 c27b2d2 4e592d4 c27b2d2 56d52d3 c27b2d2 b473579 c27b2d2 56d52d3 b473579 56d52d3 c27b2d2 43dc105 c27b2d2 43dc105 c27b2d2 4e592d4 c27b2d2 4e592d4 c27b2d2 b473579 4e592d4 b473579 4e592d4 c27b2d2 4e592d4 b473579 c27b2d2 4e592d4 c27b2d2 4e592d4 c27b2d2 4e592d4 c27b2d2 4e592d4 c27b2d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import argparse
import pandas as pd
import torch
import yaml
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, classification_report
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from torchvision import transforms
from src.triage_dataset import TriageDataset
from src.multimodal_model import MediLLMModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor()
])
def load_config(config_path="config/config.yaml", mode="multimodal"):
with open(config_path) as f:
config = yaml.safe_load(f)
return config[mode]
def predict(model, dataloader, device):
model.eval()
all_preds, all_truths = [], []
all_texts, all_paths, all_ids = [], [], []
inv_map = {0: "low", 1: "medium", 2: "high"}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Running inference"):
input_ids = batch.get("input_ids", None)
attention_mask = batch.get("attention_mask", None)
images = batch.get("image", None)
if input_ids is not None:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
if images is not None:
images = images.to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, image=images)
preds = torch.argmax(outputs, dim=1).cpu().tolist()
all_preds.extend(preds)
batch_size = len(preds)
# Patient ID (original from dataset)
all_ids.extend(batch["patient_id"])
# EMR Text
all_texts.extend(batch.get("emr_text", [""] * batch_size))
# Image Path
all_paths.extend(batch.get("image_path", [""] * batch_size))
# True Labels
if "label" in batch:
all_truths.extend([inv_map.get(label.item(), "") for label in batch["label"]])
else:
all_truths.extend([""] * batch_size)
return all_preds, all_truths, all_texts, all_paths, all_ids
def inverse_label_map():
return {0: "low", 1: "medium", 2: "high"}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, default="test_samples.csv", help="Path to test file")
parser.add_argument("--mode", type=str, choices=["text", "image", "multimodal"], default="multimodal", help="mode of data")
parser.add_argument("--model_path", type=str, required=True, help="path to the model")
parser.add_argument("--config", type=str, default="config/config.yaml", help="Path to config file")
parser.add_argument("--image_dir", type=str, default="data/images", help="path to images folder")
parser.add_argument("--output_csv", type=str, help="Optional custome path to output file")
parser.add_argument("--batch_size", type=int, help="Optional override for batch size")
parser.add_argument("--save_misclassified_only", action="store_true", help="Save only misclassified samples")
args = parser.parse_args()
# Checks
if not Path(args.csv_path).exists():
raise FileNotFoundError(f"❌ CSV file not found at {args.csv_path}")
if not Path(args.model_path).exists():
raise FileNotFoundError(f"❌ Model checkpoint not found at: {args.model_path}")
if not Path(args.config).exists():
raise FileNotFoundError(f"❌ Config file not found at: {args.config}")
if args.mode in ["image", "multimodal"] and not Path(args.image_dir).exists():
raise FileNotFoundError(f"❌ Image directory not found at: {args.image_dir}")
# Always generate mode-specific output file if not provided
if not args.output_csv:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.output_csv = f"predictions_{args.mode}_{timestamp}.csv"
config = load_config(config_path=args.config, mode=args.mode)
batch_size = args.batch_size or config["batch_size"]
dataset = TriageDataset(
csv_file=args.csv_path,
mode=args.mode,
image_base_dir=args.image_dir,
transform=inference_transform # avoid random augmentations during inference
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
model = MediLLMModel(
mode=args.mode,
dropout=config["dropout"],
hidden_dim=config["hidden_dim"]
)
model.load_state_dict(torch.load(args.model_path, map_location=DEVICE))
model.to(DEVICE)
preds, truths, texts, paths, ids = predict(model, dataloader, DEVICE)
label_inv_map = inverse_label_map()
pred_labels = [label_inv_map[p] for p in preds]
df = pd.DataFrame({
"patient_id": ids,
"predicted": pred_labels,
"truth_label": truths,
"emr_text": texts,
"image_path": paths,
})
# Filter misclassified rows if needed
if args.save_misclassified_only:
df = df[df["predicted"] != df["truth_label"]]
print(df[["patient_id", "predicted", "truth_label"]])
# Ensure output directory exists
output_path = Path(args.output_csv)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save predictions
df.to_csv(args.output_csv, index=False)
print(f"✅ Saved predictions to {output_path}")
print(f"\n🔎 Processed {len(preds)} samples ({'missclassified only' if args.save_misclassified_only else 'all'}).")
# print classification report + metrics if labels exist
if all(label in ["low", "medium", "high"] for label in truths):
print("\n📊 Classification Report:")
print(classification_report(truths, pred_labels))
acc = accuracy_score(truths, pred_labels)
f1 = f1_score(truths, pred_labels, average="weighted")
print(f"\n🎯 Accuracy: {acc:.4f}")
print(f"\n🎯 Weighted F1 Score: {f1:.4f}")
if __name__ == "__main__":
main()
|