Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from io import BytesIO | |
| import os | |
| import sys | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| import torch | |
| from torchvision import transforms as T | |
| from revq.models.quantizer import sinkhorn | |
| from revq.models.preprocessor import Preprocessor | |
| from revq.models.revq import ReVQ | |
| from revq.models.revq_quantizer import Quantizer | |
| from revq.utils.init import seed_everything | |
| seed_everything(42) | |
| from revq.models.vqgan_hf import VQModelHF | |
| # matplotlib.rcParams['font.family'] = 'Times New Roman' | |
| from diffusers import AutoencoderDC | |
| ################# | |
| handler = None | |
| device = torch.device("cpu") | |
| ################# | |
| def load_preprocessor(device, is_eval: bool = True, ckpt_path: str = "./ckpt/preprocessor.pth"): | |
| preprocessor = Preprocessor( | |
| input_data_size=[32,8,8] | |
| ).to(device) | |
| preprocessor.load_state_dict( | |
| torch.load(ckpt_path, map_location=device, weights_only=True) | |
| ) | |
| if is_eval: | |
| preprocessor.eval() | |
| return preprocessor | |
| # ReVQ: for reset strategy | |
| def fig_to_array(fig): | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png') # 改为 png,不用 webp | |
| buf.seek(0) | |
| image = Image.open(buf) | |
| return np.array(image) | |
| def get_codebook(quantizer): | |
| with torch.no_grad(): | |
| codes = quantizer.embeddings.squeeze().detach() | |
| return codes | |
| def draw_fig(ax, quantizer, data, color="r", title=""): | |
| codes = get_codebook(quantizer) | |
| ax.scatter(data[:, 0], data[:, 1], s=60, marker="*") | |
| if color == "r": | |
| ax.scatter(codes[:, 0], codes[:, 1], s=40, c='red', alpha=0.5) | |
| else: | |
| ax.scatter(codes[:, 0], codes[:, 1], s=40, c='green', alpha=0.5) | |
| ax.set_xlim(-5, 10) | |
| ax.set_ylim(-10, 5) | |
| ax.tick_params(axis='x', labelsize=22) | |
| ax.tick_params(axis='y', labelsize=22) | |
| ax.set_xticks(np.arange(-5, 11, 5)) | |
| ax.set_yticks(np.arange(-10, 6, 5)) | |
| ax.grid(linestyle='--', color='#333333', alpha=0.7) | |
| ax.set_title(f"{title}", fontsize=24) | |
| def draw_arrow(ax, start, end): | |
| for i in range(len(start)): | |
| ax.arrow(start[i][0], start[i][1], end[i][0] - start[i][0], end[i][1] - start[i][1], | |
| head_width=0.1, head_length=0.1, fc='orange', ec='orange', alpha=0.8, | |
| ls="-", lw=1) | |
| def draw_reset_result(num_data=16, num_code=12): | |
| fig_reset, ax_reset = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
| fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
| x = torch.randn(num_data, 1) * 2 + 5 | |
| y = torch.randn(num_data, 1) * 2 - 5 | |
| data = torch.cat([x, y], dim=1) | |
| quantizer = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1) | |
| optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1) | |
| quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False) | |
| optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1) | |
| draw_fig(ax_reset[0], quantizer, data, color='g', title=f"Initialization") | |
| draw_fig(ax_nreset[0], quantizer_nreset, data, color='r', title=f"Initialization") | |
| ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
| ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
| i_list = [1, 3, 10, 50, 200] | |
| count = 0 | |
| for i in range(500): | |
| optimizer.zero_grad() | |
| optimizer_nreset.zero_grad() | |
| output_dict = quantizer(data.unsqueeze(1)) | |
| output_dict_nreset = quantizer_nreset(data.unsqueeze(1)) | |
| quant_data = output_dict["x_quant"].squeeze() | |
| quant_data_nreset = output_dict_nreset["x_quant"].squeeze() | |
| indices = output_dict["indices"].squeeze() | |
| indices = output_dict_nreset["indices"].squeeze() | |
| loss = torch.mean((quant_data - data) ** 2) | |
| loss_nreset = torch.mean((quant_data_nreset - data) ** 2) | |
| loss.backward() | |
| loss_nreset.backward() | |
| optimizer.step() | |
| optimizer_nreset.step() | |
| if (i+1) in i_list: | |
| count += 1 | |
| draw_fig(ax_reset[count], quantizer, data, color='g', title=f"Iters: {i+1}, MSE: {loss.item():.1f}") | |
| draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy()) | |
| draw_fig(ax_nreset[count], quantizer_nreset, data, color='r', title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}") | |
| draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy()) | |
| quantizer.reset() | |
| fig_reset.suptitle("VQ Codebook Training with Reset", fontsize=24, y=1.05) | |
| fig_nreset.suptitle("VQ Codebook Training without Reset", fontsize=24, y=1.05) | |
| img_reset = fig_to_array(fig_reset) | |
| img_nreset = fig_to_array(fig_nreset) | |
| return img_nreset, img_reset | |
| # end | |
| # ReVQ: for multi-group | |
| def get_codebook_v2(quantizer): | |
| with torch.no_grad(): | |
| embedding = quantizer.embeddings | |
| if quantizer.num_group == 1: | |
| group1 = embedding[0].squeeze() | |
| group2 = embedding[0].squeeze() | |
| else: | |
| group1 = embedding[0].squeeze() | |
| group2 = embedding[1].squeeze() | |
| codes = torch.cartesian_prod(group1, group2) | |
| return codes | |
| def draw_fig_v2(ax, quantizer, data, color='r', title=""): | |
| codes = get_codebook_v2(quantizer) | |
| ax.scatter(data[:, 0], data[:, 1], s=60, marker="*") | |
| if color == "r": | |
| ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5) | |
| else: | |
| ax.scatter(codes[:, 0], codes[:, 1], s=20, c='green', alpha=0.5) | |
| ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2) | |
| ax.set_xlim(-12, 12) | |
| ax.set_ylim(-12, 12) | |
| ax.tick_params(axis='x', labelsize=22) | |
| ax.tick_params(axis='y', labelsize=22) | |
| ax.set_xticks(np.arange(-10, 11, 5)) | |
| ax.set_yticks(np.arange(-10, 11, 5)) | |
| ax.grid(linestyle='--', color='#333333', alpha=0.7) | |
| ax.set_title(f"{title}", fontsize=26) | |
| def draw_multi_group_result(num_data=16, num_code=12): | |
| fig_s, ax_s = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
| fig_m, ax_m = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
| x = torch.randn(num_data, 1) * 3 + 4 | |
| y = torch.randn(num_data, 1) * 3 - 4 | |
| data = torch.cat([x, y], dim=1) | |
| quantizer_s = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=1, tokens_per_data=2) | |
| optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1) | |
| quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2) | |
| optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1) | |
| draw_fig_v2(ax_s[0], quantizer_s, data, color='r', title=f"Initialization") | |
| draw_fig_v2(ax_m[0], quantizer_m, data, color='g', title=f"Initialization") | |
| ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
| ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
| i_list = [5, 20, 50, 200, 1000] | |
| count = 0 | |
| for i in range(1500): | |
| optimizer_s.zero_grad() | |
| optimizer_m.zero_grad() | |
| quant_data_s = quantizer_s(data.unsqueeze(-1))["x_quant"].squeeze() | |
| quant_data_m = quantizer_m(data.unsqueeze(-1))["x_quant"].squeeze() | |
| loss_s = torch.mean((quant_data_s - data) ** 2) | |
| loss_m = torch.mean((quant_data_m - data) ** 2) | |
| loss_s.backward() | |
| loss_m.backward() | |
| optimizer_s.step() | |
| optimizer_m.step() | |
| if (i+1) in i_list: | |
| count += 1 | |
| draw_fig_v2(ax_s[count], quantizer_s, data, color='r', title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}") | |
| draw_fig_v2(ax_m[count], quantizer_m, data, color='g', title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}") | |
| quantizer_s.reset() | |
| quantizer_m.reset() | |
| fig_s.suptitle("VQ Codebook Training with Single Group", fontsize=24, y=1.05) | |
| fig_m.suptitle("VQ Codebook Training with Multi Group", fontsize=24, y=1.05) | |
| img_s = fig_to_array(fig_s) | |
| img_m = fig_to_array(fig_m) | |
| return img_s, img_m | |
| # end | |
| # ReVQ: for image reconstruction | |
| class Handler: | |
| def __init__(self, device): | |
| self.transform = T.Compose([ | |
| T.Resize(256), | |
| T.CenterCrop(256), | |
| T.ToTensor() | |
| ]) | |
| self.device = device | |
| self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4") | |
| self.basevq.to(self.device) | |
| self.basevq.eval() | |
| self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16") | |
| self.vqgan.to(self.device) | |
| self.vqgan.eval() | |
| self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4") | |
| self.optvq.to(self.device) | |
| self.optvq.eval() | |
| self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers") | |
| self.vae.to(self.device) | |
| self.vae.eval() | |
| self.preprocesser = load_preprocessor(self.device) | |
| self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T") | |
| self.revq.to(self.device) | |
| self.revq.eval() | |
| # print("Models loaded successfully!") | |
| def tensor_to_image(self, tensor): | |
| img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy() | |
| img = (img + 1) / 2 * 255 | |
| img = img.astype("uint8") | |
| return img | |
| def process_image(self, img: np.ndarray): | |
| img = Image.fromarray(img.astype("uint8")) | |
| img = self.transform(img) | |
| img = img.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| img = 2 * img - 1 | |
| # basevq | |
| quant, *_ = self.basevq.encode(img) | |
| basevq_rec = self.basevq.decode(quant) | |
| # vqgan | |
| quant, *_ = self.vqgan.encode(img) | |
| vqgan_rec = self.vqgan.decode(quant) | |
| # revq | |
| lat = self.vae.encode(img).latent | |
| lat = lat.contiguous() | |
| lat = self.preprocesser(lat) | |
| lat = self.revq.quantize(lat) | |
| revq_rec = self.revq.decode(lat) | |
| revq_rec = revq_rec.contiguous() | |
| revq_rec = self.preprocesser.inverse(revq_rec) | |
| revq_rec = self.vae.decode(revq_rec).sample | |
| # tensor to PIL image | |
| img = self.tensor_to_image(img) | |
| basevq_rec = self.tensor_to_image(basevq_rec) | |
| vqgan_rec = self.tensor_to_image(vqgan_rec) | |
| revq_rec = self.tensor_to_image(revq_rec) | |
| return basevq_rec, vqgan_rec, revq_rec | |
| if __name__ == "__main__": | |
| # create the model handler | |
| handler = Handler(device=device) | |
| print("Creating Gradio interface...") | |
| # Demo 1 接口:图像重建 | |
| demo1 = gr.Interface( | |
| fn=handler.process_image, | |
| inputs=gr.Image(label="Input Image", type="numpy"), | |
| outputs=[ | |
| gr.Image(label="BaseVQ Reconstruction", type="numpy"), | |
| gr.Image(label="VQGAN Reconstruction", type="numpy"), | |
| gr.Image(label="ReVQ Reconstruction", type="numpy"), | |
| ], | |
| title="Demo 1: Image Reconstruction", | |
| description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes." | |
| ) | |
| with gr.Blocks() as demo2: | |
| gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization") | |
| gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.") | |
| with gr.Row(): | |
| num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1) | |
| num_code = gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1) | |
| submit_btn = gr.Button("Run Visualization") | |
| with gr.Column(): # 垂直输出 | |
| out_without_reset = gr.Image(label="Without Reset") | |
| out_with_reset = gr.Image(label="With Reset") | |
| submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset]) | |
| with gr.Blocks() as demo3: | |
| gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization") | |
| gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.") | |
| with gr.Row(): | |
| num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1) | |
| num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1) | |
| submit_btn = gr.Button("Run Visualization") | |
| with gr.Column(): # 垂直输出 | |
| out_s = gr.Image(label="Single Group") | |
| out_m = gr.Image(label="Multi Group") | |
| submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m]) | |
| demo = gr.TabbedInterface( | |
| interface_list=[demo1, demo2, demo3], | |
| tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"] | |
| ) | |
| demo.launch(share=True) | |