Train-Lora / app.py
rahul7star's picture
Update app.py
d58b801 verified
raw
history blame
8.49 kB
# universal_lora_trainer_accelerate_singlefile_dynamic.py
"""
Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
- Auto-detects base model type (Flux, SD, ChronoEdit, QwenEdit, etc.)
- Auto-selects correct adapter target (unet, transformer, text_encoder)
- Supports CSV and Parquet datasets
- Uploads adapter to HF Hub using HF_TOKEN (env only)
"""
import os, torch, gradio as gr, pandas as pd, numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from huggingface_hub import create_repo, upload_folder, hf_hub_download
from diffusers import DiffusionPipeline
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T, torchvision
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
import torch.nn as nn
# Optional: ChronoEdit + QwenEdit
try:
from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
CHRONOEDIT_AVAILABLE = True
except Exception:
CHRONOEDIT_AVAILABLE = False
try:
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
QWENEDIT_AVAILABLE = True
except Exception:
QWENEDIT_AVAILABLE = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
def is_hub_repo_like(s): return "/" in s and not Path(s).exists()
def download_from_hf(repo_id, filename, token=None):
token = token or os.environ.get("HF_TOKEN")
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
# ---------------- Dataset ----------------
class MediaTextDataset(Dataset):
def __init__(self, source, csv_name="dataset.csv", max_frames=5):
self.is_hub = is_hub_repo_like(source)
self.source = source
token = os.environ.get("HF_TOKEN")
if self.is_hub:
file_path = download_from_hf(source, csv_name, token)
else:
file_path = Path(source) / csv_name
if not Path(file_path).exists():
alt = Path(str(file_path).replace(".csv", ".parquet"))
if alt.exists(): file_path = alt
self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
self.root = Path(source) if not self.is_hub else None
self.img_tf = T.Compose([T.ToPILImage(), T.Resize((512,512)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
self.video_tf = T.Compose([T.ToPILImage(), T.Resize((128,256)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
self.max_frames = max_frames
def __len__(self): return len(self.df)
def _maybe_dl(self, fname): return str(Path(self.root)/fname) if self.root else download_from_hf(self.source, fname)
def __getitem__(self, i):
rec = self.df.iloc[i]
p = Path(self._maybe_dl(rec["file_name"]))
if p.suffix.lower() in IMAGE_EXTS:
img = torchvision.io.read_image(str(p))
if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy()
return {"type": "image", "image": self.img_tf(img), "caption": rec["text"]}
elif p.suffix.lower() in VIDEO_EXTS:
vid,_,_ = torchvision.io.read_video(str(p))
total, idxs = len(vid), []
if total == 0: return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))}
if total < self.max_frames: idxs = list(range(total))+[total-1]*(self.max_frames-total)
else: idxs = np.linspace(0,total-1,self.max_frames).round().astype(int)
frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs])
return {"type": "video", "frames": frames, "caption": rec["text"]}
else: raise RuntimeError(f"Unsupported file {p}")
# ---------------- Dynamic pipeline loader ----------------
def load_pipeline_auto(base_model, dtype=torch.float16):
low = base_model.lower()
if "chrono" in low and CHRONOEDIT_AVAILABLE:
print(f"Using ChronoEdit pipeline for {base_model}")
return ChronoEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
elif "qwen" in low and QWENEDIT_AVAILABLE:
print(f"Using QwenEdit pipeline for {base_model}")
return QwenImageEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
else:
print(f"Using Diffusion pipeline for {base_model}")
return DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
def infer_target_for_task(task_type, model_name):
if task_type == "prompt-lora" or "qwen" in model_name.lower():
return "text_encoder"
elif task_type == "text-video" or "chrono" in model_name.lower() or "wan" in model_name.lower():
return "transformer"
else:
return "unet"
def find_target_modules(model):
names = [n for n,_ in model.named_modules()]
targets = [n.split(".")[-1] for n in names if any(k in n for k in ["to_q","to_k","to_v","q_proj","v_proj"])]
return targets or ["to_q","to_k","to_v","to_out"]
# ---------------- Training ----------------
def train_lora(base_model, dataset_src, csv_name, task_type, output_dir, epochs=1, lr=1e-4, r=8, alpha=16):
accelerator = Accelerator()
pipe = load_pipeline_auto(base_model)
target = infer_target_for_task(task_type, base_model)
if not hasattr(pipe, target): raise RuntimeError(f"Pipeline has no {target}")
lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=find_target_modules(getattr(pipe, target)), lora_dropout=0.0)
lora_module = get_peft_model(getattr(pipe, target), lcfg)
dataset = MediaTextDataset(dataset_src, csv_name)
loader = DataLoader(dataset, batch_size=1, shuffle=True)
lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader)
mse = nn.MSELoss(); logs=[]
for ep in range(epochs):
for i,b in enumerate(tqdm(loader, desc=f"Epoch {ep+1}")):
ex = b[0]; loss=torch.tensor(0.0, device=DEVICE)
if ex["type"]=="image" and hasattr(pipe,"vae"):
img=ex["image"].unsqueeze(0).to(DEVICE)
lat=pipe.vae.encode(img).latent_dist.sample()*pipe.vae.config.scaling_factor
noise=torch.randn_like(lat); loss=mse(lat,noise)
accelerator.backward(loss); opt.step(); opt.zero_grad()
logs.append(f"step {i} loss {loss.item():.4f}")
Path(output_dir).mkdir(exist_ok=True)
lora_module.save_pretrained(output_dir)
return output_dir, logs[-20:]
# ---------------- Upload ----------------
def upload_adapter(local, repo_id):
token=os.environ.get("HF_TOKEN")
if not token: raise RuntimeError("HF_TOKEN missing")
create_repo(repo_id, exist_ok=True)
upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
return f"https://huggingface.co/{repo_id}"
# ---------------- Gradio UI ----------------
def run_ui():
with gr.Blocks() as demo:
gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Flux / ChronoEdit / QwenEdit)")
with gr.Row():
base_model=gr.Textbox(label="Base model", value="black-forest-labs/FLUX.1-dev")
dataset=gr.Textbox(label="Dataset folder or HF repo", value="./dataset")
csvname=gr.Textbox(label="CSV/Parquet file", value="dataset.csv")
task=gr.Dropdown(["text-image","text-video","prompt-lora"], label="Task type", value="text-image")
out=gr.Textbox(label="Output dir", value="./adapter_out")
repo=gr.Textbox(label="Upload HF repo (optional)", value="")
with gr.Row():
r=gr.Slider(1,64,value=8,label="LoRA rank"); a=gr.Slider(1,64,value=16,label="LoRA alpha")
ep=gr.Number(value=1,label="Epochs"); lr=gr.Number(value=1e-4,label="Learning rate")
btn=gr.Button("πŸš€ Start Training")
logs=gr.Textbox(label="Logs", lines=12)
img=gr.Image(label="Sample Output (optional)")
def launch(bm,ds,csv,t,out_dir,r_,a_,ep_,lr_,repo_):
try:
out,log=train_lora(bm,ds,csv,t,out_dir,int(ep_),float(lr_),int(r_),int(a_))
link=upload_adapter(out,repo_) if repo_ else None
return "\n".join(log), None, link
except Exception as e:
return f"❌ {e}", None, None
btn.click(launch,[base_model,dataset,csvname,task,out,r,a,ep,lr,repo],[logs,img,gr.Textbox()])
return demo
if __name__=="__main__":
run_ui().launch(server_name="0.0.0.0",server_port=7860)