Preetham22 commited on
Commit
54d948e
·
1 Parent(s): 83c4f3c

Add changes for CI complicance

Browse files
.gitignore CHANGED
@@ -24,4 +24,13 @@ logs/
24
  .env
25
 
26
  # logs
27
- *.log
 
 
 
 
 
 
 
 
 
 
24
  .env
25
 
26
  # logs
27
+ *.log
28
+
29
+ # --- EXCEPTIONS (ALLOW) ---
30
+ !data/dummy_images/
31
+ !data/dummy_images/*.jpg
32
+ !data/dummy_images/*.jpeg
33
+ !data/dummy_images/*.png
34
+
35
+ # Allow test CSV for CI
36
+ !data/test_emr_records.csv
src/generate_emr_csv.py CHANGED
@@ -1,23 +1,22 @@
 
1
  import random
2
  import csv
3
  import string
4
  from pathlib import Path
5
 
 
 
 
 
 
 
6
  # Paths
7
  CURRENT_DIR = Path(__file__).resolve().parent
8
- IMAGES_DIR = CURRENT_DIR.parent / "data" / "images"
9
  OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records.csv"
10
 
11
  # Label to triage
12
  triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"}
13
- SAMPLES_PER_CLASS = 300
14
-
15
- # Folders
16
- categories = {
17
- "COVID": IMAGES_DIR / "COVID",
18
- "NORMAL": IMAGES_DIR / "NORMAL",
19
- "VIRAL PNEUMONIA": IMAGES_DIR / "VIRAL PNEUMONIA",
20
- }
21
 
22
  # Shared ambiguous templates
23
  shared_symptoms = [
@@ -120,7 +119,17 @@ def build_emr(label, i):
120
 
121
 
122
  # Generate records
123
- def generate_dataset():
 
 
 
 
 
 
 
 
 
 
124
  records = []
125
  for label, img_dir in categories.items():
126
  image_files = sorted(
@@ -136,7 +145,7 @@ def generate_dataset():
136
 
137
  for i in range(SAMPLES_PER_CLASS):
138
  image_path = str(
139
- random.choice(image_files).relative_to(IMAGES_DIR.parent.parent)
140
  )
141
  text = build_emr(label, i)
142
  triage = triage_map[label]
@@ -144,13 +153,13 @@ def generate_dataset():
144
 
145
  # Shuffle + write
146
  random.shuffle(records)
147
- with open(OUTPUT_FILE, "w", newline="") as f:
148
  writer = csv.writer(f)
149
  writer.writerow(["patient_id", "image_path", "emr_text", "triage_level"])
150
  writer.writerows(records)
151
 
152
- print(f"✅ Softlabel EMR dataset generated at {OUTPUT_FILE}")
153
 
154
 
155
  if __name__ == "__main__":
156
- generate_dataset()
 
1
+ import os
2
  import random
3
  import csv
4
  import string
5
  from pathlib import Path
6
 
7
+ # Detect CI environment
8
+ IS_CI = os.getenv("CI", "false").lower() == "true"
9
+
10
+ # Set number of samples accordingly
11
+ SAMPLES_PER_CLASS = 3 if IS_CI else 300 # Reduced for CI to speed up tests
12
+
13
  # Paths
14
  CURRENT_DIR = Path(__file__).resolve().parent
15
+ IMAGES_DIR = CURRENT_DIR.parent / "data" / "images" # Absolute path of images folder
16
  OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records.csv"
17
 
18
  # Label to triage
19
  triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"}
 
 
 
 
 
 
 
 
20
 
21
  # Shared ambiguous templates
22
  shared_symptoms = [
 
119
 
120
 
121
  # Generate records
122
+ def generate_dataset(image_dir_override=None, output_path_override=None):
123
+ root_image_dir = image_dir_override or IMAGES_DIR
124
+ output_file = output_path_override or OUTPUT_FILE
125
+
126
+ # Folders
127
+ categories = {
128
+ "COVID": root_image_dir / "COVID", # Absolute path of Image labels
129
+ "NORMAL": root_image_dir / "NORMAL",
130
+ "VIRAL PNEUMONIA": root_image_dir / "VIRAL PNEUMONIA",
131
+ }
132
+
133
  records = []
134
  for label, img_dir in categories.items():
135
  image_files = sorted(
 
145
 
146
  for i in range(SAMPLES_PER_CLASS):
147
  image_path = str(
148
+ random.choice(image_files).relative_to(root_image_dir.parent.parent) # path of image respective to the project root
149
  )
150
  text = build_emr(label, i)
151
  triage = triage_map[label]
 
153
 
154
  # Shuffle + write
155
  random.shuffle(records)
156
+ with open(output_file, "w", newline="") as f:
157
  writer = csv.writer(f)
158
  writer.writerow(["patient_id", "image_path", "emr_text", "triage_level"])
159
  writer.writerows(records)
160
 
161
+ print(f"✅ EMR dataset generated at {output_file}")
162
 
163
 
164
  if __name__ == "__main__":
165
+ generate_dataset(image_dir_override=IMAGES_DIR, output_path_override=OUTPUT_FILE)
src/triage_dataset.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  from torch.utils.data import Dataset
4
  from PIL import Image
5
  from torchvision import transforms
@@ -7,6 +8,9 @@ from torchvision.transforms import InterpolationMode
7
  import pandas as pd
8
  from transformers import AutoTokenizer
9
 
 
 
 
10
 
11
  class TriageDataset(Dataset):
12
  def __init__(
@@ -16,6 +20,7 @@ class TriageDataset(Dataset):
16
  max_length=128,
17
  transform=None,
18
  mode="multimodal",
 
19
  ):
20
  assert mode in [
21
  "text",
@@ -27,6 +32,10 @@ class TriageDataset(Dataset):
27
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
28
  self.max_length = max_length
29
  self.mode = mode.lower()
 
 
 
 
30
 
31
  self.transform = (
32
  transform
@@ -78,11 +87,15 @@ class TriageDataset(Dataset):
78
 
79
  if self.mode in ["image", "multimodal"]:
80
  # Process image
81
- base_dir = os.path.dirname(os.path.dirname(__file__))
82
- image_path = os.path.join(base_dir, row["image_path"])
 
83
 
84
- if not os.path.exists(image_path):
85
- raise FileNotFoundError(f"Image file not found: {image_path}")
 
 
 
86
  image = Image.open(image_path).convert("RGB")
87
  output["image"] = self.transform(image)
88
 
 
1
  import os
2
  import torch
3
+ from pathlib import Path
4
  from torch.utils.data import Dataset
5
  from PIL import Image
6
  from torchvision import transforms
 
8
  import pandas as pd
9
  from transformers import AutoTokenizer
10
 
11
+ # Check if running in CI environment
12
+ IS_CI = os.getenv("CI", "false").lower() == "true"
13
+
14
 
15
  class TriageDataset(Dataset):
16
  def __init__(
 
20
  max_length=128,
21
  transform=None,
22
  mode="multimodal",
23
+ image_base_dir=None,
24
  ):
25
  assert mode in [
26
  "text",
 
32
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
33
  self.max_length = max_length
34
  self.mode = mode.lower()
35
+ if self.mode in ["image", "multimodal"]:
36
+ if image_base_dir is None:
37
+ raise ValueError("image directory must be provided for image or multimodal mode.")
38
+ self.image_base_dir = Path(image_base_dir).resolve()
39
 
40
  self.transform = (
41
  transform
 
87
 
88
  if self.mode in ["image", "multimodal"]:
89
  # Process image
90
+ image_path = Path(row["image_path"])
91
+ if not image_path.is_absolute():
92
+ image_path = self.image_base_dir / image_path
93
 
94
+ if not image_path.exists():
95
+ if IS_CI:
96
+ raise FileNotFoundError(f"[CI] Image file not found: {image_path}")
97
+ else:
98
+ raise FileNotFoundError(f"[LOCAL] Image file not found: {image_path}")
99
  image = Image.open(image_path).convert("RGB")
100
  output["image"] = self.transform(image)
101
 
tests/test_generate_emr_csv.py CHANGED
@@ -1,22 +1,30 @@
1
  import os
 
2
  import csv
3
  import sys
4
  import pytest
5
  from pathlib import Path
6
- from collections import Counter
7
 
8
- # Add repo root to the sys.path
9
- BASE_DIR = os.path.dirname(os.path.dirname(__file__))
10
- if BASE_DIR not in sys.path:
11
- sys.path.append(BASE_DIR)
12
 
13
- from src.generate_emr_csv import generate_dataset
14
 
 
 
15
 
16
- CSV_PATH = Path(BASE_DIR) / "data" / "emr_records.csv"
17
- EXPECTED_CLASSES = {"low", "medium", "high"}
 
 
 
 
 
18
  EXPECTED_COLUMNS = ["patient_id", "image_path", "emr_text", "triage_level"]
19
- EXPECTED_SAMPLES_PER_CLASS = 300
 
20
 
21
  AMBIGUOUS_PHRASES = [
22
  "Symptoms could relate to a range of viral infections.",
@@ -46,54 +54,67 @@ NOISE_SENTENCES = [
46
  ]
47
 
48
 
49
- def test_dataset_generation_runs():
50
- generate_dataset()
51
- assert CSV_PATH.exists(), "CSV file should be generated"
52
- with open(CSV_PATH, "r") as f:
53
- lines = f.readlines()
54
- assert len(lines) > 1 # Header + Content
55
 
56
 
57
- @pytest.fixture(scope="module")
58
- def load_emr_csv():
59
  assert CSV_PATH.exists(), f"CSV file not found at: {CSV_PATH}"
 
 
 
 
 
 
 
 
 
 
60
  with open(CSV_PATH, newline="") as f:
61
  reader = csv.DictReader(f)
62
  rows = list(reader)
63
- return rows
64
-
65
 
66
- def test_csv_structure(load_emr_csv):
67
- row = load_emr_csv[0]
68
- assert set(row.keys()) == set(EXPECTED_COLUMNS), "CSV columns mismatch"
69
 
 
 
 
70
 
71
- def test_total_and_per_class_counts(load_emr_csv):
72
- assert len(load_emr_csv) == 900, "Total records should be 900"
73
- counts = Counter(row["triage_level"] for row in load_emr_csv)
74
- for cls in EXPECTED_CLASSES:
75
- assert counts[cls] == EXPECTED_SAMPLES_PER_CLASS, f"{cls} count mismatch"
76
 
77
 
78
  def test_patient_id_format_and_uniqueness(load_emr_csv):
79
- ids = [row["patient_id"] for row in load_emr_csv]
80
- assert all(id and "-" in id for id in ids), "Malformed patient IDs found"
81
- assert len(set(ids)) == 900, "Duplicate patient IDs found"
 
 
 
 
82
 
83
 
84
- def test_emr_text_quality(load_emr_csv):
85
- for row in load_emr_csv:
86
- text = row["emr_text"]
87
- assert (
88
- isinstance(text, str) and len(text.split()) > 10
89
- ), "EMR text too short or malformed"
90
- assert "Temperature" in text and "SPO2" in text, "Vitals info missing"
 
 
91
 
92
 
93
- def test_image_path_format(load_emr_csv):
94
- for row in load_emr_csv:
95
- path = row["image_path"]
96
- assert path.endswith((".jpg", ".jpeg", ".png")), f"Invalid image path: {path}"
 
 
 
 
97
 
98
 
99
  def test_ambiguous_and_noise_injection(load_emr_csv):
@@ -101,28 +122,34 @@ def test_ambiguous_and_noise_injection(load_emr_csv):
101
  symptom_hits = 0
102
  noise_hits = 0
103
 
104
- for row in load_emr_csv:
105
- text = row["emr_text"]
106
- if any(phrase in text for phrase in AMBIGUOUS_PHRASES):
107
- ambiguous_hits += 1
108
- if any(symptom in text for symptom in SHARED_SYMPTOMS):
109
- symptom_hits += 1
110
- if any(noise in text for noise in NOISE_SENTENCES):
111
- noise_hits += 1
 
 
112
 
113
  assert ambiguous_hits > 800, "Ambiguous phrases missing in too many EMRs"
114
  assert symptom_hits > 800, "Shared symptom clues underrepresented"
115
  assert noise_hits > 700, "Too few EMRs contain noise sentences"
116
 
117
 
118
- def test_label_validity(load_emr_csv):
119
- for row in load_emr_csv:
120
- assert (
121
- row["triage_level"] in EXPECTED_CLASSES
122
- ), f"Invalid label: {row['triage_level']}"
 
 
123
 
124
 
125
- def test_no_empty_fields(load_emr_csv):
126
- for row in load_emr_csv:
127
- for col in EXPECTED_COLUMNS:
128
- assert row[col].strip(), f"Empty field found in colum '{col}'"
 
 
 
1
  import os
2
+ import re
3
  import csv
4
  import sys
5
  import pytest
6
  from pathlib import Path
 
7
 
8
+ # Add src/ to path so we can import from it
9
+ BASE_DIR = Path(__file__).resolve().parent.parent
10
+ SRC_DIR = BASE_DIR / "src"
11
+ sys.path.insert(0, str(SRC_DIR))
12
 
13
+ from generate_emr_csv import generate_dataset, OUTPUT_FILE
14
 
15
+ # Determine if running in CI
16
+ IS_CI = os.getenv("CI", "false").lower() == "true"
17
 
18
+ # Paths
19
+ DATA_DIR = BASE_DIR / "data"
20
+ DUMMY_IMAGES_DIR = DATA_DIR / "dummy_images"
21
+ REAL_IMAGES_DIR = DATA_DIR / "images"
22
+ CSV_PATH = DATA_DIR / ("test_emr_records.csv" if IS_CI else OUTPUT_FILE)
23
+
24
+ # Constants
25
  EXPECTED_COLUMNS = ["patient_id", "image_path", "emr_text", "triage_level"]
26
+ EXPECTED_CLASSES = ["low", "medium", "high"]
27
+ EXPECTED_SAMPLES_PER_CLASS = 3 if IS_CI else 300
28
 
29
  AMBIGUOUS_PHRASES = [
30
  "Symptoms could relate to a range of viral infections.",
 
54
  ]
55
 
56
 
57
+ @pytest.fixture(scope="module", autouse=True)
58
+ def generate_csv_for_test():
59
+ image_dir = DUMMY_IMAGES_DIR if IS_CI else REAL_IMAGES_DIR
60
+ generate_dataset(image_dir_override=image_dir, output_path_override=CSV_PATH)
 
 
61
 
62
 
63
+ def test_csv_exists():
 
64
  assert CSV_PATH.exists(), f"CSV file not found at: {CSV_PATH}"
65
+
66
+
67
+ def test_csv_structure():
68
+ with open(CSV_PATH, newline="") as f:
69
+ reader = csv.reader(f)
70
+ header = next(reader)
71
+ assert set(header) == set(EXPECTED_COLUMNS), "CSV columns mismatch"
72
+
73
+
74
+ def test_total_and_per_class_counts():
75
  with open(CSV_PATH, newline="") as f:
76
  reader = csv.DictReader(f)
77
  rows = list(reader)
 
 
78
 
79
+ expected_total = EXPECTED_SAMPLES_PER_CLASS * len(EXPECTED_CLASSES)
80
+ assert len(rows) == expected_total
 
81
 
82
+ counts = {"low": 0, "medium": 0, "high": 0}
83
+ for row in rows:
84
+ counts[row["triage_level"]] += 1
85
 
86
+ assert all(c == EXPECTED_SAMPLES_PER_CLASS for c in counts.values)
 
 
 
 
87
 
88
 
89
  def test_patient_id_format_and_uniqueness(load_emr_csv):
90
+ with open(CSV_PATH, newline="") as f:
91
+ reader = csv.DictReader(f)
92
+ ids = [row["patient_id"] for row in reader]
93
+ assert len(ids) == len(set(ids)), "Duplicate patient IDs found"
94
+ pattern = re.compile(r"^ID-[A-Z]{2}\d{2}$")
95
+ for pid in ids:
96
+ assert pattern.match(pid), f"Invalid patient ID format: {pid}"
97
 
98
 
99
+ def test_emr_text_quality():
100
+ with open(CSV_PATH, newline="") as f:
101
+ reader = csv.DictReader(f)
102
+ for row in reader:
103
+ text = row["emr_text"]
104
+ assert (
105
+ isinstance(text, str) and len(text.split()) > 10
106
+ ), "EMR text too short or malformed"
107
+ assert "Temperature" in text and "SPO2" in text, "Vitals info missing"
108
 
109
 
110
+ def test_image_path_format():
111
+ expected_path = DUMMY_IMAGES_DIR.relative_to(BASE_DIR) if IS_CI else REAL_IMAGES_DIR.relative_to(BASE_DIR)
112
+ with open(CSV_PATH, newline="") as f:
113
+ reader = csv.DictReader(f)
114
+ for row in reader:
115
+ path = row["image_path"]
116
+ assert path.startswith(expected_path), f"Image path should start with '{expected_path}', got: {path}"
117
+ assert path.endswith((".jpg", ".jpeg", ".png")), f"Invalid image path: {path}"
118
 
119
 
120
  def test_ambiguous_and_noise_injection(load_emr_csv):
 
122
  symptom_hits = 0
123
  noise_hits = 0
124
 
125
+ with open(CSV_PATH, newline="") as f:
126
+ reader = csv.DictReader(f)
127
+ for row in reader:
128
+ text = row["emr_text"]
129
+ if any(phrase in text for phrase in AMBIGUOUS_PHRASES):
130
+ ambiguous_hits += 1
131
+ if any(symptom in text for symptom in SHARED_SYMPTOMS):
132
+ symptom_hits += 1
133
+ if any(noise in text for noise in NOISE_SENTENCES):
134
+ noise_hits += 1
135
 
136
  assert ambiguous_hits > 800, "Ambiguous phrases missing in too many EMRs"
137
  assert symptom_hits > 800, "Shared symptom clues underrepresented"
138
  assert noise_hits > 700, "Too few EMRs contain noise sentences"
139
 
140
 
141
+ def test_label_validity():
142
+ with open(CSV_PATH, newline="") as f:
143
+ reader = csv.DictReader(f)
144
+ for row in reader:
145
+ assert (
146
+ row["triage_level"] in EXPECTED_CLASSES
147
+ ), f"Invalid label: {row['triage_level']}"
148
 
149
 
150
+ def test_no_empty_fields():
151
+ with open(CSV_PATH, newline="") as f:
152
+ reader = csv.DictReader(f)
153
+ for row in reader:
154
+ for key, val in row.items():
155
+ assert val.strip() != "", f"Empty field found for {key}"
tests/test_multimodal_model.py CHANGED
@@ -1,12 +1,12 @@
1
  import sys
2
- import os
3
  import torch
4
  import pytest
5
- from transformers import AutoTokenizer
 
6
 
7
 
8
- # Add repo root to the sys.path
9
- BASE_DIR = os.path.dirname(os.path.dirname(__file__))
10
  if BASE_DIR not in sys.path:
11
  sys.path.append(BASE_DIR)
12
 
@@ -15,57 +15,83 @@ from src.multimodal_model import MediLLMModel
15
  BATCH_SIZE = 2
16
  SEQ_LEN = 128
17
  IMAGE_SIZE = (3, 224, 224)
18
- TEXT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
19
-
20
- tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
21
 
22
 
23
  @pytest.fixture
24
  def dummy_inputs():
25
- text_batch = ["Patient reports mild cough and fever."] * BATCH_SIZE
26
- encoding = tokenizer(
27
- text_batch,
28
- padding="max_length",
29
- truncation=True,
30
- max_length=SEQ_LEN,
31
- return_tensors="pt",
32
- )
33
  return {
34
- "input_ids": encoding["input_ids"],
35
- "attention_mask": encoding["attention_mask"],
36
  "image": torch.randn(BATCH_SIZE, *IMAGE_SIZE),
37
  }
38
 
39
 
40
- def test_text_only(dummy_inputs):
 
 
 
 
 
 
 
 
 
 
41
  model = MediLLMModel(mode="text")
42
  model.eval()
43
  outputs = model(
44
  input_ids=dummy_inputs["input_ids"],
45
- attention_mask=dummy_inputs["attention_mask"],
46
  )
47
- assert outputs.shape == (BATCH_SIZE, 3), "Incorrect output shape for text-only mode"
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- def test_image_only(dummy_inputs):
51
  model = MediLLMModel(mode="image")
52
  model.eval()
53
  outputs = model(image=dummy_inputs["image"])
54
- assert outputs.shape == (
55
- BATCH_SIZE,
56
- 3,
57
- ), "Incorrect output shape for image-only mode"
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- def test_multimodal(dummy_inputs):
61
  model = MediLLMModel(mode="multimodal")
62
  model.eval()
63
  outputs = model(
64
  input_ids=dummy_inputs["input_ids"],
65
- attention_mask=dummy_inputs["attention_mask"],
66
  image=dummy_inputs["image"],
67
  )
68
- assert outputs.shape == (
69
- BATCH_SIZE,
70
- 3,
71
- ), "Incorrect output shape for multimodal mode"
 
1
  import sys
 
2
  import torch
3
  import pytest
4
+ from pathlib import Path
5
+ from unittest.mock import patch, MagicMock
6
 
7
 
8
+ # Add repo root to sys.path
9
+ BASE_DIR = Path(__file__).resolve().parent.parent
10
  if BASE_DIR not in sys.path:
11
  sys.path.append(BASE_DIR)
12
 
 
15
  BATCH_SIZE = 2
16
  SEQ_LEN = 128
17
  IMAGE_SIZE = (3, 224, 224)
 
 
 
18
 
19
 
20
  @pytest.fixture
21
  def dummy_inputs():
 
 
 
 
 
 
 
 
22
  return {
23
+ "input_ids": torch.randint(0, 30522, (BATCH_SIZE, SEQ_LEN)), # dummy token IDs
24
+ "attention_mask": torch.ones(BATCH_SIZE, SEQ_LEN),
25
  "image": torch.randn(BATCH_SIZE, *IMAGE_SIZE),
26
  }
27
 
28
 
29
+ @patch("src.multimodal_model.AutoModel.from_pretrained")
30
+ @patch("src.multimodal_model.timm.create_model")
31
+ def test_text_only(mock_create_model, mock_auto_model, dummy_inputs):
32
+ # Mock text encoder
33
+ mock_text_encoder = MagicMock()
34
+ mock_text_encoder.config.hidden_size = 768
35
+ mock_text_encoder.return_value = MagicMock(
36
+ last_hidden_state=torch.randn(BATCH_SIZE, SEQ_LEN, 768)
37
+ )
38
+ mock_auto_model.return_value = mock_text_encoder
39
+
40
  model = MediLLMModel(mode="text")
41
  model.eval()
42
  outputs = model(
43
  input_ids=dummy_inputs["input_ids"],
44
+ attention_mask=dummy_inputs["attention_mask"]
45
  )
 
46
 
47
+ assert outputs.shape == (BATCH_SIZE, 3)
48
+ probs = torch.softmax(outputs, dim=1)
49
+ assert torch.allclose(probs.sum(dim=1), torch.ones(BATCH_SIZE), atol=1e-5)
50
+
51
+
52
+ @patch("src.multimodal_model.Automodel.from_pretrained")
53
+ @patch("src.multimodal_model.timm.create_model")
54
+ def test_image_only(mock_create_model, mock_auto_model, dummy_inputs):
55
+ # Mock image encoder
56
+ mock_image_encoder = MagicMock()
57
+ mock_image_encoder.num_features = 2048
58
+ mock_image_encoder.return_value = torch.randn(BATCH_SIZE, 2048)
59
+ mock_create_model.return_value = mock_image_encoder
60
 
 
61
  model = MediLLMModel(mode="image")
62
  model.eval()
63
  outputs = model(image=dummy_inputs["image"])
 
 
 
 
64
 
65
+ assert outputs.shape == (BATCH_SIZE, 3)
66
+ probs = torch.softmax(outputs, dim=1)
67
+ assert torch.allclose(probs.sum(dim=1), torch.ones(BATCH_SIZE), atol=1e-5)
68
+
69
+
70
+ @patch("src.multimodal_model.AutoModel.from_pretrained")
71
+ @patch("src.multimodal_model.timm.create_model")
72
+ def test_multimodal(mock_create_model, mock_auto_model, dummy_inputs):
73
+ # Mock text encoder
74
+ mock_text_encoder = MagicMock()
75
+ mock_text_encoder.config.hidden_size = 768
76
+ mock_text_encoder.return_value = MagicMock(
77
+ last_hidden_state=torch.randn(BATCH_SIZE, SEQ_LEN, 768)
78
+ )
79
+ mock_auto_model.return_value = mock_text_encoder
80
+
81
+ # Mock image encoder
82
+ mock_image_encoder = MagicMock()
83
+ mock_image_encoder.num_features = 2048
84
+ mock_image_encoder.return_value = torch.randn(BATCH_SIZE, 2048)
85
+ mock_create_model.return_value = mock_image_encoder
86
 
 
87
  model = MediLLMModel(mode="multimodal")
88
  model.eval()
89
  outputs = model(
90
  input_ids=dummy_inputs["input_ids"],
91
+ atttention_mask=dummy_inputs["attention_mask"],
92
  image=dummy_inputs["image"],
93
  )
94
+
95
+ assert outputs.shape == (BATCH_SIZE, 3)
96
+ probs = torch.softmax(outputs, dim=1)
97
+ assert torch.allclose(probs.sum(dim=1), torch.ones(BATCH_SIZE), atol=1e-5)
tests/test_triage_dataset.py CHANGED
@@ -3,23 +3,36 @@ import sys
3
  import pytest
4
  import torch
5
  import pandas as pd
 
6
 
7
- base_dir = os.path.dirname(os.path.dirname(__file__))
8
- if base_dir not in sys.path:
9
- sys.path.append(base_dir)
 
10
 
11
- from src.triage_dataset import TriageDataset
12
 
13
- # Path to CSV and example image should match the local structure
14
- CSV_PATH = os.path.join(base_dir, "data", "emr_records.csv")
 
 
 
 
 
 
 
15
 
16
 
17
  @pytest.mark.parametrize("mode", ["text", "image", "multimodal"])
18
  def test_dataset_loading(mode):
19
- dataset = TriageDataset(csv_file=CSV_PATH, mode=mode)
 
 
 
 
20
 
21
  # Check dataset length
22
- assert len(dataset) == 900, "Expected 900 records in the dataset"
23
 
24
  # Check one sample
25
  sample = dataset[0]
 
3
  import pytest
4
  import torch
5
  import pandas as pd
6
+ from pathlib import Path
7
 
8
+ # Setup path to src/
9
+ BASE_DIR = Path(__file__).resolve().parent.parent
10
+ SRC_DIR = BASE_DIR / "src"
11
+ sys.path.insert(0, str(SRC_DIR))
12
 
13
+ from triage_dataset import TriageDataset
14
 
15
+ # Detect CI environment
16
+ IS_CI = os.getenv("CI", "false").lower() == "true"
17
+
18
+ # Paths
19
+ DATA_DIR = BASE_DIR / "data"
20
+ CSV_PATH = DATA_DIR / ("test_emr_records.csv" if IS_CI else "emr_records.csv")
21
+ IMAGE_DIR = DATA_DIR / ("dummy_images" if IS_CI else "images")
22
+ EXPECTED_SAMPLES_PER_CLASS = 3 if IS_CI else 300
23
+ EXPECTED_TOTAL = 3 * 3 if IS_CI else 300 * 3 # 3 classes
24
 
25
 
26
  @pytest.mark.parametrize("mode", ["text", "image", "multimodal"])
27
  def test_dataset_loading(mode):
28
+ kwargs = {"csv_file": CSV_PATH, "mode": mode}
29
+ if mode in ["image", "multimodal"]:
30
+ kwargs["image_base_dir"] = IMAGE_DIR
31
+
32
+ dataset = TriageDataset(**kwargs)
33
 
34
  # Check dataset length
35
+ assert len(dataset) == EXPECTED_TOTAL, f"Expected {EXPECTED_TOTAL} records in the dataset"
36
 
37
  # Check one sample
38
  sample = dataset[0]
tools/generate_dummy_images.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import shutil
3
+
4
+ # Define image categories and paths
5
+ LABELS = ["COVID", "NORMAL", "VIRAL PNEUMONIA"]
6
+ BASE_DIR = Path(__file__).resolve().parent.parent
7
+ DATA_DIR = BASE_DIR / "data" / "images"
8
+ DST_DIR = BASE_DIR / "data" / "dummy_images"
9
+ NUM_IMAGES_PER_CLASS = 3 # keep small for CI
10
+
11
+
12
+ def create_dummy_images():
13
+ for label in LABELS:
14
+ src_dir = DATA_DIR / label
15
+ dst_dir = DST_DIR / label
16
+ dst_dir.mkdir(parents=True, exist_ok=True)
17
+
18
+ image_files = sorted([f for f in src_dir.glob("*") if f.is_file()])
19
+ for i, img_path in enumerate(image_files[:NUM_IMAGES_PER_CLASS]):
20
+ ext = img_path.suffix
21
+ dummy_filename = f"dummy_{i + 1}{ext}"
22
+ dst_path = dst_dir / dummy_filename
23
+ shutil.copy(img_path, dst_path)
24
+
25
+ print(f"✅ Dummy image copies created in: {DST_DIR}")
26
+
27
+
28
+ if __name__ == "__main__":
29
+ create_dummy_images()