|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import random |
|
|
import torch |
|
|
import spaces |
|
|
from PIL import Image |
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from optimization import optimize_pipeline_ |
|
|
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline |
|
|
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel |
|
|
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 |
|
|
import math |
|
|
import os |
|
|
import tempfile |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
HF_MODEL = os.environ.get("HF_UPLOAD_REPO", "rahul7star/qwen-edit-img-repo") |
|
|
dtype = torch.bfloat16 |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
BASE_PROMPTS = { |
|
|
"front": "Move the camera to a front-facing position showing the full character. Background is plain white.", |
|
|
"back": "Move the camera to a back-facing position showing the full character. Background is plain white.", |
|
|
"left": "Move the camera to a side (left) profile view. Background is plain white.", |
|
|
"right": "Move the camera to a side (right) profile view. Background is plain white.", |
|
|
"45_left": "Rotate camera 45° left showing the full character", |
|
|
"45_right": "Rotate camera 45° right showing the full character", |
|
|
|
|
|
|
|
|
"top_down": "Switch to top-down view showing the full character", |
|
|
"low_angle": "Switch to low-angle view", |
|
|
"close_up": "Switch to close-up lens", |
|
|
"medium_close_up": "Switch to medium close-up lens", |
|
|
"zoom_out": "Switch to zoom out lens", |
|
|
} |
|
|
|
|
|
|
|
|
RESOLUTIONS = { |
|
|
"1:4": (512, 2048), |
|
|
"1:3": (576, 1728), |
|
|
"nealy 9:16": (768, 1344), |
|
|
"nealy 2:3": (832, 1216), |
|
|
"3:4": (896, 1152), |
|
|
} |
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
|
|
|
def upload_image_and_prompt_cpu(input_image, prompt_text) -> str: |
|
|
from datetime import datetime |
|
|
import uuid, shutil |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
api = HfApi() |
|
|
print(prompt_text) |
|
|
today_str = datetime.now().strftime("%Y-%m-%d") |
|
|
unique_subfolder = f"Upload-Image-{uuid.uuid4().hex[:8]}" |
|
|
hf_folder = f"{today_str}/{unique_subfolder}" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: |
|
|
if isinstance(input_image, str): |
|
|
shutil.copy(input_image, tmp_img.name) |
|
|
else: |
|
|
input_image.save(tmp_img.name, format="PNG") |
|
|
tmp_img_path = tmp_img.name |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=tmp_img_path, |
|
|
path_in_repo=f"{hf_folder}/input_image.png", |
|
|
repo_id=HF_MODEL, |
|
|
repo_type="model", |
|
|
token=os.environ.get("HUGGINGFACE_HUB_TOKEN") |
|
|
) |
|
|
|
|
|
summary_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name |
|
|
with open(summary_file, "w", encoding="utf-8") as f: |
|
|
f.write(prompt_text) |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=summary_file, |
|
|
path_in_repo=f"{hf_folder}/summary.txt", |
|
|
repo_id=HF_MODEL, |
|
|
repo_type="model", |
|
|
token=os.environ.get("HUGGINGFACE_HUB_TOKEN") |
|
|
) |
|
|
|
|
|
os.remove(tmp_img_path) |
|
|
os.remove(summary_file) |
|
|
return hf_folder |
|
|
|
|
|
|
|
|
scheduler_config = { |
|
|
"base_image_seq_len": 256, |
|
|
"base_shift": math.log(3), |
|
|
"invert_sigmas": False, |
|
|
"max_image_seq_len": 8192, |
|
|
"max_shift": math.log(3), |
|
|
"num_train_timesteps": 1000, |
|
|
"shift": 1.0, |
|
|
"shift_terminal": None, |
|
|
"stochastic_sampling": False, |
|
|
"time_shift_type": "exponential", |
|
|
"use_beta_sigmas": False, |
|
|
"use_dynamic_shifting": True, |
|
|
"use_exponential_sigmas": False, |
|
|
"use_karras_sigmas": False, |
|
|
} |
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) |
|
|
|
|
|
pipe = QwenImageEditPlusPipeline.from_pretrained( |
|
|
"Qwen/Qwen-Image-Edit-2509", |
|
|
scheduler=scheduler, |
|
|
torch_dtype=dtype |
|
|
).to(device) |
|
|
|
|
|
|
|
|
pipe.load_lora_weights( |
|
|
"2vXpSwA7/iroiro-lora", |
|
|
weight_name="qwen_lora/Qwen-Image-Edit-2509-Lightning-4steps-V1.0-bf16_dim1.safetensors" |
|
|
) |
|
|
pipe.fuse_lora(lora_scale=1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.unload_lora_weights() |
|
|
pipe.transformer.__class__ = QwenImageTransformer2DModel |
|
|
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) |
|
|
optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt") |
|
|
|
|
|
|
|
|
def _append_prompt(base: str, extra: str) -> str: |
|
|
extra = (extra or "").strip() |
|
|
return (base if not extra else f"{base} {extra}").strip() |
|
|
|
|
|
def generate_single_view(input_images, prompt, seed, num_inference_steps, true_guidance_scale): |
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
result = pipe( |
|
|
image=input_images if input_images else None, |
|
|
prompt=prompt, |
|
|
negative_prompt=" ", |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
true_cfg_scale=true_guidance_scale, |
|
|
num_images_per_prompt=1, |
|
|
).images |
|
|
|
|
|
return result[0] |
|
|
|
|
|
def resize_to_preset(img: Image.Image, preset_key: str) -> Image.Image: |
|
|
w, h = RESOLUTIONS[preset_key] |
|
|
return img.resize((w, h), Image.LANCZOS) |
|
|
|
|
|
def concat_images_horizontally(images, bg_color=(255, 255, 255)): |
|
|
images = [img.convert("RGB") for img in images if img is not None] |
|
|
if not images: |
|
|
return None |
|
|
h = max(img.height for img in images) |
|
|
resized = [] |
|
|
for img in images: |
|
|
if img.height != h: |
|
|
w = int(img.width * (h / img.height)) |
|
|
img = img.resize((w, h), Image.LANCZOS) |
|
|
resized.append(img) |
|
|
w_total = sum(img.width for img in resized) |
|
|
canvas = Image.new("RGB", (w_total, h), bg_color) |
|
|
x = 0 |
|
|
for img in resized: |
|
|
canvas.paste(img, (x, 0)) |
|
|
x += img.width |
|
|
return canvas |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def generate_turnaround( |
|
|
image, |
|
|
extra_prompt="", |
|
|
preset_key="nealy 9:16", |
|
|
seed=42, |
|
|
randomize_seed=False, |
|
|
true_guidance_scale=1.0, |
|
|
num_inference_steps=4, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
if image is None: |
|
|
return [None]*(len(BASE_PROMPTS)+1), seed, "❌ 入力画像をアップロードしてください" |
|
|
|
|
|
input_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB") |
|
|
pil_images = [input_image] |
|
|
|
|
|
results = {} |
|
|
total = len(BASE_PROMPTS) |
|
|
for i, (key, base_prompt) in enumerate(BASE_PROMPTS.items(), start=1): |
|
|
progress(i/total, desc=f"{key} 生成中...") |
|
|
prompt_full = _append_prompt(base_prompt, extra_prompt) |
|
|
img = generate_single_view(pil_images, prompt_full, seed+i, num_inference_steps, true_guidance_scale) |
|
|
results[key] = resize_to_preset(img, preset_key) |
|
|
|
|
|
concat = concat_images_horizontally(list(results.values())) |
|
|
return [*results.values(), concat, seed, f"✅ {len(results)}視点の画像+連結画像を生成しました"] |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container {margin: 0 auto; max-width: 1400px;} |
|
|
.image-container img {object-fit: contain !important; max-width: 100%; max-height: 100%;} |
|
|
.notice {background: #fff5f5; border: 1px solid #fca5a5; color: #7f1d1d; padding: 12px 14px; border-radius: 10px; font-weight: 600; line-height: 1.5; margin-bottom: 10px;} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
with gr.Column(elem_id="col-container"): |
|
|
input_image = gr.Image(label="入力画像", type="pil", height=500) |
|
|
extra_prompt = gr.Textbox( |
|
|
label="追加プロンプト(各視点プロンプト末尾に追加)", |
|
|
placeholder="high detail, anime style, soft lighting, 4k", |
|
|
lines=2 |
|
|
) |
|
|
preset_dropdown = gr.Dropdown( |
|
|
label="出力解像度プリセット", |
|
|
choices=list(RESOLUTIONS.keys()), |
|
|
value="nealy 9:16" |
|
|
) |
|
|
run_button = gr.Button("🎨 生成開始", variant="primary") |
|
|
status_text = gr.Textbox(label="ステータス", interactive=False) |
|
|
|
|
|
|
|
|
result_images = [] |
|
|
for key in BASE_PROMPTS.keys(): |
|
|
result_images.append(gr.Image(label=key.capitalize(), type="pil", format="png", height=400, show_download_button=True)) |
|
|
result_concat = gr.Image(label="連結画像(全視点)", type="pil", format="png", height=400, show_download_button=True) |
|
|
|
|
|
with gr.Accordion("⚙️ 詳細設定", open=False): |
|
|
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
|
|
randomize_seed_checkbox = gr.Checkbox(label="ランダムシード", value=True) |
|
|
guidance_scale_slider = gr.Slider(label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0) |
|
|
num_steps_slider = gr.Slider(label="生成ステップ数", minimum=1, maximum=40, step=1, value=4) |
|
|
|
|
|
run_button.click( |
|
|
fn=generate_turnaround, |
|
|
inputs=[input_image, extra_prompt, preset_dropdown, seed_slider, randomize_seed_checkbox, guidance_scale_slider, num_steps_slider], |
|
|
outputs=[*result_images, result_concat, seed_slider, status_text] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|