Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import pytest | |
| import torch | |
| import pandas as pd | |
| from pathlib import Path | |
| # Setup path to src/ | |
| BASE_DIR = Path(__file__).resolve().parent.parent | |
| SRC_DIR = BASE_DIR / "src" | |
| sys.path.insert(0, str(SRC_DIR)) | |
| from triage_dataset import TriageDataset | |
| # Detect CI environment | |
| IS_CI = os.getenv("CI", "false").lower() == "true" | |
| # Paths | |
| DATA_DIR = BASE_DIR / "data" | |
| CSV_PATH = DATA_DIR / ("test_emr_records.csv" if IS_CI else "emr_records.csv") | |
| IMAGE_DIR = (DATA_DIR / "dummy_images").resolve() if IS_CI else (DATA_DIR / "images").resolve() | |
| EXPECTED_SAMPLES_PER_CLASS = 3 if IS_CI else 300 | |
| EXPECTED_TOTAL = 3 * 3 if IS_CI else 300 * 3 # 3 classes | |
| def test_dataset_loading(mode): | |
| kwargs = {"csv_file": CSV_PATH, "mode": mode} | |
| if mode in ["image", "multimodal"]: | |
| kwargs["image_base_dir"] = IMAGE_DIR | |
| dataset = TriageDataset(**kwargs) | |
| # Check dataset length | |
| assert len(dataset) == EXPECTED_TOTAL, f"Expected {EXPECTED_TOTAL} records in the dataset" | |
| # Check one sample | |
| sample = dataset[0] | |
| if mode in ["text", "multimodal"]: | |
| assert "input_ids" in sample, "Missing input_ids in text/multimodal mode" | |
| assert ( | |
| "attention_mask" in sample | |
| ), "Missing attention_mask in text/multimodal mode" | |
| assert sample["input_ids"].shape[0] == 128, "Incorrect token length" | |
| if mode in ["image", "multimodal"]: | |
| assert "image" in sample, "Missing image in image/multimodal mode" | |
| assert isinstance(sample["image"], torch.Tensor), "Image not a tensor" | |
| assert sample["image"].shape[1:] == (224, 224), "Incorrect image size" | |
| assert "label" in sample, "Missing label" | |
| assert sample["label"].item() in [0, 1, 2], "Invalid label value" | |
| def test_missing_image_raises_error(tmp_path): | |
| # Create a temporary CSV file with an invalid image path | |
| fake_csv = tmp_path / "fake_emr_records.csv" | |
| fake_df = pd.DataFrame([{ | |
| "patient_id": "ID-XX99", | |
| "image_path": "data/images/NORMAL/non_existent_image.jpg", | |
| "emr_text": "Patient ID-XX99 reports symptoms. Temperature recorded at 98.6°F and SPO2 at 97%.", | |
| "triage_level": "low" | |
| }]) | |
| fake_df.to_csv(fake_csv, index=False) | |
| # Instantiate the dataset in image mode | |
| dataset = TriageDataset(csv_file=fake_csv, mode="image", image_base_dir=IMAGE_DIR) | |
| # Expect a FileNotFoundError when trying to access the missing image | |
| with pytest.raises(FileNotFoundError, match="Image file not found"): | |
| _ = dataset[0] | |