AndyRaoTHU commited on
Commit
6e14098
·
1 Parent(s): f0dcf96
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -140,6 +140,10 @@ class Handler:
140
  self.vqgan.to(self.device)
141
  self.vqgan.eval()
142
 
 
 
 
 
143
  self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
144
  self.vae.to(self.device)
145
  self.vae.eval()
@@ -176,14 +180,19 @@ class Handler:
176
  revq_rec = revq_rec.contiguous()
177
  revq_rec = self.preprocesser.inverse(revq_rec)
178
  revq_rec = self.vae.decode(revq_rec).sample
 
 
 
179
 
180
  # tensor to PIL image
181
  img = self.tensor_to_image(img)
182
  basevq_rec = self.tensor_to_image(basevq_rec)
183
  vqgan_rec = self.tensor_to_image(vqgan_rec)
184
  revq_rec = self.tensor_to_image(revq_rec)
 
185
  print("Shapes:", img.shape, basevq_rec.shape, vqgan_rec.shape, revq_rec.shape)
186
- return img, basevq_rec, vqgan_rec, revq_rec
 
187
 
188
  if __name__ == "__main__":
189
  # create the model handler
 
140
  self.vqgan.to(self.device)
141
  self.vqgan.eval()
142
 
143
+ self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
144
+ self.optvq.to(self.device)
145
+ self.optvq.eval()
146
+
147
  self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
148
  self.vae.to(self.device)
149
  self.vae.eval()
 
180
  revq_rec = revq_rec.contiguous()
181
  revq_rec = self.preprocesser.inverse(revq_rec)
182
  revq_rec = self.vae.decode(revq_rec).sample
183
+ # optvq
184
+ quant, *_ = self.optvq.encode(img)
185
+ optvq_rec = self.optvq.decode(quant)
186
 
187
  # tensor to PIL image
188
  img = self.tensor_to_image(img)
189
  basevq_rec = self.tensor_to_image(basevq_rec)
190
  vqgan_rec = self.tensor_to_image(vqgan_rec)
191
  revq_rec = self.tensor_to_image(revq_rec)
192
+ optvq_rec = self.tensor_to_image(optvq_rec)
193
  print("Shapes:", img.shape, basevq_rec.shape, vqgan_rec.shape, revq_rec.shape)
194
+ # return img, basevq_rec, vqgan_rec, revq_rec
195
+ return img, basevq_rec, vqgan_rec, optvq_rec
196
 
197
  if __name__ == "__main__":
198
  # create the model handler