File size: 2,591 Bytes
9218201
 
 
 
83c4f3c
54d948e
9218201
54d948e
 
 
 
9218201
54d948e
9218201
54d948e
 
 
 
 
 
be196af
54d948e
 
9218201
 
 
 
54d948e
 
 
 
 
9218201
 
54d948e
9218201
 
 
 
 
562137e
9218201
 
 
 
 
 
 
 
 
 
 
 
83c4f3c
 
 
 
 
 
 
 
 
 
 
 
 
 
5b886c6
83c4f3c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import sys
import pytest
import torch
import pandas as pd
from pathlib import Path

# Setup path to src/
BASE_DIR = Path(__file__).resolve().parent.parent
SRC_DIR = BASE_DIR / "src"
sys.path.insert(0, str(SRC_DIR))

from triage_dataset import TriageDataset

# Detect CI environment
IS_CI = os.getenv("CI", "false").lower() == "true"

# Paths
DATA_DIR = BASE_DIR / "data"
CSV_PATH = DATA_DIR / ("test_emr_records.csv" if IS_CI else "emr_records.csv")
IMAGE_DIR = (DATA_DIR / "dummy_images").resolve() if IS_CI else (DATA_DIR / "images").resolve()
EXPECTED_SAMPLES_PER_CLASS = 3 if IS_CI else 300
EXPECTED_TOTAL = 3 * 3 if IS_CI else 300 * 3  # 3 classes


@pytest.mark.parametrize("mode", ["text", "image", "multimodal"])
def test_dataset_loading(mode):
    kwargs = {"csv_file": CSV_PATH, "mode": mode}
    if mode in ["image", "multimodal"]:
        kwargs["image_base_dir"] = IMAGE_DIR

    dataset = TriageDataset(**kwargs)

    # Check dataset length
    assert len(dataset) == EXPECTED_TOTAL, f"Expected {EXPECTED_TOTAL} records in the dataset"

    # Check one sample
    sample = dataset[0]

    if mode in ["text", "multimodal"]:
        assert "input_ids" in sample, "Missing input_ids in text/multimodal mode"
        assert (
            "attention_mask" in sample
        ), "Missing attention_mask in text/multimodal mode"
        assert sample["input_ids"].shape[0] == 128, "Incorrect token length"

    if mode in ["image", "multimodal"]:
        assert "image" in sample, "Missing image in image/multimodal mode"
        assert isinstance(sample["image"], torch.Tensor), "Image not a tensor"
        assert sample["image"].shape[1:] == (224, 224), "Incorrect image size"

    assert "label" in sample, "Missing label"
    assert sample["label"].item() in [0, 1, 2], "Invalid label value"


def test_missing_image_raises_error(tmp_path):
    # Create a temporary CSV file with an invalid image path
    fake_csv = tmp_path / "fake_emr_records.csv"
    fake_df = pd.DataFrame([{
        "patient_id": "ID-XX99",
        "image_path": "data/images/NORMAL/non_existent_image.jpg",
        "emr_text": "Patient ID-XX99 reports symptoms. Temperature recorded at 98.6°F and SPO2 at 97%.",
        "triage_level": "low"
    }])
    fake_df.to_csv(fake_csv, index=False)

    # Instantiate the dataset in image mode
    dataset = TriageDataset(csv_file=fake_csv, mode="image", image_base_dir=IMAGE_DIR)

    # Expect a FileNotFoundError when trying to access the missing image
    with pytest.raises(FileNotFoundError, match="Image file not found"):
        _ = dataset[0]