File size: 5,195 Bytes
54d948e
350af03
 
376c77f
350af03
 
54d948e
 
 
 
 
 
376c77f
350af03
54d948e
3154902
350af03
376c77f
 
350af03
376c77f
 
 
 
 
 
 
 
 
0c8a8ec
9218201
376c77f
 
 
 
 
 
af86b36
8bafee5
 
376c77f
 
 
 
 
 
 
 
e4a5d84
350af03
af86b36
376c77f
 
af86b36
 
376c77f
 
af86b36
0c8a8ec
376c77f
 
 
 
 
 
 
0c8a8ec
af86b36
0c8a8ec
8bafee5
376c77f
8bafee5
376c77f
e4a5d84
af86b36
8bafee5
376c77f
 
af86b36
376c77f
 
8bafee5
af86b36
8bafee5
376c77f
8bafee5
e4a5d84
0c8a8ec
 
376c77f
 
 
 
 
 
 
 
 
af86b36
376c77f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bafee5
0c8a8ec
af86b36
376c77f
54d948e
 
 
 
 
 
 
 
 
 
 
9218201
 
 
 
 
 
 
 
af86b36
716c074
83c4f3c
 
716c074
9218201
 
10fdd1f
9218201
 
 
b136189
9218201
 
 
54d948e
9218201
562137e
9218201
 
54d948e
9218201
 
 
54d948e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import random
import csv
import string
from pathlib import Path

# Detect CI environment
IS_CI = os.getenv("CI", "false").lower() == "true"

# Set number of samples accordingly
SAMPLES_PER_CLASS = 3 if IS_CI else 300  # Reduced for CI to speed up tests

# Paths
CURRENT_DIR = Path(__file__).resolve().parent
IMAGES_DIR = CURRENT_DIR.parent / "data" / "images"  # Absolute path of images folder
OUTPUT_FILE = CURRENT_DIR.parent / "data" / "emr_records.csv"

# Label to triage
triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"}

# Shared ambiguous templates
shared_symptoms = [
    "Mild cough and slight fever reported.",
    "General fatigue and throat irritation present.",
    "Breathing mildly labored during physical exertion.",
    "No major respiratory distress; mild wheezing noted.",
    "Occasional chest tightness reported.",
    "Vital signs mostly stable; slight variation in temperature.",
]

# Overlapping diagnosis clues to add ambiguity
shared_diagnosis = [
    "Symptoms could relate to a range of viral infections.",
    "Presentation not distinctly matching any single infection.",
    "Further tests required to confirm diagnosis.",
    "Findings are borderline; clinical judgment advised.",
    "Observation warranted due to overlapping signs.",
    "Initial assessment inconclusive.",
]

# Noise sentences
neutral_noise = [
    "Patient is cooperative and alert.",
    "Dietary habits unremarkable.",
    "Hydration status normal.",
    "Follow-up advised if symptoms persist.",
    "No notable family medical history.",
    "No medications currently administered.",
]


def random_token():
    prefix = "ID"
    letters = "".join(random.choices(string.ascii_uppercase, k=2))
    digits = "".join(random.choices(string.digits, k=2))
    return f"{prefix}-{letters}{digits}"


def get_oxygen(label):
    # Soft blur across classes
    if label == "NORMAL":
        return random.randint(94, 100)
    elif label == "VIRAL PNEUMONIA":
        return random.randint(90, 96)
    else:
        return random.randint(87, 94)


def get_temp(label):
    if label == "NORMAL":
        return round(random.uniform(97.5, 99.0), 1)
    else:
        return round(random.uniform(98.8, 102.5), 1)


def get_age():
    return random.randint(18, 85)


def get_days():
    return random.randint(1, 10)


def build_emr(label, i):
    pid = random_token()
    age = f"{get_age()}-year-old"
    days = get_days()
    temp = get_temp(label)
    oxygen = get_oxygen(label)

    intro = f"Patient {pid}, a {age}, reports symptoms for {days} days."
    vitals = f"Temperature recorded at {temp}°F and SPO2 at {oxygen}%."

    # Shared symptoms + blurred logic
    body = [
        intro,
        random.choice(shared_symptoms),
        vitals,
        random.choice(shared_diagnosis),
    ]

    # Optionally inject a mild class-specific clue (with low probability)
    if random.random() < 0.3:
        if label == "COVID":
            body.append("Patient reports recent loss of taste.")
        elif label == "VIRAL PNEUMONIA":
            body.append("Chest X-ray shows scattered infiltrates.")
        elif label == "NORMAL":
            body.append("No active complaints at this time.")

    # Inject 1–2 noise sentences
    if random.random() < 0.8:
        body.insert(random.randint(1, len(body)), random.choice(neutral_noise))
    if random.random() < 0.5:
        body.insert(random.randint(1, len(body)), random.choice(neutral_noise))

    random.shuffle(body[1:])  # Keep intro in position 0
    return " ".join(body)


# Generate records
def generate_dataset(image_dir_override=None, output_path_override=None):
    root_image_dir = image_dir_override or IMAGES_DIR
    output_file = output_path_override or OUTPUT_FILE

    # Folders
    categories = {
        "COVID": root_image_dir / "COVID",  # Absolute path of Image labels
        "NORMAL": root_image_dir / "NORMAL",
        "VIRAL PNEUMONIA": root_image_dir / "VIRAL PNEUMONIA",
    }

    records = []
    for label, img_dir in categories.items():
        image_files = sorted(
            [
                f
                for f in img_dir.glob("*")
                if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
            ]
        )
        if not image_files:
            raise FileNotFoundError(
                f"No images found in {img_dir}. Folder contents: {list(img_dir.iterdir())}")

        for i in range(SAMPLES_PER_CLASS):
            image_path = str(
                random.choice(image_files).relative_to(CURRENT_DIR.parent)  # path of image respective to the project root
            )
            text = build_emr(label, i)
            triage = triage_map[label]
            records.append([f"{label}-{i + 1}", image_path, text, triage])

    # Shuffle + write
    random.shuffle(records)
    with open(output_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["patient_id", "image_path", "emr_text", "triage_level"])
        writer.writerows(records)

    print(f"✅ EMR dataset generated at {output_file}")


if __name__ == "__main__":
    generate_dataset(image_dir_override=IMAGES_DIR, output_path_override=OUTPUT_FILE)