Spaces:
Sleeping
Sleeping
Commit
·
38eb8fe
1
Parent(s):
38e4fb0
update reset
Browse files
app.py
CHANGED
|
@@ -141,18 +141,17 @@ def draw_fig(ax, quantizer, data, title=""):
|
|
| 141 |
ax.set_xticks(np.arange(-5, 11, 5))
|
| 142 |
ax.set_yticks(np.arange(-10, 6, 5))
|
| 143 |
ax.grid(linestyle='--', color='#333333', alpha=0.7)
|
| 144 |
-
ax.set_title(f"{title}", fontsize=
|
| 145 |
|
| 146 |
def draw_arrow(ax, start, end):
|
| 147 |
for i in range(len(start)):
|
| 148 |
-
# 不用用plt.annotate
|
| 149 |
ax.arrow(start[i][0], start[i][1], end[i][0] - start[i][0], end[i][1] - start[i][1],
|
| 150 |
head_width=0.1, head_length=0.1, fc='orange', ec='orange', alpha=0.8,
|
| 151 |
ls="-", lw=1)
|
| 152 |
|
| 153 |
def draw_reset_result(num_data=16, num_code=12):
|
| 154 |
-
fig_reset, ax_reset = plt.subplots(1, 6, figsize=(
|
| 155 |
-
fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(
|
| 156 |
x = torch.randn(num_data, 1) * 2 + 5
|
| 157 |
y = torch.randn(num_data, 1) * 2 - 5
|
| 158 |
data = torch.cat([x, y], dim=1)
|
|
@@ -162,8 +161,8 @@ def draw_reset_result(num_data=16, num_code=12):
|
|
| 162 |
optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
|
| 163 |
draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
|
| 164 |
draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
|
| 165 |
-
ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=
|
| 166 |
-
ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=
|
| 167 |
|
| 168 |
i_list = [1, 3, 10, 50, 200]
|
| 169 |
|
|
@@ -288,21 +287,6 @@ if __name__ == "__main__":
|
|
| 288 |
description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes."
|
| 289 |
)
|
| 290 |
|
| 291 |
-
# demo2 = gr.Interface(
|
| 292 |
-
# fn=draw_process,
|
| 293 |
-
# inputs=[
|
| 294 |
-
# gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1),
|
| 295 |
-
# gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1),
|
| 296 |
-
# gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
|
| 297 |
-
# ],
|
| 298 |
-
# outputs=[
|
| 299 |
-
# gr.Image(label="NN", type="numpy"),
|
| 300 |
-
# gr.Image(label="OptVQ", type="numpy")
|
| 301 |
-
# ],
|
| 302 |
-
# title="Demo 2: 2D Matching Visualization",
|
| 303 |
-
# description="Visualize nearest neighbor vs. optimal transport matching for synthetic 2D data."
|
| 304 |
-
# )
|
| 305 |
-
|
| 306 |
demo2 = gr.Interface(
|
| 307 |
fn=draw_reset_result, # 不再需要 handler 包装
|
| 308 |
inputs=[
|
|
|
|
| 141 |
ax.set_xticks(np.arange(-5, 11, 5))
|
| 142 |
ax.set_yticks(np.arange(-10, 6, 5))
|
| 143 |
ax.grid(linestyle='--', color='#333333', alpha=0.7)
|
| 144 |
+
ax.set_title(f"{title}", fontsize=36)
|
| 145 |
|
| 146 |
def draw_arrow(ax, start, end):
|
| 147 |
for i in range(len(start)):
|
|
|
|
| 148 |
ax.arrow(start[i][0], start[i][1], end[i][0] - start[i][0], end[i][1] - start[i][1],
|
| 149 |
head_width=0.1, head_length=0.1, fc='orange', ec='orange', alpha=0.8,
|
| 150 |
ls="-", lw=1)
|
| 151 |
|
| 152 |
def draw_reset_result(num_data=16, num_code=12):
|
| 153 |
+
fig_reset, ax_reset = plt.subplots(1, 6, figsize=(72, 10), dpi=400)
|
| 154 |
+
fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(72, 10), dpi=400)
|
| 155 |
x = torch.randn(num_data, 1) * 2 + 5
|
| 156 |
y = torch.randn(num_data, 1) * 2 - 5
|
| 157 |
data = torch.cat([x, y], dim=1)
|
|
|
|
| 161 |
optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
|
| 162 |
draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
|
| 163 |
draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
|
| 164 |
+
ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 165 |
+
ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24)
|
| 166 |
|
| 167 |
i_list = [1, 3, 10, 50, 200]
|
| 168 |
|
|
|
|
| 287 |
description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes."
|
| 288 |
)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
demo2 = gr.Interface(
|
| 291 |
fn=draw_reset_result, # 不再需要 handler 包装
|
| 292 |
inputs=[
|