Spaces:
Running
Running
Create app_quant.py
Browse files- app_quant.py +702 -0
app_quant.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# universal_lora_trainer_quant_dynamic.py
|
| 2 |
+
"""
|
| 3 |
+
Universal Dynamic LoRA Trainer (Accelerate + PEFT) with optional QLoRA 4-bit support.
|
| 4 |
+
|
| 5 |
+
- Supports CSV and Parquet dataset files (columns: file_name, text)
|
| 6 |
+
- Accepts dataset from a local folder or Hugging Face dataset repo id (username/repo)
|
| 7 |
+
- Real LoRA training (PEFT) for:
|
| 8 |
+
* text->image (UNet)
|
| 9 |
+
* text->video (ChronoEdit transformer)
|
| 10 |
+
* prompt-enhancer (text_encoder / QwenEdit)
|
| 11 |
+
- Optional:
|
| 12 |
+
* 4-bit quantization (bitsandbytes / QLoRA)
|
| 13 |
+
* xFormers / FlashAttention
|
| 14 |
+
* AdaLoRA (if available)
|
| 15 |
+
- Uses HF_TOKEN from environment for upload
|
| 16 |
+
- Use `accelerate launch` for multi-GPU / optimized run
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import math
|
| 21 |
+
import tempfile
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional, Tuple, List
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from torch.utils.data import Dataset, DataLoader
|
| 28 |
+
import torchvision
|
| 29 |
+
import torchvision.transforms as T
|
| 30 |
+
import pandas as pd
|
| 31 |
+
import numpy as np
|
| 32 |
+
import gradio as gr
|
| 33 |
+
from tqdm.auto import tqdm
|
| 34 |
+
|
| 35 |
+
from huggingface_hub import create_repo, upload_folder, hf_hub_download, list_repo_files
|
| 36 |
+
|
| 37 |
+
from diffusers import DiffusionPipeline
|
| 38 |
+
|
| 39 |
+
# optional pip installs - guard imports
|
| 40 |
+
try:
|
| 41 |
+
from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
|
| 42 |
+
CHRONOEDIT_AVAILABLE = True
|
| 43 |
+
except Exception:
|
| 44 |
+
CHRONOEDIT_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
# Qwen image edit optional
|
| 47 |
+
try:
|
| 48 |
+
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline # optional name
|
| 49 |
+
QWENEDIT_AVAILABLE = True
|
| 50 |
+
except Exception:
|
| 51 |
+
QWENEDIT_AVAILABLE = False
|
| 52 |
+
|
| 53 |
+
# BitsAndBytes (quantization)
|
| 54 |
+
try:
|
| 55 |
+
from transformers import BitsAndBytesConfig
|
| 56 |
+
BNB_AVAILABLE = True
|
| 57 |
+
except Exception:
|
| 58 |
+
BitsAndBytesConfig = None
|
| 59 |
+
BNB_AVAILABLE = False
|
| 60 |
+
|
| 61 |
+
# xFormers
|
| 62 |
+
try:
|
| 63 |
+
import xformers # noqa
|
| 64 |
+
XFORMERS_AVAILABLE = True
|
| 65 |
+
except Exception:
|
| 66 |
+
XFORMERS_AVAILABLE = False
|
| 67 |
+
|
| 68 |
+
# PEFT / AdaLoRA
|
| 69 |
+
try:
|
| 70 |
+
from peft import LoraConfig, get_peft_model
|
| 71 |
+
try:
|
| 72 |
+
from peft import AdaLoraConfig # optional
|
| 73 |
+
ADALORA_AVAILABLE = True
|
| 74 |
+
except Exception:
|
| 75 |
+
AdaLoraConfig = None
|
| 76 |
+
ADALORA_AVAILABLE = False
|
| 77 |
+
except Exception as e:
|
| 78 |
+
raise RuntimeError("Install peft: pip install peft") from e
|
| 79 |
+
|
| 80 |
+
# Accelerate
|
| 81 |
+
try:
|
| 82 |
+
from accelerate import Accelerator
|
| 83 |
+
except Exception as e:
|
| 84 |
+
raise RuntimeError("Install accelerate: pip install accelerate") from e
|
| 85 |
+
|
| 86 |
+
# ------------------------
|
| 87 |
+
# Config
|
| 88 |
+
# ------------------------
|
| 89 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 90 |
+
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
| 91 |
+
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
|
| 92 |
+
|
| 93 |
+
# ------------------------
|
| 94 |
+
# Utilities
|
| 95 |
+
# ------------------------
|
| 96 |
+
def is_hub_repo_like(s: str) -> bool:
|
| 97 |
+
return "/" in s and not Path(s).exists()
|
| 98 |
+
|
| 99 |
+
def download_from_hf(repo_id: str, filename: str, token: Optional[str] = None, repo_type: str = "dataset") -> str:
|
| 100 |
+
token = token or os.environ.get("HF_TOKEN")
|
| 101 |
+
return hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token, repo_type=repo_type)
|
| 102 |
+
|
| 103 |
+
def try_list_repo_files(repo_id: str, repo_type: str = "dataset", token: Optional[str] = None):
|
| 104 |
+
token = token or os.environ.get("HF_TOKEN")
|
| 105 |
+
try:
|
| 106 |
+
return list_repo_files(repo_id, token=token, repo_type=repo_type)
|
| 107 |
+
except Exception:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
def find_target_modules(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
|
| 111 |
+
names = [n for n, _ in model.named_modules()]
|
| 112 |
+
selected = set()
|
| 113 |
+
for cand in candidates:
|
| 114 |
+
for n in names:
|
| 115 |
+
if cand in n:
|
| 116 |
+
selected.add(n.split(".")[-1])
|
| 117 |
+
if not selected:
|
| 118 |
+
return ["to_q", "to_k", "to_v", "to_out"]
|
| 119 |
+
return list(selected)
|
| 120 |
+
|
| 121 |
+
# ------------------------
|
| 122 |
+
# Dataset class (CSV/Parquet)
|
| 123 |
+
# ------------------------
|
| 124 |
+
class MediaTextDataset(Dataset):
|
| 125 |
+
"""
|
| 126 |
+
Loads records from CSV or Parquet with columns:
|
| 127 |
+
- file_name (relative path in folder or filename inside HF dataset repo)
|
| 128 |
+
- text
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5,
|
| 131 |
+
image_size=(512,512), video_frame_size=(128,256), hub_token: Optional[str] = None):
|
| 132 |
+
self.source = dataset_source
|
| 133 |
+
self.is_hub = is_hub_repo_like(dataset_source)
|
| 134 |
+
self.max_frames = max_frames
|
| 135 |
+
self.image_size = image_size
|
| 136 |
+
self.video_frame_size = video_frame_size
|
| 137 |
+
self.hub_token = hub_token or os.environ.get("HF_TOKEN")
|
| 138 |
+
|
| 139 |
+
# load dataframe (CSV or parquet)
|
| 140 |
+
if self.is_hub:
|
| 141 |
+
# try CSV then parquet; specify repo_type="dataset"
|
| 142 |
+
searched = try_list_repo_files(self.source, repo_type="dataset", token=self.hub_token)
|
| 143 |
+
# prefer exact csv_name
|
| 144 |
+
try:
|
| 145 |
+
csv_local = download_from_hf(self.source, csv_name, token=self.hub_token, repo_type="dataset")
|
| 146 |
+
except Exception:
|
| 147 |
+
# try .parquet variant
|
| 148 |
+
alt = csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else csv_name + ".parquet"
|
| 149 |
+
csv_local = download_from_hf(self.source, alt, token=self.hub_token, repo_type="dataset")
|
| 150 |
+
if str(csv_local).endswith(".parquet"):
|
| 151 |
+
df = pd.read_parquet(csv_local)
|
| 152 |
+
else:
|
| 153 |
+
df = pd.read_csv(csv_local)
|
| 154 |
+
self.df = df
|
| 155 |
+
self.root = None
|
| 156 |
+
else:
|
| 157 |
+
root = Path(dataset_source)
|
| 158 |
+
csv_path = root / csv_name
|
| 159 |
+
parquet_path = root / csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else root / (csv_name + ".parquet")
|
| 160 |
+
if csv_path.exists():
|
| 161 |
+
self.df = pd.read_csv(csv_path)
|
| 162 |
+
elif parquet_path.exists():
|
| 163 |
+
self.df = pd.read_parquet(parquet_path)
|
| 164 |
+
else:
|
| 165 |
+
p = root / csv_name
|
| 166 |
+
if p.exists():
|
| 167 |
+
if p.suffix.lower() == ".parquet":
|
| 168 |
+
self.df = pd.read_parquet(p)
|
| 169 |
+
else:
|
| 170 |
+
self.df = pd.read_csv(p)
|
| 171 |
+
else:
|
| 172 |
+
raise FileNotFoundError(f"Can't find {csv_name} in {dataset_source}")
|
| 173 |
+
self.root = root
|
| 174 |
+
|
| 175 |
+
# transforms
|
| 176 |
+
self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
|
| 177 |
+
self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
|
| 178 |
+
|
| 179 |
+
def __len__(self):
|
| 180 |
+
return len(self.df)
|
| 181 |
+
|
| 182 |
+
def _maybe_download_from_hub(self, file_name: str) -> str:
|
| 183 |
+
if self.root is not None:
|
| 184 |
+
p = self.root / file_name
|
| 185 |
+
if p.exists():
|
| 186 |
+
return str(p)
|
| 187 |
+
# else download from dataset repo
|
| 188 |
+
return download_from_hf(self.source, file_name, token=self.hub_token, repo_type="dataset")
|
| 189 |
+
|
| 190 |
+
def _read_video_frames(self, path: str, num_frames: int):
|
| 191 |
+
video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec')
|
| 192 |
+
total = len(video_frames)
|
| 193 |
+
if total == 0:
|
| 194 |
+
C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
|
| 195 |
+
return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
|
| 196 |
+
if total < num_frames:
|
| 197 |
+
idxs = list(range(total)) + [total-1]*(num_frames-total)
|
| 198 |
+
else:
|
| 199 |
+
idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
|
| 200 |
+
frames = []
|
| 201 |
+
for i in idxs:
|
| 202 |
+
arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
|
| 203 |
+
frames.append(self.video_transform(arr))
|
| 204 |
+
frames = torch.stack(frames, dim=0)
|
| 205 |
+
return frames
|
| 206 |
+
|
| 207 |
+
def __getitem__(self, idx):
|
| 208 |
+
rec = self.df.iloc[idx]
|
| 209 |
+
file_name = rec["file_name"]
|
| 210 |
+
caption = rec["text"]
|
| 211 |
+
if self.is_hub:
|
| 212 |
+
local_path = self._maybe_download_from_hub(file_name)
|
| 213 |
+
else:
|
| 214 |
+
local_path = str(Path(self.root) / file_name)
|
| 215 |
+
p = Path(local_path)
|
| 216 |
+
suffix = p.suffix.lower()
|
| 217 |
+
if suffix in IMAGE_EXTS:
|
| 218 |
+
img = torchvision.io.read_image(local_path) # [C,H,W]
|
| 219 |
+
if isinstance(img, torch.Tensor):
|
| 220 |
+
img = img.permute(1,2,0).numpy()
|
| 221 |
+
return {'type': 'image', 'image': self.image_transform(img), 'caption': caption, 'file_name': file_name}
|
| 222 |
+
elif suffix in VIDEO_EXTS:
|
| 223 |
+
frames = self._read_video_frames(local_path, self.max_frames) # [T,C,H,W]
|
| 224 |
+
return {'type': 'video', 'frames': frames, 'caption': caption, 'file_name': file_name}
|
| 225 |
+
else:
|
| 226 |
+
raise RuntimeError(f"Unsupported media type: {local_path}")
|
| 227 |
+
|
| 228 |
+
# ------------------------
|
| 229 |
+
# Pipeline loader with optional quantization
|
| 230 |
+
# ------------------------
|
| 231 |
+
def load_pipeline_auto(base_model_id: str, use_4bit: bool = False, bnb_config: Optional[object] = None, torch_dtype=torch.float16):
|
| 232 |
+
low = base_model_id.lower()
|
| 233 |
+
is_chrono = "chrono" in low or "wan" in low or "video" in low
|
| 234 |
+
is_qwen = "qwen" in low or "qwenimage" in low
|
| 235 |
+
# choose pipeline
|
| 236 |
+
if is_chrono and CHRONOEDIT_AVAILABLE:
|
| 237 |
+
print("Loading ChronoEdit pipeline")
|
| 238 |
+
# ChronoEdit may not accept quant config; try with safer call
|
| 239 |
+
if use_4bit and bnb_config is not None:
|
| 240 |
+
pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) # quantized loading of chronoedit not widely supported
|
| 241 |
+
else:
|
| 242 |
+
pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 243 |
+
elif is_qwen and QWENEDIT_AVAILABLE:
|
| 244 |
+
print("Loading QWEN image-edit pipeline")
|
| 245 |
+
pipe = QwenImageEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 246 |
+
else:
|
| 247 |
+
# fallback to DiffusionPipeline - supports quantization_config for diffusers+transformers
|
| 248 |
+
print("Loading standard DiffusionPipeline:", base_model_id, "use_4bit=", use_4bit)
|
| 249 |
+
if use_4bit and BNB_AVAILABLE and bnb_config is not None:
|
| 250 |
+
pipe = DiffusionPipeline.from_pretrained(base_model_id, quantization_config=bnb_config, torch_dtype=torch.float16)
|
| 251 |
+
else:
|
| 252 |
+
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 253 |
+
return pipe
|
| 254 |
+
|
| 255 |
+
# ------------------------
|
| 256 |
+
# Auto infer adapter target
|
| 257 |
+
# ------------------------
|
| 258 |
+
def infer_target_for_task(task_type: str, model_name: str) -> str:
|
| 259 |
+
low = model_name.lower()
|
| 260 |
+
if task_type == "prompt-lora" or "qwen" in low or "qwenedit" in low:
|
| 261 |
+
return "text_encoder"
|
| 262 |
+
if task_type == "text-video" or "chrono" in low or "wan" in low:
|
| 263 |
+
return "transformer"
|
| 264 |
+
# default
|
| 265 |
+
return "unet"
|
| 266 |
+
|
| 267 |
+
# ------------------------
|
| 268 |
+
# LoRA attach (supports AdaLoRA if available)
|
| 269 |
+
# ------------------------
|
| 270 |
+
def attach_lora(pipe, adapter_target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0, use_adalora: bool = False):
|
| 271 |
+
if adapter_target == "unet":
|
| 272 |
+
if not hasattr(pipe, "unet"):
|
| 273 |
+
raise RuntimeError("Pipeline has no UNet to attach LoRA")
|
| 274 |
+
target_module = pipe.unet
|
| 275 |
+
attr = "unet"
|
| 276 |
+
elif adapter_target == "transformer":
|
| 277 |
+
if not hasattr(pipe, "transformer"):
|
| 278 |
+
raise RuntimeError("Pipeline has no transformer to attach LoRA")
|
| 279 |
+
target_module = pipe.transformer
|
| 280 |
+
attr = "transformer"
|
| 281 |
+
elif adapter_target == "text_encoder":
|
| 282 |
+
if not hasattr(pipe, "text_encoder"):
|
| 283 |
+
# some models name it differently; try encoder attribute fallback
|
| 284 |
+
if hasattr(pipe, "text_encoder"):
|
| 285 |
+
target_module = pipe.text_encoder
|
| 286 |
+
attr = "text_encoder"
|
| 287 |
+
else:
|
| 288 |
+
raise RuntimeError("Pipeline has no text_encoder for prompt-loRA")
|
| 289 |
+
else:
|
| 290 |
+
target_module = pipe.text_encoder
|
| 291 |
+
attr = "text_encoder"
|
| 292 |
+
else:
|
| 293 |
+
raise RuntimeError("Unknown adapter_target")
|
| 294 |
+
|
| 295 |
+
target_modules = find_target_modules(target_module)
|
| 296 |
+
print("Detected target_modules for LoRA:", target_modules)
|
| 297 |
+
|
| 298 |
+
if use_adalora and ADALORA_AVAILABLE:
|
| 299 |
+
lora_config = AdaLoraConfig(
|
| 300 |
+
r=r,
|
| 301 |
+
lora_alpha=alpha,
|
| 302 |
+
target_modules=target_modules,
|
| 303 |
+
init_r=4,
|
| 304 |
+
lora_dropout=dropout,
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
lora_config = LoraConfig(
|
| 308 |
+
r=r,
|
| 309 |
+
lora_alpha=alpha,
|
| 310 |
+
target_modules=target_modules,
|
| 311 |
+
lora_dropout=dropout,
|
| 312 |
+
bias="none",
|
| 313 |
+
task_type="SEQ_2_SEQ_LM",
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
peft_model = get_peft_model(target_module, lora_config)
|
| 317 |
+
setattr(pipe, attr, peft_model)
|
| 318 |
+
return pipe, attr
|
| 319 |
+
|
| 320 |
+
# ------------------------
|
| 321 |
+
# Training loop (Accelerate-aware)
|
| 322 |
+
# ------------------------
|
| 323 |
+
def train_lora_accelerate(base_model_id: str,
|
| 324 |
+
dataset_source: str,
|
| 325 |
+
csv_name: str,
|
| 326 |
+
task_type: str,
|
| 327 |
+
adapter_target_override: Optional[str],
|
| 328 |
+
output_dir: str,
|
| 329 |
+
epochs: int = 1,
|
| 330 |
+
batch_size: int = 1,
|
| 331 |
+
lr: float = 1e-4,
|
| 332 |
+
max_train_steps: Optional[int] = None,
|
| 333 |
+
lora_r: int = 8,
|
| 334 |
+
lora_alpha: int = 16,
|
| 335 |
+
use_4bit: bool = False,
|
| 336 |
+
enable_xformers: bool = False,
|
| 337 |
+
use_adalora: bool = False,
|
| 338 |
+
gradient_accumulation_steps: int = 1,
|
| 339 |
+
mixed_precision: Optional[str] = None,
|
| 340 |
+
save_every_steps: int = 200,
|
| 341 |
+
max_frames: int = 5):
|
| 342 |
+
|
| 343 |
+
# Setup Accelerator
|
| 344 |
+
accelerator = Accelerator(mixed_precision=mixed_precision or ("fp16" if torch.cuda.is_available() else "no")),
|
| 345 |
+
# Note: Accelerator is returned as a tuple if trailing comma; fix:
|
| 346 |
+
accelerator = accelerator if isinstance(accelerator, Accelerator) else accelerator[0]
|
| 347 |
+
device = accelerator.device
|
| 348 |
+
|
| 349 |
+
# prepare bitsandbytes config if requested
|
| 350 |
+
bnb_conf = None
|
| 351 |
+
if use_4bit and BNB_AVAILABLE:
|
| 352 |
+
bnb_conf = BitsAndBytesConfig(
|
| 353 |
+
load_in_4bit=True,
|
| 354 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 355 |
+
bnb_4bit_use_double_quant=True,
|
| 356 |
+
bnb_4bit_quant_type="nf4",
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Load pipeline (supports quant for standard diffusers)
|
| 360 |
+
pipe = load_pipeline_auto(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
|
| 361 |
+
|
| 362 |
+
# optionally enable memory efficient attention
|
| 363 |
+
if enable_xformers:
|
| 364 |
+
try:
|
| 365 |
+
if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
|
| 366 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 367 |
+
elif hasattr(pipe, "enable_attention_slicing"):
|
| 368 |
+
pipe.enable_attention_slicing()
|
| 369 |
+
print("xFormers / memory efficient attention enabled.")
|
| 370 |
+
except Exception as e:
|
| 371 |
+
print("Could not enable xformers:", e)
|
| 372 |
+
|
| 373 |
+
# infer adapter target automatically if not overridden
|
| 374 |
+
adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task(task_type, base_model_id)
|
| 375 |
+
print("Adapter target set to:", adapter_target)
|
| 376 |
+
|
| 377 |
+
# attach LoRA
|
| 378 |
+
pipe, attr = attach_lora(pipe, adapter_target, r=lora_r, alpha=lora_alpha, dropout=0.0, use_adalora=use_adalora)
|
| 379 |
+
# pick the peft module for optimization
|
| 380 |
+
peft_module = getattr(pipe, attr)
|
| 381 |
+
|
| 382 |
+
# dataset + dataloader (we use batch_size=1 collate)
|
| 383 |
+
dataset = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=max_frames)
|
| 384 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)
|
| 385 |
+
|
| 386 |
+
# optimizer
|
| 387 |
+
trainable_params = [p for n,p in peft_module.named_parameters() if p.requires_grad]
|
| 388 |
+
optimizer = torch.optim.AdamW(trainable_params, lr=lr)
|
| 389 |
+
|
| 390 |
+
# prepare objects with accelerator
|
| 391 |
+
peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader)
|
| 392 |
+
|
| 393 |
+
# training loop
|
| 394 |
+
logs = []
|
| 395 |
+
global_step = 0
|
| 396 |
+
loss_fn = nn.MSELoss()
|
| 397 |
+
|
| 398 |
+
# scheduler setup if available
|
| 399 |
+
if hasattr(pipe, "scheduler"):
|
| 400 |
+
try:
|
| 401 |
+
pipe.scheduler.set_timesteps(50, device=device)
|
| 402 |
+
timesteps = pipe.scheduler.timesteps
|
| 403 |
+
except Exception:
|
| 404 |
+
timesteps = None
|
| 405 |
+
else:
|
| 406 |
+
timesteps = None
|
| 407 |
+
|
| 408 |
+
# Training
|
| 409 |
+
for epoch in range(int(epochs)):
|
| 410 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
|
| 411 |
+
for batch in pbar:
|
| 412 |
+
example = batch[0]
|
| 413 |
+
# image flow
|
| 414 |
+
if example["type"] == "image":
|
| 415 |
+
img = example["image"].unsqueeze(0).to(device)
|
| 416 |
+
caption = [example["caption"]]
|
| 417 |
+
|
| 418 |
+
if not hasattr(pipe, "encode_prompt"):
|
| 419 |
+
raise RuntimeError("Pipeline lacks encode_prompt - cannot encode prompts")
|
| 420 |
+
|
| 421 |
+
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
|
| 422 |
+
prompt=caption,
|
| 423 |
+
negative_prompt=None,
|
| 424 |
+
do_classifier_free_guidance=True,
|
| 425 |
+
num_videos_per_prompt=1,
|
| 426 |
+
prompt_embeds=None,
|
| 427 |
+
negative_prompt_embeds=None,
|
| 428 |
+
max_sequence_length=512,
|
| 429 |
+
device=device,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if not hasattr(pipe, "vae"):
|
| 433 |
+
raise RuntimeError("Pipeline lacks VAE - required for latent conversion")
|
| 434 |
+
|
| 435 |
+
with torch.no_grad():
|
| 436 |
+
latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
|
| 437 |
+
|
| 438 |
+
noise = torch.randn_like(latents).to(device)
|
| 439 |
+
if timesteps is None:
|
| 440 |
+
t = torch.tensor(1, device=device)
|
| 441 |
+
else:
|
| 442 |
+
t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
|
| 443 |
+
noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
|
| 444 |
+
|
| 445 |
+
# forward through peft_module (unet)
|
| 446 |
+
out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)
|
| 447 |
+
if hasattr(out, "sample"):
|
| 448 |
+
noise_pred = out.sample
|
| 449 |
+
elif isinstance(out, tuple):
|
| 450 |
+
noise_pred = out[0]
|
| 451 |
+
else:
|
| 452 |
+
noise_pred = out
|
| 453 |
+
|
| 454 |
+
loss = loss_fn(noise_pred, noise)
|
| 455 |
+
|
| 456 |
+
else:
|
| 457 |
+
# video flow (ChronoEdit simplified)
|
| 458 |
+
if not CHRONOEDIT_AVAILABLE:
|
| 459 |
+
raise RuntimeError("ChronoEdit training requested but not installed in environment")
|
| 460 |
+
frames = example["frames"].unsqueeze(0).to(device) # [1, T, C, H, W]
|
| 461 |
+
frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
|
| 462 |
+
video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
|
| 463 |
+
latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device, generator=None, latents=None, last_image=None)
|
| 464 |
+
if pipe.config.expand_timesteps:
|
| 465 |
+
latents, condition, first_frame_mask = latents_out
|
| 466 |
+
else:
|
| 467 |
+
latents, condition = latents_out
|
| 468 |
+
first_frame_mask = None
|
| 469 |
+
|
| 470 |
+
noise = torch.randn_like(latents).to(device)
|
| 471 |
+
t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
|
| 472 |
+
noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
|
| 473 |
+
|
| 474 |
+
if pipe.config.expand_timesteps:
|
| 475 |
+
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * noisy_latents
|
| 476 |
+
else:
|
| 477 |
+
latent_model_input = torch.cat([noisy_latents, condition], dim=1)
|
| 478 |
+
|
| 479 |
+
out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]), encoder_hidden_states=None, encoder_hidden_states_image=None, return_dict=False)
|
| 480 |
+
noise_pred = out[0] if isinstance(out, tuple) else out
|
| 481 |
+
loss = loss_fn(noise_pred, noise)
|
| 482 |
+
|
| 483 |
+
# backward and optimizer step (accelerator)
|
| 484 |
+
accelerator.backward(loss)
|
| 485 |
+
optimizer.step()
|
| 486 |
+
optimizer.zero_grad()
|
| 487 |
+
global_step += 1
|
| 488 |
+
|
| 489 |
+
logs.append(f"step {global_step} loss {loss.item():.6f}")
|
| 490 |
+
pbar.set_postfix({"loss": f"{loss.item():.6f}"})
|
| 491 |
+
|
| 492 |
+
if max_train_steps and global_step >= max_train_steps:
|
| 493 |
+
break
|
| 494 |
+
|
| 495 |
+
if global_step % save_every_steps == 0:
|
| 496 |
+
out_sub = Path(output_dir) / f"lora_step_{global_step}"
|
| 497 |
+
out_sub.mkdir(parents=True, exist_ok=True)
|
| 498 |
+
try:
|
| 499 |
+
peft_module.save_pretrained(str(out_sub))
|
| 500 |
+
except Exception:
|
| 501 |
+
torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
|
| 502 |
+
print(f"Saved adapter at {out_sub}")
|
| 503 |
+
|
| 504 |
+
if max_train_steps and global_step >= max_train_steps:
|
| 505 |
+
break
|
| 506 |
+
|
| 507 |
+
# final save
|
| 508 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 509 |
+
try:
|
| 510 |
+
peft_module.save_pretrained(output_dir)
|
| 511 |
+
except Exception:
|
| 512 |
+
torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
|
| 513 |
+
|
| 514 |
+
return output_dir, logs
|
| 515 |
+
|
| 516 |
+
# ------------------------
|
| 517 |
+
# Test generation (best-effort)
|
| 518 |
+
# ------------------------
|
| 519 |
+
def test_generation_load_and_run(base_model_id: str, adapter_dir: Optional[str], adapter_target: str, prompt: str, use_4bit: bool = False):
|
| 520 |
+
# load base pipeline (no heavy quant config)
|
| 521 |
+
bnb_conf = None
|
| 522 |
+
if use_4bit and BNB_AVAILABLE:
|
| 523 |
+
bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
| 524 |
+
pipe = load_pipeline_auto(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
| 525 |
+
|
| 526 |
+
# attempt to load adapter into target module (best-effort)
|
| 527 |
+
try:
|
| 528 |
+
if adapter_target == "unet" and hasattr(pipe, "unet"):
|
| 529 |
+
lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.unet))
|
| 530 |
+
pipe.unet = get_peft_model(pipe.unet, lcfg)
|
| 531 |
+
try:
|
| 532 |
+
pipe.unet.load_state_dict(torch.load(Path(adapter_dir) / "pytorch_model.bin"), strict=False)
|
| 533 |
+
except Exception:
|
| 534 |
+
try:
|
| 535 |
+
pipe.unet.load_adapter(adapter_dir)
|
| 536 |
+
except Exception:
|
| 537 |
+
pass
|
| 538 |
+
elif adapter_target == "transformer" and hasattr(pipe, "transformer"):
|
| 539 |
+
lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.transformer))
|
| 540 |
+
pipe.transformer = get_peft_model(pipe.transformer, lcfg)
|
| 541 |
+
elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"):
|
| 542 |
+
lcfg = LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.text_encoder))
|
| 543 |
+
pipe.text_encoder = get_peft_model(pipe.text_encoder, lcfg)
|
| 544 |
+
except Exception as e:
|
| 545 |
+
print("Adapter load warning", e)
|
| 546 |
+
|
| 547 |
+
pipe.to(DEVICE)
|
| 548 |
+
out = pipe(prompt=prompt, num_inference_steps=8)
|
| 549 |
+
if hasattr(out, "images"):
|
| 550 |
+
return out.images[0]
|
| 551 |
+
elif hasattr(out, "frames"):
|
| 552 |
+
frames = out.frames[0]
|
| 553 |
+
from PIL import Image
|
| 554 |
+
return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8"))
|
| 555 |
+
raise RuntimeError("No images/frames returned")
|
| 556 |
+
|
| 557 |
+
# ------------------------
|
| 558 |
+
# Upload adapter to HF Hub
|
| 559 |
+
# ------------------------
|
| 560 |
+
def upload_adapter(local_dir: str, repo_id: str) -> str:
|
| 561 |
+
token = os.environ.get("HF_TOKEN")
|
| 562 |
+
if token is None:
|
| 563 |
+
raise RuntimeError("HF_TOKEN not set in environment for upload")
|
| 564 |
+
create_repo(repo_id, exist_ok=True)
|
| 565 |
+
upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token)
|
| 566 |
+
return f"https://huggingface.co/{repo_id}"
|
| 567 |
+
|
| 568 |
+
# ------------------------
|
| 569 |
+
# UI: Boost info helper
|
| 570 |
+
# ------------------------
|
| 571 |
+
def boost_info_text(use_4bit: bool, enable_xformers: bool, mixed_precision: Optional[str], device_type: str):
|
| 572 |
+
lines = []
|
| 573 |
+
lines.append(f"Device: {device_type.upper()}")
|
| 574 |
+
if use_4bit and BNB_AVAILABLE:
|
| 575 |
+
lines.append("4-bit QLoRA enabled: ~4x memory saving (bitsandbytes NF4 + double quant).")
|
| 576 |
+
else:
|
| 577 |
+
lines.append("QLoRA disabled or bitsandbytes not installed.")
|
| 578 |
+
if enable_xformers and XFORMERS_AVAILABLE:
|
| 579 |
+
lines.append("xFormers/FlashAttention: memory-efficient attention enabled (faster & lower mem).")
|
| 580 |
+
else:
|
| 581 |
+
lines.append("xFormers not enabled or not installed.")
|
| 582 |
+
if mixed_precision:
|
| 583 |
+
lines.append(f"Mixed precision: {mixed_precision}")
|
| 584 |
+
else:
|
| 585 |
+
lines.append("Mixed precision: default (no automatic FP16/BF16).")
|
| 586 |
+
lines.append("Expected: GPU + 4-bit + xFormers = fastest. CPU + 4-bit possible but slow.")
|
| 587 |
+
return "\n".join(lines)
|
| 588 |
+
|
| 589 |
+
# ------------------------
|
| 590 |
+
# Gradio UI wiring
|
| 591 |
+
# ------------------------
|
| 592 |
+
def run_all_ui(base_model_id: str,
|
| 593 |
+
dataset_source: str,
|
| 594 |
+
csv_name: str,
|
| 595 |
+
task_type: str,
|
| 596 |
+
adapter_target_override: str,
|
| 597 |
+
lora_r: int,
|
| 598 |
+
lora_alpha: int,
|
| 599 |
+
epochs: int,
|
| 600 |
+
batch_size: int,
|
| 601 |
+
lr: float,
|
| 602 |
+
max_train_steps: int,
|
| 603 |
+
output_dir: str,
|
| 604 |
+
upload_repo: str,
|
| 605 |
+
use_4bit: bool,
|
| 606 |
+
enable_xformers: bool,
|
| 607 |
+
use_adalora: bool,
|
| 608 |
+
grad_accum: int,
|
| 609 |
+
mixed_precision: str,
|
| 610 |
+
save_every_steps: int):
|
| 611 |
+
# map task_type -> adapter_target if override empty
|
| 612 |
+
adapter_target = adapter_target_override or infer_target_for_task(task_type, base_model_id)
|
| 613 |
+
try:
|
| 614 |
+
out_dir, logs = train_lora_accelerate(
|
| 615 |
+
base_model_id,
|
| 616 |
+
dataset_source,
|
| 617 |
+
csv_name,
|
| 618 |
+
task_type,
|
| 619 |
+
adapter_target,
|
| 620 |
+
output_dir,
|
| 621 |
+
epochs=epochs,
|
| 622 |
+
batch_size=batch_size,
|
| 623 |
+
lr=lr,
|
| 624 |
+
max_train_steps=(max_train_steps if max_train_steps>0 else None),
|
| 625 |
+
lora_r=lora_r,
|
| 626 |
+
lora_alpha=lora_alpha,
|
| 627 |
+
use_4bit=use_4bit,
|
| 628 |
+
enable_xformers=enable_xformers,
|
| 629 |
+
use_adalora=use_adalora,
|
| 630 |
+
gradient_accumulation_steps=grad_accum,
|
| 631 |
+
mixed_precision=(mixed_precision if mixed_precision != "none" else None),
|
| 632 |
+
save_every_steps=save_every_steps,
|
| 633 |
+
)
|
| 634 |
+
except Exception as e:
|
| 635 |
+
return f"Training failed: {e}", None, None
|
| 636 |
+
|
| 637 |
+
link = None
|
| 638 |
+
if upload_repo:
|
| 639 |
+
try:
|
| 640 |
+
link = upload_adapter(out_dir, upload_repo)
|
| 641 |
+
except Exception as e:
|
| 642 |
+
link = f"Upload failed: {e}"
|
| 643 |
+
|
| 644 |
+
# quick test generation using first dataset prompt
|
| 645 |
+
try:
|
| 646 |
+
ds = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=5)
|
| 647 |
+
test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
|
| 648 |
+
except Exception:
|
| 649 |
+
test_prompt = "A cat on a skateboard"
|
| 650 |
+
|
| 651 |
+
test_img = None
|
| 652 |
+
try:
|
| 653 |
+
test_img = test_generation_load_and_run(base_model_id, out_dir, adapter_target, test_prompt, use_4bit=use_4bit)
|
| 654 |
+
except Exception as e:
|
| 655 |
+
print("Test gen failed:", e)
|
| 656 |
+
|
| 657 |
+
return "\n".join(logs[-200:]), test_img, link
|
| 658 |
+
|
| 659 |
+
def build_ui():
|
| 660 |
+
with gr.Blocks() as demo:
|
| 661 |
+
gr.Markdown("# Universal LoRA Trainer — Quantization & Speedups (single-file)")
|
| 662 |
+
with gr.Row():
|
| 663 |
+
with gr.Column(scale=2):
|
| 664 |
+
base_model = gr.Textbox(label="Base model id (Diffusers / ChronoEdit / Qwen)", value="runwayml/stable-diffusion-v1-5")
|
| 665 |
+
dataset_source = gr.Textbox(label="Dataset folder or HF dataset repo (username/repo)", value="./dataset")
|
| 666 |
+
csv_name = gr.Textbox(label="CSV/Parquet filename", value="dataset.csv")
|
| 667 |
+
task_type = gr.Dropdown(label="Task type", choices=["text-image", "text-video", "prompt-lora"], value="text-image")
|
| 668 |
+
adapter_target_override = gr.Textbox(label="Adapter target override (leave blank for auto)", value="")
|
| 669 |
+
lora_r = gr.Slider(1, 64, value=8, step=1, label="LoRA rank (r)")
|
| 670 |
+
lora_alpha = gr.Slider(1, 128, value=16, step=1, label="LoRA alpha")
|
| 671 |
+
epochs = gr.Number(label="Epochs", value=1)
|
| 672 |
+
batch_size = gr.Number(label="Batch size (per device)", value=1)
|
| 673 |
+
lr = gr.Number(label="Learning rate", value=1e-4)
|
| 674 |
+
max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0)
|
| 675 |
+
save_every_steps = gr.Number(label="Save every steps", value=200)
|
| 676 |
+
output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
|
| 677 |
+
upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional, username/repo)", value="")
|
| 678 |
+
with gr.Column(scale=1):
|
| 679 |
+
gr.Markdown("## Speed / Quantization")
|
| 680 |
+
use_4bit = gr.Checkbox(label="Enable 4-bit QLoRA (bitsandbytes)", value=False)
|
| 681 |
+
enable_xformers = gr.Checkbox(label="Enable xFormers / memory efficient attention", value=False)
|
| 682 |
+
use_adalora = gr.Checkbox(label="Use AdaLoRA (if available in peft)", value=False)
|
| 683 |
+
grad_accum = gr.Number(label="Gradient accumulation steps", value=1)
|
| 684 |
+
mixed_precision = gr.Radio(choices=["none", "fp16", "bf16"], value=("fp16" if torch.cuda.is_available() else "none"), label="Mixed precision")
|
| 685 |
+
gr.Markdown("### Boost Info")
|
| 686 |
+
boost_info = gr.Textbox(label="Expected boost / notes", value="", lines=6)
|
| 687 |
+
start_btn = gr.Button("Start Training")
|
| 688 |
+
with gr.Row():
|
| 689 |
+
logs = gr.Textbox(label="Training logs (tail)", lines=18)
|
| 690 |
+
sample_image = gr.Image(label="Sample generated frame after training")
|
| 691 |
+
upload_link = gr.Textbox(label="Upload link / status")
|
| 692 |
+
def on_start(base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, grad_accum_val, mixed_precision_val, save_every_steps):
|
| 693 |
+
boost_text = boost_info_text(use_4bit_val, enable_xformers_val, mixed_precision_val, "gpu" if torch.cuda.is_available() else "cpu")
|
| 694 |
+
# start training (blocking)
|
| 695 |
+
logs_out, sample, link = run_all_ui(base_model, dataset_source, csv_name, task_type, adapter_target_override, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, int(grad_accum_val), mixed_precision_val, int(save_every_steps))
|
| 696 |
+
return boost_text + "\n\n" + logs_out, sample, link
|
| 697 |
+
start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit, enable_xformers, use_adalora, grad_accum, mixed_precision, save_every_steps], outputs=[boost_info, sample_image, upload_link])
|
| 698 |
+
return demo
|
| 699 |
+
|
| 700 |
+
if __name__ == "__main__":
|
| 701 |
+
ui = build_ui()
|
| 702 |
+
ui.launch(server_name="0.0.0.0", server_port=7860)
|