Spaces:
Running
Running
| # universal_lora_trainer_accelerate_singlefile_dynamic.py | |
| """ | |
| Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio) | |
| - Gemma LLM default | |
| - Robust batch handling (fixes KeyError: 0) | |
| - Streams logs to Gradio (includes progress %) | |
| - Supports CSV/Parquet HuggingFace or local datasets | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| from torch.utils.data import Dataset, DataLoader | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate import Accelerator | |
| from huggingface_hub import hf_hub_download, create_repo, upload_folder | |
| # transformers optional | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| TRANSFORMERS_AVAILABLE = True | |
| except Exception: | |
| TRANSFORMERS_AVAILABLE = False | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------- Helpers ---------------- | |
| def is_hub_repo_like(s): | |
| return "/" in s and not Path(s).exists() | |
| def download_from_hf(repo_id, filename, token=None): | |
| token = token or os.environ.get("HF_TOKEN") | |
| return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token) | |
| # ---------------- Dataset ---------------- | |
| class MediaTextDataset(Dataset): | |
| def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None): | |
| self.is_hub = is_hub_repo_like(source) | |
| token = os.environ.get("HF_TOKEN") | |
| if self.is_hub: | |
| file_path = download_from_hf(source, csv_name, token) | |
| else: | |
| file_path = Path(source) / csv_name | |
| # fallback to parquet if CSV missing | |
| if not Path(file_path).exists(): | |
| alt = Path(str(file_path).replace(".csv", ".parquet")) | |
| if alt.exists(): | |
| file_path = alt | |
| else: | |
| raise FileNotFoundError(f"Dataset file not found: {file_path}") | |
| self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path) | |
| if max_records: | |
| self.df = self.df.head(max_records) | |
| self.text_columns = text_columns or ["short_prompt", "long_prompt"] | |
| print(f"[DEBUG] Loaded dataset: {file_path}, columns: {list(self.df.columns)}") | |
| print(f"[DEBUG] Sample rows:\n{self.df.head(3)}") | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, i): | |
| rec = self.df.iloc[i] | |
| out = {"text": {}} | |
| for col in self.text_columns: | |
| out["text"][col] = rec[col] if col in rec else "" | |
| return out | |
| # ---------------- Model loader ---------------- | |
| def load_pipeline_auto(base_model, dtype=torch.float16): | |
| if "gemma" in base_model.lower(): | |
| if not TRANSFORMERS_AVAILABLE: | |
| raise RuntimeError("Transformers not installed for LLM support.") | |
| print(f"[INFO] Using Gemma LLM for {base_model}") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype) | |
| return {"model": model, "tokenizer": tokenizer} | |
| else: | |
| raise NotImplementedError("Only Gemma LLM supported in this script.") | |
| def find_target_modules(model): | |
| candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] | |
| names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] | |
| targets = [n.split(".")[-1] for n in names if any(c in n for c in candidates)] | |
| if not targets: | |
| targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] | |
| print(f"[WARNING] No standard attention modules found, using Linear layers for LoRA.") | |
| else: | |
| print(f"[INFO] LoRA target modules detected: {targets[:40]}{'...' if len(targets)>40 else ''}") | |
| return targets | |
| # ---------------- Batch unwrapping ---------------- | |
| def unwrap_batch(batch, short_col, long_col): | |
| if isinstance(batch, (list, tuple)): | |
| ex = batch[0] | |
| if "text" in ex: | |
| return ex | |
| if "short" in ex and "long" in ex: | |
| return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}} | |
| return {"text": ex} | |
| if isinstance(batch, dict): | |
| first_elem = {} | |
| is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values()) | |
| if is_batched: | |
| for k, v in batch.items(): | |
| try: first = v[0] | |
| except Exception: first = v | |
| first_elem[k] = first | |
| if "text" in first_elem: | |
| t = first_elem["text"] | |
| if isinstance(t, (list, tuple)) and len(t) > 0: | |
| return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}} | |
| if isinstance(t, dict): return {"text": t} | |
| return {"text": {short_col: str(t), long_col: ""}} | |
| if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem): | |
| s = first_elem.get(short_col, first_elem.get("short", "")) | |
| l = first_elem.get(long_col, first_elem.get("long", "")) | |
| return {"text": {short_col: str(s), long_col: str(l)}} | |
| return {"text": {short_col: str(first_elem)}} | |
| if "text" in batch and isinstance(batch["text"], dict): | |
| return {"text": batch["text"]} | |
| s = batch.get(short_col, batch.get("short", "")) | |
| l = batch.get(long_col, batch.get("long", "")) | |
| return {"text": {short_col: str(s), long_col: str(l)}} | |
| return {"text": {short_col: str(batch), long_col: ""}} | |
| # ---------------- Training (forward + backward + logs) ---------------- | |
| def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir, | |
| epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, num_workers=0, | |
| max_train_records=None): | |
| accelerator = Accelerator() | |
| pipe = load_pipeline_auto(base_model) | |
| model_obj = pipe["model"] | |
| tokenizer = pipe["tokenizer"] | |
| model_obj.train() | |
| target_modules = find_target_modules(model_obj) | |
| lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0) | |
| lora_module = get_peft_model(model_obj, lcfg) | |
| dataset = MediaTextDataset(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
| optimizer = torch.optim.AdamW(lora_module.parameters(), lr=lr) | |
| lora_module, optimizer, loader = accelerator.prepare(lora_module, optimizer, loader) | |
| total_steps = max(1, epochs * len(loader)) | |
| step_counter = 0 | |
| logs = [] | |
| yield "[DEBUG] Starting training loop...\n", 0.0 | |
| for ep in range(epochs): | |
| yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / total_steps | |
| for i, batch in enumerate(loader): | |
| ex = unwrap_batch(batch, text_cols[0], text_cols[1]) | |
| texts = ex.get("text", {}) | |
| short_text = str(texts.get(text_cols[0], "") or "") | |
| long_text = str(texts.get(text_cols[1], "") or "") | |
| # --- FIX: Tokenize as text pair to align sequence lengths --- | |
| enc = tokenizer( | |
| short_text, | |
| text_pair=long_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, # enforce same length for both | |
| ) | |
| enc = {k: v.to(accelerator.device) for k, v in enc.items()} | |
| enc["labels"] = enc["input_ids"].clone() | |
| # --- Forward pass --- | |
| outputs = lora_module(**enc) | |
| forward_loss = getattr(outputs, "loss", None) | |
| if forward_loss is None: | |
| logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] | |
| forward_loss = torch.nn.functional.cross_entropy( | |
| logits.view(-1, logits.size(-1)), enc["labels"].view(-1), ignore_index=tokenizer.pad_token_id | |
| ) | |
| logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}") | |
| optimizer.zero_grad() | |
| accelerator.backward(forward_loss) | |
| optimizer.step() | |
| step_counter += 1 | |
| yield "\n".join(logs[-10:]), step_counter / total_steps | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| lora_module.save_pretrained(output_dir) | |
| yield f"[INFO] β LoRA saved to {output_dir}\n", 1.0 | |
| def upload_adapter(local, repo_id): | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| raise RuntimeError("HF_TOKEN missing") | |
| create_repo(repo_id, exist_ok=True) | |
| upload_folder(local, repo_id=repo_id, repo_type="model", token=token) | |
| return f"https://huggingface.co/{repo_id}" | |
| # ---------------- Gradio UI ---------------- | |
| def run_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π Universal Dynamic LoRA Trainer (Gemma LLM)") | |
| with gr.Row(): | |
| base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it") | |
| dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01") | |
| csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv") | |
| short_col = gr.Textbox(label="Short prompt column", value="short_prompt") | |
| long_col = gr.Textbox(label="Long prompt column", value="long_prompt") | |
| out = gr.Textbox(label="Output dir", value="./adapter_out") | |
| repo = gr.Textbox(label="Upload HF repo (optional)", value="rahul7star/gemma-3-270m-ccebc0") | |
| with gr.Row(): | |
| batch_size = gr.Number(value=1, label="Batch size") | |
| num_workers = gr.Number(value=0, label="DataLoader num_workers") | |
| r = gr.Number(value=8, label="LoRA rank") | |
| a = gr.Number(value=16, label="LoRA alpha") | |
| ep = gr.Number(value=1, label="Epochs") | |
| lr = gr.Number(value=1e-4, label="Learning rate") | |
| max_records = gr.Number(value=1000, label="Max training records") | |
| logs = gr.Textbox(label="Logs (streaming)", lines=25) | |
| def launch(bm, ds, csv, sc, lc, out_dir, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_): | |
| gen = train_lora_stream( | |
| bm, ds, csv, [sc, lc], out_dir, | |
| epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_), | |
| batch_size=int(batch), num_workers=int(num_w), | |
| max_train_records=int(max_rec) | |
| ) | |
| for item in gen: | |
| if isinstance(item, tuple): | |
| text = item[0] | |
| else: | |
| text = item | |
| yield text | |
| if repo_: | |
| link = upload_adapter(out_dir, repo_) | |
| yield f"[INFO] Uploaded to {link}\n" | |
| btn = gr.Button("π Start Training") | |
| btn.click(fn=launch, | |
| inputs=[base_model, dataset, csvname, short_col, long_col, out, | |
| batch_size, num_workers, r, a, ep, lr, max_records, repo], | |
| outputs=[logs], | |
| queue=True) | |
| return demo | |
| if __name__ == "__main__": | |
| run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True) | |