|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from itertools import islice |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from pallete import colormap |
|
|
from labels import cloth_labels, fashion_labels, ADE20k_labels |
|
|
|
|
|
REPO_ID = "leonelhs/segmentators" |
|
|
|
|
|
ADE20k_path = hf_hub_download(repo_id=REPO_ID, filename="segformer/segformer-b5-finetuned-ade-640-640.onnx") |
|
|
fashion_path = hf_hub_download(repo_id=REPO_ID, filename="segformer/segformer-b3-fashion.onnx") |
|
|
clothes_path = hf_hub_download(repo_id=REPO_ID, filename="segformer/segformer_b2_clothes.onnx") |
|
|
|
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.intra_op_num_threads = os.cpu_count() |
|
|
|
|
|
session_ade20k = ort.InferenceSession(fashion_path, sess_options, providers=["CPUExecutionProvider"]) |
|
|
session_cloth = ort.InferenceSession(clothes_path, sess_options, providers=["CPUExecutionProvider"]) |
|
|
session_fashion = ort.InferenceSession(fashion_path, sess_options, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
|
|
|
def predict(input_img, model="ADE20k"): |
|
|
|
|
|
session = session_ade20k |
|
|
labels = ADE20k_labels |
|
|
|
|
|
if model == "Cloth": |
|
|
session = session_cloth |
|
|
labels = cloth_labels |
|
|
elif model == "Fashion": |
|
|
session = session_fashion |
|
|
labels = fashion_labels |
|
|
|
|
|
|
|
|
img = cv2.imread(input_img) |
|
|
img = cv2.resize(img, (640, 640)).astype(np.float32) |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
img_batch = np.expand_dims(img, axis=0) |
|
|
img_batch = np.transpose(img_batch, (0, 3, 1, 2)) |
|
|
|
|
|
inputs = {'input': img_batch} |
|
|
|
|
|
logits = session.run(None, inputs)[0] |
|
|
|
|
|
logits = np.transpose(logits, (0, 2, 3, 1)) |
|
|
segmented_mask = np.argmax(logits, axis=-1)[0].astype("float32") |
|
|
segmented_mask = cv2.resize(segmented_mask, (640, 640)).astype("uint8") |
|
|
|
|
|
parts = [] |
|
|
unique_labels = np.unique(segmented_mask) |
|
|
|
|
|
for label in unique_labels: |
|
|
part = np.where(segmented_mask == label) |
|
|
color_seg = np.full((640, 640, 3), 0, dtype=np.uint8) |
|
|
color_seg[part[0], part[1], :] = colormap[label] |
|
|
color_seg = cv2.cvtColor(color_seg, cv2.COLOR_BGR2GRAY) |
|
|
parts.append((color_seg, labels[label])) |
|
|
|
|
|
return Image.fromarray(img.astype("uint8")), parts |
|
|
|
|
|
with gr.Blocks(title="SegFormer") as app: |
|
|
navbar = gr.Navbar(visible=True, main_page_name="Workspace") |
|
|
gr.Markdown("## SegFormer ONNX") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
inp = gr.Image(type="filepath", label="Upload Image") |
|
|
mod = gr.Dropdown(choices=["ADE20k","Cloth","Fashion"], label="Model generator", value="ADE20k") |
|
|
btn_predict = gr.Button("Parse") |
|
|
with gr.Column(scale=2): |
|
|
out = gr.AnnotatedImage(label="Image parsed annotated") |
|
|
|
|
|
btn_predict.click(predict, inputs=[inp, mod], outputs=[out]) |
|
|
|
|
|
|
|
|
with app.route("About this", "/about"): |
|
|
with open("README.md") as f: |
|
|
for line in islice(f, 12, None): |
|
|
gr.Markdown(line.strip()) |
|
|
|
|
|
app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True) |
|
|
app.queue() |