File size: 3,065 Bytes
199c8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20077d0
 
b5f2ae4
 
 
 
e7f69aa
20077d0
 
 
 
 
 
 
 
209cdb2
20077d0
 
199c8cd
acc4c00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199c8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20077d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------

# Convert a Pytorch model to a Hugging Face model

import torch.nn as nn
import torch
import os

from huggingface_hub import PyTorchModelHubMixin
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from revq.models.backbone.dcae import DCAEDecoder
from revq.models.revq_quantizer import Quantizer

class Viewer:
    def __init__(self, input_size, code_dim):
        self.input_size = input_size
        self.code_dim = code_dim

    def shuffle(self, x: torch.Tensor):
        # (B, C, H, W) -> (B, T, D)
        B, C, H, W = x.size()
        x = x.view(B, -1, self.code_dim)
        return x

    def unshuffle(self, x: torch.Tensor):
        # (B, T, D) -> (B, C, H, W)
        B, T, D = x.size()
        C, H, W = self.input_size
        x = x.view(B, C, H, W)
        return x

class ReVQ(PyTorchModelHubMixin, nn.Module):
    @classmethod
    def _from_pretrained(cls, model_id: str, **kwargs):
        print(f"Loading ReVQ model from {model_id}...")
        config_path = hf_hub_download(repo_id=model_id, filename="512T_NC=16384.yaml")
        ckpt_path = hf_hub_download(repo_id=model_id, filename="ckpt.pth")

        full_cfg = OmegaConf.load(config_path)
        model_cfg = full_cfg.get("model", {})
        decoder_config = model_cfg.get("decoder", {})
        quantize_config = model_cfg.get("quantizer", {})

        model = cls(decoder=decoder_config, 
                    quantize=quantize_config, 
                    ckpt_path=ckpt_path,
                    )

        return model
    
    def __init__(self, 
            decoder: dict = {},
            quantize: dict = {},
            ckpt_path: str = None,
        ):
        super(ReVQ, self).__init__()
        self.decoder = DCAEDecoder(**decoder)
        self.quantizer = self.setup_quantizer(quantize)
        self.viewer = Viewer(input_size=[32, 8, 8], code_dim=4)

        if ckpt_path is not None and os.path.exists(ckpt_path):
            # Load the model checkpoint
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            self.decoder.load_state_dict(checkpoint["decoder"])
            self.quantizer.load_state_dict(checkpoint["quantizer"])
    
    def setup_quantizer(self, quantizer_config):
        quantizer = Quantizer(**quantizer_config)
        return quantizer
    
    def quantize(self, x):
        x_shuffle = self.viewer.shuffle(x)
        quant_shuffle = self.quantizer(x_shuffle)["x_quant"]
        quant = self.viewer.unshuffle(quant_shuffle)
        return quant

    def decode(self, x):
        x_rec = self.decoder(x)
        return x_rec

    def forward(self, x):
        quant = self.quantize(x)
        rec = self.decode(quant)
        return quant, rec