AndyRaoTHU commited on
Commit
84b5521
·
1 Parent(s): 209cdb2
revq/models/backbone/dcae.py CHANGED
@@ -240,6 +240,7 @@ class DCAEEncoder(nn.Module):
240
  class DCAEDecoder(nn.Module):
241
  def __init__(
242
  self,
 
243
  in_channels: int,
244
  latent_channels: int,
245
  attention_head_dim: int = 32,
 
240
  class DCAEDecoder(nn.Module):
241
  def __init__(
242
  self,
243
+ TYPE: str,
244
  in_channels: int,
245
  latent_channels: int,
246
  attention_head_dim: int = 32,
revq/models/revq_quantizer.py CHANGED
@@ -15,7 +15,8 @@ def find_nearest(x, y):
15
  return min_dist, indices
16
 
17
  class Quantizer(nn.Module):
18
- def __init__(self, code_dim: int = 128, num_code: int = 1024,
 
19
  num_group: int = None, tokens_per_data: int = 4,
20
  auto_reset: bool = True, reset_ratio: float = 0.2,
21
  reset_noise: float = 0.1
 
15
  return min_dist, indices
16
 
17
  class Quantizer(nn.Module):
18
+ def __init__(self, TYPE: str = "vq",
19
+ code_dim: int = 4, num_code: int = 1024,
20
  num_group: int = None, tokens_per_data: int = 4,
21
  auto_reset: bool = True, reset_ratio: float = 0.2,
22
  reset_noise: float = 0.1