File size: 8,488 Bytes
d58b801
e94e7ec
d58b801
 
 
 
 
e94e7ec
 
d58b801
e94e7ec
 
cc128a0
a418f82
d58b801
 
 
 
 
a418f82
d58b801
e94e7ec
 
 
 
 
 
cc128a0
d58b801
 
 
 
cc128a0
e94e7ec
d58b801
e94e7ec
 
d58b801
e94e7ec
d58b801
e94e7ec
d58b801
a2dc5f7
d58b801
e94e7ec
d58b801
 
 
 
e94e7ec
d58b801
e94e7ec
d58b801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a418f82
d58b801
 
 
 
 
 
 
 
e94e7ec
d58b801
e94e7ec
d58b801
 
 
 
e94e7ec
d58b801
 
e94e7ec
d58b801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e94e7ec
d58b801
e94e7ec
 
d58b801
 
e94e7ec
d58b801
 
 
 
 
 
 
 
e94e7ec
d58b801
 
 
 
 
 
 
 
 
 
 
 
 
a418f82
 
d58b801
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# universal_lora_trainer_accelerate_singlefile_dynamic.py
"""
Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
- Auto-detects base model type (Flux, SD, ChronoEdit, QwenEdit, etc.)
- Auto-selects correct adapter target (unet, transformer, text_encoder)
- Supports CSV and Parquet datasets
- Uploads adapter to HF Hub using HF_TOKEN (env only)
"""

import os, torch, gradio as gr, pandas as pd, numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from huggingface_hub import create_repo, upload_folder, hf_hub_download
from diffusers import DiffusionPipeline
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T, torchvision
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
import torch.nn as nn

# Optional: ChronoEdit + QwenEdit
try:
    from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
    CHRONOEDIT_AVAILABLE = True
except Exception:
    CHRONOEDIT_AVAILABLE = False

try:
    from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
    QWENEDIT_AVAILABLE = True
except Exception:
    QWENEDIT_AVAILABLE = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}

def is_hub_repo_like(s): return "/" in s and not Path(s).exists()

def download_from_hf(repo_id, filename, token=None):
    token = token or os.environ.get("HF_TOKEN")
    return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)

# ---------------- Dataset ----------------
class MediaTextDataset(Dataset):
    def __init__(self, source, csv_name="dataset.csv", max_frames=5):
        self.is_hub = is_hub_repo_like(source)
        self.source = source
        token = os.environ.get("HF_TOKEN")
        if self.is_hub:
            file_path = download_from_hf(source, csv_name, token)
        else:
            file_path = Path(source) / csv_name
        if not Path(file_path).exists():
            alt = Path(str(file_path).replace(".csv", ".parquet"))
            if alt.exists(): file_path = alt
        self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
        self.root = Path(source) if not self.is_hub else None
        self.img_tf = T.Compose([T.ToPILImage(), T.Resize((512,512)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
        self.video_tf = T.Compose([T.ToPILImage(), T.Resize((128,256)), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
        self.max_frames = max_frames
    def __len__(self): return len(self.df)
    def _maybe_dl(self, fname): return str(Path(self.root)/fname) if self.root else download_from_hf(self.source, fname)
    def __getitem__(self, i):
        rec = self.df.iloc[i]
        p = Path(self._maybe_dl(rec["file_name"]))
        if p.suffix.lower() in IMAGE_EXTS:
            img = torchvision.io.read_image(str(p))
            if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy()
            return {"type": "image", "image": self.img_tf(img), "caption": rec["text"]}
        elif p.suffix.lower() in VIDEO_EXTS:
            vid,_,_ = torchvision.io.read_video(str(p))
            total, idxs = len(vid), []
            if total == 0: return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))}
            if total < self.max_frames: idxs = list(range(total))+[total-1]*(self.max_frames-total)
            else: idxs = np.linspace(0,total-1,self.max_frames).round().astype(int)
            frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs])
            return {"type": "video", "frames": frames, "caption": rec["text"]}
        else: raise RuntimeError(f"Unsupported file {p}")

# ---------------- Dynamic pipeline loader ----------------
def load_pipeline_auto(base_model, dtype=torch.float16):
    low = base_model.lower()
    if "chrono" in low and CHRONOEDIT_AVAILABLE:
        print(f"Using ChronoEdit pipeline for {base_model}")
        return ChronoEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
    elif "qwen" in low and QWENEDIT_AVAILABLE:
        print(f"Using QwenEdit pipeline for {base_model}")
        return QwenImageEditPipeline.from_pretrained(base_model, torch_dtype=dtype)
    else:
        print(f"Using Diffusion pipeline for {base_model}")
        return DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)

def infer_target_for_task(task_type, model_name):
    if task_type == "prompt-lora" or "qwen" in model_name.lower():
        return "text_encoder"
    elif task_type == "text-video" or "chrono" in model_name.lower() or "wan" in model_name.lower():
        return "transformer"
    else:
        return "unet"

def find_target_modules(model):
    names = [n for n,_ in model.named_modules()]
    targets = [n.split(".")[-1] for n in names if any(k in n for k in ["to_q","to_k","to_v","q_proj","v_proj"])]
    return targets or ["to_q","to_k","to_v","to_out"]

# ---------------- Training ----------------
def train_lora(base_model, dataset_src, csv_name, task_type, output_dir, epochs=1, lr=1e-4, r=8, alpha=16):
    accelerator = Accelerator()
    pipe = load_pipeline_auto(base_model)
    target = infer_target_for_task(task_type, base_model)
    if not hasattr(pipe, target): raise RuntimeError(f"Pipeline has no {target}")
    lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=find_target_modules(getattr(pipe, target)), lora_dropout=0.0)
    lora_module = get_peft_model(getattr(pipe, target), lcfg)
    dataset = MediaTextDataset(dataset_src, csv_name)
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader)
    mse = nn.MSELoss(); logs=[]
    for ep in range(epochs):
        for i,b in enumerate(tqdm(loader, desc=f"Epoch {ep+1}")):
            ex = b[0]; loss=torch.tensor(0.0, device=DEVICE)
            if ex["type"]=="image" and hasattr(pipe,"vae"):
                img=ex["image"].unsqueeze(0).to(DEVICE)
                lat=pipe.vae.encode(img).latent_dist.sample()*pipe.vae.config.scaling_factor
                noise=torch.randn_like(lat); loss=mse(lat,noise)
            accelerator.backward(loss); opt.step(); opt.zero_grad()
            logs.append(f"step {i} loss {loss.item():.4f}")
    Path(output_dir).mkdir(exist_ok=True)
    lora_module.save_pretrained(output_dir)
    return output_dir, logs[-20:]

# ---------------- Upload ----------------
def upload_adapter(local, repo_id):
    token=os.environ.get("HF_TOKEN")
    if not token: raise RuntimeError("HF_TOKEN missing")
    create_repo(repo_id, exist_ok=True)
    upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
    return f"https://huggingface.co/{repo_id}"

# ---------------- Gradio UI ----------------
def run_ui():
    with gr.Blocks() as demo:
        gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Flux / ChronoEdit / QwenEdit)")
        with gr.Row():
            base_model=gr.Textbox(label="Base model", value="black-forest-labs/FLUX.1-dev")
            dataset=gr.Textbox(label="Dataset folder or HF repo", value="./dataset")
            csvname=gr.Textbox(label="CSV/Parquet file", value="dataset.csv")
            task=gr.Dropdown(["text-image","text-video","prompt-lora"], label="Task type", value="text-image")
            out=gr.Textbox(label="Output dir", value="./adapter_out")
            repo=gr.Textbox(label="Upload HF repo (optional)", value="")
        with gr.Row():
            r=gr.Slider(1,64,value=8,label="LoRA rank"); a=gr.Slider(1,64,value=16,label="LoRA alpha")
            ep=gr.Number(value=1,label="Epochs"); lr=gr.Number(value=1e-4,label="Learning rate")
            btn=gr.Button("πŸš€ Start Training")
        logs=gr.Textbox(label="Logs", lines=12)
        img=gr.Image(label="Sample Output (optional)")
        def launch(bm,ds,csv,t,out_dir,r_,a_,ep_,lr_,repo_):
            try:
                out,log=train_lora(bm,ds,csv,t,out_dir,int(ep_),float(lr_),int(r_),int(a_))
                link=upload_adapter(out,repo_) if repo_ else None
                return "\n".join(log), None, link
            except Exception as e:
                return f"❌ {e}", None, None
        btn.click(launch,[base_model,dataset,csvname,task,out,r,a,ep,lr,repo],[logs,img,gr.Textbox()])
    return demo

if __name__=="__main__":
    run_ui().launch(server_name="0.0.0.0",server_port=7860)