Spaces:
Sleeping
Sleeping
| import random | |
| import csv | |
| import string | |
| from pathlib import Path | |
| # Constants | |
| SAMPLES_PER_CLASS = 10 | |
| CURRENT_DIR = Path(__file__).resolve().parent | |
| PROJECT_ROOT = CURRENT_DIR.parent | |
| IMAGE_DIR = PROJECT_ROOT / "data" / "images" | |
| TRAIN_CSV_PATH = PROJECT_ROOT / "data" / "emr_records.csv" | |
| OUTPUT_CSV = PROJECT_ROOT / "test_samples.csv" | |
| LABELS = ["COVID", "NORMAL", "VIRAL PNEUMONIA"] | |
| # Labels to triage map | |
| triage_map = {"COVID": "high", "NORMAL": "low", "VIRAL PNEUMONIA": "medium"} | |
| alt_symptoms = [ | |
| "The patient has noted intermittent chest pressure and occasional shortness of breath.", | |
| "A gradual onset of dry cough with mild respiratory discomfort has been documented.", | |
| "Reported complaints include mild fatigue and sporadic episodes of wheezing.", | |
| "Mild respiratory symptoms have progressed over several days.", | |
| "Episodes of throat irritation and general malaise observed.", | |
| ] | |
| alt_diagnosis = [ | |
| "Clinical features are suggestive of a nonspecific viral etiology.", | |
| "Diagnosis remains unclear pending further laboratory confirmation.", | |
| "Preliminary indicators fall into a diagnostic grey area.", | |
| "No definitive pattern observed; further evaluation is warranted.", | |
| "Presentation overlaps multiple pulmonary conditions.", | |
| ] | |
| alt_noise = [ | |
| "Patient remains oriented with stable hemodynamics.", | |
| "No remarkable family history or chronic illness reported.", | |
| "Nutritional intake and sleep patterns appear adequate.", | |
| "No prior admissions or surgical history disclosed.", | |
| "Standard precautions have been advised post-evaluation.", | |
| ] | |
| def random_token(): | |
| prefix = "TEST" | |
| letters = "".join(random.choices(string.ascii_uppercase, k=2)) | |
| digits = "".join(random.choices(string.digits, k=2)) | |
| return f"{prefix}-{letters}{digits}" | |
| def get_oxygen(label): | |
| if label == "NORMAL": | |
| return random.randint(94, 100) | |
| elif label == "VIRAL PNEUMONIA": | |
| return random.randint(90, 96) | |
| else: | |
| return random.randint(87, 94) | |
| def get_temp(label): | |
| if label == "NORMAL": | |
| return round(random.uniform(97.5, 99.0), 1) | |
| else: | |
| return round(random.uniform(98.8, 102.5), 1) | |
| def get_age(): | |
| return random.randint(18, 85) | |
| def get_days(): | |
| return random.randint(1, 10) | |
| def build_alt_emr(label): | |
| pid = random_token() | |
| age = f"{get_age()} years old" | |
| days = get_days() | |
| temp = get_temp(label) | |
| oxygen = get_oxygen(label) | |
| sent_intro = f"Patient {pid}, a {age} individual presented after experiencing symptoms for approximately {days} days." | |
| sent_vitals = f"Vital measurements include a body temperature of {temp}°F and an oxygen saturation level of {oxygen}%." | |
| body = [ | |
| sent_intro, | |
| random.choice(alt_symptoms), | |
| sent_vitals, | |
| random.choice(alt_diagnosis), | |
| ] | |
| if random.random() < 0.3: | |
| if label == "COVID": | |
| body.append("Anosmia has been intermittently observed over recent days.") | |
| elif label == "VIRAL PNEUMONIA": | |
| body.append("Radiographic evidence reveals dispersed infiltrative patterns.") | |
| elif label == "NORMAL": | |
| body.append("There are currently no active complaints from the patient.") | |
| # inject 1-2 neutral clinical observations | |
| if random.random() < 0.9: | |
| body.insert(random.randint(1, len(body)), random.choice(alt_noise)) | |
| if random.random() < 0.5: | |
| body.insert(random.randint(1, len(body)), random.choice(alt_noise)) | |
| random.shuffle(body[1:]) # Keep the first sentence intact | |
| return " ".join(body) | |
| def get_training_image_set(): | |
| if not TRAIN_CSV_PATH.exists(): | |
| raise FileNotFoundError(f"Training CSV not found at {TRAIN_CSV_PATH}") | |
| with open(TRAIN_CSV_PATH, newline="") as f: | |
| reader = csv.DictReader(f) | |
| return set(row["image_path"].strip() for row in reader) | |
| def generate_test_csv(): | |
| training_images = get_training_image_set() | |
| records = [] | |
| for label in LABELS: | |
| label_dir = IMAGE_DIR / label | |
| image_files = sorted([ | |
| f for f in label_dir.glob("*") if f.suffix.lower() in [".png", ".jpg", ".jpeg"] | |
| ]) | |
| unseen_images = [ | |
| f for f in image_files | |
| if str(f.relative_to(PROJECT_ROOT)) not in training_images | |
| ] | |
| if len(unseen_images) < SAMPLES_PER_CLASS: | |
| raise ValueError(f"Not enough unseen images in {label_dir}." | |
| f"Needed {SAMPLES_PER_CLASS}, found {len(unseen_images)}") | |
| sampled_images = random.sample(unseen_images, SAMPLES_PER_CLASS) | |
| for i, img_path in enumerate(sampled_images): | |
| relative_path = str(img_path.relative_to(PROJECT_ROOT)) | |
| text = build_alt_emr(label) | |
| triage = triage_map[label] | |
| records.append([f"{label}-{i + 1}", text, relative_path, triage]) | |
| random.shuffle(records) | |
| with open(OUTPUT_CSV, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["patient_id", "emr_text", "image_path", "triage_level"]) | |
| writer.writerows(records) | |
| print(f"✅ test CSV file generated: {OUTPUT_CSV}") | |
| print(f"📦 Total samples: {len(records)} (10 per class)") | |
| if __name__ == "__main__": | |
| generate_test_csv() | |