leonelhs commited on
Commit
12b5ec5
·
1 Parent(s): 1082d5a

change to onnx

Browse files
Files changed (5) hide show
  1. .gitignore +4 -2
  2. README.md +1 -1
  3. app-onnx.py +118 -0
  4. export_onnx.py +32 -0
  5. requirements.txt +10 -2
.gitignore CHANGED
@@ -157,8 +157,10 @@ cython_debug/
157
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
161
 
162
  *.jpeg
163
  *.png
164
- .DS_Store
 
 
 
157
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
 
162
  *.jpeg
163
  *.png
164
+ .DS_Store
165
+
166
+ saved_models/*
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.49.0
8
- app_file: app.py
9
  pinned: false
10
  license: openrail
11
  short_description: Removes background using DIS
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.49.0
8
+ app_file: app-onnx.py
9
  pinned: false
10
  license: openrail
11
  short_description: Removes background using DIS
app-onnx.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################################
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) [2025] [leonelhs@gmail.com]
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+ #
25
+ #######################################################################################
26
+
27
+ # This file implements an API endpoint for DIS background image removal system.
28
+ # [Self space] - [https://huggingface.co/spaces/leonelhs/removebg]
29
+ #
30
+ # Source code is based on or inspired by several projects.
31
+ # For more details and proper attribution, please refer to the following resources:
32
+ #
33
+ # - [DIS] - [https://github.com/xuebinqin/DIS]
34
+ # - [removebg] - [https://huggingface.co/spaces/gaviego/removebg]
35
+ # https://github.com/gaurav0651/dis-bg-remover
36
+
37
+ from itertools import islice
38
+
39
+ import cv2
40
+ import gradio as gr
41
+ import numpy as np
42
+ import onnxruntime as ort
43
+ from PIL import Image
44
+ from huggingface_hub import hf_hub_download
45
+
46
+ REPO_ID = "leonelhs/removators"
47
+
48
+ # Load the ONNX model
49
+ model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.onnx')
50
+
51
+ session = ort.InferenceSession(model_path)
52
+
53
+ def normalize(image, mean, std):
54
+ """Normalize a numpy image with mean and standard deviation."""
55
+ return (image / 255.0 - mean) / std
56
+
57
+ def predict(image_path):
58
+ input_size = (1024, 1024)
59
+
60
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR)
61
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert from BGR to RGB if using OpenCV
62
+
63
+ # If image is grayscale, convert to RGB
64
+ if len(img.shape) == 2:
65
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
66
+
67
+ # Normalize the image using NumPy
68
+ img = img.astype(np.float32) # Convert to float
69
+ im_normalized = normalize(img, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
70
+
71
+ # Resize the image
72
+ img_resized = cv2.resize(im_normalized, input_size, interpolation=cv2.INTER_LINEAR)
73
+ img_resized = np.transpose(img_resized, (2, 0, 1)) # CHW format
74
+ img_resized = np.expand_dims(img_resized, axis=0) # Add batch dimension
75
+
76
+ # Run inference
77
+ img_resized = img_resized.astype(np.float32)
78
+ ort_inputs = {session.get_inputs()[0].name: img_resized}
79
+ prediction = session.run(None, ort_inputs)
80
+
81
+ # Process the model output
82
+ result = prediction[0][0] # Assuming single output and single batch
83
+ result = np.clip(result, 0, 1) # Assuming you want to clip the result to [0, 1]
84
+ result = (result * 255).astype(np.uint8) # Rescale to [0, 255]
85
+ result = np.transpose(result, (1, 2, 0)) # HWC format
86
+ # Resize to original shape
87
+ original_shape = img.shape[:2]
88
+ return cv2.resize(result, (original_shape[1], original_shape[0]), interpolation=cv2.INTER_LINEAR)
89
+
90
+
91
+ def cuts(image):
92
+ mask = predict(image)
93
+ mask = Image.fromarray(mask).convert('L')
94
+ cutted = Image.open(image).convert("RGB")
95
+ cutted.putalpha(mask)
96
+ return [image, cutted], mask
97
+
98
+ with gr.Blocks(title="DIS") as app:
99
+ navbar = gr.Navbar(visible=True, main_page_name="Workspace")
100
+ gr.Markdown("## Dichotomous Image Segmentation")
101
+ with gr.Row():
102
+ with gr.Column(scale=1):
103
+ inp_image = gr.Image(type="filepath", label="Upload Image")
104
+ btn_predict = gr.Button(variant="primary", value="Remove background")
105
+ with gr.Column(scale=2):
106
+ with gr.Row():
107
+ preview = gr.ImageSlider(type="filepath", label="Comparer")
108
+
109
+ btn_predict.click(cuts, inputs=[inp_image], outputs=[preview, inp_image])
110
+
111
+ with app.route("Readme", "/readme"):
112
+ with open("README.md") as f:
113
+ for line in islice(f, 12, None):
114
+ gr.Markdown(line.strip())
115
+
116
+
117
+ app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
118
+ app.queue()
export_onnx.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggingface_hub import hf_hub_download
3
+ from models.isnet import ISNetDIS
4
+
5
+ REPO_ID = "leonelhs/removators"
6
+
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+
9
+ net = ISNetDIS()
10
+
11
+ model_path = hf_hub_download(repo_id=REPO_ID, filename='isnet.pth')
12
+ net.load_state_dict(torch.load(model_path, map_location=device))
13
+ net.to(device)
14
+ net.eval()
15
+
16
+ dummy_input = torch.ones(1, 3, 1024, 1024)
17
+
18
+ # Export the model
19
+ torch.onnx.export(
20
+ net, # model
21
+ dummy_input, # example input
22
+ "linear_model.onnx", # output file
23
+ input_names=["input"], # name inputs
24
+ output_names=["output"], # name outputs
25
+ dynamic_axes={ # allow variable batch size
26
+ "input": {0: "batch_size"},
27
+ "output": {0: "batch_size"}
28
+ },
29
+ opset_version=17 # ONNX version
30
+ )
31
+
32
+ print("Model exported to linear_model.onnx")
requirements.txt CHANGED
@@ -1,2 +1,10 @@
1
- torch>=2.8.0
2
- torchvision>=0.23.0
 
 
 
 
 
 
 
 
 
1
+ # Enable only for pythorch app.py
2
+ # torch>=2.8.0
3
+ # torchvision>=0.23.0
4
+
5
+ # Requirements for app-onnx.py
6
+ numpy==2.2.6
7
+ onnxruntime==1.22.1
8
+ opencv-python-headless==4.12.0.88
9
+ pillow==11.3.0
10
+