Spaces:
Running
Running
| # universal_lora_trainer_quant_dynamic.py | |
| """ | |
| Universal Dynamic LoRA Trainer (Accelerate + PEFT) with optional QLoRA 4-bit support. | |
| - Supports CSV and Parquet dataset files (columns: file_name, text) | |
| - Accepts dataset from a local folder or Hugging Face dataset repo id (username/repo) | |
| - Real LoRA training (PEFT) for: | |
| * text->image (UNet) | |
| * text->video (ChronoEdit transformer) | |
| * prompt-enhancer (text_encoder / QwenEdit) | |
| - Optional: | |
| * 4-bit quantization (bitsandbytes / QLoRA) | |
| * xFormers / FlashAttention | |
| * AdaLoRA (if available) | |
| - Uses HF_TOKEN from environment for upload | |
| - Use `accelerate launch` for multi-GPU / optimized run | |
| """ | |
| import os | |
| import math | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional, Tuple, List | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchvision | |
| import torchvision.transforms as T | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import create_repo, upload_folder, hf_hub_download, list_repo_files | |
| from diffusers import DiffusionPipeline | |
| # optional pip installs - guard imports | |
| try: | |
| from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline | |
| CHRONOEDIT_AVAILABLE = True | |
| except Exception: | |
| CHRONOEDIT_AVAILABLE = False | |
| # Qwen image edit optional | |
| try: | |
| from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline # optional name | |
| QWENEDIT_AVAILABLE = True | |
| except Exception: | |
| QWENEDIT_AVAILABLE = False | |
| # BitsAndBytes (quantization) | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| BNB_AVAILABLE = True | |
| except Exception: | |
| BitsAndBytesConfig = None | |
| BNB_AVAILABLE = False | |
| # xFormers | |
| try: | |
| import xformers # noqa | |
| XFORMERS_AVAILABLE = True | |
| except Exception: | |
| XFORMERS_AVAILABLE = False | |
| # PEFT / AdaLoRA | |
| try: | |
| from peft import LoraConfig, get_peft_model | |
| try: | |
| from peft import AdaLoraConfig # optional | |
| ADALORA_AVAILABLE = True | |
| except Exception: | |
| AdaLoraConfig = None | |
| ADALORA_AVAILABLE = False | |
| except Exception as e: | |
| raise RuntimeError("Install peft: pip install peft") from e | |
| # Accelerate | |
| try: | |
| from accelerate import Accelerator | |
| except Exception as e: | |
| raise RuntimeError("Install accelerate: pip install accelerate") from e | |
| # ------------------------ | |
| # Config | |
| # ------------------------ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} | |
| VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"} | |
| # ------------------------ | |
| # Utilities | |
| # ------------------------ | |
| def is_hub_repo_like(s: str) -> bool: | |
| return "/" in s and not Path(s).exists() | |
| def download_from_hf(repo_id: str, filename: str, token: Optional[str] = None, repo_type: str = "dataset") -> str: | |
| token = token or os.environ.get("HF_TOKEN") | |
| return hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token, repo_type=repo_type) | |
| def try_list_repo_files(repo_id: str, repo_type: str = "dataset", token: Optional[str] = None): | |
| token = token or os.environ.get("HF_TOKEN") | |
| try: | |
| return list_repo_files(repo_id, token=token, repo_type=repo_type) | |
| except Exception: | |
| return [] | |
| def find_target_modules(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")): | |
| names = [n for n, _ in model.named_modules()] | |
| selected = set() | |
| for cand in candidates: | |
| for n in names: | |
| if cand in n: | |
| selected.add(n.split(".")[-1]) | |
| if not selected: | |
| return ["to_q", "to_k", "to_v", "to_out"] | |
| return list(selected) | |
| # ------------------------ | |
| # Dataset class (CSV/Parquet) | |
| # ------------------------ | |
| class MediaTextDataset(Dataset): | |
| """ | |
| Loads records from CSV or Parquet with columns: | |
| - file_name (relative path in folder or filename inside HF dataset repo) | |
| - text | |
| """ | |
| def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5, | |
| image_size=(512,512), video_frame_size=(128,256), hub_token: Optional[str] = None): | |
| self.source = dataset_source | |
| self.is_hub = is_hub_repo_like(dataset_source) | |
| self.max_frames = max_frames | |
| self.image_size = image_size | |
| self.video_frame_size = video_frame_size | |
| self.hub_token = hub_token or os.environ.get("HF_TOKEN") | |
| # load dataframe (CSV or parquet) | |
| if self.is_hub: | |
| # try CSV then parquet; specify repo_type="dataset" | |
| searched = try_list_repo_files(self.source, repo_type="dataset", token=self.hub_token) | |
| # prefer exact csv_name | |
| try: | |
| csv_local = download_from_hf(self.source, csv_name, token=self.hub_token, repo_type="dataset") | |
| except Exception: | |
| # try .parquet variant | |
| alt = csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else csv_name + ".parquet" | |
| csv_local = download_from_hf(self.source, alt, token=self.hub_token, repo_type="dataset") | |
| if str(csv_local).endswith(".parquet"): | |
| df = pd.read_parquet(csv_local) | |
| else: | |
| df = pd.read_csv(csv_local) | |
| self.df = df | |
| self.root = None | |
| else: | |
| root = Path(dataset_source) | |
| csv_path = root / csv_name | |
| parquet_path = root / csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else root / (csv_name + ".parquet") | |
| if csv_path.exists(): | |
| self.df = pd.read_csv(csv_path) | |
| elif parquet_path.exists(): | |
| self.df = pd.read_parquet(parquet_path) | |
| else: | |
| p = root / csv_name | |
| if p.exists(): | |
| if p.suffix.lower() == ".parquet": | |
| self.df = pd.read_parquet(p) | |
| else: | |
| self.df = pd.read_csv(p) | |
| else: | |
| raise FileNotFoundError(f"Can't find {csv_name} in {dataset_source}") | |
| self.root = root | |
| # transforms | |
| self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)]) | |
| self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)]) | |
| def __len__(self): | |
| return len(self.df) | |
| def _maybe_download_from_hub(self, file_name: str) -> str: | |
| if self.root is not None: | |
| p = self.root / file_name | |
| if p.exists(): | |
| return str(p) | |
| # else download from dataset repo | |
| return download_from_hf(self.source, file_name, token=self.hub_token, repo_type="dataset") | |
| def _read_video_frames(self, path: str, num_frames: int): | |
| video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec') | |
| total = len(video_frames) | |
| if total == 0: | |
| C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1] | |
| return torch.zeros((num_frames, C, H, W), dtype=torch.float32) | |
| if total < num_frames: | |
| idxs = list(range(total)) + [total-1]*(num_frames-total) | |
| else: | |
| idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist() | |
| frames = [] | |
| for i in idxs: | |
| arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i]) | |
| frames.append(self.video_transform(arr)) | |
| frames = torch.stack(frames, dim=0) | |
| return frames | |
| def __getitem__(self, idx): | |
| rec = self.df.iloc[idx] | |
| file_name = rec["file_name"] | |
| caption = rec["text"] | |
| if self.is_hub: | |
| local_path = self._maybe_download_from_hub(file_name) | |
| else: | |
| local_path = str(Path(self.root) / file_name) | |
| p = Path(local_path) | |
| suffix = p.suffix.lower() | |
| if suffix in IMAGE_EXTS: | |
| img = torchvision.io.read_image(local_path) # [C,H,W] | |
| if isinstance(img, torch.Tensor): | |
| img = img.permute(1,2,0).numpy() | |
| return {'type': 'image', 'image': self.image_transform(img), 'caption': caption, 'file_name': file_name} | |
| elif suffix in VIDEO_EXTS: | |
| frames = self._read_video_frames(local_path, self.max_frames) # [T,C,H,W] | |
| return {'type': 'video', 'frames': frames, 'caption': caption, 'file_name': file_name} | |
| else: | |
| raise RuntimeError(f"Unsupported media type: {local_path}") | |
| # ------------------------ | |
| # Pipeline loader with optional quantization | |
| # ------------------------ | |
| def load_pipeline_auto(base_model_id: str, use_4bit: bool = False, bnb_config: Optional[object] = None, torch_dtype=torch.float16): | |
| low = base_model_id.lower() | |
| is_chrono = "chrono" in low or "wan" in low or "video" in low | |
| is_qwen = "qwen" in low or "qwenimage" in low | |
| # choose pipeline | |
| if is_chrono and CHRONOEDIT_AVAILABLE: | |
| print("Loading ChronoEdit pipeline") | |
| # ChronoEdit may not accept quant config; try with safer call | |
| if use_4bit and bnb_config is not None: | |
| pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) # quantized loading of chronoedit not widely supported | |
| else: | |
| pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype) | |
| elif is_qwen and QWENEDIT_AVAILABLE: | |
| print("Loading QWEN image-edit pipeline") | |
| pipe = QwenImageEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype) | |
| else: | |
| # fallback to DiffusionPipeline - supports quantization_config for diffusers+transformers | |
| print("Loading standard DiffusionPipeline:", base_model_id, "use_4bit=", use_4bit) | |
| if use_4bit and BNB_AVAILABLE and bnb_config is not None: | |
| pipe = DiffusionPipeline.from_pretrained(base_model_id, quantization_config=bnb_config, torch_dtype=torch.float16) | |
| else: | |
| pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype) | |
| return pipe | |
| # ------------------------ | |
| # Auto infer adapter target | |
| # ------------------------ | |
| def infer_target_for_task(task_type: str, model_name: str) -> str: | |
| low = model_name.lower() | |
| if task_type == "prompt-lora" or "qwen" in low or "qwenedit" in low: | |
| return "text_encoder" | |
| if task_type == "text-video" or "chrono" in low or "wan" in low: | |
| return "transformer" | |
| # default | |
| return "unet" | |
| # ------------------------ | |
| # LoRA attach (supports AdaLoRA if available) | |
| # ------------------------ | |
| def attach_lora(pipe, adapter_target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0, use_adalora: bool = False): | |
| if adapter_target == "unet": | |
| if not hasattr(pipe, "unet"): | |
| raise RuntimeError("Pipeline has no UNet to attach LoRA") | |
| target_module = pipe.unet | |
| attr = "unet" | |
| elif adapter_target == "transformer": | |
| if not hasattr(pipe, "transformer"): | |
| raise RuntimeError("Pipeline has no transformer to attach LoRA") | |
| target_module = pipe.transformer | |
| attr = "transformer" | |
| elif adapter_target == "text_encoder": | |
| if not hasattr(pipe, "text_encoder"): | |
| # some models name it differently; try encoder attribute fallback | |
| if hasattr(pipe, "text_encoder"): | |
| target_module = pipe.text_encoder | |
| attr = "text_encoder" | |
| else: | |
| raise RuntimeError("Pipeline has no text_encoder for prompt-loRA") | |
| else: | |
| target_module = pipe.text_encoder | |
| attr = "text_encoder" | |
| else: | |
| raise RuntimeError("Unknown adapter_target") | |
| target_modules = find_target_modules(target_module) | |
| print("Detected target_modules for LoRA:", target_modules) | |
| if use_adalora and ADALORA_AVAILABLE: | |
| lora_config = AdaLoraConfig( | |
| r=r, | |
| lora_alpha=alpha, | |
| target_modules=target_modules, | |
| init_r=4, | |
| lora_dropout=dropout, | |
| ) | |
| else: | |
| lora_config = LoraConfig( | |
| r=r, | |
| lora_alpha=alpha, | |
| target_modules=target_modules, | |
| lora_dropout=dropout, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM", | |
| ) | |
| peft_model = get_peft_model(target_module, lora_config) | |
| setattr(pipe, attr, peft_model) | |
| return pipe, attr | |
| # ------------------------ | |
| # Training loop (Accelerate-aware) | |
| # ------------------------ | |
| def train_lora_accelerate(base_model_id: str, | |
| dataset_source: str, | |
| csv_name: str, | |
| task_type: str, | |
| adapter_target_override: Optional[str], | |
| output_dir: str, | |
| epochs: int = 1, | |
| batch_size: int = 1, | |
| lr: float = 1e-4, | |
| max_train_steps: Optional[int] = None, | |
| lora_r: int = 8, | |
| lora_alpha: int = 16, | |
| use_4bit: bool = False, | |
| enable_xformers: bool = False, | |
| use_adalora: bool = False, | |
| gradient_accumulation_steps: int = 1, | |
| mixed_precision: Optional[str] = None, | |
| save_every_steps: int = 200, | |
| max_frames: int = 5): | |
| # Setup Accelerator | |
| accelerator = Accelerator(mixed_precision=mixed_precision or ("fp16" if torch.cuda.is_available() else "no")), | |
| # Note: Accelerator is returned as a tuple if trailing comma; fix: | |
| accelerator = accelerator if isinstance(accelerator, Accelerator) else accelerator[0] | |
| device = accelerator.device | |
| # prepare bitsandbytes config if requested | |
| bnb_conf = None | |
| if use_4bit and BNB_AVAILABLE: | |
| bnb_conf = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| # Load pipeline (supports quant for standard diffusers) | |
| pipe = load_pipeline_auto(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32) | |
| # optionally enable memory efficient attention | |
| if enable_xformers: | |
| try: | |
| if hasattr(pipe, "enable_xformers_memory_efficient_attention"): | |
| pipe.enable_xformers_memory_efficient_attention() | |
| elif hasattr(pipe, "enable_attention_slicing"): | |
| pipe.enable_attention_slicing() | |
| print("xFormers / memory efficient attention enabled.") | |
| except Exception as e: | |
| print("Could not enable xformers:", e) | |
| # infer adapter target automatically if not overridden | |
| adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task(task_type, base_model_id) | |
| print("Adapter target set to:", adapter_target) | |
| # attach LoRA | |
| pipe, attr = attach_lora(pipe, adapter_target, r=lora_r, alpha=lora_alpha, dropout=0.0, use_adalora=use_adalora) | |
| # pick the peft module for optimization | |
| peft_module = getattr(pipe, attr) | |
| # dataset + dataloader (we use batch_size=1 collate) | |
| dataset = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=max_frames) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x) | |
| # optimizer | |
| trainable_params = [p for n,p in peft_module.named_parameters() if p.requires_grad] | |
| optimizer = torch.optim.AdamW(trainable_params, lr=lr) | |
| # prepare objects with accelerator | |
| peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader) | |
| # training loop | |
| logs = [] | |
| global_step = 0 | |
| loss_fn = nn.MSELoss() | |
| # scheduler setup if available | |
| if hasattr(pipe, "scheduler"): | |
| try: | |
| pipe.scheduler.set_timesteps(50, device=device) | |
| timesteps = pipe.scheduler.timesteps | |
| except Exception: | |
| timesteps = None | |
| else: | |
| timesteps = None | |
| # Training | |
| for epoch in range(int(epochs)): | |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") | |
| for batch in pbar: | |
| example = batch[0] | |
| # image flow | |
| if example["type"] == "image": | |
| img = example["image"].unsqueeze(0).to(device) | |
| caption = [example["caption"]] | |
| if not hasattr(pipe, "encode_prompt"): | |
| raise RuntimeError("Pipeline lacks encode_prompt - cannot encode prompts") | |
| prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=caption, | |
| negative_prompt=None, | |
| do_classifier_free_guidance=True, | |
| num_videos_per_prompt=1, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| max_sequence_length=512, | |
| device=device, | |
| ) | |
| if not hasattr(pipe, "vae"): | |
| raise RuntimeError("Pipeline lacks VAE - required for latent conversion") | |
| with torch.no_grad(): | |
| latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor | |
| noise = torch.randn_like(latents).to(device) | |
| if timesteps is None: | |
| t = torch.tensor(1, device=device) | |
| else: | |
| t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device) | |
| noisy_latents = pipe.scheduler.add_noise(latents, noise, t) | |
| # forward through peft_module (unet) | |
| out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds) | |
| if hasattr(out, "sample"): | |
| noise_pred = out.sample | |
| elif isinstance(out, tuple): | |
| noise_pred = out[0] | |
| else: | |
| noise_pred = out | |
| loss = loss_fn(noise_pred, noise) | |
| else: | |
| # video flow (ChronoEdit simplified) | |
| if not CHRONOEDIT_AVAILABLE: | |
| raise RuntimeError("ChronoEdit training requested but not installed in environment") | |
| frames = example["frames"].unsqueeze(0).to(device) # [1, T, C, H, W] | |
| frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist() | |
| video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device) | |
| latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device, generator=None, latents=None, last_image=None) | |
| if pipe.config.expand_timesteps: | |
| latents, condition, first_frame_mask = latents_out | |
| else: | |
| latents, condition = latents_out | |
| first_frame_mask = None | |
| noise = torch.randn_like(latents).to(device) | |
| t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device) | |
| noisy_latents = pipe.scheduler.add_noise(latents, noise, t) | |
| if pipe.config.expand_timesteps: | |
| latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * noisy_latents | |
| else: | |
| latent_model_input = torch.cat([noisy_latents, condition], dim=1) | |
| out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]), encoder_hidden_states=None, encoder_hidden_states_image=None, return_dict=False) | |
| noise_pred = out[0] if isinstance(out, tuple) else out | |
| loss = loss_fn(noise_pred, noise) | |
| # backward and optimizer step (accelerator) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| logs.append(f"step {global_step} loss {loss.item():.6f}") | |
| pbar.set_postfix({"loss": f"{loss.item():.6f}"}) | |
| if max_train_steps and global_step >= max_train_steps: | |
| break | |
| if global_step % save_every_steps == 0: | |
| out_sub = Path(output_dir) / f"lora_step_{global_step}" | |
| out_sub.mkdir(parents=True, exist_ok=True) | |
| try: | |
| peft_module.save_pretrained(str(out_sub)) | |
| except Exception: | |
| torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt")) | |
| print(f"Saved adapter at {out_sub}") | |
| if max_train_steps and global_step >= max_train_steps: | |
| break | |
| # final save | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| try: | |
| peft_module.save_pretrained(output_dir) | |
| except Exception: | |
| torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt")) | |
| return output_dir, logs | |
| # ------------------------ | |
| # Test generation (best-effort) | |
| # ------------------------ | |
| def test_generation_load_and_run(base_model_id: str, adapter_dir: Optional[str], adapter_target: str, prompt: str, use_4bit: bool = False): | |
| # load base pipeline (no heavy quant config) | |
| bnb_conf = None | |
| if use_4bit and BNB_AVAILABLE: | |
| bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") | |
| pipe = load_pipeline_auto(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
| # attempt to load adapter into target module (best-effort) | |
| try: | |
| if adapter_target == "unet" and hasattr(pipe, "unet"): | |
| lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.unet)) | |
| pipe.unet = get_peft_model(pipe.unet, lcfg) | |
| try: | |
| pipe.unet.load_state_dict(torch.load(Path(adapter_dir) / "pytorch_model.bin"), strict=False) | |
| except Exception: | |
| try: | |
| pipe.unet.load_adapter(adapter_dir) | |
| except Exception: | |
| pass | |
| elif adapter_target == "transformer" and hasattr(pipe, "transformer"): | |
| lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.transformer)) | |
| pipe.transformer = get_peft_model(pipe.transformer, lcfg) | |
| elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"): | |
| lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.text_encoder)) | |
| pipe.text_encoder = get_peft_model(pipe.text_encoder, lcfg) | |
| except Exception as e: | |
| print("Adapter load warning", e) | |
| pipe.to(DEVICE) | |
| out = pipe(prompt=prompt, num_inference_steps=8) | |
| if hasattr(out, "images"): | |
| return out.images[0] | |
| elif hasattr(out, "frames"): | |
| frames = out.frames[0] | |
| from PIL import Image | |
| return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8")) | |
| raise RuntimeError("No images/frames returned") | |
| # ------------------------ | |
| # Upload adapter to HF Hub | |
| # ------------------------ | |
| def upload_adapter(local_dir: str, repo_id: str) -> str: | |
| token = os.environ.get("HF_TOKEN") | |
| if token is None: | |
| raise RuntimeError("HF_TOKEN not set in environment for upload") | |
| create_repo(repo_id, exist_ok=True) | |
| upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token) | |
| return f"https://huggingface.co/{repo_id}" | |
| # ------------------------ | |
| # UI: Boost info helper | |
| # ------------------------ | |
| def boost_info_text(use_4bit: bool, enable_xformers: bool, mixed_precision: Optional[str], device_type: str): | |
| lines = [] | |
| lines.append(f"Device: {device_type.upper()}") | |
| if use_4bit and BNB_AVAILABLE: | |
| lines.append("4-bit QLoRA enabled: ~4x memory saving (bitsandbytes NF4 + double quant).") | |
| else: | |
| lines.append("QLoRA disabled or bitsandbytes not installed.") | |
| if enable_xformers and XFORMERS_AVAILABLE: | |
| lines.append("xFormers/FlashAttention: memory-efficient attention enabled (faster & lower mem).") | |
| else: | |
| lines.append("xFormers not enabled or not installed.") | |
| if mixed_precision: | |
| lines.append(f"Mixed precision: {mixed_precision}") | |
| else: | |
| lines.append("Mixed precision: default (no automatic FP16/BF16).") | |
| lines.append("Expected: GPU + 4-bit + xFormers = fastest. CPU + 4-bit possible but slow.") | |
| return "\n".join(lines) | |
| # ------------------------ | |
| # Gradio UI wiring | |
| # ------------------------ | |
| def run_all_ui(base_model_id: str, | |
| dataset_source: str, | |
| csv_name: str, | |
| task_type: str, | |
| adapter_target_override: str, | |
| lora_r: int, | |
| lora_alpha: int, | |
| epochs: int, | |
| batch_size: int, | |
| lr: float, | |
| max_train_steps: int, | |
| output_dir: str, | |
| upload_repo: str, | |
| use_4bit: bool, | |
| enable_xformers: bool, | |
| use_adalora: bool, | |
| grad_accum: int, | |
| mixed_precision: str, | |
| save_every_steps: int): | |
| # map task_type -> adapter_target if override empty | |
| adapter_target = adapter_target_override or infer_target_for_task(task_type, base_model_id) | |
| try: | |
| out_dir, logs = train_lora_accelerate( | |
| base_model_id, | |
| dataset_source, | |
| csv_name, | |
| task_type, | |
| adapter_target, | |
| output_dir, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| max_train_steps=(max_train_steps if max_train_steps>0 else None), | |
| lora_r=lora_r, | |
| lora_alpha=lora_alpha, | |
| use_4bit=use_4bit, | |
| enable_xformers=enable_xformers, | |
| use_adalora=use_adalora, | |
| gradient_accumulation_steps=grad_accum, | |
| mixed_precision=(mixed_precision if mixed_precision != "none" else None), | |
| save_every_steps=save_every_steps, | |
| ) | |
| except Exception as e: | |
| return f"Training failed: {e}", None, None | |
| link = None | |
| if upload_repo: | |
| try: | |
| link = upload_adapter(out_dir, upload_repo) | |
| except Exception as e: | |
| link = f"Upload failed: {e}" | |
| # quick test generation using first dataset prompt | |
| try: | |
| ds = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=5) | |
| test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard" | |
| except Exception: | |
| test_prompt = "A cat on a skateboard" | |
| test_img = None | |
| try: | |
| test_img = test_generation_load_and_run(base_model_id, out_dir, adapter_target, test_prompt, use_4bit=use_4bit) | |
| except Exception as e: | |
| print("Test gen failed:", e) | |
| return "\n".join(logs[-200:]), test_img, link | |
| def build_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Universal LoRA Trainer — Quantization & Speedups (single-file)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| base_model = gr.Textbox(label="Base model id (Diffusers / ChronoEdit / Qwen)", value="runwayml/stable-diffusion-v1-5") | |
| dataset_source = gr.Textbox(label="Dataset folder or HF dataset repo (username/repo)", value="./dataset") | |
| csv_name = gr.Textbox(label="CSV/Parquet filename", value="dataset.csv") | |
| task_type = gr.Dropdown(label="Task type", choices=["text-image", "text-video", "prompt-lora"], value="text-image") | |
| adapter_target_override = gr.Textbox(label="Adapter target override (leave blank for auto)", value="") | |
| lora_r = gr.Slider(1, 64, value=8, step=1, label="LoRA rank (r)") | |
| lora_alpha = gr.Slider(1, 128, value=16, step=1, label="LoRA alpha") | |
| epochs = gr.Number(label="Epochs", value=1) | |
| batch_size = gr.Number(label="Batch size (per device)", value=1) | |
| lr = gr.Number(label="Learning rate", value=1e-4) | |
| max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0) | |
| save_every_steps = gr.Number(label="Save every steps", value=200) | |
| output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out") | |
| upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional, username/repo)", value="") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Speed / Quantization") | |
| use_4bit = gr.Checkbox(label="Enable 4-bit QLoRA (bitsandbytes)", value=False) | |
| enable_xformers = gr.Checkbox(label="Enable xFormers / memory efficient attention", value=False) | |
| use_adalora = gr.Checkbox(label="Use AdaLoRA (if available in peft)", value=False) | |
| grad_accum = gr.Number(label="Gradient accumulation steps", value=1) | |
| mixed_precision = gr.Radio(choices=["none", "fp16", "bf16"], value=("fp16" if torch.cuda.is_available() else "none"), label="Mixed precision") | |
| gr.Markdown("### Boost Info") | |
| boost_info = gr.Textbox(label="Expected boost / notes", value="", lines=6) | |
| start_btn = gr.Button("Start Training") | |
| with gr.Row(): | |
| logs = gr.Textbox(label="Training logs (tail)", lines=18) | |
| sample_image = gr.Image(label="Sample generated frame after training") | |
| upload_link = gr.Textbox(label="Upload link / status") | |
| def on_start(base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, grad_accum_val, mixed_precision_val, save_every_steps): | |
| boost_text = boost_info_text(use_4bit_val, enable_xformers_val, mixed_precision_val, "gpu" if torch.cuda.is_available() else "cpu") | |
| # start training (blocking) | |
| logs_out, sample, link = run_all_ui(base_model, dataset_source, csv_name, task_type, adapter_target_override, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, int(grad_accum_val), mixed_precision_val, int(save_every_steps)) | |
| return boost_text + "\n\n" + logs_out, sample, link | |
| start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit, enable_xformers, use_adalora, grad_accum, mixed_precision, save_every_steps], outputs=[boost_info, sample_image, upload_link]) | |
| return demo | |
| if __name__ == "__main__": | |
| ui = build_ui() | |
| ui.launch(server_name="0.0.0.0", server_port=7860) | |