import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation import io DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024" processor = SegformerImageProcessor.from_pretrained(MODEL_ID) model = SegformerForSemanticSegmentation.from_pretrained(MODEL_ID).to(DEVICE) model.eval() id2label = { "1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached", "5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools", "9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate", "13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached", "17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead", "21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam", "25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive", "29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead", "33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive", "37": "meandering dead", "38": "crown of thorn", "39": "dead clam" } label2id = {label: int(id_) for id_, label in id2label.items()} label2color = { "human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], "seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134] } LABELS_LIST = ["unlabeled"] + [id2label[str(i)] for i in range(1, 40)] COLORMAP = np.asarray([[0, 0, 0]] + [label2color[label] for label in LABELS_LIST[1:]], dtype=np.uint8) def label_to_color_image(label): if label.ndim != 2: raise ValueError("Expect 2-D input label") if np.max(label) >= len(COLORMAP): raise ValueError("label value too large.") return COLORMAP[label] def draw_plot(pred_img, seg_np): fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis('off') LABEL_NAMES = np.asarray(LABELS_LIST) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg_np.astype("uint8")) ax = plt.subplot(grid_spec[1]) legend_colors = FULL_COLOR_MAP[unique_labels].astype(np.uint8) plt.imshow(legend_colors, interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def run_inference_with_components(input_img: np.ndarray, overlay_opacity: float, focus_class: str): if input_img is None: return None, [] img = Image.fromarray(input_img.astype(np.uint8)) if img.mode != "RGB": img = img.convert("RGB") inputs = processor(images=img, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled = torch.nn.functional.interpolate( logits, size=img.size[::-1], mode="bilinear", align_corners=False ) seg_np = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) color_seg_full = COLORMAP[seg_np] original_img_np = np.array(img) pred_img = (original_img_np * (1 - overlay_opacity) + color_seg_full * overlay_opacity).astype(np.uint8) matplotlib_plot_fig = draw_plot(pred_img, seg_np) buf = io.BytesIO() matplotlib_plot_fig.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) legend_img_pil = Image.open(buf).copy() plt.close(matplotlib_plot_fig) gallery_outputs = [ (original_img_np, "Original Image"), (color_seg_full, "Segmentation Mask (Color Map)"), (pred_img, f"Overlayed Result (Opacity: {overlay_opacity:.2f})"), (legend_img_pil, "Overlay + Legend Plot") ] if focus_class in label2id: focus_class_id = label2id[focus_class] if focus_class_id < len(LABELS_LIST): highlight_mask = (seg_np == focus_class_id) highlight_color = COLORMAP[focus_class_id] highlighted_overlay = original_img_np.copy() highlighted_overlay[highlight_mask] = ( highlighted_overlay[highlight_mask] * (1 - overlay_opacity) + highlight_color * overlay_opacity ).astype(np.uint8) gallery_outputs.append((highlighted_overlay, f"Highlighted: {focus_class}")) return None, gallery_outputs with gr.Blocks(css=".output-image {height: 500px !important;} .gradio-container {max-width: 1200px;}") as demo: gr.Markdown("# Coral Segmentation") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="numpy", label="Input Image", height=300) run_button = gr.Button("Run Segmentation", variant="primary") overlay_opacity_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Overlay Opacity (Mask Blending)", interactive=True ) focus_class_dropdown = gr.Dropdown( choices=LABELS_LIST[1:], value="seagrass", label="Select Class to Highlight", interactive=True ) with gr.Column(scale=2): output_gallery = gr.Gallery( label="Segmentation Results (Images and Legend Plot)", height=600, preview=True, columns=2, object_fit="contain" ) dummy_plot = gr.Plot(visible=False) run_button.click( fn=run_inference_with_components, inputs=[input_image, overlay_opacity_slider, focus_class_dropdown], outputs=[dummy_plot, output_gallery], queue=True ) gr.Examples( examples=[ "coral_sample_1.png", "coral_sample_2.png", "coral_sample_3.png" ], inputs=input_image, outputs=[dummy_plot, output_gallery], fn=run_inference_with_components, cache_examples=False, run_on_click=True ) if __name__ == "__main__": demo.launch()