|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoModel |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) |
|
|
model = model.eval().to(device) |
|
|
|
|
|
|
|
|
def calculate_ctr(mask): |
|
|
lungs = np.zeros_like(mask, dtype=np.uint8) |
|
|
lungs[(mask == 1) | (mask == 2)] = 1 |
|
|
heart = (mask == 3).astype("uint8") |
|
|
|
|
|
lung_y, lung_x = np.where(lungs == 1) |
|
|
heart_y, heart_x = np.where(heart == 1) |
|
|
|
|
|
if lung_x.size == 0 or heart_x.size == 0: |
|
|
return None, None, None, None, None |
|
|
|
|
|
thorax_left = int(lung_x.min()) |
|
|
thorax_right = int(lung_x.max()) |
|
|
heart_left = int(heart_x.min()) |
|
|
heart_right = int(heart_x.max()) |
|
|
|
|
|
lung_range = thorax_right - thorax_left |
|
|
heart_range = heart_right - heart_left |
|
|
if lung_range == 0: |
|
|
ctr = None |
|
|
else: |
|
|
ctr = float(heart_range / lung_range) |
|
|
|
|
|
return ctr, thorax_left, thorax_right, heart_left, heart_right |
|
|
|
|
|
|
|
|
def _run_model(image): |
|
|
"""Shared logic: from PIL image -> (img_gray, mask, view_idx, age, female_prob, coords...)""" |
|
|
img = np.array(image.convert("L")) |
|
|
h, w = img.shape[:2] |
|
|
|
|
|
x = model.preprocess(img) |
|
|
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = model(x.to(device)) |
|
|
|
|
|
mask_small = out["mask"].argmax(1)[0].cpu().numpy() |
|
|
mask = cv2.resize(mask_small.astype("uint8"), (w, h), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
view_idx = out["view"].argmax(1).item() |
|
|
age_pred = float(out["age"].item()) |
|
|
female_prob = float(out["female"].item()) |
|
|
|
|
|
ctr, thorax_left, thorax_right, heart_left, heart_right = calculate_ctr(mask) |
|
|
|
|
|
return ( |
|
|
img, |
|
|
mask, |
|
|
h, |
|
|
w, |
|
|
ctr, |
|
|
thorax_left, |
|
|
thorax_right, |
|
|
heart_left, |
|
|
heart_right, |
|
|
view_idx, |
|
|
age_pred, |
|
|
female_prob, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze(image): |
|
|
if image is None: |
|
|
return None, "No image uploaded." |
|
|
|
|
|
( |
|
|
img, |
|
|
mask, |
|
|
h, |
|
|
w, |
|
|
ctr, |
|
|
thorax_left, |
|
|
thorax_right, |
|
|
heart_left, |
|
|
heart_right, |
|
|
view_idx, |
|
|
age_pred, |
|
|
female_prob, |
|
|
) = _run_model(image) |
|
|
|
|
|
color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
overlay = color.copy() |
|
|
overlay[mask == 1] = [0, 255, 0] |
|
|
overlay[mask == 2] = [0, 128, 255] |
|
|
overlay[mask == 3] = [255, 0, 0] |
|
|
blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0) |
|
|
|
|
|
view_map = {0: "AP", 1: "PA", 2: "lateral"} |
|
|
view = view_map.get(view_idx, "unknown") |
|
|
|
|
|
lines = [] |
|
|
if ctr is not None: |
|
|
lines.append(f"CTR: {ctr:.2f}") |
|
|
else: |
|
|
lines.append("CTR: could not be reliably calculated (segmentation issue).") |
|
|
|
|
|
lines.extend([ |
|
|
f"View (model): {view}", |
|
|
f"Predicted age: {age_pred:.0f} years", |
|
|
f"Predicted sex: {'Female' if female_prob >= 0.5 else 'Male'} (prob={female_prob:.2f})", |
|
|
"", |
|
|
"⚠️ Research/educational use only, NOT for clinical decision-making.", |
|
|
]) |
|
|
|
|
|
if view != "PA": |
|
|
lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.") |
|
|
|
|
|
return blended, "\n".join(lines) |
|
|
|
|
|
|
|
|
visual_demo = gr.Interface( |
|
|
fn=analyze, |
|
|
inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"), |
|
|
outputs=[ |
|
|
gr.Image(label="Segmentation overlay"), |
|
|
gr.Textbox(label="AI output"), |
|
|
], |
|
|
title="AI CTR helper (research only)", |
|
|
description=( |
|
|
"Segments heart and lungs and estimates CTR using 'ianpan/chest-x-ray-basic'. " |
|
|
"Research use only." |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_points(image): |
|
|
if image is None: |
|
|
return {"error": "No image uploaded"} |
|
|
|
|
|
( |
|
|
img, |
|
|
mask, |
|
|
h, |
|
|
w, |
|
|
ctr, |
|
|
thorax_left, |
|
|
thorax_right, |
|
|
heart_left, |
|
|
heart_right, |
|
|
view_idx, |
|
|
age_pred, |
|
|
female_prob, |
|
|
) = _run_model(image) |
|
|
|
|
|
result = { |
|
|
"image_width": w, |
|
|
"image_height": h, |
|
|
"ctr": ctr, |
|
|
"thorax_left_px": thorax_left, |
|
|
"thorax_right_px": thorax_right, |
|
|
"heart_left_px": heart_left, |
|
|
"heart_right_px": heart_right, |
|
|
"view_idx": int(view_idx), |
|
|
} |
|
|
return result |
|
|
|
|
|
|
|
|
points_api = gr.Interface( |
|
|
fn=get_points, |
|
|
inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"), |
|
|
outputs=gr.JSON(label="CTR points JSON"), |
|
|
title="CTR points API", |
|
|
description="Returns thorax/heart x-coordinates and CTR as JSON.", |
|
|
api_name="ctr_points", |
|
|
) |
|
|
|
|
|
demo = gr.TabbedInterface( |
|
|
[visual_demo, points_api], |
|
|
["Viewer", "JSON API"], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|