Spaces:
Running
Running
| # universal_lora_trainer_accelerate_singlefile_dynamic.py | |
| """ | |
| Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio) | |
| - Auto-detects base model type (Flux, SD, ChronoEdit, QwenEdit, etc.) | |
| - Auto-selects correct adapter target (unet, transformer, text_encoder) | |
| - Supports CSV and Parquet datasets | |
| - Uploads adapter to HF Hub using HF_TOKEN (env only) | |
| """ | |
| import os, torch, gradio as gr, pandas as pd, numpy as np | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import create_repo, upload_folder, hf_hub_download | |
| from diffusers import DiffusionPipeline | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchvision.transforms as T, torchvision | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate import Accelerator | |
| import torch.nn as nn | |
| # Optional: ChronoEdit + QwenEdit | |
| try: | |
| from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline | |
| CHRONOEDIT_AVAILABLE = True | |
| except Exception: | |
| CHRONOEDIT_AVAILABLE = False | |
| try: | |
| from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline | |
| QWENEDIT_AVAILABLE = True | |
| except Exception: | |
| QWENEDIT_AVAILABLE = False | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"} | |
| VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"} | |
| def is_hub_repo_like(s): return "/" in s and not Path(s).exists() | |
| def download_from_hf(repo_id, filename, token=None): | |
| token = token or os.environ.get("HF_TOKEN") | |
| return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token) | |
| # ---------------- Dataset ---------------- | |
| class MediaTextDataset(Dataset): | |
| def __init__(self, source, csv_name="dataset.csv", max_frames=5): | |
| self.is_hub = is_hub_repo_like(source) | |
| self.source = source | |
| token = os.environ.get("HF_TOKEN") | |
| if self.is_hub: | |
| file_path = download_from_hf(source, csv_name, token) | |
| else: | |
| file_path = Path(source) / csv_name | |
| if not Path(file_path).exists(): | |
| alt = Path(str(file_path).replace(".csv", ".parquet")) | |
| if alt.exists(): file_path = alt | |
| self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path) | |
| self.root = Path(source) if not self.is_hub else None | |
| self.img_tf = T.Compose([T.ToPILImage(), T.Resize((512,512)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)]) | |
| self.video_tf = T.Compose([T.ToPILImage(), T.Resize((128,256)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)]) | |
| self.max_frames = max_frames | |
| def __len__(self): return len(self.df) | |
| def _maybe_dl(self, fname): return str(Path(self.root)/fname) if self.root else download_from_hf(self.source, fname) | |
| def __getitem__(self, i): | |
| rec = self.df.iloc[i] | |
| p = Path(self._maybe_dl(rec["file_name"])) | |
| if p.suffix.lower() in IMAGE_EXTS: | |
| img = torchvision.io.read_image(str(p)) | |
| if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy() | |
| return {"type": "image", "image": self.img_tf(img), "caption": rec["text"]} | |
| elif p.suffix.lower() in VIDEO_EXTS: | |
| vid,_,_ = torchvision.io.read_video(str(p)) | |
| total, idxs = len(vid), [] | |
| if total == 0: return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))} | |
| if total < self.max_frames: idxs = list(range(total))+[total-1]*(self.max_frames-total) | |
| else: idxs = np.linspace(0,total-1,self.max_frames).round().astype(int) | |
| frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs]) | |
| return {"type": "video", "frames": frames, "caption": rec["text"]} | |
| else: raise RuntimeError(f"Unsupported file {p}") | |
| # ---------------- Dynamic pipeline loader ---------------- | |
| def load_pipeline_auto(base_model, dtype=torch.float16): | |
| low = base_model.lower() | |
| if "chrono" in low and CHRONOEDIT_AVAILABLE: | |
| print(f"Using ChronoEdit pipeline for {base_model}") | |
| return ChronoEditPipeline.from_pretrained(base_model, torch_dtype=dtype) | |
| elif "qwen" in low and QWENEDIT_AVAILABLE: | |
| print(f"Using QwenEdit pipeline for {base_model}") | |
| return QwenImageEditPipeline.from_pretrained(base_model, torch_dtype=dtype) | |
| else: | |
| print(f"Using Diffusion pipeline for {base_model}") | |
| return DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype) | |
| def infer_target_for_task(task_type, model_name): | |
| if task_type == "prompt-lora" or "qwen" in model_name.lower(): | |
| return "text_encoder" | |
| elif task_type == "text-video" or "chrono" in model_name.lower() or "wan" in model_name.lower(): | |
| return "transformer" | |
| else: | |
| return "unet" | |
| def find_target_modules(model): | |
| names = [n for n,_ in model.named_modules()] | |
| targets = [n.split(".")[-1] for n in names if any(k in n for k in ["to_q","to_k","to_v","q_proj","v_proj"])] | |
| return targets or ["to_q","to_k","to_v","to_out"] | |
| # ---------------- Training ---------------- | |
| def train_lora(base_model, dataset_src, csv_name, task_type, output_dir, epochs=1, lr=1e-4, r=8, alpha=16): | |
| accelerator = Accelerator() | |
| pipe = load_pipeline_auto(base_model) | |
| target = infer_target_for_task(task_type, base_model) | |
| if not hasattr(pipe, target): raise RuntimeError(f"Pipeline has no {target}") | |
| lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=find_target_modules(getattr(pipe, target)), lora_dropout=0.0) | |
| lora_module = get_peft_model(getattr(pipe, target), lcfg) | |
| dataset = MediaTextDataset(dataset_src, csv_name) | |
| loader = DataLoader(dataset, batch_size=1, shuffle=True) | |
| lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader) | |
| mse = nn.MSELoss(); logs=[] | |
| for ep in range(epochs): | |
| for i,b in enumerate(tqdm(loader, desc=f"Epoch {ep+1}")): | |
| ex = b[0]; loss=torch.tensor(0.0, device=DEVICE) | |
| if ex["type"]=="image" and hasattr(pipe,"vae"): | |
| img=ex["image"].unsqueeze(0).to(DEVICE) | |
| lat=pipe.vae.encode(img).latent_dist.sample()*pipe.vae.config.scaling_factor | |
| noise=torch.randn_like(lat); loss=mse(lat,noise) | |
| accelerator.backward(loss); opt.step(); opt.zero_grad() | |
| logs.append(f"step {i} loss {loss.item():.4f}") | |
| Path(output_dir).mkdir(exist_ok=True) | |
| lora_module.save_pretrained(output_dir) | |
| return output_dir, logs[-20:] | |
| # ---------------- Upload ---------------- | |
| def upload_adapter(local, repo_id): | |
| token=os.environ.get("HF_TOKEN") | |
| if not token: raise RuntimeError("HF_TOKEN missing") | |
| create_repo(repo_id, exist_ok=True) | |
| upload_folder(local, repo_id=repo_id, repo_type="model", token=token) | |
| return f"https://huggingface.co/{repo_id}" | |
| # ---------------- Gradio UI ---------------- | |
| def run_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π Universal Dynamic LoRA Trainer (Flux / ChronoEdit / QwenEdit)") | |
| with gr.Row(): | |
| base_model=gr.Textbox(label="Base model", value="black-forest-labs/FLUX.1-dev") | |
| dataset=gr.Textbox(label="Dataset folder or HF repo", value="./dataset") | |
| csvname=gr.Textbox(label="CSV/Parquet file", value="dataset.csv") | |
| task=gr.Dropdown(["text-image","text-video","prompt-lora"], label="Task type", value="text-image") | |
| out=gr.Textbox(label="Output dir", value="./adapter_out") | |
| repo=gr.Textbox(label="Upload HF repo (optional)", value="") | |
| with gr.Row(): | |
| r=gr.Slider(1,64,value=8,label="LoRA rank"); a=gr.Slider(1,64,value=16,label="LoRA alpha") | |
| ep=gr.Number(value=1,label="Epochs"); lr=gr.Number(value=1e-4,label="Learning rate") | |
| btn=gr.Button("π Start Training") | |
| logs=gr.Textbox(label="Logs", lines=12) | |
| img=gr.Image(label="Sample Output (optional)") | |
| def launch(bm,ds,csv,t,out_dir,r_,a_,ep_,lr_,repo_): | |
| try: | |
| out,log=train_lora(bm,ds,csv,t,out_dir,int(ep_),float(lr_),int(r_),int(a_)) | |
| link=upload_adapter(out,repo_) if repo_ else None | |
| return "\n".join(log), None, link | |
| except Exception as e: | |
| return f"β {e}", None, None | |
| btn.click(launch,[base_model,dataset,csvname,task,out,r,a,ep,lr,repo],[logs,img,gr.Textbox()]) | |
| return demo | |
| if __name__=="__main__": | |
| run_ui().launch(server_name="0.0.0.0",server_port=7860) | |