Spaces:
Runtime error
Runtime error
Zongsheng
commited on
Commit
·
4cd2c6a
1
Parent(s):
f857ecf
add resize for arbitraty size
Browse files- sampler.py +5 -4
sampler.py
CHANGED
|
@@ -166,6 +166,11 @@ class DifIRSampler(BaseSampler):
|
|
| 166 |
# basical image restoration
|
| 167 |
device = next(self.model.parameters()).device
|
| 168 |
y0 = y0.to(device=device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if need_restoration:
|
| 170 |
with torch.no_grad():
|
| 171 |
if model_kwargs_ir is None:
|
|
@@ -176,10 +181,6 @@ class DifIRSampler(BaseSampler):
|
|
| 176 |
im_hq = y0
|
| 177 |
im_hq.clamp_(0.0, 1.0)
|
| 178 |
|
| 179 |
-
h_old, w_old = im_hq.shape[2:4]
|
| 180 |
-
if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
|
| 181 |
-
im_hq = resize(im_hq, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
|
| 182 |
-
|
| 183 |
# diffuse for im_hq
|
| 184 |
yt = self.diffusion.q_sample(
|
| 185 |
x_start=post_fun(im_hq),
|
|
|
|
| 166 |
# basical image restoration
|
| 167 |
device = next(self.model.parameters()).device
|
| 168 |
y0 = y0.to(device=device, dtype=torch.float32)
|
| 169 |
+
|
| 170 |
+
h_old, w_old = y0.shape[2:4]
|
| 171 |
+
if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
|
| 172 |
+
y0 = resize(y0, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
|
| 173 |
+
|
| 174 |
if need_restoration:
|
| 175 |
with torch.no_grad():
|
| 176 |
if model_kwargs_ir is None:
|
|
|
|
| 181 |
im_hq = y0
|
| 182 |
im_hq.clamp_(0.0, 1.0)
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
# diffuse for im_hq
|
| 185 |
yt = self.diffusion.q_sample(
|
| 186 |
x_start=post_fun(im_hq),
|