Spaces:
Sleeping
Sleeping
Commit
·
4359005
1
Parent(s):
57b27d9
update
Browse files
app.py
CHANGED
|
@@ -23,9 +23,6 @@ from revq.models.vqgan_hf import VQModelHF
|
|
| 23 |
from diffusers import AutoencoderDC
|
| 24 |
|
| 25 |
#################
|
| 26 |
-
N_data = 50
|
| 27 |
-
N_code = 20
|
| 28 |
-
dim = 2
|
| 29 |
handler = None
|
| 30 |
device = torch.device("cpu")
|
| 31 |
#################
|
|
@@ -54,10 +51,13 @@ def get_codebook(quantizer):
|
|
| 54 |
codes = quantizer.embeddings.squeeze().detach()
|
| 55 |
return codes
|
| 56 |
|
| 57 |
-
def draw_fig(ax, quantizer, data, title=""):
|
| 58 |
codes = get_codebook(quantizer)
|
| 59 |
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
ax.set_xlim(-5, 10)
|
| 62 |
ax.set_ylim(-10, 5)
|
| 63 |
ax.tick_params(axis='x', labelsize=22)
|
|
@@ -83,8 +83,8 @@ def draw_reset_result(num_data=16, num_code=12):
|
|
| 83 |
optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
|
| 84 |
quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False)
|
| 85 |
optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
|
| 86 |
-
draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
|
| 87 |
-
draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
|
| 88 |
ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 89 |
ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 90 |
|
|
@@ -109,10 +109,10 @@ def draw_reset_result(num_data=16, num_code=12):
|
|
| 109 |
|
| 110 |
if (i+1) in i_list:
|
| 111 |
count += 1
|
| 112 |
-
draw_fig(ax_reset[count], quantizer, data, title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
|
| 113 |
draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy())
|
| 114 |
|
| 115 |
-
draw_fig(ax_nreset[count], quantizer_nreset, data, title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}")
|
| 116 |
draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy())
|
| 117 |
|
| 118 |
quantizer.reset()
|
|
@@ -123,11 +123,10 @@ def draw_reset_result(num_data=16, num_code=12):
|
|
| 123 |
img_reset = fig_to_array(fig_reset)
|
| 124 |
img_nreset = fig_to_array(fig_nreset)
|
| 125 |
|
| 126 |
-
return
|
| 127 |
|
| 128 |
# end
|
| 129 |
|
| 130 |
-
|
| 131 |
# ReVQ: for multi-group
|
| 132 |
def get_codebook_v2(quantizer):
|
| 133 |
with torch.no_grad():
|
|
@@ -141,10 +140,13 @@ def get_codebook_v2(quantizer):
|
|
| 141 |
codes = torch.cartesian_prod(group1, group2)
|
| 142 |
return codes
|
| 143 |
|
| 144 |
-
def draw_fig_v2(ax, quantizer, data, title=""):
|
| 145 |
codes = get_codebook_v2(quantizer)
|
| 146 |
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
| 148 |
ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2)
|
| 149 |
ax.set_xlim(-12, 12)
|
| 150 |
ax.set_ylim(-12, 12)
|
|
@@ -166,8 +168,8 @@ def draw_multi_group_result(num_data=16, num_code=12):
|
|
| 166 |
optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1)
|
| 167 |
quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2)
|
| 168 |
optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1)
|
| 169 |
-
draw_fig_v2(ax_s[0], quantizer_s, data, title=f"Initialization")
|
| 170 |
-
draw_fig_v2(ax_m[0], quantizer_m, data, title=f"Initialization")
|
| 171 |
ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 172 |
ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 173 |
i_list = [5, 20, 50, 200, 1000]
|
|
@@ -187,8 +189,8 @@ def draw_multi_group_result(num_data=16, num_code=12):
|
|
| 187 |
|
| 188 |
if (i+1) in i_list:
|
| 189 |
count += 1
|
| 190 |
-
draw_fig_v2(ax_s[count], quantizer_s, data, title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}")
|
| 191 |
-
draw_fig_v2(ax_m[count], quantizer_m, data, title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}")
|
| 192 |
|
| 193 |
quantizer_s.reset()
|
| 194 |
quantizer_m.reset()
|
|
@@ -204,7 +206,6 @@ def draw_multi_group_result(num_data=16, num_code=12):
|
|
| 204 |
# end
|
| 205 |
|
| 206 |
# ReVQ: for image reconstruction
|
| 207 |
-
|
| 208 |
class Handler:
|
| 209 |
def __init__(self, device):
|
| 210 |
self.transform = T.Compose([
|
|
@@ -293,7 +294,7 @@ if __name__ == "__main__":
|
|
| 293 |
|
| 294 |
with gr.Blocks() as demo2:
|
| 295 |
gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
|
| 296 |
-
gr.Markdown("Visualizes codebook and data movement at different training steps.")
|
| 297 |
|
| 298 |
with gr.Row():
|
| 299 |
num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
|
|
@@ -302,15 +303,15 @@ if __name__ == "__main__":
|
|
| 302 |
submit_btn = gr.Button("Run Visualization")
|
| 303 |
|
| 304 |
with gr.Column(): # 垂直输出
|
| 305 |
-
out_with_reset = gr.Image(label="With Reset")
|
| 306 |
out_without_reset = gr.Image(label="Without Reset")
|
|
|
|
| 307 |
|
| 308 |
-
submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[
|
| 309 |
|
| 310 |
|
| 311 |
with gr.Blocks() as demo3:
|
| 312 |
gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
|
| 313 |
-
gr.Markdown("Visualizes codebook and data movement at different training steps with multi-group strategy.")
|
| 314 |
|
| 315 |
with gr.Row():
|
| 316 |
num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
|
|
@@ -324,48 +325,9 @@ if __name__ == "__main__":
|
|
| 324 |
|
| 325 |
submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
|
| 326 |
|
| 327 |
-
# 合并两个 interface 成 Tabbed UI
|
| 328 |
-
# demo = gr.TabbedInterface(
|
| 329 |
-
# interface_list=[demo1, demo2],
|
| 330 |
-
# tab_names=["Image Reconstruction", "Reset Strategy"]
|
| 331 |
-
# )
|
| 332 |
demo = gr.TabbedInterface(
|
| 333 |
interface_list=[demo1, demo2, demo3],
|
| 334 |
tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
|
| 335 |
)
|
| 336 |
|
| 337 |
demo.launch(share=True)
|
| 338 |
-
|
| 339 |
-
# create the interface
|
| 340 |
-
# with gr.Blocks() as demo:
|
| 341 |
-
# gr.Textbox(value="This demo shows the image reconstruction comparison between ReVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
|
| 342 |
-
# with gr.Row():
|
| 343 |
-
# with gr.Column():
|
| 344 |
-
# image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
|
| 345 |
-
# btn_demo1 = gr.Button(value="Run reconstruction")
|
| 346 |
-
# image_basevq = gr.Image(label="BaseVQ rec.")
|
| 347 |
-
# image_vqgan = gr.Image(label="VQGAN rec.")
|
| 348 |
-
# image_revq = gr.Image(label="ReVQ rec.")
|
| 349 |
-
# btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_basevq, image_vqgan, image_revq])
|
| 350 |
-
|
| 351 |
-
# gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
|
| 352 |
-
# gr.Markdown("### Demo 2: 2D visualizations of matching results\n"
|
| 353 |
-
# "This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. "
|
| 354 |
-
# "The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.")
|
| 355 |
-
# with gr.Row():
|
| 356 |
-
# with gr.Column():
|
| 357 |
-
# input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
|
| 358 |
-
# input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
|
| 359 |
-
# input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
|
| 360 |
-
# btn_demo2 = gr.Button(value="Run 2D example")
|
| 361 |
-
# output_nn = gr.Image(label="NN", interactive=False, type="numpy")
|
| 362 |
-
# output_optvq = gr.Image(label="OptVQ", interactive=False, type="numpy")
|
| 363 |
-
|
| 364 |
-
# # set the function
|
| 365 |
-
# input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
| 366 |
-
# input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
| 367 |
-
# input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
| 368 |
-
# btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
| 369 |
-
# btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
| 370 |
-
|
| 371 |
-
# demo.launch()
|
|
|
|
| 23 |
from diffusers import AutoencoderDC
|
| 24 |
|
| 25 |
#################
|
|
|
|
|
|
|
|
|
|
| 26 |
handler = None
|
| 27 |
device = torch.device("cpu")
|
| 28 |
#################
|
|
|
|
| 51 |
codes = quantizer.embeddings.squeeze().detach()
|
| 52 |
return codes
|
| 53 |
|
| 54 |
+
def draw_fig(ax, quantizer, data, color="r", title=""):
|
| 55 |
codes = get_codebook(quantizer)
|
| 56 |
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
|
| 57 |
+
if color == "r":
|
| 58 |
+
ax.scatter(codes[:, 0], codes[:, 1], s=40, c='red', alpha=0.5)
|
| 59 |
+
else:
|
| 60 |
+
ax.scatter(codes[:, 0], codes[:, 1], s=40, c='green', alpha=0.5)
|
| 61 |
ax.set_xlim(-5, 10)
|
| 62 |
ax.set_ylim(-10, 5)
|
| 63 |
ax.tick_params(axis='x', labelsize=22)
|
|
|
|
| 83 |
optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
|
| 84 |
quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False)
|
| 85 |
optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
|
| 86 |
+
draw_fig(ax_reset[0], quantizer, data, color='g', title=f"Initialization")
|
| 87 |
+
draw_fig(ax_nreset[0], quantizer_nreset, data, color='r', title=f"Initialization")
|
| 88 |
ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 89 |
ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 90 |
|
|
|
|
| 109 |
|
| 110 |
if (i+1) in i_list:
|
| 111 |
count += 1
|
| 112 |
+
draw_fig(ax_reset[count], quantizer, data, color='g', title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
|
| 113 |
draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy())
|
| 114 |
|
| 115 |
+
draw_fig(ax_nreset[count], quantizer_nreset, data, color='r', title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}")
|
| 116 |
draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy())
|
| 117 |
|
| 118 |
quantizer.reset()
|
|
|
|
| 123 |
img_reset = fig_to_array(fig_reset)
|
| 124 |
img_nreset = fig_to_array(fig_nreset)
|
| 125 |
|
| 126 |
+
return img_nreset, img_reset
|
| 127 |
|
| 128 |
# end
|
| 129 |
|
|
|
|
| 130 |
# ReVQ: for multi-group
|
| 131 |
def get_codebook_v2(quantizer):
|
| 132 |
with torch.no_grad():
|
|
|
|
| 140 |
codes = torch.cartesian_prod(group1, group2)
|
| 141 |
return codes
|
| 142 |
|
| 143 |
+
def draw_fig_v2(ax, quantizer, data, color='r', title=""):
|
| 144 |
codes = get_codebook_v2(quantizer)
|
| 145 |
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*")
|
| 146 |
+
if color == "r":
|
| 147 |
+
ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5)
|
| 148 |
+
else:
|
| 149 |
+
ax.scatter(codes[:, 0], codes[:, 1], s=20, c='green', alpha=0.5)
|
| 150 |
ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2)
|
| 151 |
ax.set_xlim(-12, 12)
|
| 152 |
ax.set_ylim(-12, 12)
|
|
|
|
| 168 |
optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1)
|
| 169 |
quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2)
|
| 170 |
optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1)
|
| 171 |
+
draw_fig_v2(ax_s[0], quantizer_s, data, color='r', title=f"Initialization")
|
| 172 |
+
draw_fig_v2(ax_m[0], quantizer_m, data, color='g', title=f"Initialization")
|
| 173 |
ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 174 |
ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 175 |
i_list = [5, 20, 50, 200, 1000]
|
|
|
|
| 189 |
|
| 190 |
if (i+1) in i_list:
|
| 191 |
count += 1
|
| 192 |
+
draw_fig_v2(ax_s[count], quantizer_s, data, color='r', title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}")
|
| 193 |
+
draw_fig_v2(ax_m[count], quantizer_m, data, color='g', title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}")
|
| 194 |
|
| 195 |
quantizer_s.reset()
|
| 196 |
quantizer_m.reset()
|
|
|
|
| 206 |
# end
|
| 207 |
|
| 208 |
# ReVQ: for image reconstruction
|
|
|
|
| 209 |
class Handler:
|
| 210 |
def __init__(self, device):
|
| 211 |
self.transform = T.Compose([
|
|
|
|
| 294 |
|
| 295 |
with gr.Blocks() as demo2:
|
| 296 |
gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
|
| 297 |
+
gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.")
|
| 298 |
|
| 299 |
with gr.Row():
|
| 300 |
num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
|
|
|
|
| 303 |
submit_btn = gr.Button("Run Visualization")
|
| 304 |
|
| 305 |
with gr.Column(): # 垂直输出
|
|
|
|
| 306 |
out_without_reset = gr.Image(label="Without Reset")
|
| 307 |
+
out_with_reset = gr.Image(label="With Reset")
|
| 308 |
|
| 309 |
+
submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset])
|
| 310 |
|
| 311 |
|
| 312 |
with gr.Blocks() as demo3:
|
| 313 |
gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
|
| 314 |
+
gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.")
|
| 315 |
|
| 316 |
with gr.Row():
|
| 317 |
num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
|
|
|
|
| 325 |
|
| 326 |
submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
demo = gr.TabbedInterface(
|
| 329 |
interface_list=[demo1, demo2, demo3],
|
| 330 |
tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
|
| 331 |
)
|
| 332 |
|
| 333 |
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|