Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, LMSDiscreteScheduler | |
| from my_model import unet_2d_condition | |
| import json | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from functools import partial | |
| import math | |
| from utils import compute_ca_loss | |
| from gradio import processing_utils | |
| from typing import Optional | |
| import spaces | |
| import warnings | |
| import sys | |
| sys.tracebacklimit = 0 | |
| class Blocks(gr.Blocks): | |
| def __init__( | |
| self, | |
| theme: str = "default", | |
| analytics_enabled: Optional[bool] = None, | |
| mode: str = "blocks", | |
| title: str = "Gradio", | |
| css: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| self.extra_configs = { | |
| 'thumbnail': kwargs.pop('thumbnail', ''), | |
| 'url': kwargs.pop('url', 'https://gradio.app/'), | |
| 'creator': kwargs.pop('creator', '@teamGradio'), | |
| } | |
| super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs) | |
| warnings.filterwarnings("ignore") | |
| def get_config_file(self): | |
| config = super(Blocks, self).get_config_file() | |
| for k, v in self.extra_configs.items(): | |
| config[k] = v | |
| return config | |
| def draw_box(boxes=[], texts=[], img=None): | |
| if len(boxes) == 0 and img is None: | |
| return None | |
| if img is None: | |
| img = Image.new('RGB', (512, 512), (255, 255, 255)) | |
| colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.truetype("DejaVuSansMono.ttf", size=18) | |
| print(boxes) | |
| for bid, box in enumerate(boxes): | |
| draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) | |
| anno_text = texts[bid] | |
| draw.rectangle( | |
| [box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], | |
| outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) | |
| draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font, | |
| fill=(255, 255, 255)) | |
| return img | |
| def get_concat(ims): | |
| if len(ims) == 1: | |
| n_col = 1 | |
| else: | |
| n_col = 2 | |
| n_row = math.ceil(len(ims) / 2) | |
| dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white") | |
| for i, im in enumerate(ims): | |
| row_id = i // n_col | |
| col_id = i % n_col | |
| dst.paste(im, (im.width * col_id, im.height * row_id)) | |
| return dst | |
| def binarize(x): | |
| return (x != 0).astype('uint8') * 255 | |
| def sized_center_crop(img, cropx, cropy): | |
| y, x = img.shape[:2] | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| return img[starty:starty + cropy, startx:startx + cropx] | |
| def sized_center_fill(img, fill, cropx, cropy): | |
| y, x = img.shape[:2] | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| img[starty:starty + cropy, startx:startx + cropx] = fill | |
| return img | |
| def sized_center_mask(img, cropx, cropy): | |
| y, x = img.shape[:2] | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| center_region = img[starty:starty + cropy, startx:startx + cropx].copy() | |
| img = (img * 0.2).astype('uint8') | |
| img[starty:starty + cropy, startx:startx + cropx] = center_region | |
| return img | |
| def center_crop(img, HW=None, tgt_size=(512, 512)): | |
| if HW is None: | |
| H, W = img.shape[:2] | |
| HW = min(H, W) | |
| img = sized_center_crop(img, HW, HW) | |
| img = Image.fromarray(img) | |
| img = img.resize(tgt_size) | |
| return np.array(img) | |
| def draw(input, grounding_texts, new_image_trigger, state): | |
| if type(input) == dict: | |
| # import pdb; pdb.set_trace() | |
| # image = input['composite'] | |
| mask = input['composite'] | |
| else: | |
| mask = input | |
| if mask.ndim == 3: | |
| mask = 255 - mask[..., 0] | |
| image_scale = 1.0 | |
| mask = binarize(mask) | |
| if type(mask) != np.ndarray: | |
| mask = np.array(mask) | |
| if mask.sum() == 0: | |
| state = {} | |
| image = None | |
| if 'boxes' not in state: | |
| state['boxes'] = [] | |
| if 'masks' not in state or len(state['masks']) == 0: | |
| state['masks'] = [] | |
| last_mask = np.zeros_like(mask) | |
| else: | |
| last_mask = state['masks'][-1] | |
| if type(mask) == np.ndarray and mask.size > 1: | |
| diff_mask = mask - last_mask | |
| else: | |
| diff_mask = np.zeros([]) | |
| if diff_mask.sum() > 0: | |
| x1x2 = np.where(diff_mask.max(0) != 0)[0] | |
| y1y2 = np.where(diff_mask.max(1) != 0)[0] | |
| y1, y2 = y1y2.min(), y1y2.max() | |
| x1, x2 = x1x2.min(), x1x2.max() | |
| if (x2 - x1 > 5) and (y2 - y1 > 5): | |
| state['masks'].append(mask.copy()) | |
| state['boxes'].append((x1, y1, x2, y2)) | |
| grounding_texts = [x.strip() for x in grounding_texts.split(';')] | |
| grounding_texts = [x for x in grounding_texts if len(x) > 0] | |
| if len(grounding_texts) < len(state['boxes']): | |
| grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))] | |
| box_image = draw_box(state['boxes'], grounding_texts, image) | |
| return [box_image, new_image_trigger, image_scale, state] | |
| def clear(sketch_pad_trigger, batch_size, state, switch_task=False): | |
| sketch_pad_trigger = sketch_pad_trigger + 1 | |
| blank_samples = batch_size % 2 if batch_size > 1 else 0 | |
| out_images = [None] | |
| # state = {} | |
| return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}] | |
| def main(): | |
| css = """ | |
| #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img | |
| { | |
| height: var(--height) !important; | |
| max-height: var(--height) !important; | |
| min-height: var(--height) !important; | |
| } | |
| #paper-info a { | |
| color:#008AD7; | |
| text-decoration: none; | |
| } | |
| #paper-info a:hover { | |
| cursor: pointer; | |
| text-decoration: none; | |
| } | |
| .tooltip { | |
| color: #555; | |
| position: relative; | |
| display: inline-block; | |
| cursor: pointer; | |
| } | |
| .tooltip .tooltiptext { | |
| visibility: hidden; | |
| width: 400px; | |
| background-color: #555; | |
| color: #fff; | |
| text-align: center; | |
| padding: 5px; | |
| border-radius: 5px; | |
| position: absolute; | |
| z-index: 1; /* Set z-index to 1 */ | |
| left: 10px; | |
| top: 100%; | |
| opacity: 0; | |
| transition: opacity 0.3s; | |
| } | |
| .tooltip:hover .tooltiptext { | |
| visibility: visible; | |
| opacity: 1; | |
| z-index: 9999; /* Set a high z-index value when hovering */ | |
| } | |
| """ | |
| rescale_js = """ | |
| function(x) { | |
| const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app'); | |
| let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0; | |
| const image_width = root.querySelector('#img2img_image').clientWidth; | |
| const target_height = parseInt(image_width * image_scale); | |
| document.body.style.setProperty('--height', `${target_height}px`); | |
| root.querySelectorAll('button.justify-center.rounded')[0].style.display='none'; | |
| root.querySelectorAll('button.justify-center.rounded')[1].style.display='none'; | |
| return x; | |
| } | |
| """ | |
| with open('./conf/unet/config.json') as f: | |
| unet_config = json.load(f) | |
| unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5', | |
| subfolder="unet") | |
| tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") | |
| vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| unet.to(device) | |
| text_encoder.to(device) | |
| vae.to(device) | |
| def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad, | |
| loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter, | |
| state): | |
| if 'boxes' not in state: | |
| state['boxes'] = [] | |
| boxes = state['boxes'] | |
| grounding_texts = [x.strip() for x in grounding_texts.split(';')] | |
| # assert len(boxes) == len(grounding_texts) | |
| if len(boxes) != len(grounding_texts): | |
| if len(boxes) < len(grounding_texts): | |
| raise ValueError("""The number of boxes should be equal to the number of grounding objects. | |
| Number of boxes drawn: {}, number of grounding tokens: {}. | |
| Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts))) | |
| grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts)) | |
| boxes = (np.asarray(boxes) / 512).tolist() | |
| boxes = [[box] for box in boxes] | |
| grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)}) | |
| language_instruction_list = language_instruction.strip('.').split(' ') | |
| object_positions = [] | |
| for obj in grounding_texts: | |
| obj_position = [] | |
| for word in obj.split(' '): | |
| obj_first_index = language_instruction_list.index(word) + 1 | |
| obj_position.append(obj_first_index) | |
| object_positions.append(obj_position) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, | |
| object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, | |
| guidance_scale) | |
| blank_samples = batch_size % 2 if batch_size > 1 else 0 | |
| gen_images = [x for i, x in enumerate(gen_images)] \ | |
| + [_ for _ in range(blank_samples)] \ | |
| + [_ for _ in range(4 - batch_size - blank_samples)] | |
| return gen_images + [state] | |
| ''' | |
| inference model | |
| ''' | |
| def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, | |
| loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale): | |
| uncond_input = tokenizer( | |
| [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
| input_ids = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids[0].unsqueeze(0).to(device) | |
| # text_embeddings = text_encoder(input_ids)[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]]) | |
| # text_embeddings[1, 1, :] = text_embeddings[1, 2, :] | |
| generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise | |
| latents = torch.randn( | |
| (batch_size, 4, 64, 64), | |
| generator=generator, | |
| ).to(device) | |
| noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", | |
| num_train_timesteps=1000) | |
| # generator = torch.Generator("cuda").manual_seed(1024) | |
| noise_scheduler.set_timesteps(51) | |
| latents = latents * noise_scheduler.init_noise_sigma | |
| loss = torch.tensor(10000) | |
| for index, t in enumerate(noise_scheduler.timesteps): | |
| iteration = 0 | |
| while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step: | |
| latents = latents.requires_grad_(True) | |
| # latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = latents | |
| latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) | |
| noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \ | |
| unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0]) | |
| # update latents with guidence from gaussian blob | |
| loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes, | |
| object_positions=object_positions) * loss_scale | |
| print(loss.item() / loss_scale) | |
| grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0] | |
| latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2 | |
| iteration += 1 | |
| torch.cuda.empty_cache() | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) | |
| noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \ | |
| unet(latent_model_input, t, encoder_hidden_states=text_embeddings) | |
| noise_pred = noise_pred.sample | |
| # perform classifier-free guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| torch.cuda.empty_cache() | |
| # Decode image | |
| with torch.no_grad(): | |
| # print("decode image") | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| images = (image * 255).round().astype("uint8") | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| with Blocks( | |
| css=css, | |
| analytics_enabled=False, | |
| title="Layout-Guidance demo", | |
| ) as demo: | |
| description = """<p style="text-align: center; font-weight: bold;"> | |
| <span style="font-size: 28px">Layout Guidance</span> | |
| <br> | |
| <span style="font-size: 18px" id="paper-info"> | |
| [<a href=" " target="_blank">Project Page</a>] | |
| [<a href=" " target="_blank">Paper</a>] | |
| [<a href=" " target="_blank">GitHub</a>] | |
| </span> | |
| </p> | |
| """ | |
| gr.HTML(description) | |
| with gr.Column(): | |
| language_instruction = gr.Textbox( | |
| label="Text Prompt", | |
| ) | |
| grounding_instruction = gr.Textbox( | |
| label="Grounding instruction (Separated by semicolon)", | |
| ) | |
| sketch_pad_trigger = gr.Number(value=0, visible=False) | |
| sketch_pad_resize_trigger = gr.Number(value=0, visible=False) | |
| init_white_trigger = gr.Number(value=0, visible=False) | |
| image_scale = gr.Number(value=0, elem_id="image_scale", visible=False) | |
| new_image_trigger = gr.Number(value=0, visible=False) | |
| with gr.Row(): | |
| sketch_pad = gr.Paint(label="Sketch Pad", container=False, layers=False, scale=1, elem_id="img2img_image", canvas_size=(512,512)) | |
| out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad") | |
| out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image") | |
| with gr.Row(): | |
| clear_btn = gr.Button(value='Clear') | |
| gen_btn = gr.Button(value='Generate') | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Column(): | |
| description = """<div class="tooltip">Loss Scale Factor ⓘ | |
| <span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span> | |
| </div> | |
| <div class="tooltip">Guidance Scale ⓘ | |
| <span class="tooltiptext">The scale factor of classifier-free guidance. </span> | |
| </div> | |
| <div class="tooltip" >Max Iteration per Step ⓘ | |
| <span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span> | |
| </div> | |
| <div class="tooltip" >Loss Threshold ⓘ | |
| <span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span> | |
| </div> | |
| <div class="tooltip" >Max Step of Backward Guidance ⓘ | |
| <span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span> | |
| </div> | |
| """ | |
| gr.HTML(description) | |
| Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor") | |
| guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale") | |
| batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False) | |
| max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step") | |
| loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold") | |
| max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance") | |
| rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed") | |
| state = gr.State({}) | |
| class Controller: | |
| def __init__(self): | |
| self.calls = 0 | |
| self.tracks = 0 | |
| self.resizes = 0 | |
| self.scales = 0 | |
| def init_white(self, init_white_trigger): | |
| self.calls += 1 | |
| return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1 | |
| def change_n_samples(self, n_samples): | |
| blank_samples = n_samples % 2 if n_samples > 1 else 0 | |
| return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \ | |
| + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)] | |
| controller = Controller() | |
| demo.load( | |
| lambda x: x + 1, | |
| inputs=sketch_pad_trigger, | |
| outputs=sketch_pad_trigger, | |
| queue=False) | |
| sketch_pad.change( | |
| draw, | |
| inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state], | |
| outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state], | |
| concurrency_limit=1, | |
| queue=False, | |
| ) | |
| grounding_instruction.change( | |
| draw, | |
| inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state], | |
| outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state], | |
| concurrency_limit=1, | |
| queue=False, | |
| ) | |
| clear_btn.click( | |
| clear, | |
| inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state], | |
| outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state], | |
| concurrency_limit=1, | |
| queue=False) | |
| sketch_pad_trigger.change( | |
| controller.init_white, | |
| inputs=[init_white_trigger], | |
| outputs=[sketch_pad, image_scale, init_white_trigger], | |
| concurrency_limit=1, | |
| queue=False) | |
| gen_btn.click( | |
| fn=partial(generate, unet, vae, tokenizer, text_encoder), | |
| inputs=[ | |
| language_instruction, grounding_instruction, sketch_pad, | |
| loss_threshold, guidance_scale, batch_size, rand_seed, | |
| max_step, | |
| Loss_scale, max_iter, | |
| state, | |
| ], | |
| outputs=[out_gen_1, state], | |
| concurrency_limit=1, | |
| queue=True | |
| ) | |
| sketch_pad_resize_trigger.change( | |
| None, | |
| None, | |
| sketch_pad_resize_trigger, | |
| js=rescale_js, | |
| concurrency_limit=1, | |
| queue=False) | |
| init_white_trigger.change( | |
| None, | |
| None, | |
| init_white_trigger, | |
| js=rescale_js, | |
| concurrency_limit=1, | |
| queue=False) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| # "images/input.png", | |
| "A hello kitty toy is playing with a purple ball.", | |
| "hello kitty;ball", | |
| "images/hello_kitty_results.png" | |
| ], | |
| ], | |
| inputs=[language_instruction, grounding_instruction, out_gen_1], | |
| outputs=None, | |
| fn=None, | |
| cache_examples=False, | |
| ) | |
| description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>""" | |
| gr.HTML(description) | |
| demo.queue(api_open=False) | |
| demo.launch(share=False, show_api=False, show_error=True) | |
| if __name__ == '__main__': | |
| main() |