mlhw6-task3 / app.py
youngNdum's picture
Update app.py
c660041 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import io
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024"
processor = SegformerImageProcessor.from_pretrained(MODEL_ID)
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()
id2label = {
"1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached",
"5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools",
"9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate",
"13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached",
"17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead",
"21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam",
"25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive",
"29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead",
"33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive",
"37": "meandering dead", "38": "crown of thorn", "39": "dead clam"
}
label2id = {label: int(id_) for id_, label in id2label.items()}
label2color = {
"human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0],
"sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125],
"algae covered substrate": [125, 163, 125], "dark": [31, 31, 31],
"branching bleached": [252, 231, 240], "branching dead": [123, 50, 86],
"branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194],
"pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255],
"table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116],
"millepora": [244, 150, 115], "turbinaria": [228, 255, 119],
"other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61],
"other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21],
"massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228],
"meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14],
"meandering bleached": [251, 243, 216], "transect line": [0, 255, 0],
"transect tools": [8, 205, 12], "sea urchin": [0, 142, 255],
"sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80],
"clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134],
"seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134]
}
LABELS_LIST = ["unlabeled"] + [id2label[str(i)] for i in range(1, 40)]
COLORMAP = np.asarray([[0, 0, 0]] + [label2color[label] for label in LABELS_LIST[1:]], dtype=np.uint8)
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(COLORMAP):
raise ValueError("label value too large.")
return COLORMAP[label]
def draw_plot(pred_img, seg_np):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis('off')
LABEL_NAMES = np.asarray(LABELS_LIST)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg_np.astype("uint8"))
ax = plt.subplot(grid_spec[1])
legend_colors = FULL_COLOR_MAP[unique_labels].astype(np.uint8)
plt.imshow(legend_colors, interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def run_inference_with_components(input_img: np.ndarray, overlay_opacity: float, focus_class: str):
if input_img is None:
return None, []
img = Image.fromarray(input_img.astype(np.uint8))
if img.mode != "RGB":
img = img.convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(
logits,
size=img.size[::-1],
mode="bilinear",
align_corners=False
)
seg_np = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
color_seg_full = COLORMAP[seg_np]
original_img_np = np.array(img)
pred_img = (original_img_np * (1 - overlay_opacity) + color_seg_full * overlay_opacity).astype(np.uint8)
matplotlib_plot_fig = draw_plot(pred_img, seg_np)
buf = io.BytesIO()
matplotlib_plot_fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
legend_img_pil = Image.open(buf).copy()
plt.close(matplotlib_plot_fig)
gallery_outputs = [
(original_img_np, "Original Image"),
(color_seg_full, "Segmentation Mask (Color Map)"),
(pred_img, f"Overlayed Result (Opacity: {overlay_opacity:.2f})"),
(legend_img_pil, "Overlay + Legend Plot")
]
if focus_class in label2id:
focus_class_id = label2id[focus_class]
if focus_class_id < len(LABELS_LIST):
highlight_mask = (seg_np == focus_class_id)
highlight_color = COLORMAP[focus_class_id]
highlighted_overlay = original_img_np.copy()
highlighted_overlay[highlight_mask] = (
highlighted_overlay[highlight_mask] * (1 - overlay_opacity) +
highlight_color * overlay_opacity
).astype(np.uint8)
gallery_outputs.append((highlighted_overlay, f"Highlighted: {focus_class}"))
return None, gallery_outputs
with gr.Blocks(css=".output-image {height: 500px !important;} .gradio-container {max-width: 1200px;}") as demo:
gr.Markdown("# Coral Segmentation")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="numpy", label="Input Image", height=300)
run_button = gr.Button("Run Segmentation", variant="primary")
overlay_opacity_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Overlay Opacity (Mask Blending)",
interactive=True
)
focus_class_dropdown = gr.Dropdown(
choices=LABELS_LIST[1:],
value="seagrass",
label="Select Class to Highlight",
interactive=True
)
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Segmentation Results (Images and Legend Plot)",
height=600,
preview=True,
columns=2,
object_fit="contain"
)
dummy_plot = gr.Plot(visible=False)
run_button.click(
fn=run_inference_with_components,
inputs=[input_image, overlay_opacity_slider, focus_class_dropdown],
outputs=[dummy_plot, output_gallery],
queue=True
)
gr.Examples(
examples=[
"coral_sample_1.png",
"coral_sample_2.png",
"coral_sample_3.png"
],
inputs=input_image,
outputs=[dummy_plot, output_gallery],
fn=run_inference_with_components,
cache_examples=False,
run_on_click=True
)
if __name__ == "__main__":
demo.launch()