AndyRaoTHU commited on
Commit
38eb8fe
·
1 Parent(s): 38e4fb0

update reset

Browse files
Files changed (1) hide show
  1. app.py +5 -21
app.py CHANGED
@@ -141,18 +141,17 @@ def draw_fig(ax, quantizer, data, title=""):
141
  ax.set_xticks(np.arange(-5, 11, 5))
142
  ax.set_yticks(np.arange(-10, 6, 5))
143
  ax.grid(linestyle='--', color='#333333', alpha=0.7)
144
- ax.set_title(f"{title}", fontsize=26)
145
 
146
  def draw_arrow(ax, start, end):
147
  for i in range(len(start)):
148
- # 不用用plt.annotate
149
  ax.arrow(start[i][0], start[i][1], end[i][0] - start[i][0], end[i][1] - start[i][1],
150
  head_width=0.1, head_length=0.1, fc='orange', ec='orange', alpha=0.8,
151
  ls="-", lw=1)
152
 
153
  def draw_reset_result(num_data=16, num_code=12):
154
- fig_reset, ax_reset = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
155
- fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
156
  x = torch.randn(num_data, 1) * 2 + 5
157
  y = torch.randn(num_data, 1) * 2 - 5
158
  data = torch.cat([x, y], dim=1)
@@ -162,8 +161,8 @@ def draw_reset_result(num_data=16, num_code=12):
162
  optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
163
  draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
164
  draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
165
- ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
166
- ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
167
 
168
  i_list = [1, 3, 10, 50, 200]
169
 
@@ -288,21 +287,6 @@ if __name__ == "__main__":
288
  description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes."
289
  )
290
 
291
- # demo2 = gr.Interface(
292
- # fn=draw_process,
293
- # inputs=[
294
- # gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1),
295
- # gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1),
296
- # gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
297
- # ],
298
- # outputs=[
299
- # gr.Image(label="NN", type="numpy"),
300
- # gr.Image(label="OptVQ", type="numpy")
301
- # ],
302
- # title="Demo 2: 2D Matching Visualization",
303
- # description="Visualize nearest neighbor vs. optimal transport matching for synthetic 2D data."
304
- # )
305
-
306
  demo2 = gr.Interface(
307
  fn=draw_reset_result, # 不再需要 handler 包装
308
  inputs=[
 
141
  ax.set_xticks(np.arange(-5, 11, 5))
142
  ax.set_yticks(np.arange(-10, 6, 5))
143
  ax.grid(linestyle='--', color='#333333', alpha=0.7)
144
+ ax.set_title(f"{title}", fontsize=36)
145
 
146
  def draw_arrow(ax, start, end):
147
  for i in range(len(start)):
 
148
  ax.arrow(start[i][0], start[i][1], end[i][0] - start[i][0], end[i][1] - start[i][1],
149
  head_width=0.1, head_length=0.1, fc='orange', ec='orange', alpha=0.8,
150
  ls="-", lw=1)
151
 
152
  def draw_reset_result(num_data=16, num_code=12):
153
+ fig_reset, ax_reset = plt.subplots(1, 6, figsize=(72, 10), dpi=400)
154
+ fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(72, 10), dpi=400)
155
  x = torch.randn(num_data, 1) * 2 + 5
156
  y = torch.randn(num_data, 1) * 2 - 5
157
  data = torch.cat([x, y], dim=1)
 
161
  optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
162
  draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
163
  draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
164
+ ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
165
+ ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
166
 
167
  i_list = [1, 3, 10, 50, 200]
168
 
 
287
  description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes."
288
  )
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  demo2 = gr.Interface(
291
  fn=draw_reset_result, # 不再需要 handler 包装
292
  inputs=[