radctr-ai / app.py
n0wm3's picture
Update app.py
f4d2612 verified
import gradio as gr
import cv2
import numpy as np
import torch
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
model = model.eval().to(device)
def calculate_ctr(mask):
lungs = np.zeros_like(mask, dtype=np.uint8)
lungs[(mask == 1) | (mask == 2)] = 1
heart = (mask == 3).astype("uint8")
lung_y, lung_x = np.where(lungs == 1)
heart_y, heart_x = np.where(heart == 1)
if lung_x.size == 0 or heart_x.size == 0:
return None, None, None, None, None
thorax_left = int(lung_x.min())
thorax_right = int(lung_x.max())
heart_left = int(heart_x.min())
heart_right = int(heart_x.max())
lung_range = thorax_right - thorax_left
heart_range = heart_right - heart_left
if lung_range == 0:
ctr = None
else:
ctr = float(heart_range / lung_range)
return ctr, thorax_left, thorax_right, heart_left, heart_right
def _run_model(image):
"""Shared logic: from PIL image -> (img_gray, mask, view_idx, age, female_prob, coords...)"""
img = np.array(image.convert("L"))
h, w = img.shape[:2]
x = model.preprocess(img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float()
with torch.inference_mode():
out = model(x.to(device))
mask_small = out["mask"].argmax(1)[0].cpu().numpy()
mask = cv2.resize(mask_small.astype("uint8"), (w, h), interpolation=cv2.INTER_NEAREST)
view_idx = out["view"].argmax(1).item()
age_pred = float(out["age"].item())
female_prob = float(out["female"].item())
ctr, thorax_left, thorax_right, heart_left, heart_right = calculate_ctr(mask)
return (
img,
mask,
h,
w,
ctr,
thorax_left,
thorax_right,
heart_left,
heart_right,
view_idx,
age_pred,
female_prob,
)
# ---------- 1) Visual demo (what you already have) ----------
def analyze(image):
if image is None:
return None, "No image uploaded."
(
img,
mask,
h,
w,
ctr,
thorax_left,
thorax_right,
heart_left,
heart_right,
view_idx,
age_pred,
female_prob,
) = _run_model(image)
color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
overlay = color.copy()
overlay[mask == 1] = [0, 255, 0]
overlay[mask == 2] = [0, 128, 255]
overlay[mask == 3] = [255, 0, 0]
blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0)
view_map = {0: "AP", 1: "PA", 2: "lateral"}
view = view_map.get(view_idx, "unknown")
lines = []
if ctr is not None:
lines.append(f"CTR: {ctr:.2f}")
else:
lines.append("CTR: could not be reliably calculated (segmentation issue).")
lines.extend([
f"View (model): {view}",
f"Predicted age: {age_pred:.0f} years",
f"Predicted sex: {'Female' if female_prob >= 0.5 else 'Male'} (prob={female_prob:.2f})",
"",
"⚠️ Research/educational use only, NOT for clinical decision-making.",
])
if view != "PA":
lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.")
return blended, "\n".join(lines)
visual_demo = gr.Interface(
fn=analyze,
inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
outputs=[
gr.Image(label="Segmentation overlay"),
gr.Textbox(label="AI output"),
],
title="AI CTR helper (research only)",
description=(
"Segments heart and lungs and estimates CTR using 'ianpan/chest-x-ray-basic'. "
"Research use only."
),
)
# ---------- 2) JSON points API (for your Lovable app) ----------
def get_points(image):
if image is None:
return {"error": "No image uploaded"}
(
img,
mask,
h,
w,
ctr,
thorax_left,
thorax_right,
heart_left,
heart_right,
view_idx,
age_pred,
female_prob,
) = _run_model(image)
result = {
"image_width": w,
"image_height": h,
"ctr": ctr,
"thorax_left_px": thorax_left,
"thorax_right_px": thorax_right,
"heart_left_px": heart_left,
"heart_right_px": heart_right,
"view_idx": int(view_idx),
}
return result
points_api = gr.Interface(
fn=get_points,
inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
outputs=gr.JSON(label="CTR points JSON"),
title="CTR points API",
description="Returns thorax/heart x-coordinates and CTR as JSON.",
api_name="ctr_points", # important for programmatic calls
)
demo = gr.TabbedInterface(
[visual_demo, points_api],
["Viewer", "JSON API"],
)
if __name__ == "__main__":
demo.launch()