Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ import math
|
|
| 11 |
from utils import compute_ca_loss
|
| 12 |
from gradio import processing_utils
|
| 13 |
from typing import Optional
|
| 14 |
-
|
| 15 |
import warnings
|
| 16 |
|
| 17 |
import sys
|
|
@@ -67,96 +67,7 @@ def draw_box(boxes=[], texts=[], img=None):
|
|
| 67 |
fill=(255, 255, 255))
|
| 68 |
return img
|
| 69 |
|
| 70 |
-
'''
|
| 71 |
-
inference model
|
| 72 |
-
'''
|
| 73 |
-
|
| 74 |
-
def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
|
| 75 |
-
uncond_input = tokenizer(
|
| 76 |
-
[""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
| 77 |
-
)
|
| 78 |
-
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
|
| 79 |
-
|
| 80 |
-
input_ids = tokenizer(
|
| 81 |
-
prompt,
|
| 82 |
-
padding="max_length",
|
| 83 |
-
truncation=True,
|
| 84 |
-
max_length=tokenizer.model_max_length,
|
| 85 |
-
return_tensors="pt",
|
| 86 |
-
).input_ids[0].unsqueeze(0).to(device)
|
| 87 |
-
# text_embeddings = text_encoder(input_ids)[0]
|
| 88 |
-
text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
|
| 89 |
-
# text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
|
| 90 |
-
generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
|
| 91 |
-
|
| 92 |
-
latents = torch.randn(
|
| 93 |
-
(batch_size, 4, 64, 64),
|
| 94 |
-
generator=generator,
|
| 95 |
-
).to(device)
|
| 96 |
-
|
| 97 |
-
noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
| 98 |
-
|
| 99 |
-
# generator = torch.Generator("cuda").manual_seed(1024)
|
| 100 |
-
noise_scheduler.set_timesteps(51)
|
| 101 |
-
|
| 102 |
-
latents = latents * noise_scheduler.init_noise_sigma
|
| 103 |
-
|
| 104 |
-
loss = torch.tensor(10000)
|
| 105 |
-
|
| 106 |
-
for index, t in enumerate(noise_scheduler.timesteps):
|
| 107 |
-
iteration = 0
|
| 108 |
-
|
| 109 |
-
while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
|
| 110 |
-
latents = latents.requires_grad_(True)
|
| 111 |
-
|
| 112 |
-
# latent_model_input = torch.cat([latents] * 2)
|
| 113 |
-
latent_model_input = latents
|
| 114 |
|
| 115 |
-
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
|
| 116 |
-
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
|
| 117 |
-
unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
|
| 118 |
-
|
| 119 |
-
# update latents with guidence from gaussian blob
|
| 120 |
-
|
| 121 |
-
loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
|
| 122 |
-
object_positions=object_positions) * loss_scale
|
| 123 |
-
|
| 124 |
-
print(loss.item() / loss_scale)
|
| 125 |
-
|
| 126 |
-
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
|
| 127 |
-
|
| 128 |
-
latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
|
| 129 |
-
iteration += 1
|
| 130 |
-
torch.cuda.empty_cache()
|
| 131 |
-
torch.cuda.empty_cache()
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
with torch.no_grad():
|
| 135 |
-
|
| 136 |
-
latent_model_input = torch.cat([latents] * 2)
|
| 137 |
-
|
| 138 |
-
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
|
| 139 |
-
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
|
| 140 |
-
unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
| 141 |
-
|
| 142 |
-
noise_pred = noise_pred.sample
|
| 143 |
-
|
| 144 |
-
# perform classifier-free guidance
|
| 145 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 146 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 147 |
-
|
| 148 |
-
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
|
| 149 |
-
torch.cuda.empty_cache()
|
| 150 |
-
# Decode image
|
| 151 |
-
with torch.no_grad():
|
| 152 |
-
# print("decode image")
|
| 153 |
-
latents = 1 / 0.18215 * latents
|
| 154 |
-
image = vae.decode(latents).sample
|
| 155 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
| 156 |
-
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 157 |
-
images = (image * 255).round().astype("uint8")
|
| 158 |
-
pil_images = [Image.fromarray(image) for image in images]
|
| 159 |
-
return pil_images
|
| 160 |
|
| 161 |
def get_concat(ims):
|
| 162 |
if len(ims) == 1:
|
|
@@ -172,42 +83,6 @@ def get_concat(ims):
|
|
| 172 |
return dst
|
| 173 |
|
| 174 |
|
| 175 |
-
def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
|
| 176 |
-
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
|
| 177 |
-
state):
|
| 178 |
-
if 'boxes' not in state:
|
| 179 |
-
state['boxes'] = []
|
| 180 |
-
boxes = state['boxes']
|
| 181 |
-
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
|
| 182 |
-
# assert len(boxes) == len(grounding_texts)
|
| 183 |
-
if len(boxes) != len(grounding_texts):
|
| 184 |
-
if len(boxes) < len(grounding_texts):
|
| 185 |
-
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
|
| 186 |
-
Number of boxes drawn: {}, number of grounding tokens: {}.
|
| 187 |
-
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
|
| 188 |
-
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
|
| 189 |
-
|
| 190 |
-
boxes = (np.asarray(boxes) / 512).tolist()
|
| 191 |
-
boxes = [[box] for box in boxes]
|
| 192 |
-
grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
|
| 193 |
-
language_instruction_list = language_instruction.strip('.').split(' ')
|
| 194 |
-
object_positions = []
|
| 195 |
-
for obj in grounding_texts:
|
| 196 |
-
obj_position = []
|
| 197 |
-
for word in obj.split(' '):
|
| 198 |
-
obj_first_index = language_instruction_list.index(word) + 1
|
| 199 |
-
obj_position.append(obj_first_index)
|
| 200 |
-
object_positions.append(obj_position)
|
| 201 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 202 |
-
|
| 203 |
-
gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
|
| 204 |
-
|
| 205 |
-
blank_samples = batch_size % 2 if batch_size > 1 else 0
|
| 206 |
-
gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
|
| 207 |
-
+ [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
|
| 208 |
-
+ [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
|
| 209 |
-
|
| 210 |
-
return gen_images + [state]
|
| 211 |
|
| 212 |
|
| 213 |
def binarize(x):
|
|
@@ -251,8 +126,9 @@ def center_crop(img, HW=None, tgt_size=(512, 512)):
|
|
| 251 |
|
| 252 |
def draw(input, grounding_texts, new_image_trigger, state):
|
| 253 |
if type(input) == dict:
|
| 254 |
-
|
| 255 |
-
|
|
|
|
| 256 |
else:
|
| 257 |
mask = input
|
| 258 |
if mask.ndim == 3:
|
|
@@ -307,7 +183,7 @@ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
|
|
| 307 |
if task != 'Grounded Inpainting':
|
| 308 |
sketch_pad_trigger = sketch_pad_trigger + 1
|
| 309 |
blank_samples = batch_size % 2 if batch_size > 1 else 0
|
| 310 |
-
out_images = [gr.Image.
|
| 311 |
# state = {}
|
| 312 |
return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
|
| 313 |
|
|
@@ -387,6 +263,139 @@ def main():
|
|
| 387 |
text_encoder.to(device)
|
| 388 |
vae.to(device)
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
with Blocks(
|
| 391 |
css=css,
|
| 392 |
analytics_enabled=False,
|
|
@@ -418,7 +427,7 @@ def main():
|
|
| 418 |
|
| 419 |
|
| 420 |
with gr.Row():
|
| 421 |
-
sketch_pad = gr.Paint(label="Sketch Pad",
|
| 422 |
out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
|
| 423 |
out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
|
| 424 |
|
|
@@ -479,7 +488,7 @@ def main():
|
|
| 479 |
inputs=sketch_pad_trigger,
|
| 480 |
outputs=sketch_pad_trigger,
|
| 481 |
queue=False)
|
| 482 |
-
sketch_pad.
|
| 483 |
draw,
|
| 484 |
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
|
| 485 |
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
|
|
@@ -519,13 +528,13 @@ def main():
|
|
| 519 |
None,
|
| 520 |
None,
|
| 521 |
sketch_pad_resize_trigger,
|
| 522 |
-
|
| 523 |
queue=False)
|
| 524 |
init_white_trigger.change(
|
| 525 |
None,
|
| 526 |
None,
|
| 527 |
init_white_trigger,
|
| 528 |
-
|
| 529 |
queue=False)
|
| 530 |
|
| 531 |
with gr.Column():
|
|
@@ -546,7 +555,7 @@ def main():
|
|
| 546 |
description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
|
| 547 |
gr.HTML(description)
|
| 548 |
|
| 549 |
-
demo.queue(
|
| 550 |
demo.launch(share=False, show_api=False, show_error=True)
|
| 551 |
|
| 552 |
if __name__ == '__main__':
|
|
|
|
| 11 |
from utils import compute_ca_loss
|
| 12 |
from gradio import processing_utils
|
| 13 |
from typing import Optional
|
| 14 |
+
import spaces
|
| 15 |
import warnings
|
| 16 |
|
| 17 |
import sys
|
|
|
|
| 67 |
fill=(255, 255, 255))
|
| 68 |
return img
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def get_concat(ims):
|
| 73 |
if len(ims) == 1:
|
|
|
|
| 83 |
return dst
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def binarize(x):
|
|
|
|
| 126 |
|
| 127 |
def draw(input, grounding_texts, new_image_trigger, state):
|
| 128 |
if type(input) == dict:
|
| 129 |
+
# import pdb; pdb.set_trace()
|
| 130 |
+
# image = input['composite']
|
| 131 |
+
mask = input['composite']
|
| 132 |
else:
|
| 133 |
mask = input
|
| 134 |
if mask.ndim == 3:
|
|
|
|
| 183 |
if task != 'Grounded Inpainting':
|
| 184 |
sketch_pad_trigger = sketch_pad_trigger + 1
|
| 185 |
blank_samples = batch_size % 2 if batch_size > 1 else 0
|
| 186 |
+
out_images = [gr.Image.change(value=None, visible=True) for i in range(batch_size)]
|
| 187 |
# state = {}
|
| 188 |
return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
|
| 189 |
|
|
|
|
| 263 |
text_encoder.to(device)
|
| 264 |
vae.to(device)
|
| 265 |
|
| 266 |
+
def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
|
| 267 |
+
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
|
| 268 |
+
state):
|
| 269 |
+
if 'boxes' not in state:
|
| 270 |
+
state['boxes'] = []
|
| 271 |
+
boxes = state['boxes']
|
| 272 |
+
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
|
| 273 |
+
# assert len(boxes) == len(grounding_texts)
|
| 274 |
+
if len(boxes) != len(grounding_texts):
|
| 275 |
+
if len(boxes) < len(grounding_texts):
|
| 276 |
+
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
|
| 277 |
+
Number of boxes drawn: {}, number of grounding tokens: {}.
|
| 278 |
+
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
|
| 279 |
+
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
|
| 280 |
+
|
| 281 |
+
boxes = (np.asarray(boxes) / 512).tolist()
|
| 282 |
+
boxes = [[box] for box in boxes]
|
| 283 |
+
grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
|
| 284 |
+
language_instruction_list = language_instruction.strip('.').split(' ')
|
| 285 |
+
object_positions = []
|
| 286 |
+
for obj in grounding_texts:
|
| 287 |
+
obj_position = []
|
| 288 |
+
for word in obj.split(' '):
|
| 289 |
+
obj_first_index = language_instruction_list.index(word) + 1
|
| 290 |
+
obj_position.append(obj_first_index)
|
| 291 |
+
object_positions.append(obj_position)
|
| 292 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 293 |
+
|
| 294 |
+
gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes,
|
| 295 |
+
object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed,
|
| 296 |
+
guidance_scale)
|
| 297 |
+
|
| 298 |
+
blank_samples = batch_size % 2 if batch_size > 1 else 0
|
| 299 |
+
gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
|
| 300 |
+
+ [gr.Image.change(fn=None, show_api=True) for _ in range(blank_samples)] \
|
| 301 |
+
+ [gr.Image.change(fn=None, show_api=False) for _ in range(4 - batch_size - blank_samples)]
|
| 302 |
+
|
| 303 |
+
return gen_images + [state]
|
| 304 |
+
|
| 305 |
+
'''
|
| 306 |
+
inference model
|
| 307 |
+
'''
|
| 308 |
+
|
| 309 |
+
@spaces.GPU(duration=180)
|
| 310 |
+
def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale,
|
| 311 |
+
loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
|
| 312 |
+
uncond_input = tokenizer(
|
| 313 |
+
[""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
| 314 |
+
)
|
| 315 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
|
| 316 |
+
|
| 317 |
+
input_ids = tokenizer(
|
| 318 |
+
prompt,
|
| 319 |
+
padding="max_length",
|
| 320 |
+
truncation=True,
|
| 321 |
+
max_length=tokenizer.model_max_length,
|
| 322 |
+
return_tensors="pt",
|
| 323 |
+
).input_ids[0].unsqueeze(0).to(device)
|
| 324 |
+
# text_embeddings = text_encoder(input_ids)[0]
|
| 325 |
+
text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
|
| 326 |
+
# text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
|
| 327 |
+
generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
|
| 328 |
+
|
| 329 |
+
latents = torch.randn(
|
| 330 |
+
(batch_size, 4, 64, 64),
|
| 331 |
+
generator=generator,
|
| 332 |
+
).to(device)
|
| 333 |
+
|
| 334 |
+
noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
| 335 |
+
num_train_timesteps=1000)
|
| 336 |
+
|
| 337 |
+
# generator = torch.Generator("cuda").manual_seed(1024)
|
| 338 |
+
noise_scheduler.set_timesteps(51)
|
| 339 |
+
|
| 340 |
+
latents = latents * noise_scheduler.init_noise_sigma
|
| 341 |
+
|
| 342 |
+
loss = torch.tensor(10000)
|
| 343 |
+
|
| 344 |
+
for index, t in enumerate(noise_scheduler.timesteps):
|
| 345 |
+
iteration = 0
|
| 346 |
+
|
| 347 |
+
while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
|
| 348 |
+
latents = latents.requires_grad_(True)
|
| 349 |
+
|
| 350 |
+
# latent_model_input = torch.cat([latents] * 2)
|
| 351 |
+
latent_model_input = latents
|
| 352 |
+
|
| 353 |
+
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
|
| 354 |
+
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
|
| 355 |
+
unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
|
| 356 |
+
|
| 357 |
+
# update latents with guidence from gaussian blob
|
| 358 |
+
|
| 359 |
+
loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
|
| 360 |
+
object_positions=object_positions) * loss_scale
|
| 361 |
+
|
| 362 |
+
print(loss.item() / loss_scale)
|
| 363 |
+
|
| 364 |
+
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
|
| 365 |
+
|
| 366 |
+
latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
|
| 367 |
+
iteration += 1
|
| 368 |
+
torch.cuda.empty_cache()
|
| 369 |
+
torch.cuda.empty_cache()
|
| 370 |
+
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
|
| 373 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 374 |
+
|
| 375 |
+
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
|
| 376 |
+
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
|
| 377 |
+
unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
| 378 |
+
|
| 379 |
+
noise_pred = noise_pred.sample
|
| 380 |
+
|
| 381 |
+
# perform classifier-free guidance
|
| 382 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 383 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 384 |
+
|
| 385 |
+
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
|
| 386 |
+
torch.cuda.empty_cache()
|
| 387 |
+
# Decode image
|
| 388 |
+
with torch.no_grad():
|
| 389 |
+
# print("decode image")
|
| 390 |
+
latents = 1 / 0.18215 * latents
|
| 391 |
+
image = vae.decode(latents).sample
|
| 392 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 393 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 394 |
+
images = (image * 255).round().astype("uint8")
|
| 395 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 396 |
+
return pil_images
|
| 397 |
+
|
| 398 |
+
|
| 399 |
with Blocks(
|
| 400 |
css=css,
|
| 401 |
analytics_enabled=False,
|
|
|
|
| 427 |
|
| 428 |
|
| 429 |
with gr.Row():
|
| 430 |
+
sketch_pad = gr.Paint(label="Sketch Pad", container=False, layers=False, scale=1, elem_id="img2img_image")
|
| 431 |
out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
|
| 432 |
out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
|
| 433 |
|
|
|
|
| 488 |
inputs=sketch_pad_trigger,
|
| 489 |
outputs=sketch_pad_trigger,
|
| 490 |
queue=False)
|
| 491 |
+
sketch_pad.change(
|
| 492 |
draw,
|
| 493 |
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
|
| 494 |
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
|
|
|
|
| 528 |
None,
|
| 529 |
None,
|
| 530 |
sketch_pad_resize_trigger,
|
| 531 |
+
js=rescale_js,
|
| 532 |
queue=False)
|
| 533 |
init_white_trigger.change(
|
| 534 |
None,
|
| 535 |
None,
|
| 536 |
init_white_trigger,
|
| 537 |
+
js=rescale_js,
|
| 538 |
queue=False)
|
| 539 |
|
| 540 |
with gr.Column():
|
|
|
|
| 555 |
description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
|
| 556 |
gr.HTML(description)
|
| 557 |
|
| 558 |
+
demo.queue(api_open=False)
|
| 559 |
demo.launch(share=False, show_api=False, show_error=True)
|
| 560 |
|
| 561 |
if __name__ == '__main__':
|