File size: 5,262 Bytes
9d9cc25
 
 
 
 
 
 
 
 
 
 
f294e5b
9d9cc25
 
c27b2d2
 
 
9d9cc25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27b2d2
9d9cc25
 
c27b2d2
 
9d9cc25
 
 
 
c27b2d2
9d9cc25
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import random
import csv
import string
from pathlib import Path

# Constants
SAMPLES_PER_CLASS = 10
CURRENT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = CURRENT_DIR.parent
IMAGE_DIR = PROJECT_ROOT / "data" / "images"
TRAIN_CSV_PATH = PROJECT_ROOT / "data" / "emr_records.csv"
OUTPUT_CSV = PROJECT_ROOT / "test_samples.csv"
LABELS = ["COVID", "NORMAL", "VIRAL PNEUMONIA"]

# Labels to triage map
triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"}

alt_symptoms = [
    "The patient has noted intermittent chest pressure and occasional shortness of breath.",
    "A gradual onset of dry cough with mild respiratory discomfort has been documented.",
    "Reported complaints include mild fatigue and sporadic episodes of wheezing.",
    "Mild respiratory symptoms have progressed over several days.",
    "Episodes of throat irritation and general malaise observed.",
]

alt_diagnosis = [
    "Clinical features are suggestive of a nonspecific viral etiology.",
    "Diagnosis remains unclear pending further laboratory confirmation.",
    "Preliminary indicators fall into a diagnostic grey area.",
    "No definitive pattern observed; further evaluation is warranted.",
    "Presentation overlaps multiple pulmonary conditions.",
]

alt_noise = [
    "Patient remains oriented with stable hemodynamics.",
    "No remarkable family history or chronic illness reported.",
    "Nutritional intake and sleep patterns appear adequate.",
    "No prior admissions or surgical history disclosed.",
    "Standard precautions have been advised post-evaluation.",
]


def random_token():
    prefix = "TEST"
    letters = "".join(random.choices(string.ascii_uppercase, k=2))
    digits = "".join(random.choices(string.digits, k=2))
    return f"{prefix}-{letters}{digits}"


def get_oxygen(label):
    if label == "NORMAL":
        return random.randint(94, 100)
    elif label == "VIRAL PNEUMONIA":
        return random.randint(90, 96)
    else:
        return random.randint(87, 94)


def get_temp(label):
    if label == "NORMAL":
        return round(random.uniform(97.5, 99.0), 1)
    else:
        return round(random.uniform(98.8, 102.5), 1)


def get_age():
    return random.randint(18, 85)


def get_days():
    return random.randint(1, 10)


def build_alt_emr(label):
    pid = random_token()
    age = f"{get_age()} years old"
    days = get_days()
    temp = get_temp(label)
    oxygen = get_oxygen(label)

    sent_intro = f"Patient {pid}, a {age} individual presented after experiencing symptoms for approximately {days} days."
    sent_vitals = f"Vital measurements include a body temperature of {temp}°F and an oxygen saturation level of {oxygen}%."

    body = [
        sent_intro,
        random.choice(alt_symptoms),
        sent_vitals,
        random.choice(alt_diagnosis),
    ]

    if random.random() < 0.3:
        if label == "COVID":
            body.append("Anosmia has been intermittently observed over recent days.")
        elif label == "VIRAL PNEUMONIA":
            body.append("Radiographic evidence reveals dispersed infiltrative patterns.")
        elif label == "NORMAL":
            body.append("There are currently no active complaints from the patient.")

    # inject 1-2 neutral clinical observations
    if random.random() < 0.9:
        body.insert(random.randint(1, len(body)), random.choice(alt_noise))
    if random.random() < 0.5:
        body.insert(random.randint(1, len(body)), random.choice(alt_noise))

    random.shuffle(body[1:])  # Keep the first sentence intact
    return " ".join(body)


def get_training_image_set():
    if not TRAIN_CSV_PATH.exists():
        raise FileNotFoundError(f"Training CSV not found at {TRAIN_CSV_PATH}")
    with open(TRAIN_CSV_PATH, newline="") as f:
        reader = csv.DictReader(f)
        return set(row["image_path"].strip() for row in reader)


def generate_test_csv():
    training_images = get_training_image_set()
    records = []

    for label in LABELS:
        label_dir = IMAGE_DIR / label
        image_files = sorted([
            f for f in label_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
        ])
        unseen_images = [
            f for f in image_files
            if str(f.relative_to(PROJECT_ROOT)) not in training_images
        ]

        if len(unseen_images) < SAMPLES_PER_CLASS:
            raise ValueError(f"Not enough unseen images in {label_dir}."
                             f"Needed {SAMPLES_PER_CLASS}, found {len(unseen_images)}")
        sampled_images = random.sample(unseen_images, SAMPLES_PER_CLASS)

        for i, img_path in enumerate(sampled_images):
            relative_path = str(img_path.relative_to(PROJECT_ROOT))
            text = build_alt_emr(label)
            triage = triage_map[label]
            records.append([f"{label}-{i + 1}", text, relative_path, triage])

    random.shuffle(records)
    with open(OUTPUT_CSV, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["patient_id", "emr_text", "image_path", "triage_level"])
        writer.writerows(records)

    print(f"✅ test CSV file generated: {OUTPUT_CSV}")
    print(f"📦 Total samples: {len(records)} (10 per class)")


if __name__ == "__main__":
    generate_test_csv()