Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from diffusers import DDIMScheduler,DiffusionPipeline | |
| import torch.nn.functional as F | |
| import cv2 | |
| from torchvision.utils import save_image | |
| from diffusers.utils import load_image | |
| from torchvision.transforms.functional import to_tensor, gaussian_blur | |
| from matplotlib import pyplot as plt | |
| import gradio as gr | |
| import spaces | |
| from gradio_imageslider import ImageSlider | |
| from torchvision.transforms.functional import to_pil_image, to_tensor | |
| from PIL import ImageFilter, Image | |
| import traceback | |
| def preprocess_image(input_image, device): | |
| image = to_tensor(input_image) | |
| image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1] | |
| if image.shape[1] != 3: | |
| image = image.expand(-1, 3, -1, -1) | |
| image = F.interpolate(image, (1024, 1024)) | |
| image = image.to(dtype).to(device) | |
| return image | |
| def load_description(fp): | |
| with open(fp, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| def preprocess_mask(input_mask, device): | |
| # Split the channels | |
| r, g, b, alpha = input_mask.split() | |
| # Create a new image where: | |
| # - Black areas (where RGB = 0) become white (255). | |
| # - Transparent areas (where alpha = 0) become black (0). | |
| new_mask = Image.new("L", input_mask.size) | |
| for x in range(input_mask.width): | |
| for y in range(input_mask.height): | |
| if alpha.getpixel((x, y)) == 0: # Transparent pixel | |
| new_mask.putpixel((x, y), 0) # Set to black | |
| else: # Non-transparent pixel (originally black in the mask) | |
| new_mask.putpixel((x, y), 255) # Set to white | |
| mask = to_tensor(new_mask.convert('L')) | |
| mask = mask.unsqueeze_(0).float() # 0 or 1 | |
| mask = F.interpolate(mask, (1024, 1024)) | |
| mask = gaussian_blur(mask, kernel_size=(77, 77)) | |
| mask[mask < 0.1] = 0 | |
| mask[mask >= 0.1] = 1 | |
| mask = mask.to(dtype).to(device) | |
| return mask | |
| def make_redder(img, mask, increase_factor=0.4): | |
| img_redder = img.clone() | |
| mask_expanded = mask.expand_as(img) | |
| img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1) | |
| return img_redder | |
| # Model loading parameters | |
| is_cpu_offload_enabled = False | |
| is_attention_slicing_enabled = True | |
| # Load model | |
| dtype = torch.float16 | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
| model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| model_path, | |
| custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser.py", | |
| scheduler=scheduler, | |
| variant="fp16", | |
| use_safetensors=True, | |
| torch_dtype=dtype, | |
| ).to(device) | |
| if is_attention_slicing_enabled: | |
| pipeline.enable_attention_slicing() | |
| if is_cpu_offload_enabled: | |
| pipeline.enable_model_cpu_offload() | |
| def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8, similarity_suppression_steps=9, similarity_suppression_scale=0.3): | |
| try: | |
| generator = torch.Generator('cuda').manual_seed(seed) | |
| prompt = "" # Set prompt to null | |
| source_image_pure = gradio_image["background"] | |
| mask_image_pure = gradio_image["layers"][0] | |
| source_image = preprocess_image(source_image_pure.convert('RGB'), device) | |
| mask = preprocess_mask(mask_image_pure, device) | |
| START_STEP = 0 # AAS start step | |
| END_STEP = int(strength * num_inference_steps) # AAS end step | |
| LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer | |
| END_LAYER = 70 # AAS end layer | |
| ss_steps = similarity_suppression_steps # similarity suppression steps | |
| ss_scale = similarity_suppression_scale # similarity suppression scale | |
| image = pipeline( | |
| prompt=prompt, | |
| image=source_image, | |
| mask_image=mask, | |
| height=1024, | |
| width=1024, | |
| AAS=True, # enable AAS | |
| strength=strength, # inpainting strength | |
| rm_guidance_scale=rm_guidance_scale, # removal guidance scale | |
| ss_steps = ss_steps, # similarity suppression steps | |
| ss_scale = ss_scale, # similarity suppression scale | |
| AAS_start_step=START_STEP, # AAS start step | |
| AAS_start_layer=LAYER, # AAS start layer | |
| AAS_end_layer=END_LAYER, # AAS end layer | |
| num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps) | |
| generator=generator, | |
| guidance_scale=1 | |
| ).images[0] | |
| print('Inferece: DONE.') | |
| pil_mask = to_pil_image(mask.squeeze(0)) | |
| pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15)) | |
| mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device) | |
| mask_f = 1-(1 - mask) * (1 - mask_blurred) | |
| # image_1 = image.unsqueeze(0) | |
| return source_image_pure, pil_mask, image | |
| except: | |
| print(traceback.format_exc()) | |
| title = """<h1 align="center">Object Remove</h1>""" | |
| with gr.Blocks() as demo: | |
| gr.HTML(load_description("assets/title.md")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Advanced Options", open=False): | |
| guidance_scale = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=9, | |
| step=0.1, | |
| label="Guidance Scale" | |
| ) | |
| num_steps = gr.Slider( | |
| minimum=5, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Steps" | |
| ) | |
| seed = gr.Slider( | |
| minimum=42, | |
| maximum=999999, | |
| value=42, | |
| step=1, | |
| label="Seed" | |
| ) | |
| strength = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.8, | |
| step=0.1, | |
| label="Strength" | |
| ) | |
| similarity_suppression_steps = gr.Slider( | |
| minimum=0, | |
| maximum=10, | |
| value=9, | |
| step=1, | |
| label="Similarity Suppression Steps" | |
| ) | |
| similarity_suppression_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.3, | |
| step=0.1, | |
| label="Similarity Suppression Scale" | |
| ) | |
| input_image = gr.ImageMask( | |
| type="pil", label="Input Image",crop_size=(1200,1200), layers=False | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| run_button = gr.Button("Generate") | |
| result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto") | |
| run_button.click( | |
| fn=remove, | |
| inputs=[input_image, guidance_scale, num_steps, seed, strength, similarity_suppression_steps, similarity_suppression_scale], | |
| outputs=result, | |
| ) | |
| demo.queue(max_size=12).launch(share=False) | |