JoelWester commited on
Commit
36d75a9
·
verified ·
1 Parent(s): d147972

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from transformers import CLIPImageProcessor, CLIPVisionModel
8
+ from diffusers import AutoencoderKL, DDPMScheduler
9
+ from src.diffusers.models.referencenet.referencenet_unet_2d_condition import (
10
+ ReferenceNetModel,
11
+ )
12
+ from src.diffusers.models.referencenet.unet_2d_condition import UNet2DConditionModel
13
+ from src.diffusers.pipelines.referencenet.pipeline_referencenet import (
14
+ StableDiffusionReferenceNetPipeline,
15
+ )
16
+ from utils.anonymize_faces_in_image import anonymize_faces_in_image
17
+ import face_alignment
18
+
19
+
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+
23
+ def load_pipeline():
24
+ face_model_id = "hkung/face-anon-simple"
25
+ clip_model_id = "openai/clip-vit-large-patch14"
26
+ sd_model_id = "stabilityai/stable-diffusion-2-1"
27
+
28
+ unet = UNet2DConditionModel.from_pretrained(
29
+ face_model_id, subfolder="unet", use_safetensors=True
30
+ )
31
+ referencenet = ReferenceNetModel.from_pretrained(
32
+ face_model_id, subfolder="referencenet", use_safetensors=True
33
+ )
34
+ conditioning_referencenet = ReferenceNetModel.from_pretrained(
35
+ face_model_id, subfolder="conditioning_referencenet", use_safetensors=True
36
+ )
37
+ vae = AutoencoderKL.from_pretrained(
38
+ sd_model_id, subfolder="vae", use_safetensors=True
39
+ )
40
+ scheduler = DDPMScheduler.from_pretrained(
41
+ sd_model_id, subfolder="scheduler", use_safetensors=True
42
+ )
43
+ feature_extractor = CLIPImageProcessor.from_pretrained(
44
+ clip_model_id, use_safetensors=True
45
+ )
46
+ image_encoder = CLIPVisionModel.from_pretrained(
47
+ clip_model_id, use_safetensors=True
48
+ )
49
+
50
+ pipe = StableDiffusionReferenceNetPipeline(
51
+ unet=unet,
52
+ referencenet=referencenet,
53
+ conditioning_referencenet=conditioning_referencenet,
54
+ vae=vae,
55
+ feature_extractor=feature_extractor,
56
+ image_encoder=image_encoder,
57
+ scheduler=scheduler,
58
+ )
59
+
60
+ pipe = pipe.to(DEVICE)
61
+ return pipe
62
+
63
+
64
+ # Load heavy stuff once at startup (better UX + energy-wise)
65
+ pipe = load_pipeline()
66
+ generator = torch.manual_seed(1)
67
+
68
+ fa = face_alignment.FaceAlignment(
69
+ face_alignment.LandmarksType.TWO_D,
70
+ face_detector="sfd",
71
+ device=DEVICE,
72
+ )
73
+
74
+
75
+ def anonymize(
76
+ image: np.ndarray,
77
+ anonymization_degree: float = 1.25,
78
+ num_inference_steps: int = 25,
79
+ guidance_scale: float = 4.0,
80
+ ):
81
+ """
82
+ Gradio callback: takes an RGB numpy image and returns anonymized PIL image.
83
+ """
84
+
85
+ if image is None:
86
+ return None
87
+
88
+ pil_image = Image.fromarray(image)
89
+
90
+ anon_image = anonymize_faces_in_image(
91
+ image=pil_image,
92
+ face_alignment=fa,
93
+ pipe=pipe,
94
+ generator=generator,
95
+ face_image_size=512,
96
+ num_inference_steps=int(num_inference_steps),
97
+ guidance_scale=float(guidance_scale),
98
+ anonymization_degree=float(anonymization_degree),
99
+ )
100
+
101
+ return anon_image
102
+
103
+
104
+ demo = gr.Interface(
105
+ fn=anonymize,
106
+ inputs=[
107
+ gr.Image(type="numpy", label="Input image"),
108
+ gr.Slider(
109
+ minimum=0.5,
110
+ maximum=2.0,
111
+ step=0.05,
112
+ value=1.25,
113
+ label="Anonymization strength",
114
+ ),
115
+ gr.Slider(
116
+ minimum=10,
117
+ maximum=50,
118
+ step=1,
119
+ value=25,
120
+ label="Diffusion steps (speed vs quality)",
121
+ ),
122
+ gr.Slider(
123
+ minimum=1.0,
124
+ maximum=10.0,
125
+ step=0.1,
126
+ value=4.0,
127
+ label="Guidance scale",
128
+ ),
129
+ ],
130
+ outputs=gr.Image(type="pil", label="Anonymized image"),
131
+ title="Face Anonymization Made Simple",
132
+ description=(
133
+ "Upload a photo and anonymize all faces using the WACV 2025 "
134
+ "\"Face Anonymization Made Simple\" model."
135
+ ),
136
+ )
137
+
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch()