Spaces:
Sleeping
Sleeping
Commit
·
63e0d46
1
Parent(s):
af6c0a4
add multi-group
Browse files
app.py
CHANGED
|
@@ -41,89 +41,6 @@ def load_preprocessor(device, is_eval: bool = True, ckpt_path: str = "./ckpt/pre
|
|
| 41 |
preprocessor.eval()
|
| 42 |
return preprocessor
|
| 43 |
|
| 44 |
-
def nearest(src, trg):
|
| 45 |
-
dis_mat = torch.cdist(src, trg)
|
| 46 |
-
min_idx = torch.argmin(dis_mat, dim=-1)
|
| 47 |
-
return min_idx
|
| 48 |
-
|
| 49 |
-
def normalize(A, dim, mode="all"):
|
| 50 |
-
if mode == "all":
|
| 51 |
-
A = (A - A.mean()) / (A.std() + 1e-6)
|
| 52 |
-
A = A - A.min()
|
| 53 |
-
elif mode == "dim":
|
| 54 |
-
A = A / dim
|
| 55 |
-
elif mode == "null":
|
| 56 |
-
pass
|
| 57 |
-
return A
|
| 58 |
-
|
| 59 |
-
def draw_NN(data, code):
|
| 60 |
-
# nearest neighbor method
|
| 61 |
-
indices = nearest(data, code)
|
| 62 |
-
data = data.numpy()
|
| 63 |
-
code = code.numpy()
|
| 64 |
-
|
| 65 |
-
plt.figure(figsize=(3, 2.5), dpi=400)
|
| 66 |
-
# draw arrows in blue color, alpha=0.5
|
| 67 |
-
for i in range(data.shape[0]):
|
| 68 |
-
idx = indices[i].item()
|
| 69 |
-
start = data[i]
|
| 70 |
-
end = code[idx]
|
| 71 |
-
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
|
| 72 |
-
head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6,
|
| 73 |
-
ls="-", lw=0.5)
|
| 74 |
-
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
|
| 75 |
-
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
|
| 76 |
-
plt.legend(loc="lower right")
|
| 77 |
-
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
|
| 78 |
-
plt.title("Nearest neighbor")
|
| 79 |
-
|
| 80 |
-
buf = BytesIO()
|
| 81 |
-
plt.savefig(buf, format="png")
|
| 82 |
-
buf.seek(0)
|
| 83 |
-
image = Image.open(buf)
|
| 84 |
-
return image
|
| 85 |
-
|
| 86 |
-
def draw_optvq(data, code):
|
| 87 |
-
cost = torch.cdist(data, code, p=2.0)
|
| 88 |
-
cost = normalize(cost, dim, mode="all")
|
| 89 |
-
Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False)
|
| 90 |
-
indices = torch.argmax(Q, dim=-1)
|
| 91 |
-
data = data.numpy()
|
| 92 |
-
code = code.numpy()
|
| 93 |
-
|
| 94 |
-
plt.figure(figsize=(3, 2.5), dpi=400)
|
| 95 |
-
# draw arrows in blue color, alpha=0.5
|
| 96 |
-
for i in range(data.shape[0]):
|
| 97 |
-
idx = indices[i].item()
|
| 98 |
-
start = data[i]
|
| 99 |
-
end = code[idx]
|
| 100 |
-
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
|
| 101 |
-
head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6,
|
| 102 |
-
ls="-", lw=0.5)
|
| 103 |
-
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
|
| 104 |
-
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
|
| 105 |
-
plt.legend(loc="lower right")
|
| 106 |
-
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
|
| 107 |
-
plt.title("Optimal Transport (OptVQ)")
|
| 108 |
-
|
| 109 |
-
buf = BytesIO()
|
| 110 |
-
plt.savefig(buf, format="png")
|
| 111 |
-
buf.seek(0)
|
| 112 |
-
image = Image.open(buf)
|
| 113 |
-
return image
|
| 114 |
-
|
| 115 |
-
def draw_process(x, y, std):
|
| 116 |
-
data = torch.randn(N_data, dim)
|
| 117 |
-
code = torch.randn(N_code, dim) * std
|
| 118 |
-
code[:, 0] += x
|
| 119 |
-
code[:, 1] += y
|
| 120 |
-
|
| 121 |
-
image_NN = draw_NN(data, code)
|
| 122 |
-
image_optvq = draw_optvq(data, code)
|
| 123 |
-
|
| 124 |
-
return image_NN, image_optvq
|
| 125 |
-
|
| 126 |
-
|
| 127 |
# ReVQ: for reset strategy
|
| 128 |
def fig_to_array(fig):
|
| 129 |
buf = BytesIO()
|
|
@@ -211,6 +128,83 @@ def draw_reset_result(num_data=16, num_code=12):
|
|
| 211 |
# end
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
class Handler:
|
| 215 |
def __init__(self, device):
|
| 216 |
self.transform = T.Compose([
|
|
@@ -326,14 +320,31 @@ if __name__ == "__main__":
|
|
| 326 |
|
| 327 |
submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_with_reset, out_without_reset])
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
# 合并两个 interface 成 Tabbed UI
|
| 330 |
# demo = gr.TabbedInterface(
|
| 331 |
# interface_list=[demo1, demo2],
|
| 332 |
# tab_names=["Image Reconstruction", "Reset Strategy"]
|
| 333 |
# )
|
| 334 |
demo = gr.TabbedInterface(
|
| 335 |
-
interface_list=[demo2],
|
| 336 |
-
tab_names=["Reset Strategy"]
|
| 337 |
)
|
| 338 |
|
| 339 |
demo.launch(share=True)
|
|
|
|
| 41 |
preprocessor.eval()
|
| 42 |
return preprocessor
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
# ReVQ: for reset strategy
|
| 45 |
def fig_to_array(fig):
|
| 46 |
buf = BytesIO()
|
|
|
|
| 128 |
# end
|
| 129 |
|
| 130 |
|
| 131 |
+
# ReVQ: for multi-group
|
| 132 |
+
def get_codebook_v2(quantizer):
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
embedding = quantizer.embeddings
|
| 135 |
+
if quantizer.num_group == 1:
|
| 136 |
+
group1 = embedding[0].squeeze()
|
| 137 |
+
group2 = embedding[0].squeeze()
|
| 138 |
+
else:
|
| 139 |
+
group1 = embedding[0].squeeze()
|
| 140 |
+
group2 = embedding[1].squeeze()
|
| 141 |
+
codes = torch.cartesian_prod(group1, group2)
|
| 142 |
+
return codes
|
| 143 |
+
|
| 144 |
+
def draw_fig_v2(ax, quantizer, data, title=""):
|
| 145 |
+
codes = get_codebook(quantizer)
|
| 146 |
+
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
|
| 147 |
+
ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5)
|
| 148 |
+
ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2)
|
| 149 |
+
ax.set_xlim(-12, 12)
|
| 150 |
+
ax.set_ylim(-12, 12)
|
| 151 |
+
ax.tick_params(axis='x', labelsize=22)
|
| 152 |
+
ax.tick_params(axis='y', labelsize=22)
|
| 153 |
+
ax.set_xticks(np.arange(-10, 11, 5))
|
| 154 |
+
ax.set_yticks(np.arange(-10, 11, 5))
|
| 155 |
+
ax.grid(linestyle='--', color='#333333', alpha=0.7)
|
| 156 |
+
ax.set_title(f"{title}", fontsize=26)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def draw_multi_group_result(num_data=16, num_code=12):
|
| 160 |
+
fig_s, ax_s = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
|
| 161 |
+
fig_m, ax_m = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
|
| 162 |
+
x = torch.randn(num_data, 1) * 3 + 4
|
| 163 |
+
y = torch.randn(num_data, 1) * 3 - 4
|
| 164 |
+
data = torch.cat([x, y], dim=1)
|
| 165 |
+
quantizer_s = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=1, tokens_per_data=2)
|
| 166 |
+
optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1)
|
| 167 |
+
quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2)
|
| 168 |
+
optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1)
|
| 169 |
+
draw_fig_v2(ax_s[0], quantizer_s, data, title=f"Initialization")
|
| 170 |
+
draw_fig_v2(ax_m[0], quantizer_m, data, title=f"Initialization")
|
| 171 |
+
ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 172 |
+
ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 173 |
+
i_list = [5, 20, 50, 200, 1000]
|
| 174 |
+
|
| 175 |
+
count = 0
|
| 176 |
+
for i in range(1500):
|
| 177 |
+
optimizer_s.zero_grad()
|
| 178 |
+
optimizer_m.zero_grad()
|
| 179 |
+
quant_data_s = quantizer_s(data.unsqueeze(-1))["x_quant"].squeeze()
|
| 180 |
+
quant_data_m = quantizer_m(data.unsqueeze(-1))["x_quant"].squeeze()
|
| 181 |
+
loss_s = torch.mean((quant_data_s - data) ** 2)
|
| 182 |
+
loss_m = torch.mean((quant_data_m - data) ** 2)
|
| 183 |
+
loss_s.backward()
|
| 184 |
+
loss_m.backward()
|
| 185 |
+
optimizer_s.step()
|
| 186 |
+
optimizer_m.step()
|
| 187 |
+
|
| 188 |
+
if (i+1) in i_list:
|
| 189 |
+
count += 1
|
| 190 |
+
draw_fig_v2(ax_s[count], quantizer_s, data, title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}")
|
| 191 |
+
draw_fig_v2(ax_m[count], quantizer_m, data, title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}")
|
| 192 |
+
|
| 193 |
+
quantizer_s.reset()
|
| 194 |
+
quantizer_m.reset()
|
| 195 |
+
|
| 196 |
+
fig_s.suptitle("VQ Codebook Training with Single Group", fontsize=24, y=1.05)
|
| 197 |
+
fig_m.suptitle("VQ Codebook Training with Multi Group", fontsize=24, y=1.05)
|
| 198 |
+
|
| 199 |
+
img_s = fig_to_array(fig_s)
|
| 200 |
+
img_m = fig_to_array(fig_m)
|
| 201 |
+
|
| 202 |
+
return img_s, img_m
|
| 203 |
+
|
| 204 |
+
# end
|
| 205 |
+
|
| 206 |
+
# ReVQ: for image reconstruction
|
| 207 |
+
|
| 208 |
class Handler:
|
| 209 |
def __init__(self, device):
|
| 210 |
self.transform = T.Compose([
|
|
|
|
| 320 |
|
| 321 |
submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_with_reset, out_without_reset])
|
| 322 |
|
| 323 |
+
|
| 324 |
+
with gr.Blocks() as demo3:
|
| 325 |
+
gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
|
| 326 |
+
gr.Markdown("Visualizes codebook and data movement at different training steps with multi-group strategy.")
|
| 327 |
+
|
| 328 |
+
with gr.Row():
|
| 329 |
+
num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
|
| 330 |
+
num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1)
|
| 331 |
+
|
| 332 |
+
submit_btn = gr.Button("Run Visualization")
|
| 333 |
+
|
| 334 |
+
with gr.Column(): # 垂直输出
|
| 335 |
+
out_s = gr.Image(label="Single Group")
|
| 336 |
+
out_m = gr.Image(label="Multi Group")
|
| 337 |
+
|
| 338 |
+
submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
|
| 339 |
+
|
| 340 |
# 合并两个 interface 成 Tabbed UI
|
| 341 |
# demo = gr.TabbedInterface(
|
| 342 |
# interface_list=[demo1, demo2],
|
| 343 |
# tab_names=["Image Reconstruction", "Reset Strategy"]
|
| 344 |
# )
|
| 345 |
demo = gr.TabbedInterface(
|
| 346 |
+
interface_list=[demo2, demo3],
|
| 347 |
+
tab_names=["Reset Strategy", "Channel Multi-Group Strategy"]
|
| 348 |
)
|
| 349 |
|
| 350 |
demo.launch(share=True)
|