Spaces:
Sleeping
Sleeping
File size: 8,488 Bytes
d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec cc128a0 a418f82 d58b801 a418f82 d58b801 e94e7ec cc128a0 d58b801 cc128a0 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 a2dc5f7 d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 a418f82 d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 e94e7ec d58b801 a418f82 d58b801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# universal_lora_trainer_accelerate_singlefile_dynamic.py
"""
Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
- Auto-detects base model type (Flux, SD, ChronoEdit, QwenEdit, etc.)
- Auto-selects correct adapter target (unet, transformer, text_encoder)
- Supports CSV and Parquet datasets
- Uploads adapter to HF Hub using HF_TOKEN (env only)
"""
import os, torch, gradio as gr, pandas as pd, numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from huggingface_hub import create_repo, upload_folder, hf_hub_download
from diffusers import DiffusionPipeline
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T, torchvision
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
import torch.nn as nn
# Optional: ChronoEdit + QwenEdit
try:
from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
CHRONOEDIT_AVAILABLE = True
except Exception:
CHRONOEDIT_AVAILABLE = False
try:
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
QWENEDIT_AVAILABLE = True
except Exception:
QWENEDIT_AVAILABLE = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
def is_hub_repo_like(s): return "/" in s and not Path(s).exists()
def download_from_hf(repo_id, filename, token=None):
token = token or os.environ.get("HF_TOKEN")
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
# ---------------- Dataset ----------------
class MediaTextDataset(Dataset):
def __init__(self, source, csv_name="dataset.csv", max_frames=5):
self.is_hub = is_hub_repo_like(source)
self.source = source
token = os.environ.get("HF_TOKEN")
if self.is_hub:
file_path = download_from_hf(source, csv_name, token)
else:
file_path = Path(source) / csv_name
if not Path(file_path).exists():
alt = Path(str(file_path).replace(".csv", ".parquet"))
if alt.exists(): file_path = alt
self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
self.root = Path(source) if not self.is_hub else None
self.img_tf = T.Compose([T.ToPILImage(), T.Resize((512,512)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
self.video_tf = T.Compose([T.ToPILImage(), T.Resize((128,256)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
self.max_frames = max_frames
def __len__(self): return len(self.df)
def _maybe_dl(self, fname): return str(Path(self.root)/fname) if self.root else download_from_hf(self.source, fname)
def __getitem__(self, i):
rec = self.df.iloc[i]
p = Path(self._maybe_dl(rec["file_name"]))
if p.suffix.lower() in IMAGE_EXTS:
img = torchvision.io.read_image(str(p))
if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy()
return {"type": "image", "image": self.img_tf(img), "caption": rec["text"]}
elif p.suffix.lower() in VIDEO_EXTS:
vid,_,_ = torchvision.io.read_video(str(p))
total, idxs = len(vid), []
if total == 0: return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))}
if total < self.max_frames: idxs = list(range(total))+[total-1]*(self.max_frames-total)
else: idxs = np.linspace(0,total-1,self.max_frames).round().astype(int)
frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs])
return {"type": "video", "frames": frames, "caption": rec["text"]}
else: raise RuntimeError(f"Unsupported file {p}")
# ---------------- Dynamic pipeline loader ----------------
def load_pipeline_auto(base_model, dtype=torch.float16):
low = base_model.lower()
if "chrono" in low and CHRONOEDIT_AVAILABLE:
print(f"Using ChronoEdit pipeline for {base_model}")
return ChronoEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
elif "qwen" in low and QWENEDIT_AVAILABLE:
print(f"Using QwenEdit pipeline for {base_model}")
return QwenImageEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
else:
print(f"Using Diffusion pipeline for {base_model}")
return DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
def infer_target_for_task(task_type, model_name):
if task_type == "prompt-lora" or "qwen" in model_name.lower():
return "text_encoder"
elif task_type == "text-video" or "chrono" in model_name.lower() or "wan" in model_name.lower():
return "transformer"
else:
return "unet"
def find_target_modules(model):
names = [n for n,_ in model.named_modules()]
targets = [n.split(".")[-1] for n in names if any(k in n for k in ["to_q","to_k","to_v","q_proj","v_proj"])]
return targets or ["to_q","to_k","to_v","to_out"]
# ---------------- Training ----------------
def train_lora(base_model, dataset_src, csv_name, task_type, output_dir, epochs=1, lr=1e-4, r=8, alpha=16):
accelerator = Accelerator()
pipe = load_pipeline_auto(base_model)
target = infer_target_for_task(task_type, base_model)
if not hasattr(pipe, target): raise RuntimeError(f"Pipeline has no {target}")
lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=find_target_modules(getattr(pipe, target)), lora_dropout=0.0)
lora_module = get_peft_model(getattr(pipe, target), lcfg)
dataset = MediaTextDataset(dataset_src, csv_name)
loader = DataLoader(dataset, batch_size=1, shuffle=True)
lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader)
mse = nn.MSELoss(); logs=[]
for ep in range(epochs):
for i,b in enumerate(tqdm(loader, desc=f"Epoch {ep+1}")):
ex = b[0]; loss=torch.tensor(0.0, device=DEVICE)
if ex["type"]=="image" and hasattr(pipe,"vae"):
img=ex["image"].unsqueeze(0).to(DEVICE)
lat=pipe.vae.encode(img).latent_dist.sample()*pipe.vae.config.scaling_factor
noise=torch.randn_like(lat); loss=mse(lat,noise)
accelerator.backward(loss); opt.step(); opt.zero_grad()
logs.append(f"step {i} loss {loss.item():.4f}")
Path(output_dir).mkdir(exist_ok=True)
lora_module.save_pretrained(output_dir)
return output_dir, logs[-20:]
# ---------------- Upload ----------------
def upload_adapter(local, repo_id):
token=os.environ.get("HF_TOKEN")
if not token: raise RuntimeError("HF_TOKEN missing")
create_repo(repo_id, exist_ok=True)
upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
return f"https://huggingface.co/{repo_id}"
# ---------------- Gradio UI ----------------
def run_ui():
with gr.Blocks() as demo:
gr.Markdown("# π Universal Dynamic LoRA Trainer (Flux / ChronoEdit / QwenEdit)")
with gr.Row():
base_model=gr.Textbox(label="Base model", value="black-forest-labs/FLUX.1-dev")
dataset=gr.Textbox(label="Dataset folder or HF repo", value="./dataset")
csvname=gr.Textbox(label="CSV/Parquet file", value="dataset.csv")
task=gr.Dropdown(["text-image","text-video","prompt-lora"], label="Task type", value="text-image")
out=gr.Textbox(label="Output dir", value="./adapter_out")
repo=gr.Textbox(label="Upload HF repo (optional)", value="")
with gr.Row():
r=gr.Slider(1,64,value=8,label="LoRA rank"); a=gr.Slider(1,64,value=16,label="LoRA alpha")
ep=gr.Number(value=1,label="Epochs"); lr=gr.Number(value=1e-4,label="Learning rate")
btn=gr.Button("π Start Training")
logs=gr.Textbox(label="Logs", lines=12)
img=gr.Image(label="Sample Output (optional)")
def launch(bm,ds,csv,t,out_dir,r_,a_,ep_,lr_,repo_):
try:
out,log=train_lora(bm,ds,csv,t,out_dir,int(ep_),float(lr_),int(r_),int(a_))
link=upload_adapter(out,repo_) if repo_ else None
return "\n".join(log), None, link
except Exception as e:
return f"β {e}", None, None
btn.click(launch,[base_model,dataset,csvname,task,out,r,a,ep,lr,repo],[logs,img,gr.Textbox()])
return demo
if __name__=="__main__":
run_ui().launch(server_name="0.0.0.0",server_port=7860)
|