rahul7star commited on
Commit
d58b801
·
verified ·
1 Parent(s): e7524a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -470
app.py CHANGED
@@ -1,507 +1,168 @@
1
- # universal_lora_trainer_accelerate_singlefile.py
2
  """
3
- Universal LoRA Trainer (Accelerate + PEFT) single-file app with Gradio UI.
4
-
5
- Features:
6
- - Supports CSV and Parquet dataset files (columns: file_name, text)
7
- - Accepts dataset from a local folder or Hugging Face repo id (username/repo)
8
- - Real LoRA training (PEFT) for: text->image (UNet), text->video (ChronoEdit transformer),
9
- and prompt-enhancer LoRA (QwenEdit/text_encoder)
10
- - Uses accelerate for device orchestration (recommended: use `accelerate launch ...` for multi-GPU)
11
- - Shows logs and sample generation in Gradio
12
- - Uploads adapter to HF Hub using HF_TOKEN from environment (not UI)
13
-
14
- Requirements:
15
- pip install torch torchvision diffusers transformers accelerate peft huggingface_hub gradio pandas tqdm
16
-
17
- Optional (ChronoEdit speedups): pip install chronoedit-diffusers flash-attn
18
  """
19
 
20
- import os
21
- import tempfile
22
  from pathlib import Path
23
- from typing import Optional, Tuple, List
24
-
25
- import torch
26
- import torch.nn as nn
27
- from torch.utils.data import Dataset, DataLoader
28
- import torchvision
29
- import torchvision.transforms as T
30
- import pandas as pd
31
- import numpy as np
32
- import gradio as gr
33
  from tqdm.auto import tqdm
34
-
35
  from huggingface_hub import create_repo, upload_folder, hf_hub_download
36
-
37
  from diffusers import DiffusionPipeline
 
 
 
 
 
38
 
39
- # Optional ChronoEdit
40
  try:
41
  from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
42
  CHRONOEDIT_AVAILABLE = True
43
  except Exception:
44
  CHRONOEDIT_AVAILABLE = False
45
 
46
- # PEFT + Accelerate
47
- try:
48
- from peft import LoraConfig, get_peft_model
49
- except Exception as e:
50
- raise RuntimeError("Install peft (pip install peft)") from e
51
-
52
  try:
53
- from accelerate import Accelerator
54
- except Exception as e:
55
- raise RuntimeError("Install accelerate (pip install accelerate)") from e
 
56
 
57
- # ---------- config ----------
58
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
59
- IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
60
  VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
61
- # ---------------------------
62
 
63
- def is_hub_repo_like(s: str) -> bool:
64
- return "/" in s and not Path(s).exists()
65
 
66
- def download_from_hf(repo_id: str, filename: str, token: Optional[str] = None) -> str:
67
  token = token or os.environ.get("HF_TOKEN")
68
- return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
69
-
70
 
71
-
72
- def find_target_modules(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
73
- names = [n for n, _ in model.named_modules()]
74
- selected = set()
75
- for cand in candidates:
76
- for n in names:
77
- if cand in n:
78
- selected.add(n.split(".")[-1])
79
- if not selected:
80
- return ["to_q", "to_k", "to_v", "to_out"]
81
- return list(selected)
82
-
83
- # -------------------------
84
- # Dataset: CSV or Parquet
85
- # -------------------------
86
  class MediaTextDataset(Dataset):
87
- """
88
- Loads records from CSV or parquet with columns:
89
- - file_name
90
- - text
91
- file_name can be a local path relative to dataset_dir, or a filename when using HF repo.
92
- """
93
- def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5,
94
- image_size=(512,512), video_frame_size=(128,256), hub_token: Optional[str] = None):
95
- self.source = dataset_source
96
- self.is_hub = is_hub_repo_like(dataset_source)
97
- self.max_frames = max_frames
98
- self.image_size = image_size
99
- self.video_frame_size = video_frame_size
100
- self.hub_token = hub_token or os.environ.get("HF_TOKEN")
101
- self.tmpdir = None
102
-
103
- # load df (csv or parquet)
104
  if self.is_hub:
105
- # download file from hub to local cache (hf_hub_download returns cached path)
106
- csv_local = download_from_hf(self.source, csv_name, token=self.hub_token)
107
- # load via pandas (auto-detect extension)
108
- if csv_local.endswith(".parquet"):
109
- df = pd.read_parquet(csv_local)
110
- else:
111
- df = pd.read_csv(csv_local)
112
- self.df = df
113
- self.root = None
114
  else:
115
- root = Path(dataset_source)
116
- csv_path_csv = root / csv_name
117
- csv_path_parquet = root / csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else root / (csv_name + ".parquet")
118
- if csv_path_csv.exists():
119
- self.df = pd.read_csv(csv_path_csv)
120
- elif csv_path_parquet.exists():
121
- self.df = pd.read_parquet(csv_path_parquet)
122
- else:
123
- # try given csv_name as parquet/csv
124
- p = root / csv_name
125
- if p.exists():
126
- if p.suffix.lower() == ".parquet":
127
- self.df = pd.read_parquet(p)
128
- else:
129
- self.df = pd.read_csv(p)
130
- else:
131
- raise FileNotFoundError(f"Can't find {csv_name} in {dataset_source}")
132
- self.root = root
133
-
134
- # transforms
135
- self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
136
- self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
137
-
138
- def __len__(self):
139
- return len(self.df)
140
-
141
- def _maybe_download_from_hub(self, file_name: str) -> str:
142
- if self.root is not None:
143
- p = self.root / file_name
144
- if p.exists():
145
- return str(p)
146
- # else download from hub
147
- return download_from_hf(self.source, file_name, token=self.hub_token)
148
-
149
- def _read_video_frames(self, path: str, num_frames: int):
150
- video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec')
151
- total = len(video_frames)
152
- if total == 0:
153
- C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
154
- return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
155
- if total < num_frames:
156
- idxs = list(range(total)) + [total-1]*(num_frames-total)
157
- else:
158
- idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
159
- frames = []
160
- for i in idxs:
161
- arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
162
- frames.append(self.video_transform(arr))
163
- frames = torch.stack(frames, dim=0) # [T, C, H, W]
164
- return frames
165
-
166
- def __getitem__(self, idx):
167
- rec = self.df.iloc[idx]
168
- file_name = rec["file_name"]
169
- caption = rec["text"]
170
- if self.is_hub:
171
- local_path = self._maybe_download_from_hub(file_name)
172
- else:
173
- local_path = str(Path(self.root) / file_name)
174
- p = Path(local_path)
175
- suffix = p.suffix.lower()
176
- if suffix in IMAGE_EXTS:
177
- img = torchvision.io.read_image(local_path) # [C,H,W]
178
- if isinstance(img, torch.Tensor):
179
- img = img.permute(1,2,0).numpy()
180
- return {"type": "image", "image": self.image_transform(img), "caption": caption, "file_name": file_name}
181
- elif suffix in VIDEO_EXTS:
182
- frames = self._read_video_frames(local_path, self.max_frames) # [T,C,H,W]
183
- return {"type": "video", "frames": frames, "caption": caption, "file_name": file_name}
184
- else:
185
- raise RuntimeError(f"Unsupported media type: {local_path}")
186
-
187
- # -------------------------
188
- # Pipeline / LoRA helpers
189
- # -------------------------
190
- def load_pipeline_auto(base_model_id: str, torch_dtype=torch.float16):
191
- is_chrono = "chrono" in base_model_id.lower() or "chronoedit" in base_model_id.lower()
192
- if CHRONOEDIT_AVAILABLE and is_chrono:
193
- print(f"Loading ChronoEdit pipeline: {base_model_id}")
194
- pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
195
  else:
196
- print(f"Loading standard Diffusers pipeline: {base_model_id}")
197
- pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
198
- return pipe
199
-
200
- def attach_lora(pipe, target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0):
201
- if target == "unet":
202
- if not hasattr(pipe, "unet"):
203
- raise RuntimeError("Pipeline has no UNet for this model")
204
- target_module = pipe.unet
205
- attr = "unet"
206
- elif target == "transformer":
207
- if not hasattr(pipe, "transformer"):
208
- raise RuntimeError("Pipeline has no transformer for this model")
209
- target_module = pipe.transformer
210
- attr = "transformer"
211
- elif target == "text_encoder":
212
- if not hasattr(pipe, "text_encoder"):
213
- raise RuntimeError("Pipeline has no text_encoder for this model")
214
- target_module = pipe.text_encoder
215
- attr = "text_encoder"
216
  else:
217
- raise RuntimeError("Unknown adapter target")
218
 
219
- target_modules = find_target_modules(target_module)
220
- print("LoRA target sub-module names detected:", target_modules)
221
- lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=dropout, bias="none", task_type="SEQ_2_SEQ_LM")
222
- peft_model = get_peft_model(target_module, lora_config)
223
- setattr(pipe, attr, peft_model)
224
- return pipe, attr
225
 
226
- # -------------------------
227
- # Training loop (Accelerate)
228
- # -------------------------
229
- def train_lora_accelerate(base_model_id: str,
230
- dataset_source: str,
231
- csv_name: str,
232
- adapter_target: str,
233
- output_dir: str,
234
- epochs: int = 1,
235
- batch_size: int = 1,
236
- lr: float = 1e-4,
237
- max_train_steps: Optional[int] = None,
238
- lora_r: int = 8,
239
- lora_alpha: int = 16,
240
- max_frames: int = 5,
241
- save_every_steps: int = 200) -> Tuple[str, List[str]]:
242
  accelerator = Accelerator()
243
- device = accelerator.device
244
-
245
- # load pipeline
246
- pipe = load_pipeline_auto(base_model_id, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
247
-
248
- dataset = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=max_frames)
249
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)
250
-
251
- # attach LoRA
252
- pipe, _ = attach_lora(pipe, adapter_target, r=lora_r, alpha=lora_alpha)
253
- # select peft module for optimizer
254
- if adapter_target == "unet":
255
- peft_module = pipe.unet
256
- elif adapter_target == "transformer":
257
- peft_module = pipe.transformer
258
- else:
259
- peft_module = pipe.text_encoder
260
-
261
- trainable_params = [p for _, p in peft_module.named_parameters() if p.requires_grad]
262
- optimizer = torch.optim.AdamW(trainable_params, lr=lr)
263
-
264
- # Prepare with accelerator
265
- peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader)
266
-
267
- # If the pipeline has other parts required during training (vae, scheduler...), we'll call them on CPU/GPU directly.
268
- logs = []
269
- global_step = 0
270
- loss_fn = nn.MSELoss()
271
-
272
- if hasattr(pipe, "scheduler"):
273
- pipe.scheduler.set_timesteps(50, device=device)
274
- timesteps = pipe.scheduler.timesteps
275
- else:
276
- timesteps = None
277
-
278
- for epoch in range(epochs):
279
- pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
280
- for batch in pbar:
281
- ex = batch[0]
282
- if ex["type"] == "image":
283
- # image training flow (SD-like)
284
- img = ex["image"].unsqueeze(0).to(device)
285
- caption = [ex["caption"]]
286
-
287
- if not hasattr(pipe, "encode_prompt"):
288
- raise RuntimeError("Pipeline lacks encode_prompt (can't encode text prompts)")
289
-
290
- prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt=caption, negative_prompt=None, do_classifier_free_guidance=True, num_videos_per_prompt=1, prompt_embeds=None, negative_prompt_embeds=None, max_sequence_length=512, device=device)
291
-
292
- # VAE encode
293
- if not hasattr(pipe, "vae"):
294
- raise RuntimeError("Pipeline lacks VAE required for latent conversion")
295
- with torch.no_grad():
296
- latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
297
-
298
- noise = torch.randn_like(latents).to(device)
299
- t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
300
- noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
301
-
302
- # call peft_module (unet) - adapt to common return types
303
- # peft_module was prepared by accelerator and is on device
304
- unet_out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)
305
- # unet can return ModelOutput with .sample or tuple
306
- if hasattr(unet_out, "sample"):
307
- noise_pred = unet_out.sample
308
- elif isinstance(unet_out, tuple):
309
- noise_pred = unet_out[0]
310
- else:
311
- # Try to find tensor in object
312
- noise_pred = unet_out
313
-
314
- loss = loss_fn(noise_pred, noise)
315
-
316
- else:
317
- # video training (ChronoEdit simplified)
318
- if not CHRONOEDIT_AVAILABLE:
319
- raise RuntimeError("ChronoEdit training requested but chronoedit_diffusers not installed")
320
- frames = ex["frames"].unsqueeze(0).to(device) # [1, T, C, H, W]
321
- frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
322
- video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
323
- latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device, generator=None, latents=None, last_image=None)
324
- if pipe.config.expand_timesteps:
325
- latents, condition, first_frame_mask = latents_out
326
- else:
327
- latents, condition = latents_out
328
- first_frame_mask = None
329
-
330
- noise = torch.randn_like(latents).to(device)
331
- t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
332
- noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
333
-
334
- if pipe.config.expand_timesteps:
335
- latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * noisy_latents
336
- else:
337
- latent_model_input = torch.cat([noisy_latents, condition], dim=1)
338
-
339
- # transformer forward (peft_module)
340
- trans_out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]), encoder_hidden_states=None, encoder_hidden_states_image=None, return_dict=False)
341
- noise_pred = trans_out[0] if isinstance(trans_out, tuple) else trans_out
342
- loss = loss_fn(noise_pred, noise)
343
-
344
- # backward + step via accelerator
345
- accelerator.backward(loss) if 'accelerator' in globals() else loss.backward()
346
- optimizer.step()
347
- optimizer.zero_grad()
348
- global_step += 1
349
-
350
- logs.append(f"step {global_step} loss {loss.item():.6f}")
351
- pbar.set_postfix({"loss": f"{loss.item():.6f}"})
352
-
353
- if max_train_steps and global_step >= max_train_steps:
354
- break
355
-
356
- if global_step % save_every_steps == 0:
357
- out_sub = Path(output_dir) / f"lora_step_{global_step}"
358
- out_sub.mkdir(parents=True, exist_ok=True)
359
- try:
360
- peft_module.save_pretrained(str(out_sub))
361
- except Exception:
362
- torch.save({k: v.cpu() for k, v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
363
- print(f"Saved adapter at {out_sub}")
364
-
365
- if max_train_steps and global_step >= max_train_steps:
366
- break
367
-
368
- # final save
369
- Path(output_dir).mkdir(parents=True, exist_ok=True)
370
- try:
371
- peft_module.save_pretrained(output_dir)
372
- except Exception:
373
- torch.save({k: v.cpu() for k, v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
374
-
375
- return output_dir, logs
376
-
377
- # -------------------------
378
- # Test generation
379
- # -------------------------
380
- def test_generation_load_and_run(base_model_id: str, adapter_dir: Optional[str], adapter_target: str, prompt: str, num_inference_steps: int = 8):
381
- pipe = load_pipeline_auto(base_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
382
- # if adapter_dir provided, try to load adapter weights into target module (best-effort)
383
- if adapter_dir:
384
- try:
385
- if adapter_target == "unet" and hasattr(pipe, "unet"):
386
- # wrap unet with a matching peft config and load
387
- lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.unet))
388
- pipe.unet = get_peft_model(pipe.unet, lcfg)
389
- try:
390
- pipe.unet.load_state_dict(torch.load(Path(adapter_dir) / "pytorch_model.bin"), strict=False)
391
- except Exception:
392
- try:
393
- pipe.unet.load_adapter(adapter_dir)
394
- except Exception:
395
- pass
396
- elif adapter_target == "transformer" and hasattr(pipe, "transformer"):
397
- lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.transformer))
398
- pipe.transformer = get_peft_model(pipe.transformer, lcfg)
399
- elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"):
400
- lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.text_encoder))
401
- pipe.text_encoder = get_peft_model(pipe.text_encoder, lcfg)
402
- except Exception as e:
403
- print("Adapter load attempt warning:", e)
404
-
405
- pipe.to(DEVICE)
406
- out = pipe(prompt=prompt, num_inference_steps=num_inference_steps)
407
- if hasattr(out, "images"):
408
- return out.images[0]
409
- elif hasattr(out, "frames"):
410
- frames = out.frames[0]
411
- from PIL import Image
412
- return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8"))
413
- else:
414
- raise RuntimeError("No images or frames returned")
415
-
416
- # -------------------------
417
- # Upload adapter
418
- # -------------------------
419
- def upload_adapter(local_dir: str, repo_id: str) -> str:
420
- token = os.environ.get("HF_TOKEN")
421
- if token is None:
422
- raise RuntimeError("HF_TOKEN not set in environment for upload")
423
  create_repo(repo_id, exist_ok=True)
424
- upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token)
425
  return f"https://huggingface.co/{repo_id}"
426
 
427
- # -------------------------
428
- # Gradio UI
429
- # -------------------------
430
- def run_all_ui(base_model_id: str,
431
- dataset_source: str,
432
- csv_name: str,
433
- task_type: str,
434
- adapter_target: str,
435
- lora_r: int,
436
- lora_alpha: int,
437
- epochs: int,
438
- batch_size: int,
439
- lr: float,
440
- max_train_steps: int,
441
- output_dir: str,
442
- upload_repo: str,
443
- save_every_steps: int):
444
- # minor mapping: QwenEdit/ prompt-lora -> text_encoder
445
- if task_type == "prompt-lora":
446
- adapter_target = "text_encoder"
447
-
448
- try:
449
- out_dir, logs = train_lora_accelerate(base_model_id, dataset_source, csv_name, adapter_target, output_dir,
450
- epochs=epochs, batch_size=batch_size, lr=lr, max_train_steps=(max_train_steps if max_train_steps>0 else None),
451
- lora_r=lora_r, lora_alpha=lora_alpha, max_frames=5, save_every_steps=save_every_steps)
452
- except Exception as e:
453
- return f"Training failed: {e}", None, None
454
-
455
- link = None
456
- if upload_repo:
457
- try:
458
- link = upload_adapter(out_dir, upload_repo)
459
- except Exception as e:
460
- link = f"Upload failed: {e}"
461
-
462
- # quick test: use first prompt from dataset
463
- try:
464
- ds = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=5)
465
- test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
466
- except Exception:
467
- test_prompt = "A cat on a skateboard"
468
-
469
- test_img = None
470
- try:
471
- test_img = test_generation_load_and_run(base_model_id, out_dir, adapter_target, test_prompt)
472
- except Exception as e:
473
- print("Test gen failed:", e)
474
-
475
- return "\n".join(logs[-200:]), test_img, link
476
-
477
- def build_ui():
478
  with gr.Blocks() as demo:
479
- gr.Markdown("# Universal LoRA Trainer (single-file) Accelerate + PEFT")
 
 
 
 
 
 
 
480
  with gr.Row():
481
- with gr.Column(scale=2):
482
- base_model = gr.Textbox(label="Base model id (Diffusers)", value="runwayml/stable-diffusion-v1-5")
483
- dataset_source = gr.Textbox(label="Dataset folder or HF repo (e.g. user/repo)", value="./dataset")
484
- csv_name = gr.Textbox(label="CSV/Parquet filename", value="dataset.csv")
485
- task_type = gr.Dropdown(label="Task type", choices=["text-image", "text-video", "prompt-lora"], value="text-image")
486
- adapter_target = gr.Dropdown(label="Adapter target (unet/transformer/text_encoder)", choices=["unet", "transformer", "text_encoder"], value="unet")
487
- lora_r = gr.Slider(1, 32, value=8, step=1, label="LoRA rank (r)")
488
- lora_alpha = gr.Slider(1, 64, value=16, step=1, label="LoRA alpha")
489
- epochs = gr.Number(label="Epochs", value=1)
490
- batch_size = gr.Number(label="Batch size (per device)", value=1)
491
- lr = gr.Number(label="Learning rate", value=1e-4)
492
- max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0)
493
- save_every_steps = gr.Number(label="Save every steps", value=200)
494
- output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
495
- upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional, username/repo)", value="")
496
- start_btn = gr.Button("Start training")
497
- with gr.Column(scale=1):
498
- logs = gr.Textbox(label="Training logs (tail)", lines=18)
499
- sample_image = gr.Image(label="Sample generated frame after training")
500
- def on_start(base_model_id, dataset_source, csv_name, task_type, adapter_target, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, save_every_steps):
501
- return run_all_ui(base_model_id, dataset_source, csv_name, task_type, adapter_target, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, int(save_every_steps))
502
- start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, save_every_steps], outputs=[logs, sample_image, gr.Textbox()])
503
  return demo
504
 
505
- if __name__ == "__main__":
506
- demo = build_ui()
507
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # universal_lora_trainer_accelerate_singlefile_dynamic.py
2
  """
3
+ Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
4
+ - Auto-detects base model type (Flux, SD, ChronoEdit, QwenEdit, etc.)
5
+ - Auto-selects correct adapter target (unet, transformer, text_encoder)
6
+ - Supports CSV and Parquet datasets
7
+ - Uploads adapter to HF Hub using HF_TOKEN (env only)
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
+ import os, torch, gradio as gr, pandas as pd, numpy as np
 
11
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
12
  from tqdm.auto import tqdm
 
13
  from huggingface_hub import create_repo, upload_folder, hf_hub_download
 
14
  from diffusers import DiffusionPipeline
15
+ from torch.utils.data import Dataset, DataLoader
16
+ import torchvision.transforms as T, torchvision
17
+ from peft import LoraConfig, get_peft_model
18
+ from accelerate import Accelerator
19
+ import torch.nn as nn
20
 
21
+ # Optional: ChronoEdit + QwenEdit
22
  try:
23
  from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
24
  CHRONOEDIT_AVAILABLE = True
25
  except Exception:
26
  CHRONOEDIT_AVAILABLE = False
27
 
 
 
 
 
 
 
28
  try:
29
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
30
+ QWENEDIT_AVAILABLE = True
31
+ except Exception:
32
+ QWENEDIT_AVAILABLE = False
33
 
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+ IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
36
  VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
 
37
 
38
+ def is_hub_repo_like(s): return "/" in s and not Path(s).exists()
 
39
 
40
+ def download_from_hf(repo_id, filename, token=None):
41
  token = token or os.environ.get("HF_TOKEN")
42
+ return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
 
43
 
44
+ # ---------------- Dataset ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class MediaTextDataset(Dataset):
46
+ def __init__(self, source, csv_name="dataset.csv", max_frames=5):
47
+ self.is_hub = is_hub_repo_like(source)
48
+ self.source = source
49
+ token = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if self.is_hub:
51
+ file_path = download_from_hf(source, csv_name, token)
 
 
 
 
 
 
 
 
52
  else:
53
+ file_path = Path(source) / csv_name
54
+ if not Path(file_path).exists():
55
+ alt = Path(str(file_path).replace(".csv", ".parquet"))
56
+ if alt.exists(): file_path = alt
57
+ self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
58
+ self.root = Path(source) if not self.is_hub else None
59
+ self.img_tf = T.Compose([T.ToPILImage(), T.Resize((512,512)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
60
+ self.video_tf = T.Compose([T.ToPILImage(), T.Resize((128,256)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
61
+ self.max_frames = max_frames
62
+ def __len__(self): return len(self.df)
63
+ def _maybe_dl(self, fname): return str(Path(self.root)/fname) if self.root else download_from_hf(self.source, fname)
64
+ def __getitem__(self, i):
65
+ rec = self.df.iloc[i]
66
+ p = Path(self._maybe_dl(rec["file_name"]))
67
+ if p.suffix.lower() in IMAGE_EXTS:
68
+ img = torchvision.io.read_image(str(p))
69
+ if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy()
70
+ return {"type": "image", "image": self.img_tf(img), "caption": rec["text"]}
71
+ elif p.suffix.lower() in VIDEO_EXTS:
72
+ vid,_,_ = torchvision.io.read_video(str(p))
73
+ total, idxs = len(vid), []
74
+ if total == 0: return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))}
75
+ if total < self.max_frames: idxs = list(range(total))+[total-1]*(self.max_frames-total)
76
+ else: idxs = np.linspace(0,total-1,self.max_frames).round().astype(int)
77
+ frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs])
78
+ return {"type": "video", "frames": frames, "caption": rec["text"]}
79
+ else: raise RuntimeError(f"Unsupported file {p}")
80
+
81
+ # ---------------- Dynamic pipeline loader ----------------
82
+ def load_pipeline_auto(base_model, dtype=torch.float16):
83
+ low = base_model.lower()
84
+ if "chrono" in low and CHRONOEDIT_AVAILABLE:
85
+ print(f"Using ChronoEdit pipeline for {base_model}")
86
+ return ChronoEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
87
+ elif "qwen" in low and QWENEDIT_AVAILABLE:
88
+ print(f"Using QwenEdit pipeline for {base_model}")
89
+ return QwenImageEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  else:
91
+ print(f"Using Diffusion pipeline for {base_model}")
92
+ return DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
93
+
94
+ def infer_target_for_task(task_type, model_name):
95
+ if task_type == "prompt-lora" or "qwen" in model_name.lower():
96
+ return "text_encoder"
97
+ elif task_type == "text-video" or "chrono" in model_name.lower() or "wan" in model_name.lower():
98
+ return "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
99
  else:
100
+ return "unet"
101
 
102
+ def find_target_modules(model):
103
+ names = [n for n,_ in model.named_modules()]
104
+ targets = [n.split(".")[-1] for n in names if any(k in n for k in ["to_q","to_k","to_v","q_proj","v_proj"])]
105
+ return targets or ["to_q","to_k","to_v","to_out"]
 
 
106
 
107
+ # ---------------- Training ----------------
108
+ def train_lora(base_model, dataset_src, csv_name, task_type, output_dir, epochs=1, lr=1e-4, r=8, alpha=16):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  accelerator = Accelerator()
110
+ pipe = load_pipeline_auto(base_model)
111
+ target = infer_target_for_task(task_type, base_model)
112
+ if not hasattr(pipe, target): raise RuntimeError(f"Pipeline has no {target}")
113
+ lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=find_target_modules(getattr(pipe, target)), lora_dropout=0.0)
114
+ lora_module = get_peft_model(getattr(pipe, target), lcfg)
115
+ dataset = MediaTextDataset(dataset_src, csv_name)
116
+ loader = DataLoader(dataset, batch_size=1, shuffle=True)
117
+ lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader)
118
+ mse = nn.MSELoss(); logs=[]
119
+ for ep in range(epochs):
120
+ for i,b in enumerate(tqdm(loader, desc=f"Epoch {ep+1}")):
121
+ ex = b[0]; loss=torch.tensor(0.0, device=DEVICE)
122
+ if ex["type"]=="image" and hasattr(pipe,"vae"):
123
+ img=ex["image"].unsqueeze(0).to(DEVICE)
124
+ lat=pipe.vae.encode(img).latent_dist.sample()*pipe.vae.config.scaling_factor
125
+ noise=torch.randn_like(lat); loss=mse(lat,noise)
126
+ accelerator.backward(loss); opt.step(); opt.zero_grad()
127
+ logs.append(f"step {i} loss {loss.item():.4f}")
128
+ Path(output_dir).mkdir(exist_ok=True)
129
+ lora_module.save_pretrained(output_dir)
130
+ return output_dir, logs[-20:]
131
+
132
+ # ---------------- Upload ----------------
133
+ def upload_adapter(local, repo_id):
134
+ token=os.environ.get("HF_TOKEN")
135
+ if not token: raise RuntimeError("HF_TOKEN missing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  create_repo(repo_id, exist_ok=True)
137
+ upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
138
  return f"https://huggingface.co/{repo_id}"
139
 
140
+ # ---------------- Gradio UI ----------------
141
+ def run_ui():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  with gr.Blocks() as demo:
143
+ gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Flux / ChronoEdit / QwenEdit)")
144
+ with gr.Row():
145
+ base_model=gr.Textbox(label="Base model", value="black-forest-labs/FLUX.1-dev")
146
+ dataset=gr.Textbox(label="Dataset folder or HF repo", value="./dataset")
147
+ csvname=gr.Textbox(label="CSV/Parquet file", value="dataset.csv")
148
+ task=gr.Dropdown(["text-image","text-video","prompt-lora"], label="Task type", value="text-image")
149
+ out=gr.Textbox(label="Output dir", value="./adapter_out")
150
+ repo=gr.Textbox(label="Upload HF repo (optional)", value="")
151
  with gr.Row():
152
+ r=gr.Slider(1,64,value=8,label="LoRA rank"); a=gr.Slider(1,64,value=16,label="LoRA alpha")
153
+ ep=gr.Number(value=1,label="Epochs"); lr=gr.Number(value=1e-4,label="Learning rate")
154
+ btn=gr.Button("🚀 Start Training")
155
+ logs=gr.Textbox(label="Logs", lines=12)
156
+ img=gr.Image(label="Sample Output (optional)")
157
+ def launch(bm,ds,csv,t,out_dir,r_,a_,ep_,lr_,repo_):
158
+ try:
159
+ out,log=train_lora(bm,ds,csv,t,out_dir,int(ep_),float(lr_),int(r_),int(a_))
160
+ link=upload_adapter(out,repo_) if repo_ else None
161
+ return "\n".join(log), None, link
162
+ except Exception as e:
163
+ return f" {e}", None, None
164
+ btn.click(launch,[base_model,dataset,csvname,task,out,r,a,ep,lr,repo],[logs,img,gr.Textbox()])
 
 
 
 
 
 
 
 
 
165
  return demo
166
 
167
+ if __name__=="__main__":
168
+ run_ui().launch(server_name="0.0.0.0",server_port=7860)