Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import matplotlib.gridspec as gridspec | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| import io | |
| DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024" | |
| processor = SegformerImageProcessor.from_pretrained(MODEL_ID) | |
| model = SegformerForSemanticSegmentation.from_pretrained(MODEL_ID).to(DEVICE) | |
| model.eval() | |
| id2label = { | |
| "1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached", | |
| "5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools", | |
| "9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate", | |
| "13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached", | |
| "17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead", | |
| "21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam", | |
| "25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive", | |
| "29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead", | |
| "33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive", | |
| "37": "meandering dead", "38": "crown of thorn", "39": "dead clam" | |
| } | |
| label2id = {label: int(id_) for id_, label in id2label.items()} | |
| label2color = { | |
| "human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], | |
| "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], | |
| "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], | |
| "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], | |
| "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], | |
| "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], | |
| "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], | |
| "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], | |
| "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], | |
| "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], | |
| "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], | |
| "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], | |
| "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], | |
| "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], | |
| "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], | |
| "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], | |
| "seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134] | |
| } | |
| LABELS_LIST = ["unlabeled"] + [id2label[str(i)] for i in range(1, 40)] | |
| COLORMAP = np.asarray([[0, 0, 0]] + [label2color[label] for label in LABELS_LIST[1:]], dtype=np.uint8) | |
| def label_to_color_image(label): | |
| if label.ndim != 2: | |
| raise ValueError("Expect 2-D input label") | |
| if np.max(label) >= len(COLORMAP): | |
| raise ValueError("label value too large.") | |
| return COLORMAP[label] | |
| def draw_plot(pred_img, seg_np): | |
| fig = plt.figure(figsize=(20, 15)) | |
| grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) | |
| plt.subplot(grid_spec[0]) | |
| plt.imshow(pred_img) | |
| plt.axis('off') | |
| LABEL_NAMES = np.asarray(LABELS_LIST) | |
| FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) | |
| FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) | |
| unique_labels = np.unique(seg_np.astype("uint8")) | |
| ax = plt.subplot(grid_spec[1]) | |
| legend_colors = FULL_COLOR_MAP[unique_labels].astype(np.uint8) | |
| plt.imshow(legend_colors, interpolation="nearest") | |
| ax.yaxis.tick_right() | |
| plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) | |
| plt.xticks([], []) | |
| ax.tick_params(width=0.0, labelsize=25) | |
| return fig | |
| def run_inference_with_components(input_img: np.ndarray, overlay_opacity: float, focus_class: str): | |
| if input_img is None: | |
| return None, [] | |
| img = Image.fromarray(input_img.astype(np.uint8)) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| inputs = processor(images=img, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled = torch.nn.functional.interpolate( | |
| logits, | |
| size=img.size[::-1], | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| seg_np = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) | |
| color_seg_full = COLORMAP[seg_np] | |
| original_img_np = np.array(img) | |
| pred_img = (original_img_np * (1 - overlay_opacity) + color_seg_full * overlay_opacity).astype(np.uint8) | |
| matplotlib_plot_fig = draw_plot(pred_img, seg_np) | |
| buf = io.BytesIO() | |
| matplotlib_plot_fig.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| legend_img_pil = Image.open(buf).copy() | |
| plt.close(matplotlib_plot_fig) | |
| gallery_outputs = [ | |
| (original_img_np, "Original Image"), | |
| (color_seg_full, "Segmentation Mask (Color Map)"), | |
| (pred_img, f"Overlayed Result (Opacity: {overlay_opacity:.2f})"), | |
| (legend_img_pil, "Overlay + Legend Plot") | |
| ] | |
| if focus_class in label2id: | |
| focus_class_id = label2id[focus_class] | |
| if focus_class_id < len(LABELS_LIST): | |
| highlight_mask = (seg_np == focus_class_id) | |
| highlight_color = COLORMAP[focus_class_id] | |
| highlighted_overlay = original_img_np.copy() | |
| highlighted_overlay[highlight_mask] = ( | |
| highlighted_overlay[highlight_mask] * (1 - overlay_opacity) + | |
| highlight_color * overlay_opacity | |
| ).astype(np.uint8) | |
| gallery_outputs.append((highlighted_overlay, f"Highlighted: {focus_class}")) | |
| return None, gallery_outputs | |
| with gr.Blocks(css=".output-image {height: 500px !important;} .gradio-container {max-width: 1200px;}") as demo: | |
| gr.Markdown("# Coral Segmentation") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="numpy", label="Input Image", height=300) | |
| run_button = gr.Button("Run Segmentation", variant="primary") | |
| overlay_opacity_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.05, | |
| label="Overlay Opacity (Mask Blending)", | |
| interactive=True | |
| ) | |
| focus_class_dropdown = gr.Dropdown( | |
| choices=LABELS_LIST[1:], | |
| value="seagrass", | |
| label="Select Class to Highlight", | |
| interactive=True | |
| ) | |
| with gr.Column(scale=2): | |
| output_gallery = gr.Gallery( | |
| label="Segmentation Results (Images and Legend Plot)", | |
| height=600, | |
| preview=True, | |
| columns=2, | |
| object_fit="contain" | |
| ) | |
| dummy_plot = gr.Plot(visible=False) | |
| run_button.click( | |
| fn=run_inference_with_components, | |
| inputs=[input_image, overlay_opacity_slider, focus_class_dropdown], | |
| outputs=[dummy_plot, output_gallery], | |
| queue=True | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "coral_sample_1.png", | |
| "coral_sample_2.png", | |
| "coral_sample_3.png" | |
| ], | |
| inputs=input_image, | |
| outputs=[dummy_plot, output_gallery], | |
| fn=run_inference_with_components, | |
| cache_examples=False, | |
| run_on_click=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |