Spaces:
Sleeping
Sleeping
File size: 5,262 Bytes
9d9cc25 f294e5b 9d9cc25 c27b2d2 9d9cc25 c27b2d2 9d9cc25 c27b2d2 9d9cc25 c27b2d2 9d9cc25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import random
import csv
import string
from pathlib import Path
# Constants
SAMPLES_PER_CLASS = 10
CURRENT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = CURRENT_DIR.parent
IMAGE_DIR = PROJECT_ROOT / "data" / "images"
TRAIN_CSV_PATH = PROJECT_ROOT / "data" / "emr_records.csv"
OUTPUT_CSV = PROJECT_ROOT / "test_samples.csv"
LABELS = ["COVID", "NORMAL", "VIRAL PNEUMONIA"]
# Labels to triage map
triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"}
alt_symptoms = [
"The patient has noted intermittent chest pressure and occasional shortness of breath.",
"A gradual onset of dry cough with mild respiratory discomfort has been documented.",
"Reported complaints include mild fatigue and sporadic episodes of wheezing.",
"Mild respiratory symptoms have progressed over several days.",
"Episodes of throat irritation and general malaise observed.",
]
alt_diagnosis = [
"Clinical features are suggestive of a nonspecific viral etiology.",
"Diagnosis remains unclear pending further laboratory confirmation.",
"Preliminary indicators fall into a diagnostic grey area.",
"No definitive pattern observed; further evaluation is warranted.",
"Presentation overlaps multiple pulmonary conditions.",
]
alt_noise = [
"Patient remains oriented with stable hemodynamics.",
"No remarkable family history or chronic illness reported.",
"Nutritional intake and sleep patterns appear adequate.",
"No prior admissions or surgical history disclosed.",
"Standard precautions have been advised post-evaluation.",
]
def random_token():
prefix = "TEST"
letters = "".join(random.choices(string.ascii_uppercase, k=2))
digits = "".join(random.choices(string.digits, k=2))
return f"{prefix}-{letters}{digits}"
def get_oxygen(label):
if label == "NORMAL":
return random.randint(94, 100)
elif label == "VIRAL PNEUMONIA":
return random.randint(90, 96)
else:
return random.randint(87, 94)
def get_temp(label):
if label == "NORMAL":
return round(random.uniform(97.5, 99.0), 1)
else:
return round(random.uniform(98.8, 102.5), 1)
def get_age():
return random.randint(18, 85)
def get_days():
return random.randint(1, 10)
def build_alt_emr(label):
pid = random_token()
age = f"{get_age()} years old"
days = get_days()
temp = get_temp(label)
oxygen = get_oxygen(label)
sent_intro = f"Patient {pid}, a {age} individual presented after experiencing symptoms for approximately {days} days."
sent_vitals = f"Vital measurements include a body temperature of {temp}°F and an oxygen saturation level of {oxygen}%."
body = [
sent_intro,
random.choice(alt_symptoms),
sent_vitals,
random.choice(alt_diagnosis),
]
if random.random() < 0.3:
if label == "COVID":
body.append("Anosmia has been intermittently observed over recent days.")
elif label == "VIRAL PNEUMONIA":
body.append("Radiographic evidence reveals dispersed infiltrative patterns.")
elif label == "NORMAL":
body.append("There are currently no active complaints from the patient.")
# inject 1-2 neutral clinical observations
if random.random() < 0.9:
body.insert(random.randint(1, len(body)), random.choice(alt_noise))
if random.random() < 0.5:
body.insert(random.randint(1, len(body)), random.choice(alt_noise))
random.shuffle(body[1:]) # Keep the first sentence intact
return " ".join(body)
def get_training_image_set():
if not TRAIN_CSV_PATH.exists():
raise FileNotFoundError(f"Training CSV not found at {TRAIN_CSV_PATH}")
with open(TRAIN_CSV_PATH, newline="") as f:
reader = csv.DictReader(f)
return set(row["image_path"].strip() for row in reader)
def generate_test_csv():
training_images = get_training_image_set()
records = []
for label in LABELS:
label_dir = IMAGE_DIR / label
image_files = sorted([
f for f in label_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
])
unseen_images = [
f for f in image_files
if str(f.relative_to(PROJECT_ROOT)) not in training_images
]
if len(unseen_images) < SAMPLES_PER_CLASS:
raise ValueError(f"Not enough unseen images in {label_dir}."
f"Needed {SAMPLES_PER_CLASS}, found {len(unseen_images)}")
sampled_images = random.sample(unseen_images, SAMPLES_PER_CLASS)
for i, img_path in enumerate(sampled_images):
relative_path = str(img_path.relative_to(PROJECT_ROOT))
text = build_alt_emr(label)
triage = triage_map[label]
records.append([f"{label}-{i + 1}", text, relative_path, triage])
random.shuffle(records)
with open(OUTPUT_CSV, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["patient_id", "emr_text", "image_path", "triage_level"])
writer.writerows(records)
print(f"✅ test CSV file generated: {OUTPUT_CSV}")
print(f"📦 Total samples: {len(records)} (10 per class)")
if __name__ == "__main__":
generate_test_csv()
|