AndyRaoTHU commited on
Commit
1388f28
·
1 Parent(s): e7b8988
Files changed (1) hide show
  1. app.py +114 -100
app.py CHANGED
@@ -111,16 +111,16 @@ def draw_optvq(data, code):
111
  image = Image.open(buf)
112
  return image
113
 
114
- # def draw_process(x, y, std):
115
- # data = torch.randn(N_data, dim)
116
- # code = torch.randn(N_code, dim) * std
117
- # code[:, 0] += x
118
- # code[:, 1] += y
119
 
120
- # image_NN = draw_NN(data, code)
121
- # image_optvq = draw_optvq(data, code)
122
 
123
- # return image_NN, image_optvq
124
 
125
  class Handler:
126
  def __init__(self, device):
@@ -172,14 +172,14 @@ class Handler:
172
  quant, *_ = self.vqgan.encode(img)
173
  vqgan_rec = self.vqgan.decode(quant)
174
  # revq
175
- # lat = self.vae.encode(img).latent
176
- # lat = lat.contiguous()
177
- # lat = self.preprocesser(lat)
178
- # lat = self.revq.quantize(lat)
179
- # revq_rec = self.revq.decode(lat)
180
- # revq_rec = revq_rec.contiguous()
181
- # revq_rec = self.preprocesser.inverse(revq_rec)
182
- # revq_rec = self.vae.decode(revq_rec).sample
183
  # optvq
184
  quant, *_ = self.optvq.encode(img)
185
  optvq_rec = self.optvq.decode(quant)
@@ -188,92 +188,106 @@ class Handler:
188
  img = self.tensor_to_image(img)
189
  basevq_rec = self.tensor_to_image(basevq_rec)
190
  vqgan_rec = self.tensor_to_image(vqgan_rec)
191
- # revq_rec = self.tensor_to_image(revq_rec)
192
  optvq_rec = self.tensor_to_image(optvq_rec)
193
  # print("Shapes:", img.shape, basevq_rec.shape, vqgan_rec.shape, revq_rec.shape)
194
- # return img, basevq_rec, vqgan_rec, revq_rec
195
- return basevq_rec, vqgan_rec, optvq_rec
196
 
197
- def draw_process(x, y, std):
198
- img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8)
199
- return img, img
200
-
201
- demo2 = gr.Interface(
202
- fn=draw_process,
203
- inputs=[
204
- gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1),
205
- gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1),
206
- gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
207
- ],
208
- outputs=[
209
- gr.Image(label="NN", type="numpy"),
210
- gr.Image(label="OptVQ", type="numpy")
211
- ],
212
- title="Demo 2: 2D Matching Visualization",
213
- description="Visualize nearest neighbor vs. optimal transport matching for synthetic 2D data."
214
- )
215
-
216
- demo2.launch()
217
-
218
- # if __name__ == "__main__":
219
- # # create the model handler
220
- # # handler = Handler(device=device)
221
-
222
- # print("Creating Gradio interface...")
223
-
224
- # demo2 = gr.Interface(
225
  # fn=draw_process,
226
- # inputs=[
227
- # gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1),
228
- # gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1),
229
- # gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
230
- # ],
231
- # outputs=[
232
- # gr.Image(label="NN", type="numpy"),
233
- # gr.Image(label="OptVQ", type="numpy")
234
- # ],
235
- # title="Demo 2: 2D Matching Visualization",
236
- # description="Visualize nearest neighbor vs. optimal transport matching for synthetic 2D data."
237
- # )
238
-
239
- # # 合并两个 interface 成 Tabbed UI
240
- # demo = gr.TabbedInterface(
241
- # interface_list=[demo2],
242
- # tab_names=["2D Matching"]
243
- # )
244
-
245
- # demo.launch()
246
-
247
- # # create the interface
248
- # # with gr.Blocks() as demo:
249
- # # gr.Textbox(value="This demo shows the image reconstruction comparison between ReVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
250
- # # with gr.Row():
251
- # # with gr.Column():
252
- # # image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
253
- # # btn_demo1 = gr.Button(value="Run reconstruction")
254
- # # image_basevq = gr.Image(label="BaseVQ rec.")
255
- # # image_vqgan = gr.Image(label="VQGAN rec.")
256
- # # image_revq = gr.Image(label="ReVQ rec.")
257
- # # btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_basevq, image_vqgan, image_revq])
258
-
259
- # # gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
260
- # # gr.Markdown("### Demo 2: 2D visualizations of matching results\n"
261
- # # "This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. "
262
- # # "The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.")
263
- # # with gr.Row():
264
- # # with gr.Column():
265
- # # input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
266
- # # input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
267
- # # input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
268
- # # btn_demo2 = gr.Button(value="Run 2D example")
269
- # # output_nn = gr.Image(label="NN", interactive=False, type="numpy")
270
- # # output_optvq = gr.Image(label="OptVQ", interactive=False, type="numpy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- # # # set the function
273
- # # input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
274
- # # input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
275
- # # input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
276
- # # btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
277
- # # btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
278
-
279
- # # demo.launch()
 
111
  image = Image.open(buf)
112
  return image
113
 
114
+ def draw_process(x, y, std):
115
+ data = torch.randn(N_data, dim)
116
+ code = torch.randn(N_code, dim) * std
117
+ code[:, 0] += x
118
+ code[:, 1] += y
119
 
120
+ image_NN = draw_NN(data, code)
121
+ image_optvq = draw_optvq(data, code)
122
 
123
+ return image_NN, image_optvq
124
 
125
  class Handler:
126
  def __init__(self, device):
 
172
  quant, *_ = self.vqgan.encode(img)
173
  vqgan_rec = self.vqgan.decode(quant)
174
  # revq
175
+ lat = self.vae.encode(img).latent
176
+ lat = lat.contiguous()
177
+ lat = self.preprocesser(lat)
178
+ lat = self.revq.quantize(lat)
179
+ revq_rec = self.revq.decode(lat)
180
+ revq_rec = revq_rec.contiguous()
181
+ revq_rec = self.preprocesser.inverse(revq_rec)
182
+ revq_rec = self.vae.decode(revq_rec).sample
183
  # optvq
184
  quant, *_ = self.optvq.encode(img)
185
  optvq_rec = self.optvq.decode(quant)
 
188
  img = self.tensor_to_image(img)
189
  basevq_rec = self.tensor_to_image(basevq_rec)
190
  vqgan_rec = self.tensor_to_image(vqgan_rec)
191
+ revq_rec = self.tensor_to_image(revq_rec)
192
  optvq_rec = self.tensor_to_image(optvq_rec)
193
  # print("Shapes:", img.shape, basevq_rec.shape, vqgan_rec.shape, revq_rec.shape)
194
+ return basevq_rec, vqgan_rec, revq_rec
195
+ # return basevq_rec, vqgan_rec, optvq_rec
196
 
197
+ # def draw_process(x, y, std):
198
+ # img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8)
199
+ # return img, img
200
+
201
+ # demo2 = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # fn=draw_process,
203
+ # inputs=[
204
+ # gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1),
205
+ # gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1),
206
+ # gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
207
+ # ],
208
+ # outputs=[
209
+ # gr.Image(label="NN", type="numpy"),
210
+ # gr.Image(label="OptVQ", type="numpy")
211
+ # ],
212
+ # title="Demo 2: 2D Matching Visualization",
213
+ # description="Visualize nearest neighbor vs. optimal transport matching for synthetic 2D data."
214
+ # )
215
+
216
+ # demo2.launch()
217
+
218
+ if __name__ == "__main__":
219
+ # create the model handler
220
+ handler = Handler(device=device)
221
+
222
+ print("Creating Gradio interface...")
223
+
224
+ # Demo 1 接口:图像重建
225
+ demo1 = gr.Interface(
226
+ fn=handler.process_image,
227
+ inputs=gr.Image(label="Input Image", type="numpy"),
228
+ outputs=[
229
+ gr.Image(label="BaseVQ Reconstruction", type="numpy"),
230
+ gr.Image(label="VQGAN Reconstruction", type="numpy"),
231
+ gr.Image(label="ReVQ Reconstruction", type="numpy"),
232
+ # 若启用 ReVQ:gr.Image(label="ReVQ Reconstruction", type="numpy"),
233
+ ],
234
+ title="Demo 1: Image Reconstruction",
235
+ description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes."
236
+ )
237
+
238
+ demo2 = gr.Interface(
239
+ fn=draw_process,
240
+ inputs=[
241
+ gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1),
242
+ gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1),
243
+ gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
244
+ ],
245
+ outputs=[
246
+ gr.Image(label="NN", type="numpy"),
247
+ gr.Image(label="OptVQ", type="numpy")
248
+ ],
249
+ title="Demo 2: 2D Matching Visualization",
250
+ description="Visualize nearest neighbor vs. optimal transport matching for synthetic 2D data."
251
+ )
252
+
253
+ # 合并两个 interface 成 Tabbed UI
254
+ demo = gr.TabbedInterface(
255
+ interface_list=[demo1, demo2],
256
+ tab_names=["Image Reconstruction", "2D Matching"]
257
+ )
258
+
259
+ demo.launch(share=True)
260
+
261
+ # create the interface
262
+ # with gr.Blocks() as demo:
263
+ # gr.Textbox(value="This demo shows the image reconstruction comparison between ReVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
264
+ # with gr.Row():
265
+ # with gr.Column():
266
+ # image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
267
+ # btn_demo1 = gr.Button(value="Run reconstruction")
268
+ # image_basevq = gr.Image(label="BaseVQ rec.")
269
+ # image_vqgan = gr.Image(label="VQGAN rec.")
270
+ # image_revq = gr.Image(label="ReVQ rec.")
271
+ # btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_basevq, image_vqgan, image_revq])
272
+
273
+ # gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
274
+ # gr.Markdown("### Demo 2: 2D visualizations of matching results\n"
275
+ # "This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. "
276
+ # "The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.")
277
+ # with gr.Row():
278
+ # with gr.Column():
279
+ # input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
280
+ # input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
281
+ # input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
282
+ # btn_demo2 = gr.Button(value="Run 2D example")
283
+ # output_nn = gr.Image(label="NN", interactive=False, type="numpy")
284
+ # output_optvq = gr.Image(label="OptVQ", interactive=False, type="numpy")
285
 
286
+ # # set the function
287
+ # input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
288
+ # input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
289
+ # input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
290
+ # btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
291
+ # btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
292
+
293
+ # demo.launch()