AndyRaoTHU commited on
Commit
63e0d46
·
1 Parent(s): af6c0a4

add multi-group

Browse files
Files changed (1) hide show
  1. app.py +96 -85
app.py CHANGED
@@ -41,89 +41,6 @@ def load_preprocessor(device, is_eval: bool = True, ckpt_path: str = "./ckpt/pre
41
  preprocessor.eval()
42
  return preprocessor
43
 
44
- def nearest(src, trg):
45
- dis_mat = torch.cdist(src, trg)
46
- min_idx = torch.argmin(dis_mat, dim=-1)
47
- return min_idx
48
-
49
- def normalize(A, dim, mode="all"):
50
- if mode == "all":
51
- A = (A - A.mean()) / (A.std() + 1e-6)
52
- A = A - A.min()
53
- elif mode == "dim":
54
- A = A / dim
55
- elif mode == "null":
56
- pass
57
- return A
58
-
59
- def draw_NN(data, code):
60
- # nearest neighbor method
61
- indices = nearest(data, code)
62
- data = data.numpy()
63
- code = code.numpy()
64
-
65
- plt.figure(figsize=(3, 2.5), dpi=400)
66
- # draw arrows in blue color, alpha=0.5
67
- for i in range(data.shape[0]):
68
- idx = indices[i].item()
69
- start = data[i]
70
- end = code[idx]
71
- plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
72
- head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6,
73
- ls="-", lw=0.5)
74
- plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
75
- plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
76
- plt.legend(loc="lower right")
77
- plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
78
- plt.title("Nearest neighbor")
79
-
80
- buf = BytesIO()
81
- plt.savefig(buf, format="png")
82
- buf.seek(0)
83
- image = Image.open(buf)
84
- return image
85
-
86
- def draw_optvq(data, code):
87
- cost = torch.cdist(data, code, p=2.0)
88
- cost = normalize(cost, dim, mode="all")
89
- Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False)
90
- indices = torch.argmax(Q, dim=-1)
91
- data = data.numpy()
92
- code = code.numpy()
93
-
94
- plt.figure(figsize=(3, 2.5), dpi=400)
95
- # draw arrows in blue color, alpha=0.5
96
- for i in range(data.shape[0]):
97
- idx = indices[i].item()
98
- start = data[i]
99
- end = code[idx]
100
- plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
101
- head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6,
102
- ls="-", lw=0.5)
103
- plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
104
- plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
105
- plt.legend(loc="lower right")
106
- plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
107
- plt.title("Optimal Transport (OptVQ)")
108
-
109
- buf = BytesIO()
110
- plt.savefig(buf, format="png")
111
- buf.seek(0)
112
- image = Image.open(buf)
113
- return image
114
-
115
- def draw_process(x, y, std):
116
- data = torch.randn(N_data, dim)
117
- code = torch.randn(N_code, dim) * std
118
- code[:, 0] += x
119
- code[:, 1] += y
120
-
121
- image_NN = draw_NN(data, code)
122
- image_optvq = draw_optvq(data, code)
123
-
124
- return image_NN, image_optvq
125
-
126
-
127
  # ReVQ: for reset strategy
128
  def fig_to_array(fig):
129
  buf = BytesIO()
@@ -211,6 +128,83 @@ def draw_reset_result(num_data=16, num_code=12):
211
  # end
212
 
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  class Handler:
215
  def __init__(self, device):
216
  self.transform = T.Compose([
@@ -326,14 +320,31 @@ if __name__ == "__main__":
326
 
327
  submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_with_reset, out_without_reset])
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  # 合并两个 interface 成 Tabbed UI
330
  # demo = gr.TabbedInterface(
331
  # interface_list=[demo1, demo2],
332
  # tab_names=["Image Reconstruction", "Reset Strategy"]
333
  # )
334
  demo = gr.TabbedInterface(
335
- interface_list=[demo2],
336
- tab_names=["Reset Strategy"]
337
  )
338
 
339
  demo.launch(share=True)
 
41
  preprocessor.eval()
42
  return preprocessor
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ReVQ: for reset strategy
45
  def fig_to_array(fig):
46
  buf = BytesIO()
 
128
  # end
129
 
130
 
131
+ # ReVQ: for multi-group
132
+ def get_codebook_v2(quantizer):
133
+ with torch.no_grad():
134
+ embedding = quantizer.embeddings
135
+ if quantizer.num_group == 1:
136
+ group1 = embedding[0].squeeze()
137
+ group2 = embedding[0].squeeze()
138
+ else:
139
+ group1 = embedding[0].squeeze()
140
+ group2 = embedding[1].squeeze()
141
+ codes = torch.cartesian_prod(group1, group2)
142
+ return codes
143
+
144
+ def draw_fig_v2(ax, quantizer, data, title=""):
145
+ codes = get_codebook(quantizer)
146
+ ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
147
+ ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5)
148
+ ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2)
149
+ ax.set_xlim(-12, 12)
150
+ ax.set_ylim(-12, 12)
151
+ ax.tick_params(axis='x', labelsize=22)
152
+ ax.tick_params(axis='y', labelsize=22)
153
+ ax.set_xticks(np.arange(-10, 11, 5))
154
+ ax.set_yticks(np.arange(-10, 11, 5))
155
+ ax.grid(linestyle='--', color='#333333', alpha=0.7)
156
+ ax.set_title(f"{title}", fontsize=26)
157
+
158
+
159
+ def draw_multi_group_result(num_data=16, num_code=12):
160
+ fig_s, ax_s = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
161
+ fig_m, ax_m = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
162
+ x = torch.randn(num_data, 1) * 3 + 4
163
+ y = torch.randn(num_data, 1) * 3 - 4
164
+ data = torch.cat([x, y], dim=1)
165
+ quantizer_s = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=1, tokens_per_data=2)
166
+ optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1)
167
+ quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2)
168
+ optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1)
169
+ draw_fig_v2(ax_s[0], quantizer_s, data, title=f"Initialization")
170
+ draw_fig_v2(ax_m[0], quantizer_m, data, title=f"Initialization")
171
+ ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
172
+ ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
173
+ i_list = [5, 20, 50, 200, 1000]
174
+
175
+ count = 0
176
+ for i in range(1500):
177
+ optimizer_s.zero_grad()
178
+ optimizer_m.zero_grad()
179
+ quant_data_s = quantizer_s(data.unsqueeze(-1))["x_quant"].squeeze()
180
+ quant_data_m = quantizer_m(data.unsqueeze(-1))["x_quant"].squeeze()
181
+ loss_s = torch.mean((quant_data_s - data) ** 2)
182
+ loss_m = torch.mean((quant_data_m - data) ** 2)
183
+ loss_s.backward()
184
+ loss_m.backward()
185
+ optimizer_s.step()
186
+ optimizer_m.step()
187
+
188
+ if (i+1) in i_list:
189
+ count += 1
190
+ draw_fig_v2(ax_s[count], quantizer_s, data, title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}")
191
+ draw_fig_v2(ax_m[count], quantizer_m, data, title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}")
192
+
193
+ quantizer_s.reset()
194
+ quantizer_m.reset()
195
+
196
+ fig_s.suptitle("VQ Codebook Training with Single Group", fontsize=24, y=1.05)
197
+ fig_m.suptitle("VQ Codebook Training with Multi Group", fontsize=24, y=1.05)
198
+
199
+ img_s = fig_to_array(fig_s)
200
+ img_m = fig_to_array(fig_m)
201
+
202
+ return img_s, img_m
203
+
204
+ # end
205
+
206
+ # ReVQ: for image reconstruction
207
+
208
  class Handler:
209
  def __init__(self, device):
210
  self.transform = T.Compose([
 
320
 
321
  submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_with_reset, out_without_reset])
322
 
323
+
324
+ with gr.Blocks() as demo3:
325
+ gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
326
+ gr.Markdown("Visualizes codebook and data movement at different training steps with multi-group strategy.")
327
+
328
+ with gr.Row():
329
+ num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
330
+ num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1)
331
+
332
+ submit_btn = gr.Button("Run Visualization")
333
+
334
+ with gr.Column(): # 垂直输出
335
+ out_s = gr.Image(label="Single Group")
336
+ out_m = gr.Image(label="Multi Group")
337
+
338
+ submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
339
+
340
  # 合并两个 interface 成 Tabbed UI
341
  # demo = gr.TabbedInterface(
342
  # interface_list=[demo1, demo2],
343
  # tab_names=["Image Reconstruction", "Reset Strategy"]
344
  # )
345
  demo = gr.TabbedInterface(
346
+ interface_list=[demo2, demo3],
347
+ tab_names=["Reset Strategy", "Channel Multi-Group Strategy"]
348
  )
349
 
350
  demo.launch(share=True)