File size: 6,037 Bytes
71cd91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils.parametrizations import weight_norm


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


class VectorQuantize(nn.Module):
    def __init__(
        self,
        input_dim: int,
        codebook_size: int,
        codebook_dim: int,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim

        self.in_project = (
            WNConv1d(
                self.input_dim, self.codebook_dim, kernel_size=1
            )  # (B, D, T) -> (B, D', T)
            if self.input_dim != self.codebook_dim
            else nn.Identity()
        )
        self.out_project = (
            WNConv1d(
                self.codebook_dim, self.input_dim, kernel_size=1
            )  # (B, D', T) -> (B, D, T)
            if self.input_dim != self.codebook_dim
            else nn.Identity()
        )

        # Initialize codebook and EMA buffers
        self.register_buffer(
            "codebook", torch.zeros(codebook_size, codebook_dim).float()
        )  # (codebook_size, D'), ensure fp32
        # Place holder, not used in inference
        self.register_buffer("inited", torch.tensor([True], dtype=torch.bool))  # (1)
        self.register_buffer(
            "cluster_size", torch.zeros(codebook_size).float()
        )  # (codebook_size), ensure fp32
        self.register_buffer(
            "embed_avg", self.codebook.clone().float()
        )  # (codebook_size, D'), ensure fp32

    def decode_code(self, embed_id):  # embed_id: (B, T)
        embed = (
            F.embedding(embed_id, self.codebook).transpose(1, 2).float()
        )  # (B, D', T), ensure fp32
        return embed

    def encode_code(self, z: torch.Tensor):  # z: (B, D, T)
        # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
        z = z.float()  # Ensure fp32
        z_e = self.in_project(z).float()  # (B, D', T), ensure fp32

        # Rearrange for quantization
        encodings = rearrange(z_e, "b d t -> (b t) d").float()  # (B*T, D'), ensure fp32

        # Quantization
        dist = (
            encodings.pow(2).sum(1, keepdim=True)  # (B*T, 1)
            - 2 * encodings @ self.codebook.float().t()  # (B*T, codebook_size)
            + self.codebook.float().pow(2).sum(1, keepdim=True).t()
        )  # (1, codebook_size)

        # dist: (B*T, codebook_size)
        indices = (-dist).max(1)[1]  # (B*T)
        indices = rearrange(indices, "(b t) -> b t", b=z.size(0))  # (B, T)

        # Get quantized vectors
        z_q = self.decode_code(indices).float()  # (B, D', T), ensure fp32

        # Straight-through estimator
        z_q = z_e + (z_q - z_e).detach()  # (B, D', T)
        z_q = self.out_project(z_q).float()  # (B, D, T), ensure fp32

        # z_q: (B, D, T), commit_loss: (B), indices: (B, T), z: (B, D', T)
        return z_q, indices


class ResidualVQ(nn.Module):
    def __init__(
        self,
        input_dim: int = 768,  # Input dimension, unrelated to RVQ
        rvq_dim=None,  # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
        output_dim: int = None,  # Output dimension, unrelated to RVQ
        num_quantizers: int = 8,
        codebook_size: int = 1024,
        codebook_dim: int = 256,  # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
    ):
        super().__init__()
        self.input_dim = input_dim

        self.num_quantizers = num_quantizers
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.rvq_dim = rvq_dim

        self.input_proj = (
            WNConv1d(input_dim, rvq_dim, kernel_size=1)
            if input_dim != rvq_dim
            else nn.Identity()
        )
        self.output_proj = (
            WNConv1d(rvq_dim, output_dim, kernel_size=1)
            if rvq_dim != output_dim
            else nn.Identity()
        )

        self.quantizers = nn.ModuleList(
            [
                VectorQuantize(
                    input_dim=rvq_dim,
                    codebook_size=self.codebook_size,
                    codebook_dim=codebook_dim,
                )
                for i in range(num_quantizers)
            ]
        )

    def encode_codes(self, z: torch.Tensor):
        z = self.input_proj(z)
        residual = z.clone().float()  # (B, D, T), ensure fp32
        all_indices = []
        # Quantize to tokens
        for i, quantizer in enumerate(self.quantizers):
            # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
            z_q_i, indices_i = quantizer.encode_code(residual)
            residual = residual - z_q_i
            all_indices.append(indices_i)  # (B, T)
        all_indices = torch.stack(all_indices)  # (N, B, T)
        return all_indices

    def decode_codes(self, codes):  # codes: (nq, B, T)
        """Decode codes from multiple quantizers to embeddings.

        Args:
            codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.

        Returns:
            emb: Tensor of shape (B, D, T) representing the decoded embeddings.
        """
        nq, B, T = codes.shape
        device = codes.device
        emb = torch.zeros(
            B, self.rvq_dim, T, device=device, dtype=torch.float32
        )  # (B, D, T)
        for i, quantizer in enumerate(self.quantizers[:nq]):
            code_i = codes[i]  # (B, T)
            quantized_i = quantizer.decode_code(code_i)  # (B, D', T)
            emb += quantizer.out_project(quantized_i)  # Accumulate quantized embeddings
        emb = self.output_proj(emb)  # (B, D, T), apply output projection
        return emb  # (B, D, T)