rahul7star commited on
Commit
ee3f61d
·
verified ·
1 Parent(s): d58b801

Create app_quant.py

Browse files
Files changed (1) hide show
  1. app_quant.py +702 -0
app_quant.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # universal_lora_trainer_quant_dynamic.py
2
+ """
3
+ Universal Dynamic LoRA Trainer (Accelerate + PEFT) with optional QLoRA 4-bit support.
4
+
5
+ - Supports CSV and Parquet dataset files (columns: file_name, text)
6
+ - Accepts dataset from a local folder or Hugging Face dataset repo id (username/repo)
7
+ - Real LoRA training (PEFT) for:
8
+ * text->image (UNet)
9
+ * text->video (ChronoEdit transformer)
10
+ * prompt-enhancer (text_encoder / QwenEdit)
11
+ - Optional:
12
+ * 4-bit quantization (bitsandbytes / QLoRA)
13
+ * xFormers / FlashAttention
14
+ * AdaLoRA (if available)
15
+ - Uses HF_TOKEN from environment for upload
16
+ - Use `accelerate launch` for multi-GPU / optimized run
17
+ """
18
+
19
+ import os
20
+ import math
21
+ import tempfile
22
+ from pathlib import Path
23
+ from typing import Optional, Tuple, List
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from torch.utils.data import Dataset, DataLoader
28
+ import torchvision
29
+ import torchvision.transforms as T
30
+ import pandas as pd
31
+ import numpy as np
32
+ import gradio as gr
33
+ from tqdm.auto import tqdm
34
+
35
+ from huggingface_hub import create_repo, upload_folder, hf_hub_download, list_repo_files
36
+
37
+ from diffusers import DiffusionPipeline
38
+
39
+ # optional pip installs - guard imports
40
+ try:
41
+ from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
42
+ CHRONOEDIT_AVAILABLE = True
43
+ except Exception:
44
+ CHRONOEDIT_AVAILABLE = False
45
+
46
+ # Qwen image edit optional
47
+ try:
48
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline # optional name
49
+ QWENEDIT_AVAILABLE = True
50
+ except Exception:
51
+ QWENEDIT_AVAILABLE = False
52
+
53
+ # BitsAndBytes (quantization)
54
+ try:
55
+ from transformers import BitsAndBytesConfig
56
+ BNB_AVAILABLE = True
57
+ except Exception:
58
+ BitsAndBytesConfig = None
59
+ BNB_AVAILABLE = False
60
+
61
+ # xFormers
62
+ try:
63
+ import xformers # noqa
64
+ XFORMERS_AVAILABLE = True
65
+ except Exception:
66
+ XFORMERS_AVAILABLE = False
67
+
68
+ # PEFT / AdaLoRA
69
+ try:
70
+ from peft import LoraConfig, get_peft_model
71
+ try:
72
+ from peft import AdaLoraConfig # optional
73
+ ADALORA_AVAILABLE = True
74
+ except Exception:
75
+ AdaLoraConfig = None
76
+ ADALORA_AVAILABLE = False
77
+ except Exception as e:
78
+ raise RuntimeError("Install peft: pip install peft") from e
79
+
80
+ # Accelerate
81
+ try:
82
+ from accelerate import Accelerator
83
+ except Exception as e:
84
+ raise RuntimeError("Install accelerate: pip install accelerate") from e
85
+
86
+ # ------------------------
87
+ # Config
88
+ # ------------------------
89
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
90
+ IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
91
+ VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
92
+
93
+ # ------------------------
94
+ # Utilities
95
+ # ------------------------
96
+ def is_hub_repo_like(s: str) -> bool:
97
+ return "/" in s and not Path(s).exists()
98
+
99
+ def download_from_hf(repo_id: str, filename: str, token: Optional[str] = None, repo_type: str = "dataset") -> str:
100
+ token = token or os.environ.get("HF_TOKEN")
101
+ return hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token, repo_type=repo_type)
102
+
103
+ def try_list_repo_files(repo_id: str, repo_type: str = "dataset", token: Optional[str] = None):
104
+ token = token or os.environ.get("HF_TOKEN")
105
+ try:
106
+ return list_repo_files(repo_id, token=token, repo_type=repo_type)
107
+ except Exception:
108
+ return []
109
+
110
+ def find_target_modules(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
111
+ names = [n for n, _ in model.named_modules()]
112
+ selected = set()
113
+ for cand in candidates:
114
+ for n in names:
115
+ if cand in n:
116
+ selected.add(n.split(".")[-1])
117
+ if not selected:
118
+ return ["to_q", "to_k", "to_v", "to_out"]
119
+ return list(selected)
120
+
121
+ # ------------------------
122
+ # Dataset class (CSV/Parquet)
123
+ # ------------------------
124
+ class MediaTextDataset(Dataset):
125
+ """
126
+ Loads records from CSV or Parquet with columns:
127
+ - file_name (relative path in folder or filename inside HF dataset repo)
128
+ - text
129
+ """
130
+ def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5,
131
+ image_size=(512,512), video_frame_size=(128,256), hub_token: Optional[str] = None):
132
+ self.source = dataset_source
133
+ self.is_hub = is_hub_repo_like(dataset_source)
134
+ self.max_frames = max_frames
135
+ self.image_size = image_size
136
+ self.video_frame_size = video_frame_size
137
+ self.hub_token = hub_token or os.environ.get("HF_TOKEN")
138
+
139
+ # load dataframe (CSV or parquet)
140
+ if self.is_hub:
141
+ # try CSV then parquet; specify repo_type="dataset"
142
+ searched = try_list_repo_files(self.source, repo_type="dataset", token=self.hub_token)
143
+ # prefer exact csv_name
144
+ try:
145
+ csv_local = download_from_hf(self.source, csv_name, token=self.hub_token, repo_type="dataset")
146
+ except Exception:
147
+ # try .parquet variant
148
+ alt = csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else csv_name + ".parquet"
149
+ csv_local = download_from_hf(self.source, alt, token=self.hub_token, repo_type="dataset")
150
+ if str(csv_local).endswith(".parquet"):
151
+ df = pd.read_parquet(csv_local)
152
+ else:
153
+ df = pd.read_csv(csv_local)
154
+ self.df = df
155
+ self.root = None
156
+ else:
157
+ root = Path(dataset_source)
158
+ csv_path = root / csv_name
159
+ parquet_path = root / csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else root / (csv_name + ".parquet")
160
+ if csv_path.exists():
161
+ self.df = pd.read_csv(csv_path)
162
+ elif parquet_path.exists():
163
+ self.df = pd.read_parquet(parquet_path)
164
+ else:
165
+ p = root / csv_name
166
+ if p.exists():
167
+ if p.suffix.lower() == ".parquet":
168
+ self.df = pd.read_parquet(p)
169
+ else:
170
+ self.df = pd.read_csv(p)
171
+ else:
172
+ raise FileNotFoundError(f"Can't find {csv_name} in {dataset_source}")
173
+ self.root = root
174
+
175
+ # transforms
176
+ self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
177
+ self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
178
+
179
+ def __len__(self):
180
+ return len(self.df)
181
+
182
+ def _maybe_download_from_hub(self, file_name: str) -> str:
183
+ if self.root is not None:
184
+ p = self.root / file_name
185
+ if p.exists():
186
+ return str(p)
187
+ # else download from dataset repo
188
+ return download_from_hf(self.source, file_name, token=self.hub_token, repo_type="dataset")
189
+
190
+ def _read_video_frames(self, path: str, num_frames: int):
191
+ video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec')
192
+ total = len(video_frames)
193
+ if total == 0:
194
+ C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
195
+ return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
196
+ if total < num_frames:
197
+ idxs = list(range(total)) + [total-1]*(num_frames-total)
198
+ else:
199
+ idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
200
+ frames = []
201
+ for i in idxs:
202
+ arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
203
+ frames.append(self.video_transform(arr))
204
+ frames = torch.stack(frames, dim=0)
205
+ return frames
206
+
207
+ def __getitem__(self, idx):
208
+ rec = self.df.iloc[idx]
209
+ file_name = rec["file_name"]
210
+ caption = rec["text"]
211
+ if self.is_hub:
212
+ local_path = self._maybe_download_from_hub(file_name)
213
+ else:
214
+ local_path = str(Path(self.root) / file_name)
215
+ p = Path(local_path)
216
+ suffix = p.suffix.lower()
217
+ if suffix in IMAGE_EXTS:
218
+ img = torchvision.io.read_image(local_path) # [C,H,W]
219
+ if isinstance(img, torch.Tensor):
220
+ img = img.permute(1,2,0).numpy()
221
+ return {'type': 'image', 'image': self.image_transform(img), 'caption': caption, 'file_name': file_name}
222
+ elif suffix in VIDEO_EXTS:
223
+ frames = self._read_video_frames(local_path, self.max_frames) # [T,C,H,W]
224
+ return {'type': 'video', 'frames': frames, 'caption': caption, 'file_name': file_name}
225
+ else:
226
+ raise RuntimeError(f"Unsupported media type: {local_path}")
227
+
228
+ # ------------------------
229
+ # Pipeline loader with optional quantization
230
+ # ------------------------
231
+ def load_pipeline_auto(base_model_id: str, use_4bit: bool = False, bnb_config: Optional[object] = None, torch_dtype=torch.float16):
232
+ low = base_model_id.lower()
233
+ is_chrono = "chrono" in low or "wan" in low or "video" in low
234
+ is_qwen = "qwen" in low or "qwenimage" in low
235
+ # choose pipeline
236
+ if is_chrono and CHRONOEDIT_AVAILABLE:
237
+ print("Loading ChronoEdit pipeline")
238
+ # ChronoEdit may not accept quant config; try with safer call
239
+ if use_4bit and bnb_config is not None:
240
+ pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) # quantized loading of chronoedit not widely supported
241
+ else:
242
+ pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
243
+ elif is_qwen and QWENEDIT_AVAILABLE:
244
+ print("Loading QWEN image-edit pipeline")
245
+ pipe = QwenImageEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
246
+ else:
247
+ # fallback to DiffusionPipeline - supports quantization_config for diffusers+transformers
248
+ print("Loading standard DiffusionPipeline:", base_model_id, "use_4bit=", use_4bit)
249
+ if use_4bit and BNB_AVAILABLE and bnb_config is not None:
250
+ pipe = DiffusionPipeline.from_pretrained(base_model_id, quantization_config=bnb_config, torch_dtype=torch.float16)
251
+ else:
252
+ pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
253
+ return pipe
254
+
255
+ # ------------------------
256
+ # Auto infer adapter target
257
+ # ------------------------
258
+ def infer_target_for_task(task_type: str, model_name: str) -> str:
259
+ low = model_name.lower()
260
+ if task_type == "prompt-lora" or "qwen" in low or "qwenedit" in low:
261
+ return "text_encoder"
262
+ if task_type == "text-video" or "chrono" in low or "wan" in low:
263
+ return "transformer"
264
+ # default
265
+ return "unet"
266
+
267
+ # ------------------------
268
+ # LoRA attach (supports AdaLoRA if available)
269
+ # ------------------------
270
+ def attach_lora(pipe, adapter_target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0, use_adalora: bool = False):
271
+ if adapter_target == "unet":
272
+ if not hasattr(pipe, "unet"):
273
+ raise RuntimeError("Pipeline has no UNet to attach LoRA")
274
+ target_module = pipe.unet
275
+ attr = "unet"
276
+ elif adapter_target == "transformer":
277
+ if not hasattr(pipe, "transformer"):
278
+ raise RuntimeError("Pipeline has no transformer to attach LoRA")
279
+ target_module = pipe.transformer
280
+ attr = "transformer"
281
+ elif adapter_target == "text_encoder":
282
+ if not hasattr(pipe, "text_encoder"):
283
+ # some models name it differently; try encoder attribute fallback
284
+ if hasattr(pipe, "text_encoder"):
285
+ target_module = pipe.text_encoder
286
+ attr = "text_encoder"
287
+ else:
288
+ raise RuntimeError("Pipeline has no text_encoder for prompt-loRA")
289
+ else:
290
+ target_module = pipe.text_encoder
291
+ attr = "text_encoder"
292
+ else:
293
+ raise RuntimeError("Unknown adapter_target")
294
+
295
+ target_modules = find_target_modules(target_module)
296
+ print("Detected target_modules for LoRA:", target_modules)
297
+
298
+ if use_adalora and ADALORA_AVAILABLE:
299
+ lora_config = AdaLoraConfig(
300
+ r=r,
301
+ lora_alpha=alpha,
302
+ target_modules=target_modules,
303
+ init_r=4,
304
+ lora_dropout=dropout,
305
+ )
306
+ else:
307
+ lora_config = LoraConfig(
308
+ r=r,
309
+ lora_alpha=alpha,
310
+ target_modules=target_modules,
311
+ lora_dropout=dropout,
312
+ bias="none",
313
+ task_type="SEQ_2_SEQ_LM",
314
+ )
315
+
316
+ peft_model = get_peft_model(target_module, lora_config)
317
+ setattr(pipe, attr, peft_model)
318
+ return pipe, attr
319
+
320
+ # ------------------------
321
+ # Training loop (Accelerate-aware)
322
+ # ------------------------
323
+ def train_lora_accelerate(base_model_id: str,
324
+ dataset_source: str,
325
+ csv_name: str,
326
+ task_type: str,
327
+ adapter_target_override: Optional[str],
328
+ output_dir: str,
329
+ epochs: int = 1,
330
+ batch_size: int = 1,
331
+ lr: float = 1e-4,
332
+ max_train_steps: Optional[int] = None,
333
+ lora_r: int = 8,
334
+ lora_alpha: int = 16,
335
+ use_4bit: bool = False,
336
+ enable_xformers: bool = False,
337
+ use_adalora: bool = False,
338
+ gradient_accumulation_steps: int = 1,
339
+ mixed_precision: Optional[str] = None,
340
+ save_every_steps: int = 200,
341
+ max_frames: int = 5):
342
+
343
+ # Setup Accelerator
344
+ accelerator = Accelerator(mixed_precision=mixed_precision or ("fp16" if torch.cuda.is_available() else "no")),
345
+ # Note: Accelerator is returned as a tuple if trailing comma; fix:
346
+ accelerator = accelerator if isinstance(accelerator, Accelerator) else accelerator[0]
347
+ device = accelerator.device
348
+
349
+ # prepare bitsandbytes config if requested
350
+ bnb_conf = None
351
+ if use_4bit and BNB_AVAILABLE:
352
+ bnb_conf = BitsAndBytesConfig(
353
+ load_in_4bit=True,
354
+ bnb_4bit_compute_dtype=torch.float16,
355
+ bnb_4bit_use_double_quant=True,
356
+ bnb_4bit_quant_type="nf4",
357
+ )
358
+
359
+ # Load pipeline (supports quant for standard diffusers)
360
+ pipe = load_pipeline_auto(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
361
+
362
+ # optionally enable memory efficient attention
363
+ if enable_xformers:
364
+ try:
365
+ if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
366
+ pipe.enable_xformers_memory_efficient_attention()
367
+ elif hasattr(pipe, "enable_attention_slicing"):
368
+ pipe.enable_attention_slicing()
369
+ print("xFormers / memory efficient attention enabled.")
370
+ except Exception as e:
371
+ print("Could not enable xformers:", e)
372
+
373
+ # infer adapter target automatically if not overridden
374
+ adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task(task_type, base_model_id)
375
+ print("Adapter target set to:", adapter_target)
376
+
377
+ # attach LoRA
378
+ pipe, attr = attach_lora(pipe, adapter_target, r=lora_r, alpha=lora_alpha, dropout=0.0, use_adalora=use_adalora)
379
+ # pick the peft module for optimization
380
+ peft_module = getattr(pipe, attr)
381
+
382
+ # dataset + dataloader (we use batch_size=1 collate)
383
+ dataset = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=max_frames)
384
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)
385
+
386
+ # optimizer
387
+ trainable_params = [p for n,p in peft_module.named_parameters() if p.requires_grad]
388
+ optimizer = torch.optim.AdamW(trainable_params, lr=lr)
389
+
390
+ # prepare objects with accelerator
391
+ peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader)
392
+
393
+ # training loop
394
+ logs = []
395
+ global_step = 0
396
+ loss_fn = nn.MSELoss()
397
+
398
+ # scheduler setup if available
399
+ if hasattr(pipe, "scheduler"):
400
+ try:
401
+ pipe.scheduler.set_timesteps(50, device=device)
402
+ timesteps = pipe.scheduler.timesteps
403
+ except Exception:
404
+ timesteps = None
405
+ else:
406
+ timesteps = None
407
+
408
+ # Training
409
+ for epoch in range(int(epochs)):
410
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
411
+ for batch in pbar:
412
+ example = batch[0]
413
+ # image flow
414
+ if example["type"] == "image":
415
+ img = example["image"].unsqueeze(0).to(device)
416
+ caption = [example["caption"]]
417
+
418
+ if not hasattr(pipe, "encode_prompt"):
419
+ raise RuntimeError("Pipeline lacks encode_prompt - cannot encode prompts")
420
+
421
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
422
+ prompt=caption,
423
+ negative_prompt=None,
424
+ do_classifier_free_guidance=True,
425
+ num_videos_per_prompt=1,
426
+ prompt_embeds=None,
427
+ negative_prompt_embeds=None,
428
+ max_sequence_length=512,
429
+ device=device,
430
+ )
431
+
432
+ if not hasattr(pipe, "vae"):
433
+ raise RuntimeError("Pipeline lacks VAE - required for latent conversion")
434
+
435
+ with torch.no_grad():
436
+ latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
437
+
438
+ noise = torch.randn_like(latents).to(device)
439
+ if timesteps is None:
440
+ t = torch.tensor(1, device=device)
441
+ else:
442
+ t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
443
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
444
+
445
+ # forward through peft_module (unet)
446
+ out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)
447
+ if hasattr(out, "sample"):
448
+ noise_pred = out.sample
449
+ elif isinstance(out, tuple):
450
+ noise_pred = out[0]
451
+ else:
452
+ noise_pred = out
453
+
454
+ loss = loss_fn(noise_pred, noise)
455
+
456
+ else:
457
+ # video flow (ChronoEdit simplified)
458
+ if not CHRONOEDIT_AVAILABLE:
459
+ raise RuntimeError("ChronoEdit training requested but not installed in environment")
460
+ frames = example["frames"].unsqueeze(0).to(device) # [1, T, C, H, W]
461
+ frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
462
+ video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
463
+ latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device, generator=None, latents=None, last_image=None)
464
+ if pipe.config.expand_timesteps:
465
+ latents, condition, first_frame_mask = latents_out
466
+ else:
467
+ latents, condition = latents_out
468
+ first_frame_mask = None
469
+
470
+ noise = torch.randn_like(latents).to(device)
471
+ t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
472
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
473
+
474
+ if pipe.config.expand_timesteps:
475
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * noisy_latents
476
+ else:
477
+ latent_model_input = torch.cat([noisy_latents, condition], dim=1)
478
+
479
+ out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]), encoder_hidden_states=None, encoder_hidden_states_image=None, return_dict=False)
480
+ noise_pred = out[0] if isinstance(out, tuple) else out
481
+ loss = loss_fn(noise_pred, noise)
482
+
483
+ # backward and optimizer step (accelerator)
484
+ accelerator.backward(loss)
485
+ optimizer.step()
486
+ optimizer.zero_grad()
487
+ global_step += 1
488
+
489
+ logs.append(f"step {global_step} loss {loss.item():.6f}")
490
+ pbar.set_postfix({"loss": f"{loss.item():.6f}"})
491
+
492
+ if max_train_steps and global_step >= max_train_steps:
493
+ break
494
+
495
+ if global_step % save_every_steps == 0:
496
+ out_sub = Path(output_dir) / f"lora_step_{global_step}"
497
+ out_sub.mkdir(parents=True, exist_ok=True)
498
+ try:
499
+ peft_module.save_pretrained(str(out_sub))
500
+ except Exception:
501
+ torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
502
+ print(f"Saved adapter at {out_sub}")
503
+
504
+ if max_train_steps and global_step >= max_train_steps:
505
+ break
506
+
507
+ # final save
508
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
509
+ try:
510
+ peft_module.save_pretrained(output_dir)
511
+ except Exception:
512
+ torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
513
+
514
+ return output_dir, logs
515
+
516
+ # ------------------------
517
+ # Test generation (best-effort)
518
+ # ------------------------
519
+ def test_generation_load_and_run(base_model_id: str, adapter_dir: Optional[str], adapter_target: str, prompt: str, use_4bit: bool = False):
520
+ # load base pipeline (no heavy quant config)
521
+ bnb_conf = None
522
+ if use_4bit and BNB_AVAILABLE:
523
+ bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
524
+ pipe = load_pipeline_auto(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
525
+
526
+ # attempt to load adapter into target module (best-effort)
527
+ try:
528
+ if adapter_target == "unet" and hasattr(pipe, "unet"):
529
+ lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.unet))
530
+ pipe.unet = get_peft_model(pipe.unet, lcfg)
531
+ try:
532
+ pipe.unet.load_state_dict(torch.load(Path(adapter_dir) / "pytorch_model.bin"), strict=False)
533
+ except Exception:
534
+ try:
535
+ pipe.unet.load_adapter(adapter_dir)
536
+ except Exception:
537
+ pass
538
+ elif adapter_target == "transformer" and hasattr(pipe, "transformer"):
539
+ lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.transformer))
540
+ pipe.transformer = get_peft_model(pipe.transformer, lcfg)
541
+ elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"):
542
+ lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.text_encoder))
543
+ pipe.text_encoder = get_peft_model(pipe.text_encoder, lcfg)
544
+ except Exception as e:
545
+ print("Adapter load warning", e)
546
+
547
+ pipe.to(DEVICE)
548
+ out = pipe(prompt=prompt, num_inference_steps=8)
549
+ if hasattr(out, "images"):
550
+ return out.images[0]
551
+ elif hasattr(out, "frames"):
552
+ frames = out.frames[0]
553
+ from PIL import Image
554
+ return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8"))
555
+ raise RuntimeError("No images/frames returned")
556
+
557
+ # ------------------------
558
+ # Upload adapter to HF Hub
559
+ # ------------------------
560
+ def upload_adapter(local_dir: str, repo_id: str) -> str:
561
+ token = os.environ.get("HF_TOKEN")
562
+ if token is None:
563
+ raise RuntimeError("HF_TOKEN not set in environment for upload")
564
+ create_repo(repo_id, exist_ok=True)
565
+ upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token)
566
+ return f"https://huggingface.co/{repo_id}"
567
+
568
+ # ------------------------
569
+ # UI: Boost info helper
570
+ # ------------------------
571
+ def boost_info_text(use_4bit: bool, enable_xformers: bool, mixed_precision: Optional[str], device_type: str):
572
+ lines = []
573
+ lines.append(f"Device: {device_type.upper()}")
574
+ if use_4bit and BNB_AVAILABLE:
575
+ lines.append("4-bit QLoRA enabled: ~4x memory saving (bitsandbytes NF4 + double quant).")
576
+ else:
577
+ lines.append("QLoRA disabled or bitsandbytes not installed.")
578
+ if enable_xformers and XFORMERS_AVAILABLE:
579
+ lines.append("xFormers/FlashAttention: memory-efficient attention enabled (faster & lower mem).")
580
+ else:
581
+ lines.append("xFormers not enabled or not installed.")
582
+ if mixed_precision:
583
+ lines.append(f"Mixed precision: {mixed_precision}")
584
+ else:
585
+ lines.append("Mixed precision: default (no automatic FP16/BF16).")
586
+ lines.append("Expected: GPU + 4-bit + xFormers = fastest. CPU + 4-bit possible but slow.")
587
+ return "\n".join(lines)
588
+
589
+ # ------------------------
590
+ # Gradio UI wiring
591
+ # ------------------------
592
+ def run_all_ui(base_model_id: str,
593
+ dataset_source: str,
594
+ csv_name: str,
595
+ task_type: str,
596
+ adapter_target_override: str,
597
+ lora_r: int,
598
+ lora_alpha: int,
599
+ epochs: int,
600
+ batch_size: int,
601
+ lr: float,
602
+ max_train_steps: int,
603
+ output_dir: str,
604
+ upload_repo: str,
605
+ use_4bit: bool,
606
+ enable_xformers: bool,
607
+ use_adalora: bool,
608
+ grad_accum: int,
609
+ mixed_precision: str,
610
+ save_every_steps: int):
611
+ # map task_type -> adapter_target if override empty
612
+ adapter_target = adapter_target_override or infer_target_for_task(task_type, base_model_id)
613
+ try:
614
+ out_dir, logs = train_lora_accelerate(
615
+ base_model_id,
616
+ dataset_source,
617
+ csv_name,
618
+ task_type,
619
+ adapter_target,
620
+ output_dir,
621
+ epochs=epochs,
622
+ batch_size=batch_size,
623
+ lr=lr,
624
+ max_train_steps=(max_train_steps if max_train_steps>0 else None),
625
+ lora_r=lora_r,
626
+ lora_alpha=lora_alpha,
627
+ use_4bit=use_4bit,
628
+ enable_xformers=enable_xformers,
629
+ use_adalora=use_adalora,
630
+ gradient_accumulation_steps=grad_accum,
631
+ mixed_precision=(mixed_precision if mixed_precision != "none" else None),
632
+ save_every_steps=save_every_steps,
633
+ )
634
+ except Exception as e:
635
+ return f"Training failed: {e}", None, None
636
+
637
+ link = None
638
+ if upload_repo:
639
+ try:
640
+ link = upload_adapter(out_dir, upload_repo)
641
+ except Exception as e:
642
+ link = f"Upload failed: {e}"
643
+
644
+ # quick test generation using first dataset prompt
645
+ try:
646
+ ds = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=5)
647
+ test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
648
+ except Exception:
649
+ test_prompt = "A cat on a skateboard"
650
+
651
+ test_img = None
652
+ try:
653
+ test_img = test_generation_load_and_run(base_model_id, out_dir, adapter_target, test_prompt, use_4bit=use_4bit)
654
+ except Exception as e:
655
+ print("Test gen failed:", e)
656
+
657
+ return "\n".join(logs[-200:]), test_img, link
658
+
659
+ def build_ui():
660
+ with gr.Blocks() as demo:
661
+ gr.Markdown("# Universal LoRA Trainer — Quantization & Speedups (single-file)")
662
+ with gr.Row():
663
+ with gr.Column(scale=2):
664
+ base_model = gr.Textbox(label="Base model id (Diffusers / ChronoEdit / Qwen)", value="runwayml/stable-diffusion-v1-5")
665
+ dataset_source = gr.Textbox(label="Dataset folder or HF dataset repo (username/repo)", value="./dataset")
666
+ csv_name = gr.Textbox(label="CSV/Parquet filename", value="dataset.csv")
667
+ task_type = gr.Dropdown(label="Task type", choices=["text-image", "text-video", "prompt-lora"], value="text-image")
668
+ adapter_target_override = gr.Textbox(label="Adapter target override (leave blank for auto)", value="")
669
+ lora_r = gr.Slider(1, 64, value=8, step=1, label="LoRA rank (r)")
670
+ lora_alpha = gr.Slider(1, 128, value=16, step=1, label="LoRA alpha")
671
+ epochs = gr.Number(label="Epochs", value=1)
672
+ batch_size = gr.Number(label="Batch size (per device)", value=1)
673
+ lr = gr.Number(label="Learning rate", value=1e-4)
674
+ max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0)
675
+ save_every_steps = gr.Number(label="Save every steps", value=200)
676
+ output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
677
+ upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional, username/repo)", value="")
678
+ with gr.Column(scale=1):
679
+ gr.Markdown("## Speed / Quantization")
680
+ use_4bit = gr.Checkbox(label="Enable 4-bit QLoRA (bitsandbytes)", value=False)
681
+ enable_xformers = gr.Checkbox(label="Enable xFormers / memory efficient attention", value=False)
682
+ use_adalora = gr.Checkbox(label="Use AdaLoRA (if available in peft)", value=False)
683
+ grad_accum = gr.Number(label="Gradient accumulation steps", value=1)
684
+ mixed_precision = gr.Radio(choices=["none", "fp16", "bf16"], value=("fp16" if torch.cuda.is_available() else "none"), label="Mixed precision")
685
+ gr.Markdown("### Boost Info")
686
+ boost_info = gr.Textbox(label="Expected boost / notes", value="", lines=6)
687
+ start_btn = gr.Button("Start Training")
688
+ with gr.Row():
689
+ logs = gr.Textbox(label="Training logs (tail)", lines=18)
690
+ sample_image = gr.Image(label="Sample generated frame after training")
691
+ upload_link = gr.Textbox(label="Upload link / status")
692
+ def on_start(base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, grad_accum_val, mixed_precision_val, save_every_steps):
693
+ boost_text = boost_info_text(use_4bit_val, enable_xformers_val, mixed_precision_val, "gpu" if torch.cuda.is_available() else "cpu")
694
+ # start training (blocking)
695
+ logs_out, sample, link = run_all_ui(base_model, dataset_source, csv_name, task_type, adapter_target_override, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, int(grad_accum_val), mixed_precision_val, int(save_every_steps))
696
+ return boost_text + "\n\n" + logs_out, sample, link
697
+ start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit, enable_xformers, use_adalora, grad_accum, mixed_precision, save_every_steps], outputs=[boost_info, sample_image, upload_link])
698
+ return demo
699
+
700
+ if __name__ == "__main__":
701
+ ui = build_ui()
702
+ ui.launch(server_name="0.0.0.0", server_port=7860)