Preetham22 commited on
Commit
9218201
·
1 Parent(s): 56f8ce8

Add tests, formatting, and CI enhancements with pre-commit support

Browse files
.flake8 ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .flake8
2
+ [flake8]
3
+ max-line-length = 88
4
+ ignore = E501, E402
5
+ exclude =
6
+ .git,
7
+ __pycache__,
8
+ .venv,
9
+ env,
10
+ build,
11
+ dist
.github/workflows/ci.yml CHANGED
@@ -23,12 +23,19 @@ jobs:
23
  run: |
24
  python -m pip install --upgrade pip
25
  pip install -r requirements.txt
26
- pip install pytest flake8
27
 
28
- - name: ✅ Lint code
 
 
 
 
 
 
 
29
  run: |
30
- flake8 src/ --ignore=E501
31
 
32
  - name: 🧪 Run unit tests
33
  run: |
34
- pytest tests/
 
23
  run: |
24
  python -m pip install --upgrade pip
25
  pip install -r requirements.txt
26
+ pip install pytest flake8 black isort
27
 
28
+ - name: ✅ Lint code with flake8
29
+ run: flake8
30
+
31
+ - name: 🔧 Check code format with black
32
+ run: |
33
+ black --check .
34
+
35
+ - name: 📦 Check import order with isort
36
  run: |
37
+ isort . --check-only
38
 
39
  - name: 🧪 Run unit tests
40
  run: |
41
+ pytest --cov=src tests/
.gitignore CHANGED
@@ -3,6 +3,7 @@ data/
3
  checkpoints/
4
  __pycache__/
5
  *.py[cod]
 
6
 
7
  # Weights & Biases
8
  wandb/
 
3
  checkpoints/
4
  __pycache__/
5
  *.py[cod]
6
+ .coverage
7
 
8
  # Weights & Biases
9
  wandb/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pycqa/flake8
3
+ rev: 6.1.0
4
+ hooks:
5
+ - id: flake8
6
+ additional_dependencies: []
7
+ args: ["--ignore=E501,E402"]
experiments/train_optuna.py CHANGED
@@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, Subset
14
  from torch.nn import CrossEntropyLoss
15
  from torch.optim import Adam
16
  from sklearn.model_selection import StratifiedShuffleSplit
17
- from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
18
 
19
 
20
  # Setup base path
 
14
  from torch.nn import CrossEntropyLoss
15
  from torch.optim import Adam
16
  from sklearn.model_selection import StratifiedShuffleSplit
17
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
18
 
19
 
20
  # Setup base path
pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 88
3
+
4
+ [tool.isort]
5
+ profile = "black"
6
+ line_length = 88
requirements.txt CHANGED
@@ -36,4 +36,6 @@ python-multipart>0.0.6
36
 
37
  # Linting and testing
38
  pytest>=7.4.0
 
 
39
  flake8>=6.1.0
 
36
 
37
  # Linting and testing
38
  pytest>=7.4.0
39
+ pytest-cov>=4.1
40
+ pre-commit>=3.5.0
41
  flake8>=6.1.0
src/generate_emr_csv.py CHANGED
@@ -29,7 +29,7 @@ shared_symptoms = [
29
  "Vital signs mostly stable; slight variation in temperature.",
30
  ]
31
 
32
- # Overlapping diagnosis clues
33
  shared_diagnosis = [
34
  "Symptoms could relate to a range of viral infections.",
35
  "Presentation not distinctly matching any single infection.",
@@ -120,28 +120,39 @@ def build_emr(label, i):
120
 
121
 
122
  # Generate records
123
- records = []
124
- for label, img_dir in categories.items():
125
- image_files = sorted(
126
- [
127
- f
128
- for f in img_dir.glob("*")
129
- if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
130
- ]
131
- )
132
- for i in range(SAMPLES_PER_CLASS):
133
- image_path = str(
134
- random.choice(image_files).relative_to(IMAGES_DIR.parent.parent)
135
  )
136
- text = build_emr(label, i)
137
- triage = triage_map[label]
138
- records.append([f"{label}-{i+1}", image_path, text, triage])
139
-
140
- # Shuffle + write
141
- random.shuffle(records)
142
- with open(OUTPUT_FILE, "w", newline="") as f:
143
- writer = csv.writer(f)
144
- writer.writerow(["patient_id", "image_path", "emr_text", "triage_level"])
145
- writer.writerows(records)
146
-
147
- print(f" Softlabel EMR dataset generated at {OUTPUT_FILE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  "Vital signs mostly stable; slight variation in temperature.",
30
  ]
31
 
32
+ # Overlapping diagnosis clues to add ambiguity
33
  shared_diagnosis = [
34
  "Symptoms could relate to a range of viral infections.",
35
  "Presentation not distinctly matching any single infection.",
 
120
 
121
 
122
  # Generate records
123
+ def generate_dataset():
124
+ records = []
125
+ for label, img_dir in categories.items():
126
+ image_files = sorted(
127
+ [
128
+ f
129
+ for f in img_dir.glob("*")
130
+ if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
131
+ ]
 
 
 
132
  )
133
+ for i in range(SAMPLES_PER_CLASS):
134
+ image_path = str(
135
+ random.choice(image_files)
136
+ .relative_to(IMAGES_DIR.parent.parent)
137
+ )
138
+ text = build_emr(label, i)
139
+ triage = triage_map[label]
140
+ records.append([f"{label}-{i+1}", image_path, text, triage])
141
+
142
+ # Shuffle + write
143
+ random.shuffle(records)
144
+ with open(OUTPUT_FILE, "w", newline="") as f:
145
+ writer = csv.writer(f)
146
+ writer.writerow([
147
+ "patient_id",
148
+ "image_path",
149
+ "emr_text",
150
+ "triage_level"
151
+ ])
152
+ writer.writerows(records)
153
+
154
+ print(f"✅ Softlabel EMR dataset generated at {OUTPUT_FILE}")
155
+
156
+
157
+ if __name__ == "__main__":
158
+ generate_dataset()
tests/__init__.py ADDED
File without changes
tests/test_generate_emr_csv.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import sys
4
+ import pytest
5
+ from collections import Counter
6
+
7
+ # Add repo root to the sys.path
8
+ BASE_DIR = os.path.dirname(os.path.dirname(__file__))
9
+ if BASE_DIR not in sys.path:
10
+ sys.path.append(BASE_DIR)
11
+
12
+ from src.generate_emr_csv import generate_dataset, OUTPUT_FILE
13
+
14
+
15
+ CSV_PATH = OUTPUT_FILE
16
+ EXPECTED_CLASSES = {"low", "medium", "high"}
17
+ EXPECTED_COLUMNS = ["patient_id", "image_path", "emr_text", "triage_level"]
18
+ EXPECTED_SAMPLES_PER_CLASS = 300
19
+
20
+ AMBIGUOUS_PHRASES = [
21
+ "Symptoms could relate to a range of viral infections.",
22
+ "Presentation not distinctly matching any single infection.",
23
+ "Further tests required to confirm diagnosis.",
24
+ "Findings are borderline; clinical judgment advised.",
25
+ "Observation warranted due to overlapping signs.",
26
+ "Initial assessment inconclusive.",
27
+ ]
28
+
29
+ SHARED_SYMPTOMS = [
30
+ "Mild cough and slight fever reported.",
31
+ "General fatigue and throat irritation present.",
32
+ "Breathing mildly labored during physical exertion.",
33
+ "No major respiratory distress; mild wheezing noted.",
34
+ "Occasional chest tightness reported.",
35
+ "Vital signs mostly stable; slight variation in temperature.",
36
+ ]
37
+
38
+ NOISE_SENTENCES = [
39
+ "Patient is cooperative and alert.",
40
+ "Dietary habits unremarkable.",
41
+ "Hydration status normal.",
42
+ "Follow-up advised if symptoms persist.",
43
+ "No notable family medical history.",
44
+ "No medications currently administered.",
45
+ ]
46
+
47
+
48
+ def test_dataset_generation_runs():
49
+ generate_dataset()
50
+ assert CSV_PATH.exists(), "CSV file should be generated"
51
+ with open(OUTPUT_FILE, "r") as f:
52
+ lines = f.readlines()
53
+ assert len(lines) > 1 # Header + Content
54
+
55
+
56
+ @pytest.fixture(scope="module")
57
+ def load_emr_csv():
58
+ assert CSV_PATH.exists(), f"CSV file not found at: {CSV_PATH}"
59
+ with open(CSV_PATH, newline="") as f:
60
+ reader = csv.DictReader(f)
61
+ rows = list(reader)
62
+ return rows
63
+
64
+
65
+ def test_csv_structure(load_emr_csv):
66
+ row = load_emr_csv[0]
67
+ assert set(row.keys()) == set(EXPECTED_COLUMNS), "CSV columns mismatch"
68
+
69
+
70
+ def test_total_and_per_class_counts(load_emr_csv):
71
+ assert len(load_emr_csv) == 900, "Total records should be 900"
72
+ counts = Counter(row["triage_level"] for row in load_emr_csv)
73
+ for cls in EXPECTED_CLASSES:
74
+ assert counts[cls] == EXPECTED_SAMPLES_PER_CLASS, (
75
+ f"{cls} count mismatch"
76
+ )
77
+
78
+
79
+ def test_patient_id_format_and_uniqueness(load_emr_csv):
80
+ ids = [row["patient_id"] for row in load_emr_csv]
81
+ assert all(id and "-" in id for id in ids), "Malformed patient IDs found"
82
+ assert len(set(ids)) == 900, "Duplicate patient IDs found"
83
+
84
+
85
+ def test_emr_text_quality(load_emr_csv):
86
+ for row in load_emr_csv:
87
+ text = row["emr_text"]
88
+ assert (
89
+ isinstance(text, str) and len(text.split()) > 10
90
+ ), "EMR text too short or malformed"
91
+ assert "Temperature" in text and "SPO2" in text, "Vitals info missing"
92
+
93
+
94
+ def test_image_path_format(load_emr_csv):
95
+ for row in load_emr_csv:
96
+ path = row["image_path"]
97
+ assert path.endswith((".jpg", ".jpeg", ".png")), (
98
+ f"Invalid image path: {path}"
99
+ )
100
+
101
+
102
+ def test_ambiguous_and_noise_injection(load_emr_csv):
103
+ ambiguous_hits = 0
104
+ symptom_hits = 0
105
+ noise_hits = 0
106
+
107
+ for row in load_emr_csv:
108
+ text = row["emr_text"]
109
+ if any(phrase in text for phrase in AMBIGUOUS_PHRASES):
110
+ ambiguous_hits += 1
111
+ if any(symptom in text for symptom in SHARED_SYMPTOMS):
112
+ symptom_hits += 1
113
+ if any(noise in text for noise in NOISE_SENTENCES):
114
+ noise_hits += 1
115
+
116
+ assert ambiguous_hits > 800, "Ambiguous phrases missing in too many EMRs"
117
+ assert symptom_hits > 800, "Shared symptom clues underrepresented"
118
+ assert noise_hits > 700, "Too few EMRs contain noise sentences"
119
+
120
+
121
+ def test_label_validity(load_emr_csv):
122
+ for row in load_emr_csv:
123
+ assert (
124
+ row["triage_level"] in EXPECTED_CLASSES
125
+ ), f"Invalid label: {row['triage_level']}"
126
+
127
+
128
+ def test_no_empty_fields(load_emr_csv):
129
+ for row in load_emr_csv:
130
+ for col in EXPECTED_COLUMNS:
131
+ assert row[col].strip(), f"Empty field found in colum '{col}'"
tests/test_multimodal_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import pytest
5
+ from transformers import AutoTokenizer
6
+
7
+
8
+ # Add repo root to the sys.path
9
+ BASE_DIR = os.path.dirname(os.path.dirname(__file__))
10
+ if BASE_DIR not in sys.path:
11
+ sys.path.append(BASE_DIR)
12
+
13
+ from src.multimodal_model import MediLLMModel
14
+
15
+ BATCH_SIZE = 2
16
+ SEQ_LEN = 128
17
+ IMAGE_SIZE = (3, 224, 224)
18
+ TEXT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
21
+
22
+
23
+ @pytest.fixture
24
+ def dummy_inputs():
25
+ text_batch = ["Patient reports mild cough and fever."] * BATCH_SIZE
26
+ encoding = tokenizer(
27
+ text_batch,
28
+ padding="max_length",
29
+ truncation=True,
30
+ max_length=SEQ_LEN,
31
+ return_tensors="pt",
32
+ )
33
+ return {
34
+ "input_ids": encoding["input_ids"],
35
+ "attention_mask": encoding["attention_mask"],
36
+ "image": torch.randn(BATCH_SIZE, *IMAGE_SIZE),
37
+ }
38
+
39
+
40
+ def test_text_only(dummy_inputs):
41
+ model = MediLLMModel(mode="text")
42
+ model.eval()
43
+ outputs = model(
44
+ input_ids=dummy_inputs["input_ids"],
45
+ attention_mask=dummy_inputs["attention_mask"],
46
+ )
47
+ assert outputs.shape == (BATCH_SIZE, 3), (
48
+ "Incorrect output shape for text-only mode"
49
+ )
50
+
51
+
52
+ def test_image_only(dummy_inputs):
53
+ model = MediLLMModel(mode="image")
54
+ model.eval()
55
+ outputs = model(image=dummy_inputs["image"])
56
+ assert outputs.shape == (
57
+ BATCH_SIZE,
58
+ 3,
59
+ ), "Incorrect output shape for image-only mode"
60
+
61
+
62
+ def test_multimodal(dummy_inputs):
63
+ model = MediLLMModel(mode="multimodal")
64
+ model.eval()
65
+ outputs = model(
66
+ input_ids=dummy_inputs["input_ids"],
67
+ attention_mask=dummy_inputs["attention_mask"],
68
+ image=dummy_inputs["image"],
69
+ )
70
+ assert outputs.shape == (
71
+ BATCH_SIZE,
72
+ 3,
73
+ ), "Incorrect output shape for multimodal mode"
tests/test_triage_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pytest
4
+ import torch
5
+
6
+ base_dir = os.path.dirname(os.path.dirname(__file__))
7
+ if base_dir not in sys.path:
8
+ sys.path.append(base_dir)
9
+
10
+ from src.triage_dataset import TriageDataset
11
+
12
+ # Path to CSV and example image should match the local structure
13
+ CSV_PATH = os.path.join(base_dir, "data", "emr_records.csv")
14
+
15
+
16
+ @pytest.mark.parametrize("mode", ["text", "image", "multimodal"])
17
+ def test_dataset_loading(mode):
18
+ dataset = TriageDataset(csv_file=CSV_PATH, mode=mode)
19
+
20
+ # Check dataset length
21
+ assert len(dataset) == 900, "Expected 900 records in the dataset"
22
+
23
+ # Check one sample
24
+ sample = dataset[0]
25
+
26
+ if mode in ["text", "multimodal"]:
27
+ assert "input_ids" in sample, (
28
+ "Missing input_ids in text/multimodal mode"
29
+ )
30
+ assert (
31
+ "attention_mask" in sample
32
+ ), "Missing attention_mask in text/multimodal mode"
33
+ assert sample["input_ids"].shape[0] == 128, "Incorrect token length"
34
+
35
+ if mode in ["image", "multimodal"]:
36
+ assert "image" in sample, "Missing image in image/multimodal mode"
37
+ assert isinstance(sample["image"], torch.Tensor), "Image not a tensor"
38
+ assert sample["image"].shape[1:] == (224, 224), "Incorrect image size"
39
+
40
+ assert "label" in sample, "Missing label"
41
+ assert sample["label"].item() in [0, 1, 2], "Invalid label value"