File size: 4,897 Bytes
d6fc487 dab03bb d6fc487 dab03bb d6fc487 74dff36 dab03bb 74dff36 f4d2612 74dff36 f4d2612 74dff36 f4d2612 74dff36 f4d2612 74dff36 f4d2612 dab03bb f4d2612 74dff36 dab03bb 74dff36 f4d2612 74dff36 dab03bb f4d2612 dab03bb 74dff36 f4d2612 74dff36 f4d2612 dab03bb 74dff36 dab03bb f4d2612 dab03bb f4d2612 74dff36 dab03bb f4d2612 dab03bb 74dff36 f4d2612 dab03bb f4d2612 dab03bb f4d2612 dab03bb f4d2612 dab03bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import gradio as gr
import cv2
import numpy as np
import torch
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
model = model.eval().to(device)
def calculate_ctr(mask):
lungs = np.zeros_like(mask, dtype=np.uint8)
lungs[(mask == 1) | (mask == 2)] = 1
heart = (mask == 3).astype("uint8")
lung_y, lung_x = np.where(lungs == 1)
heart_y, heart_x = np.where(heart == 1)
if lung_x.size == 0 or heart_x.size == 0:
return None, None, None, None, None
thorax_left = int(lung_x.min())
thorax_right = int(lung_x.max())
heart_left = int(heart_x.min())
heart_right = int(heart_x.max())
lung_range = thorax_right - thorax_left
heart_range = heart_right - heart_left
if lung_range == 0:
ctr = None
else:
ctr = float(heart_range / lung_range)
return ctr, thorax_left, thorax_right, heart_left, heart_right
def _run_model(image):
"""Shared logic: from PIL image -> (img_gray, mask, view_idx, age, female_prob, coords...)"""
img = np.array(image.convert("L"))
h, w = img.shape[:2]
x = model.preprocess(img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float()
with torch.inference_mode():
out = model(x.to(device))
mask_small = out["mask"].argmax(1)[0].cpu().numpy()
mask = cv2.resize(mask_small.astype("uint8"), (w, h), interpolation=cv2.INTER_NEAREST)
view_idx = out["view"].argmax(1).item()
age_pred = float(out["age"].item())
female_prob = float(out["female"].item())
ctr, thorax_left, thorax_right, heart_left, heart_right = calculate_ctr(mask)
return (
img,
mask,
h,
w,
ctr,
thorax_left,
thorax_right,
heart_left,
heart_right,
view_idx,
age_pred,
female_prob,
)
# ---------- 1) Visual demo (what you already have) ----------
def analyze(image):
if image is None:
return None, "No image uploaded."
(
img,
mask,
h,
w,
ctr,
thorax_left,
thorax_right,
heart_left,
heart_right,
view_idx,
age_pred,
female_prob,
) = _run_model(image)
color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
overlay = color.copy()
overlay[mask == 1] = [0, 255, 0]
overlay[mask == 2] = [0, 128, 255]
overlay[mask == 3] = [255, 0, 0]
blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0)
view_map = {0: "AP", 1: "PA", 2: "lateral"}
view = view_map.get(view_idx, "unknown")
lines = []
if ctr is not None:
lines.append(f"CTR: {ctr:.2f}")
else:
lines.append("CTR: could not be reliably calculated (segmentation issue).")
lines.extend([
f"View (model): {view}",
f"Predicted age: {age_pred:.0f} years",
f"Predicted sex: {'Female' if female_prob >= 0.5 else 'Male'} (prob={female_prob:.2f})",
"",
"⚠️ Research/educational use only, NOT for clinical decision-making.",
])
if view != "PA":
lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.")
return blended, "\n".join(lines)
visual_demo = gr.Interface(
fn=analyze,
inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
outputs=[
gr.Image(label="Segmentation overlay"),
gr.Textbox(label="AI output"),
],
title="AI CTR helper (research only)",
description=(
"Segments heart and lungs and estimates CTR using 'ianpan/chest-x-ray-basic'. "
"Research use only."
),
)
# ---------- 2) JSON points API (for your Lovable app) ----------
def get_points(image):
if image is None:
return {"error": "No image uploaded"}
(
img,
mask,
h,
w,
ctr,
thorax_left,
thorax_right,
heart_left,
heart_right,
view_idx,
age_pred,
female_prob,
) = _run_model(image)
result = {
"image_width": w,
"image_height": h,
"ctr": ctr,
"thorax_left_px": thorax_left,
"thorax_right_px": thorax_right,
"heart_left_px": heart_left,
"heart_right_px": heart_right,
"view_idx": int(view_idx),
}
return result
points_api = gr.Interface(
fn=get_points,
inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"),
outputs=gr.JSON(label="CTR points JSON"),
title="CTR points API",
description="Returns thorax/heart x-coordinates and CTR as JSON.",
api_name="ctr_points", # important for programmatic calls
)
demo = gr.TabbedInterface(
[visual_demo, points_api],
["Viewer", "JSON API"],
)
if __name__ == "__main__":
demo.launch()
|