medi-llm / tests /test_triage_dataset.py
Preetham22's picture
fix image paths
be196af
import os
import sys
import pytest
import torch
import pandas as pd
from pathlib import Path
# Setup path to src/
BASE_DIR = Path(__file__).resolve().parent.parent
SRC_DIR = BASE_DIR / "src"
sys.path.insert(0, str(SRC_DIR))
from triage_dataset import TriageDataset
# Detect CI environment
IS_CI = os.getenv("CI", "false").lower() == "true"
# Paths
DATA_DIR = BASE_DIR / "data"
CSV_PATH = DATA_DIR / ("test_emr_records.csv" if IS_CI else "emr_records.csv")
IMAGE_DIR = (DATA_DIR / "dummy_images").resolve() if IS_CI else (DATA_DIR / "images").resolve()
EXPECTED_SAMPLES_PER_CLASS = 3 if IS_CI else 300
EXPECTED_TOTAL = 3 * 3 if IS_CI else 300 * 3 # 3 classes
@pytest.mark.parametrize("mode", ["text", "image", "multimodal"])
def test_dataset_loading(mode):
kwargs = {"csv_file": CSV_PATH, "mode": mode}
if mode in ["image", "multimodal"]:
kwargs["image_base_dir"] = IMAGE_DIR
dataset = TriageDataset(**kwargs)
# Check dataset length
assert len(dataset) == EXPECTED_TOTAL, f"Expected {EXPECTED_TOTAL} records in the dataset"
# Check one sample
sample = dataset[0]
if mode in ["text", "multimodal"]:
assert "input_ids" in sample, "Missing input_ids in text/multimodal mode"
assert (
"attention_mask" in sample
), "Missing attention_mask in text/multimodal mode"
assert sample["input_ids"].shape[0] == 128, "Incorrect token length"
if mode in ["image", "multimodal"]:
assert "image" in sample, "Missing image in image/multimodal mode"
assert isinstance(sample["image"], torch.Tensor), "Image not a tensor"
assert sample["image"].shape[1:] == (224, 224), "Incorrect image size"
assert "label" in sample, "Missing label"
assert sample["label"].item() in [0, 1, 2], "Invalid label value"
def test_missing_image_raises_error(tmp_path):
# Create a temporary CSV file with an invalid image path
fake_csv = tmp_path / "fake_emr_records.csv"
fake_df = pd.DataFrame([{
"patient_id": "ID-XX99",
"image_path": "data/images/NORMAL/non_existent_image.jpg",
"emr_text": "Patient ID-XX99 reports symptoms. Temperature recorded at 98.6°F and SPO2 at 97%.",
"triage_level": "low"
}])
fake_df.to_csv(fake_csv, index=False)
# Instantiate the dataset in image mode
dataset = TriageDataset(csv_file=fake_csv, mode="image", image_base_dir=IMAGE_DIR)
# Expect a FileNotFoundError when trying to access the missing image
with pytest.raises(FileNotFoundError, match="Image file not found"):
_ = dataset[0]