Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,507 +1,168 @@
|
|
| 1 |
-
#
|
| 2 |
"""
|
| 3 |
-
Universal LoRA Trainer (Accelerate + PEFT
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
- Real LoRA training (PEFT) for: text->image (UNet), text->video (ChronoEdit transformer),
|
| 9 |
-
and prompt-enhancer LoRA (QwenEdit/text_encoder)
|
| 10 |
-
- Uses accelerate for device orchestration (recommended: use `accelerate launch ...` for multi-GPU)
|
| 11 |
-
- Shows logs and sample generation in Gradio
|
| 12 |
-
- Uploads adapter to HF Hub using HF_TOKEN from environment (not UI)
|
| 13 |
-
|
| 14 |
-
Requirements:
|
| 15 |
-
pip install torch torchvision diffusers transformers accelerate peft huggingface_hub gradio pandas tqdm
|
| 16 |
-
|
| 17 |
-
Optional (ChronoEdit speedups): pip install chronoedit-diffusers flash-attn
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
import os
|
| 21 |
-
import tempfile
|
| 22 |
from pathlib import Path
|
| 23 |
-
from typing import Optional, Tuple, List
|
| 24 |
-
|
| 25 |
-
import torch
|
| 26 |
-
import torch.nn as nn
|
| 27 |
-
from torch.utils.data import Dataset, DataLoader
|
| 28 |
-
import torchvision
|
| 29 |
-
import torchvision.transforms as T
|
| 30 |
-
import pandas as pd
|
| 31 |
-
import numpy as np
|
| 32 |
-
import gradio as gr
|
| 33 |
from tqdm.auto import tqdm
|
| 34 |
-
|
| 35 |
from huggingface_hub import create_repo, upload_folder, hf_hub_download
|
| 36 |
-
|
| 37 |
from diffusers import DiffusionPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
# Optional ChronoEdit
|
| 40 |
try:
|
| 41 |
from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
|
| 42 |
CHRONOEDIT_AVAILABLE = True
|
| 43 |
except Exception:
|
| 44 |
CHRONOEDIT_AVAILABLE = False
|
| 45 |
|
| 46 |
-
# PEFT + Accelerate
|
| 47 |
-
try:
|
| 48 |
-
from peft import LoraConfig, get_peft_model
|
| 49 |
-
except Exception as e:
|
| 50 |
-
raise RuntimeError("Install peft (pip install peft)") from e
|
| 51 |
-
|
| 52 |
try:
|
| 53 |
-
from
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
-
# ---------- config ----------
|
| 58 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
-
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"
|
| 60 |
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
|
| 61 |
-
# ---------------------------
|
| 62 |
|
| 63 |
-
def is_hub_repo_like(s:
|
| 64 |
-
return "/" in s and not Path(s).exists()
|
| 65 |
|
| 66 |
-
def download_from_hf(repo_id
|
| 67 |
token = token or os.environ.get("HF_TOKEN")
|
| 68 |
-
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
| 69 |
-
|
| 70 |
|
| 71 |
-
|
| 72 |
-
def find_target_modules(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
|
| 73 |
-
names = [n for n, _ in model.named_modules()]
|
| 74 |
-
selected = set()
|
| 75 |
-
for cand in candidates:
|
| 76 |
-
for n in names:
|
| 77 |
-
if cand in n:
|
| 78 |
-
selected.add(n.split(".")[-1])
|
| 79 |
-
if not selected:
|
| 80 |
-
return ["to_q", "to_k", "to_v", "to_out"]
|
| 81 |
-
return list(selected)
|
| 82 |
-
|
| 83 |
-
# -------------------------
|
| 84 |
-
# Dataset: CSV or Parquet
|
| 85 |
-
# -------------------------
|
| 86 |
class MediaTextDataset(Dataset):
|
| 87 |
-
""
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
file_name can be a local path relative to dataset_dir, or a filename when using HF repo.
|
| 92 |
-
"""
|
| 93 |
-
def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5,
|
| 94 |
-
image_size=(512,512), video_frame_size=(128,256), hub_token: Optional[str] = None):
|
| 95 |
-
self.source = dataset_source
|
| 96 |
-
self.is_hub = is_hub_repo_like(dataset_source)
|
| 97 |
-
self.max_frames = max_frames
|
| 98 |
-
self.image_size = image_size
|
| 99 |
-
self.video_frame_size = video_frame_size
|
| 100 |
-
self.hub_token = hub_token or os.environ.get("HF_TOKEN")
|
| 101 |
-
self.tmpdir = None
|
| 102 |
-
|
| 103 |
-
# load df (csv or parquet)
|
| 104 |
if self.is_hub:
|
| 105 |
-
|
| 106 |
-
csv_local = download_from_hf(self.source, csv_name, token=self.hub_token)
|
| 107 |
-
# load via pandas (auto-detect extension)
|
| 108 |
-
if csv_local.endswith(".parquet"):
|
| 109 |
-
df = pd.read_parquet(csv_local)
|
| 110 |
-
else:
|
| 111 |
-
df = pd.read_csv(csv_local)
|
| 112 |
-
self.df = df
|
| 113 |
-
self.root = None
|
| 114 |
else:
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
if
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
self.
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
if total == 0:
|
| 153 |
-
C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
|
| 154 |
-
return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
|
| 155 |
-
if total < num_frames:
|
| 156 |
-
idxs = list(range(total)) + [total-1]*(num_frames-total)
|
| 157 |
-
else:
|
| 158 |
-
idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
|
| 159 |
-
frames = []
|
| 160 |
-
for i in idxs:
|
| 161 |
-
arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
|
| 162 |
-
frames.append(self.video_transform(arr))
|
| 163 |
-
frames = torch.stack(frames, dim=0) # [T, C, H, W]
|
| 164 |
-
return frames
|
| 165 |
-
|
| 166 |
-
def __getitem__(self, idx):
|
| 167 |
-
rec = self.df.iloc[idx]
|
| 168 |
-
file_name = rec["file_name"]
|
| 169 |
-
caption = rec["text"]
|
| 170 |
-
if self.is_hub:
|
| 171 |
-
local_path = self._maybe_download_from_hub(file_name)
|
| 172 |
-
else:
|
| 173 |
-
local_path = str(Path(self.root) / file_name)
|
| 174 |
-
p = Path(local_path)
|
| 175 |
-
suffix = p.suffix.lower()
|
| 176 |
-
if suffix in IMAGE_EXTS:
|
| 177 |
-
img = torchvision.io.read_image(local_path) # [C,H,W]
|
| 178 |
-
if isinstance(img, torch.Tensor):
|
| 179 |
-
img = img.permute(1,2,0).numpy()
|
| 180 |
-
return {"type": "image", "image": self.image_transform(img), "caption": caption, "file_name": file_name}
|
| 181 |
-
elif suffix in VIDEO_EXTS:
|
| 182 |
-
frames = self._read_video_frames(local_path, self.max_frames) # [T,C,H,W]
|
| 183 |
-
return {"type": "video", "frames": frames, "caption": caption, "file_name": file_name}
|
| 184 |
-
else:
|
| 185 |
-
raise RuntimeError(f"Unsupported media type: {local_path}")
|
| 186 |
-
|
| 187 |
-
# -------------------------
|
| 188 |
-
# Pipeline / LoRA helpers
|
| 189 |
-
# -------------------------
|
| 190 |
-
def load_pipeline_auto(base_model_id: str, torch_dtype=torch.float16):
|
| 191 |
-
is_chrono = "chrono" in base_model_id.lower() or "chronoedit" in base_model_id.lower()
|
| 192 |
-
if CHRONOEDIT_AVAILABLE and is_chrono:
|
| 193 |
-
print(f"Loading ChronoEdit pipeline: {base_model_id}")
|
| 194 |
-
pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 195 |
else:
|
| 196 |
-
print(f"
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
target_module = pipe.unet
|
| 205 |
-
attr = "unet"
|
| 206 |
-
elif target == "transformer":
|
| 207 |
-
if not hasattr(pipe, "transformer"):
|
| 208 |
-
raise RuntimeError("Pipeline has no transformer for this model")
|
| 209 |
-
target_module = pipe.transformer
|
| 210 |
-
attr = "transformer"
|
| 211 |
-
elif target == "text_encoder":
|
| 212 |
-
if not hasattr(pipe, "text_encoder"):
|
| 213 |
-
raise RuntimeError("Pipeline has no text_encoder for this model")
|
| 214 |
-
target_module = pipe.text_encoder
|
| 215 |
-
attr = "text_encoder"
|
| 216 |
else:
|
| 217 |
-
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
setattr(pipe, attr, peft_model)
|
| 224 |
-
return pipe, attr
|
| 225 |
|
| 226 |
-
#
|
| 227 |
-
|
| 228 |
-
# -------------------------
|
| 229 |
-
def train_lora_accelerate(base_model_id: str,
|
| 230 |
-
dataset_source: str,
|
| 231 |
-
csv_name: str,
|
| 232 |
-
adapter_target: str,
|
| 233 |
-
output_dir: str,
|
| 234 |
-
epochs: int = 1,
|
| 235 |
-
batch_size: int = 1,
|
| 236 |
-
lr: float = 1e-4,
|
| 237 |
-
max_train_steps: Optional[int] = None,
|
| 238 |
-
lora_r: int = 8,
|
| 239 |
-
lora_alpha: int = 16,
|
| 240 |
-
max_frames: int = 5,
|
| 241 |
-
save_every_steps: int = 200) -> Tuple[str, List[str]]:
|
| 242 |
accelerator = Accelerator()
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
dataset = MediaTextDataset(
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
global_step = 0
|
| 270 |
-
loss_fn = nn.MSELoss()
|
| 271 |
-
|
| 272 |
-
if hasattr(pipe, "scheduler"):
|
| 273 |
-
pipe.scheduler.set_timesteps(50, device=device)
|
| 274 |
-
timesteps = pipe.scheduler.timesteps
|
| 275 |
-
else:
|
| 276 |
-
timesteps = None
|
| 277 |
-
|
| 278 |
-
for epoch in range(epochs):
|
| 279 |
-
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
|
| 280 |
-
for batch in pbar:
|
| 281 |
-
ex = batch[0]
|
| 282 |
-
if ex["type"] == "image":
|
| 283 |
-
# image training flow (SD-like)
|
| 284 |
-
img = ex["image"].unsqueeze(0).to(device)
|
| 285 |
-
caption = [ex["caption"]]
|
| 286 |
-
|
| 287 |
-
if not hasattr(pipe, "encode_prompt"):
|
| 288 |
-
raise RuntimeError("Pipeline lacks encode_prompt (can't encode text prompts)")
|
| 289 |
-
|
| 290 |
-
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt=caption, negative_prompt=None, do_classifier_free_guidance=True, num_videos_per_prompt=1, prompt_embeds=None, negative_prompt_embeds=None, max_sequence_length=512, device=device)
|
| 291 |
-
|
| 292 |
-
# VAE encode
|
| 293 |
-
if not hasattr(pipe, "vae"):
|
| 294 |
-
raise RuntimeError("Pipeline lacks VAE required for latent conversion")
|
| 295 |
-
with torch.no_grad():
|
| 296 |
-
latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
|
| 297 |
-
|
| 298 |
-
noise = torch.randn_like(latents).to(device)
|
| 299 |
-
t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
|
| 300 |
-
noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
|
| 301 |
-
|
| 302 |
-
# call peft_module (unet) - adapt to common return types
|
| 303 |
-
# peft_module was prepared by accelerator and is on device
|
| 304 |
-
unet_out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)
|
| 305 |
-
# unet can return ModelOutput with .sample or tuple
|
| 306 |
-
if hasattr(unet_out, "sample"):
|
| 307 |
-
noise_pred = unet_out.sample
|
| 308 |
-
elif isinstance(unet_out, tuple):
|
| 309 |
-
noise_pred = unet_out[0]
|
| 310 |
-
else:
|
| 311 |
-
# Try to find tensor in object
|
| 312 |
-
noise_pred = unet_out
|
| 313 |
-
|
| 314 |
-
loss = loss_fn(noise_pred, noise)
|
| 315 |
-
|
| 316 |
-
else:
|
| 317 |
-
# video training (ChronoEdit simplified)
|
| 318 |
-
if not CHRONOEDIT_AVAILABLE:
|
| 319 |
-
raise RuntimeError("ChronoEdit training requested but chronoedit_diffusers not installed")
|
| 320 |
-
frames = ex["frames"].unsqueeze(0).to(device) # [1, T, C, H, W]
|
| 321 |
-
frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
|
| 322 |
-
video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
|
| 323 |
-
latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device, generator=None, latents=None, last_image=None)
|
| 324 |
-
if pipe.config.expand_timesteps:
|
| 325 |
-
latents, condition, first_frame_mask = latents_out
|
| 326 |
-
else:
|
| 327 |
-
latents, condition = latents_out
|
| 328 |
-
first_frame_mask = None
|
| 329 |
-
|
| 330 |
-
noise = torch.randn_like(latents).to(device)
|
| 331 |
-
t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
|
| 332 |
-
noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
|
| 333 |
-
|
| 334 |
-
if pipe.config.expand_timesteps:
|
| 335 |
-
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * noisy_latents
|
| 336 |
-
else:
|
| 337 |
-
latent_model_input = torch.cat([noisy_latents, condition], dim=1)
|
| 338 |
-
|
| 339 |
-
# transformer forward (peft_module)
|
| 340 |
-
trans_out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]), encoder_hidden_states=None, encoder_hidden_states_image=None, return_dict=False)
|
| 341 |
-
noise_pred = trans_out[0] if isinstance(trans_out, tuple) else trans_out
|
| 342 |
-
loss = loss_fn(noise_pred, noise)
|
| 343 |
-
|
| 344 |
-
# backward + step via accelerator
|
| 345 |
-
accelerator.backward(loss) if 'accelerator' in globals() else loss.backward()
|
| 346 |
-
optimizer.step()
|
| 347 |
-
optimizer.zero_grad()
|
| 348 |
-
global_step += 1
|
| 349 |
-
|
| 350 |
-
logs.append(f"step {global_step} loss {loss.item():.6f}")
|
| 351 |
-
pbar.set_postfix({"loss": f"{loss.item():.6f}"})
|
| 352 |
-
|
| 353 |
-
if max_train_steps and global_step >= max_train_steps:
|
| 354 |
-
break
|
| 355 |
-
|
| 356 |
-
if global_step % save_every_steps == 0:
|
| 357 |
-
out_sub = Path(output_dir) / f"lora_step_{global_step}"
|
| 358 |
-
out_sub.mkdir(parents=True, exist_ok=True)
|
| 359 |
-
try:
|
| 360 |
-
peft_module.save_pretrained(str(out_sub))
|
| 361 |
-
except Exception:
|
| 362 |
-
torch.save({k: v.cpu() for k, v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
|
| 363 |
-
print(f"Saved adapter at {out_sub}")
|
| 364 |
-
|
| 365 |
-
if max_train_steps and global_step >= max_train_steps:
|
| 366 |
-
break
|
| 367 |
-
|
| 368 |
-
# final save
|
| 369 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 370 |
-
try:
|
| 371 |
-
peft_module.save_pretrained(output_dir)
|
| 372 |
-
except Exception:
|
| 373 |
-
torch.save({k: v.cpu() for k, v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
|
| 374 |
-
|
| 375 |
-
return output_dir, logs
|
| 376 |
-
|
| 377 |
-
# -------------------------
|
| 378 |
-
# Test generation
|
| 379 |
-
# -------------------------
|
| 380 |
-
def test_generation_load_and_run(base_model_id: str, adapter_dir: Optional[str], adapter_target: str, prompt: str, num_inference_steps: int = 8):
|
| 381 |
-
pipe = load_pipeline_auto(base_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
| 382 |
-
# if adapter_dir provided, try to load adapter weights into target module (best-effort)
|
| 383 |
-
if adapter_dir:
|
| 384 |
-
try:
|
| 385 |
-
if adapter_target == "unet" and hasattr(pipe, "unet"):
|
| 386 |
-
# wrap unet with a matching peft config and load
|
| 387 |
-
lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.unet))
|
| 388 |
-
pipe.unet = get_peft_model(pipe.unet, lcfg)
|
| 389 |
-
try:
|
| 390 |
-
pipe.unet.load_state_dict(torch.load(Path(adapter_dir) / "pytorch_model.bin"), strict=False)
|
| 391 |
-
except Exception:
|
| 392 |
-
try:
|
| 393 |
-
pipe.unet.load_adapter(adapter_dir)
|
| 394 |
-
except Exception:
|
| 395 |
-
pass
|
| 396 |
-
elif adapter_target == "transformer" and hasattr(pipe, "transformer"):
|
| 397 |
-
lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.transformer))
|
| 398 |
-
pipe.transformer = get_peft_model(pipe.transformer, lcfg)
|
| 399 |
-
elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"):
|
| 400 |
-
lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.text_encoder))
|
| 401 |
-
pipe.text_encoder = get_peft_model(pipe.text_encoder, lcfg)
|
| 402 |
-
except Exception as e:
|
| 403 |
-
print("Adapter load attempt warning:", e)
|
| 404 |
-
|
| 405 |
-
pipe.to(DEVICE)
|
| 406 |
-
out = pipe(prompt=prompt, num_inference_steps=num_inference_steps)
|
| 407 |
-
if hasattr(out, "images"):
|
| 408 |
-
return out.images[0]
|
| 409 |
-
elif hasattr(out, "frames"):
|
| 410 |
-
frames = out.frames[0]
|
| 411 |
-
from PIL import Image
|
| 412 |
-
return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8"))
|
| 413 |
-
else:
|
| 414 |
-
raise RuntimeError("No images or frames returned")
|
| 415 |
-
|
| 416 |
-
# -------------------------
|
| 417 |
-
# Upload adapter
|
| 418 |
-
# -------------------------
|
| 419 |
-
def upload_adapter(local_dir: str, repo_id: str) -> str:
|
| 420 |
-
token = os.environ.get("HF_TOKEN")
|
| 421 |
-
if token is None:
|
| 422 |
-
raise RuntimeError("HF_TOKEN not set in environment for upload")
|
| 423 |
create_repo(repo_id, exist_ok=True)
|
| 424 |
-
upload_folder(
|
| 425 |
return f"https://huggingface.co/{repo_id}"
|
| 426 |
|
| 427 |
-
#
|
| 428 |
-
|
| 429 |
-
# -------------------------
|
| 430 |
-
def run_all_ui(base_model_id: str,
|
| 431 |
-
dataset_source: str,
|
| 432 |
-
csv_name: str,
|
| 433 |
-
task_type: str,
|
| 434 |
-
adapter_target: str,
|
| 435 |
-
lora_r: int,
|
| 436 |
-
lora_alpha: int,
|
| 437 |
-
epochs: int,
|
| 438 |
-
batch_size: int,
|
| 439 |
-
lr: float,
|
| 440 |
-
max_train_steps: int,
|
| 441 |
-
output_dir: str,
|
| 442 |
-
upload_repo: str,
|
| 443 |
-
save_every_steps: int):
|
| 444 |
-
# minor mapping: QwenEdit/ prompt-lora -> text_encoder
|
| 445 |
-
if task_type == "prompt-lora":
|
| 446 |
-
adapter_target = "text_encoder"
|
| 447 |
-
|
| 448 |
-
try:
|
| 449 |
-
out_dir, logs = train_lora_accelerate(base_model_id, dataset_source, csv_name, adapter_target, output_dir,
|
| 450 |
-
epochs=epochs, batch_size=batch_size, lr=lr, max_train_steps=(max_train_steps if max_train_steps>0 else None),
|
| 451 |
-
lora_r=lora_r, lora_alpha=lora_alpha, max_frames=5, save_every_steps=save_every_steps)
|
| 452 |
-
except Exception as e:
|
| 453 |
-
return f"Training failed: {e}", None, None
|
| 454 |
-
|
| 455 |
-
link = None
|
| 456 |
-
if upload_repo:
|
| 457 |
-
try:
|
| 458 |
-
link = upload_adapter(out_dir, upload_repo)
|
| 459 |
-
except Exception as e:
|
| 460 |
-
link = f"Upload failed: {e}"
|
| 461 |
-
|
| 462 |
-
# quick test: use first prompt from dataset
|
| 463 |
-
try:
|
| 464 |
-
ds = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=5)
|
| 465 |
-
test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
|
| 466 |
-
except Exception:
|
| 467 |
-
test_prompt = "A cat on a skateboard"
|
| 468 |
-
|
| 469 |
-
test_img = None
|
| 470 |
-
try:
|
| 471 |
-
test_img = test_generation_load_and_run(base_model_id, out_dir, adapter_target, test_prompt)
|
| 472 |
-
except Exception as e:
|
| 473 |
-
print("Test gen failed:", e)
|
| 474 |
-
|
| 475 |
-
return "\n".join(logs[-200:]), test_img, link
|
| 476 |
-
|
| 477 |
-
def build_ui():
|
| 478 |
with gr.Blocks() as demo:
|
| 479 |
-
gr.Markdown("# Universal LoRA Trainer (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
with gr.Row():
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
|
| 495 |
-
upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional, username/repo)", value="")
|
| 496 |
-
start_btn = gr.Button("Start training")
|
| 497 |
-
with gr.Column(scale=1):
|
| 498 |
-
logs = gr.Textbox(label="Training logs (tail)", lines=18)
|
| 499 |
-
sample_image = gr.Image(label="Sample generated frame after training")
|
| 500 |
-
def on_start(base_model_id, dataset_source, csv_name, task_type, adapter_target, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, save_every_steps):
|
| 501 |
-
return run_all_ui(base_model_id, dataset_source, csv_name, task_type, adapter_target, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, int(save_every_steps))
|
| 502 |
-
start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, save_every_steps], outputs=[logs, sample_image, gr.Textbox()])
|
| 503 |
return demo
|
| 504 |
|
| 505 |
-
if __name__
|
| 506 |
-
|
| 507 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 1 |
+
# universal_lora_trainer_accelerate_singlefile_dynamic.py
|
| 2 |
"""
|
| 3 |
+
Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
|
| 4 |
+
- Auto-detects base model type (Flux, SD, ChronoEdit, QwenEdit, etc.)
|
| 5 |
+
- Auto-selects correct adapter target (unet, transformer, text_encoder)
|
| 6 |
+
- Supports CSV and Parquet datasets
|
| 7 |
+
- Uploads adapter to HF Hub using HF_TOKEN (env only)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
import os, torch, gradio as gr, pandas as pd, numpy as np
|
|
|
|
| 11 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from tqdm.auto import tqdm
|
|
|
|
| 13 |
from huggingface_hub import create_repo, upload_folder, hf_hub_download
|
|
|
|
| 14 |
from diffusers import DiffusionPipeline
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
import torchvision.transforms as T, torchvision
|
| 17 |
+
from peft import LoraConfig, get_peft_model
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
import torch.nn as nn
|
| 20 |
|
| 21 |
+
# Optional: ChronoEdit + QwenEdit
|
| 22 |
try:
|
| 23 |
from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
|
| 24 |
CHRONOEDIT_AVAILABLE = True
|
| 25 |
except Exception:
|
| 26 |
CHRONOEDIT_AVAILABLE = False
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
try:
|
| 29 |
+
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
|
| 30 |
+
QWENEDIT_AVAILABLE = True
|
| 31 |
+
except Exception:
|
| 32 |
+
QWENEDIT_AVAILABLE = False
|
| 33 |
|
|
|
|
| 34 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
|
| 36 |
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
|
|
|
|
| 37 |
|
| 38 |
+
def is_hub_repo_like(s): return "/" in s and not Path(s).exists()
|
|
|
|
| 39 |
|
| 40 |
+
def download_from_hf(repo_id, filename, token=None):
|
| 41 |
token = token or os.environ.get("HF_TOKEN")
|
| 42 |
+
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
|
|
|
|
| 43 |
|
| 44 |
+
# ---------------- Dataset ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class MediaTextDataset(Dataset):
|
| 46 |
+
def __init__(self, source, csv_name="dataset.csv", max_frames=5):
|
| 47 |
+
self.is_hub = is_hub_repo_like(source)
|
| 48 |
+
self.source = source
|
| 49 |
+
token = os.environ.get("HF_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if self.is_hub:
|
| 51 |
+
file_path = download_from_hf(source, csv_name, token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
else:
|
| 53 |
+
file_path = Path(source) / csv_name
|
| 54 |
+
if not Path(file_path).exists():
|
| 55 |
+
alt = Path(str(file_path).replace(".csv", ".parquet"))
|
| 56 |
+
if alt.exists(): file_path = alt
|
| 57 |
+
self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
|
| 58 |
+
self.root = Path(source) if not self.is_hub else None
|
| 59 |
+
self.img_tf = T.Compose([T.ToPILImage(), T.Resize((512,512)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
|
| 60 |
+
self.video_tf = T.Compose([T.ToPILImage(), T.Resize((128,256)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
|
| 61 |
+
self.max_frames = max_frames
|
| 62 |
+
def __len__(self): return len(self.df)
|
| 63 |
+
def _maybe_dl(self, fname): return str(Path(self.root)/fname) if self.root else download_from_hf(self.source, fname)
|
| 64 |
+
def __getitem__(self, i):
|
| 65 |
+
rec = self.df.iloc[i]
|
| 66 |
+
p = Path(self._maybe_dl(rec["file_name"]))
|
| 67 |
+
if p.suffix.lower() in IMAGE_EXTS:
|
| 68 |
+
img = torchvision.io.read_image(str(p))
|
| 69 |
+
if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy()
|
| 70 |
+
return {"type": "image", "image": self.img_tf(img), "caption": rec["text"]}
|
| 71 |
+
elif p.suffix.lower() in VIDEO_EXTS:
|
| 72 |
+
vid,_,_ = torchvision.io.read_video(str(p))
|
| 73 |
+
total, idxs = len(vid), []
|
| 74 |
+
if total == 0: return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))}
|
| 75 |
+
if total < self.max_frames: idxs = list(range(total))+[total-1]*(self.max_frames-total)
|
| 76 |
+
else: idxs = np.linspace(0,total-1,self.max_frames).round().astype(int)
|
| 77 |
+
frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs])
|
| 78 |
+
return {"type": "video", "frames": frames, "caption": rec["text"]}
|
| 79 |
+
else: raise RuntimeError(f"Unsupported file {p}")
|
| 80 |
+
|
| 81 |
+
# ---------------- Dynamic pipeline loader ----------------
|
| 82 |
+
def load_pipeline_auto(base_model, dtype=torch.float16):
|
| 83 |
+
low = base_model.lower()
|
| 84 |
+
if "chrono" in low and CHRONOEDIT_AVAILABLE:
|
| 85 |
+
print(f"Using ChronoEdit pipeline for {base_model}")
|
| 86 |
+
return ChronoEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
|
| 87 |
+
elif "qwen" in low and QWENEDIT_AVAILABLE:
|
| 88 |
+
print(f"Using QwenEdit pipeline for {base_model}")
|
| 89 |
+
return QwenImageEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
else:
|
| 91 |
+
print(f"Using Diffusion pipeline for {base_model}")
|
| 92 |
+
return DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
|
| 93 |
+
|
| 94 |
+
def infer_target_for_task(task_type, model_name):
|
| 95 |
+
if task_type == "prompt-lora" or "qwen" in model_name.lower():
|
| 96 |
+
return "text_encoder"
|
| 97 |
+
elif task_type == "text-video" or "chrono" in model_name.lower() or "wan" in model_name.lower():
|
| 98 |
+
return "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
else:
|
| 100 |
+
return "unet"
|
| 101 |
|
| 102 |
+
def find_target_modules(model):
|
| 103 |
+
names = [n for n,_ in model.named_modules()]
|
| 104 |
+
targets = [n.split(".")[-1] for n in names if any(k in n for k in ["to_q","to_k","to_v","q_proj","v_proj"])]
|
| 105 |
+
return targets or ["to_q","to_k","to_v","to_out"]
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
# ---------------- Training ----------------
|
| 108 |
+
def train_lora(base_model, dataset_src, csv_name, task_type, output_dir, epochs=1, lr=1e-4, r=8, alpha=16):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
accelerator = Accelerator()
|
| 110 |
+
pipe = load_pipeline_auto(base_model)
|
| 111 |
+
target = infer_target_for_task(task_type, base_model)
|
| 112 |
+
if not hasattr(pipe, target): raise RuntimeError(f"Pipeline has no {target}")
|
| 113 |
+
lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=find_target_modules(getattr(pipe, target)), lora_dropout=0.0)
|
| 114 |
+
lora_module = get_peft_model(getattr(pipe, target), lcfg)
|
| 115 |
+
dataset = MediaTextDataset(dataset_src, csv_name)
|
| 116 |
+
loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
| 117 |
+
lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader)
|
| 118 |
+
mse = nn.MSELoss(); logs=[]
|
| 119 |
+
for ep in range(epochs):
|
| 120 |
+
for i,b in enumerate(tqdm(loader, desc=f"Epoch {ep+1}")):
|
| 121 |
+
ex = b[0]; loss=torch.tensor(0.0, device=DEVICE)
|
| 122 |
+
if ex["type"]=="image" and hasattr(pipe,"vae"):
|
| 123 |
+
img=ex["image"].unsqueeze(0).to(DEVICE)
|
| 124 |
+
lat=pipe.vae.encode(img).latent_dist.sample()*pipe.vae.config.scaling_factor
|
| 125 |
+
noise=torch.randn_like(lat); loss=mse(lat,noise)
|
| 126 |
+
accelerator.backward(loss); opt.step(); opt.zero_grad()
|
| 127 |
+
logs.append(f"step {i} loss {loss.item():.4f}")
|
| 128 |
+
Path(output_dir).mkdir(exist_ok=True)
|
| 129 |
+
lora_module.save_pretrained(output_dir)
|
| 130 |
+
return output_dir, logs[-20:]
|
| 131 |
+
|
| 132 |
+
# ---------------- Upload ----------------
|
| 133 |
+
def upload_adapter(local, repo_id):
|
| 134 |
+
token=os.environ.get("HF_TOKEN")
|
| 135 |
+
if not token: raise RuntimeError("HF_TOKEN missing")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
create_repo(repo_id, exist_ok=True)
|
| 137 |
+
upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
|
| 138 |
return f"https://huggingface.co/{repo_id}"
|
| 139 |
|
| 140 |
+
# ---------------- Gradio UI ----------------
|
| 141 |
+
def run_ui():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
with gr.Blocks() as demo:
|
| 143 |
+
gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Flux / ChronoEdit / QwenEdit)")
|
| 144 |
+
with gr.Row():
|
| 145 |
+
base_model=gr.Textbox(label="Base model", value="black-forest-labs/FLUX.1-dev")
|
| 146 |
+
dataset=gr.Textbox(label="Dataset folder or HF repo", value="./dataset")
|
| 147 |
+
csvname=gr.Textbox(label="CSV/Parquet file", value="dataset.csv")
|
| 148 |
+
task=gr.Dropdown(["text-image","text-video","prompt-lora"], label="Task type", value="text-image")
|
| 149 |
+
out=gr.Textbox(label="Output dir", value="./adapter_out")
|
| 150 |
+
repo=gr.Textbox(label="Upload HF repo (optional)", value="")
|
| 151 |
with gr.Row():
|
| 152 |
+
r=gr.Slider(1,64,value=8,label="LoRA rank"); a=gr.Slider(1,64,value=16,label="LoRA alpha")
|
| 153 |
+
ep=gr.Number(value=1,label="Epochs"); lr=gr.Number(value=1e-4,label="Learning rate")
|
| 154 |
+
btn=gr.Button("🚀 Start Training")
|
| 155 |
+
logs=gr.Textbox(label="Logs", lines=12)
|
| 156 |
+
img=gr.Image(label="Sample Output (optional)")
|
| 157 |
+
def launch(bm,ds,csv,t,out_dir,r_,a_,ep_,lr_,repo_):
|
| 158 |
+
try:
|
| 159 |
+
out,log=train_lora(bm,ds,csv,t,out_dir,int(ep_),float(lr_),int(r_),int(a_))
|
| 160 |
+
link=upload_adapter(out,repo_) if repo_ else None
|
| 161 |
+
return "\n".join(log), None, link
|
| 162 |
+
except Exception as e:
|
| 163 |
+
return f"❌ {e}", None, None
|
| 164 |
+
btn.click(launch,[base_model,dataset,csvname,task,out,r,a,ep,lr,repo],[logs,img,gr.Textbox()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
return demo
|
| 166 |
|
| 167 |
+
if __name__=="__main__":
|
| 168 |
+
run_ui().launch(server_name="0.0.0.0",server_port=7860)
|
|
|