AndyRaoTHU commited on
Commit
4359005
·
1 Parent(s): 57b27d9
Files changed (1) hide show
  1. app.py +23 -61
app.py CHANGED
@@ -23,9 +23,6 @@ from revq.models.vqgan_hf import VQModelHF
23
  from diffusers import AutoencoderDC
24
 
25
  #################
26
- N_data = 50
27
- N_code = 20
28
- dim = 2
29
  handler = None
30
  device = torch.device("cpu")
31
  #################
@@ -54,10 +51,13 @@ def get_codebook(quantizer):
54
  codes = quantizer.embeddings.squeeze().detach()
55
  return codes
56
 
57
- def draw_fig(ax, quantizer, data, title=""):
58
  codes = get_codebook(quantizer)
59
  ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
60
- ax.scatter(codes[:, 0], codes[:, 1], s=40, c='red', alpha=0.5)
 
 
 
61
  ax.set_xlim(-5, 10)
62
  ax.set_ylim(-10, 5)
63
  ax.tick_params(axis='x', labelsize=22)
@@ -83,8 +83,8 @@ def draw_reset_result(num_data=16, num_code=12):
83
  optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
84
  quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False)
85
  optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
86
- draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
87
- draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
88
  ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
89
  ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
90
 
@@ -109,10 +109,10 @@ def draw_reset_result(num_data=16, num_code=12):
109
 
110
  if (i+1) in i_list:
111
  count += 1
112
- draw_fig(ax_reset[count], quantizer, data, title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
113
  draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy())
114
 
115
- draw_fig(ax_nreset[count], quantizer_nreset, data, title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}")
116
  draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy())
117
 
118
  quantizer.reset()
@@ -123,11 +123,10 @@ def draw_reset_result(num_data=16, num_code=12):
123
  img_reset = fig_to_array(fig_reset)
124
  img_nreset = fig_to_array(fig_nreset)
125
 
126
- return img_reset, img_nreset
127
 
128
  # end
129
 
130
-
131
  # ReVQ: for multi-group
132
  def get_codebook_v2(quantizer):
133
  with torch.no_grad():
@@ -141,10 +140,13 @@ def get_codebook_v2(quantizer):
141
  codes = torch.cartesian_prod(group1, group2)
142
  return codes
143
 
144
- def draw_fig_v2(ax, quantizer, data, title=""):
145
  codes = get_codebook_v2(quantizer)
146
  ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
147
- ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5)
 
 
 
148
  ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2)
149
  ax.set_xlim(-12, 12)
150
  ax.set_ylim(-12, 12)
@@ -166,8 +168,8 @@ def draw_multi_group_result(num_data=16, num_code=12):
166
  optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1)
167
  quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2)
168
  optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1)
169
- draw_fig_v2(ax_s[0], quantizer_s, data, title=f"Initialization")
170
- draw_fig_v2(ax_m[0], quantizer_m, data, title=f"Initialization")
171
  ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
172
  ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
173
  i_list = [5, 20, 50, 200, 1000]
@@ -187,8 +189,8 @@ def draw_multi_group_result(num_data=16, num_code=12):
187
 
188
  if (i+1) in i_list:
189
  count += 1
190
- draw_fig_v2(ax_s[count], quantizer_s, data, title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}")
191
- draw_fig_v2(ax_m[count], quantizer_m, data, title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}")
192
 
193
  quantizer_s.reset()
194
  quantizer_m.reset()
@@ -204,7 +206,6 @@ def draw_multi_group_result(num_data=16, num_code=12):
204
  # end
205
 
206
  # ReVQ: for image reconstruction
207
-
208
  class Handler:
209
  def __init__(self, device):
210
  self.transform = T.Compose([
@@ -293,7 +294,7 @@ if __name__ == "__main__":
293
 
294
  with gr.Blocks() as demo2:
295
  gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
296
- gr.Markdown("Visualizes codebook and data movement at different training steps.")
297
 
298
  with gr.Row():
299
  num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
@@ -302,15 +303,15 @@ if __name__ == "__main__":
302
  submit_btn = gr.Button("Run Visualization")
303
 
304
  with gr.Column(): # 垂直输出
305
- out_with_reset = gr.Image(label="With Reset")
306
  out_without_reset = gr.Image(label="Without Reset")
 
307
 
308
- submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_with_reset, out_without_reset])
309
 
310
 
311
  with gr.Blocks() as demo3:
312
  gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
313
- gr.Markdown("Visualizes codebook and data movement at different training steps with multi-group strategy.")
314
 
315
  with gr.Row():
316
  num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
@@ -324,48 +325,9 @@ if __name__ == "__main__":
324
 
325
  submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
326
 
327
- # 合并两个 interface 成 Tabbed UI
328
- # demo = gr.TabbedInterface(
329
- # interface_list=[demo1, demo2],
330
- # tab_names=["Image Reconstruction", "Reset Strategy"]
331
- # )
332
  demo = gr.TabbedInterface(
333
  interface_list=[demo1, demo2, demo3],
334
  tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
335
  )
336
 
337
  demo.launch(share=True)
338
-
339
- # create the interface
340
- # with gr.Blocks() as demo:
341
- # gr.Textbox(value="This demo shows the image reconstruction comparison between ReVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
342
- # with gr.Row():
343
- # with gr.Column():
344
- # image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
345
- # btn_demo1 = gr.Button(value="Run reconstruction")
346
- # image_basevq = gr.Image(label="BaseVQ rec.")
347
- # image_vqgan = gr.Image(label="VQGAN rec.")
348
- # image_revq = gr.Image(label="ReVQ rec.")
349
- # btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_basevq, image_vqgan, image_revq])
350
-
351
- # gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
352
- # gr.Markdown("### Demo 2: 2D visualizations of matching results\n"
353
- # "This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. "
354
- # "The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.")
355
- # with gr.Row():
356
- # with gr.Column():
357
- # input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
358
- # input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
359
- # input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
360
- # btn_demo2 = gr.Button(value="Run 2D example")
361
- # output_nn = gr.Image(label="NN", interactive=False, type="numpy")
362
- # output_optvq = gr.Image(label="OptVQ", interactive=False, type="numpy")
363
-
364
- # # set the function
365
- # input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
366
- # input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
367
- # input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
368
- # btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
369
- # btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
370
-
371
- # demo.launch()
 
23
  from diffusers import AutoencoderDC
24
 
25
  #################
 
 
 
26
  handler = None
27
  device = torch.device("cpu")
28
  #################
 
51
  codes = quantizer.embeddings.squeeze().detach()
52
  return codes
53
 
54
+ def draw_fig(ax, quantizer, data, color="r", title=""):
55
  codes = get_codebook(quantizer)
56
  ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
57
+ if color == "r":
58
+ ax.scatter(codes[:, 0], codes[:, 1], s=40, c='red', alpha=0.5)
59
+ else:
60
+ ax.scatter(codes[:, 0], codes[:, 1], s=40, c='green', alpha=0.5)
61
  ax.set_xlim(-5, 10)
62
  ax.set_ylim(-10, 5)
63
  ax.tick_params(axis='x', labelsize=22)
 
83
  optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
84
  quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False)
85
  optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
86
+ draw_fig(ax_reset[0], quantizer, data, color='g', title=f"Initialization")
87
+ draw_fig(ax_nreset[0], quantizer_nreset, data, color='r', title=f"Initialization")
88
  ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
89
  ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
90
 
 
109
 
110
  if (i+1) in i_list:
111
  count += 1
112
+ draw_fig(ax_reset[count], quantizer, data, color='g', title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
113
  draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy())
114
 
115
+ draw_fig(ax_nreset[count], quantizer_nreset, data, color='r', title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}")
116
  draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy())
117
 
118
  quantizer.reset()
 
123
  img_reset = fig_to_array(fig_reset)
124
  img_nreset = fig_to_array(fig_nreset)
125
 
126
+ return img_nreset, img_reset
127
 
128
  # end
129
 
 
130
  # ReVQ: for multi-group
131
  def get_codebook_v2(quantizer):
132
  with torch.no_grad():
 
140
  codes = torch.cartesian_prod(group1, group2)
141
  return codes
142
 
143
+ def draw_fig_v2(ax, quantizer, data, color='r', title=""):
144
  codes = get_codebook_v2(quantizer)
145
  ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
146
+ if color == "r":
147
+ ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5)
148
+ else:
149
+ ax.scatter(codes[:, 0], codes[:, 1], s=20, c='green', alpha=0.5)
150
  ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2)
151
  ax.set_xlim(-12, 12)
152
  ax.set_ylim(-12, 12)
 
168
  optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1)
169
  quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2)
170
  optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1)
171
+ draw_fig_v2(ax_s[0], quantizer_s, data, color='r', title=f"Initialization")
172
+ draw_fig_v2(ax_m[0], quantizer_m, data, color='g', title=f"Initialization")
173
  ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
174
  ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
175
  i_list = [5, 20, 50, 200, 1000]
 
189
 
190
  if (i+1) in i_list:
191
  count += 1
192
+ draw_fig_v2(ax_s[count], quantizer_s, data, color='r', title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}")
193
+ draw_fig_v2(ax_m[count], quantizer_m, data, color='g', title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}")
194
 
195
  quantizer_s.reset()
196
  quantizer_m.reset()
 
206
  # end
207
 
208
  # ReVQ: for image reconstruction
 
209
  class Handler:
210
  def __init__(self, device):
211
  self.transform = T.Compose([
 
294
 
295
  with gr.Blocks() as demo2:
296
  gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
297
+ gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.")
298
 
299
  with gr.Row():
300
  num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
 
303
  submit_btn = gr.Button("Run Visualization")
304
 
305
  with gr.Column(): # 垂直输出
 
306
  out_without_reset = gr.Image(label="Without Reset")
307
+ out_with_reset = gr.Image(label="With Reset")
308
 
309
+ submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset])
310
 
311
 
312
  with gr.Blocks() as demo3:
313
  gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
314
+ gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.")
315
 
316
  with gr.Row():
317
  num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
 
325
 
326
  submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
327
 
 
 
 
 
 
328
  demo = gr.TabbedInterface(
329
  interface_list=[demo1, demo2, demo3],
330
  tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
331
  )
332
 
333
  demo.launch(share=True)