Spaces:
Running
Running
| # universal_lora_trainer_gradio.py | |
| import spaces | |
| import os | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| from torch.utils.data import Dataset, DataLoader | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate import Accelerator | |
| from huggingface_hub import create_repo, upload_folder, hf_hub_download | |
| # transformers optional | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| TRANSFORMERS_AVAILABLE = True | |
| except Exception: | |
| TRANSFORMERS_AVAILABLE = False | |
| # ---------------- Helpers ---------------- | |
| def is_hub_repo_like(s): | |
| return "/" in s and not Path(s).exists() | |
| def download_from_hf(repo_id, filename, token=None): | |
| token = token or os.environ.get("HF_TOKEN") | |
| return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token) | |
| # ---------------- Dataset ---------------- | |
| class MediaTextDataset(Dataset): | |
| def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None): | |
| self.is_hub = is_hub_repo_like(source) | |
| token = os.environ.get("HF_TOKEN") | |
| if self.is_hub: | |
| file_path = download_from_hf(source, csv_name, token) | |
| else: | |
| file_path = Path(source) / csv_name | |
| # fallback to parquet if CSV missing | |
| if not Path(file_path).exists(): | |
| alt = Path(str(file_path).replace(".csv", ".parquet")) | |
| if alt.exists(): | |
| file_path = alt | |
| else: | |
| raise FileNotFoundError(f"Dataset file not found: {file_path}") | |
| self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path) | |
| if max_records: | |
| self.df = self.df.head(max_records) | |
| self.text_columns = text_columns or ["short_prompt", "long_prompt"] | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, i): | |
| rec = self.df.iloc[i] | |
| out = {"text": {}} | |
| for col in self.text_columns: | |
| out["text"][col] = rec[col] if col in rec else "" | |
| return out | |
| # ---------------- Model loader ---------------- | |
| def load_pipeline_auto(base_model, dtype=torch.float16): | |
| if "gemma" in base_model.lower(): | |
| if not TRANSFORMERS_AVAILABLE: | |
| raise RuntimeError("Transformers not installed for LLM support.") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype) | |
| return {"model": model, "tokenizer": tokenizer} | |
| else: | |
| raise NotImplementedError("Only Gemma LLM supported in this script.") | |
| def find_target_modules(model): | |
| candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] | |
| names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] | |
| targets = [n.split(".")[-1] for n in names if any(c in n for c in candidates)] | |
| if not targets: | |
| targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] | |
| return targets | |
| def unwrap_batch(batch, short_col, long_col): | |
| if isinstance(batch, (list, tuple)): | |
| ex = batch[0] | |
| if "text" in ex: | |
| return ex | |
| if "short" in ex and "long" in ex: | |
| return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}} | |
| return {"text": ex} | |
| if isinstance(batch, dict): | |
| first_elem = {} | |
| is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values()) | |
| if is_batched: | |
| for k, v in batch.items(): | |
| try: first = v[0] | |
| except Exception: first = v | |
| first_elem[k] = first | |
| if "text" in first_elem: | |
| t = first_elem["text"] | |
| if isinstance(t, (list, tuple)) and len(t) > 0: | |
| return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}} | |
| if isinstance(t, dict): return {"text": t} | |
| return {"text": {short_col: str(t), long_col: ""}} | |
| if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem): | |
| s = first_elem.get(short_col, first_elem.get("short", "")) | |
| l = first_elem.get(long_col, first_elem.get("long", "")) | |
| return {"text": {short_col: str(s), long_col: str(l)}} | |
| return {"text": {short_col: str(first_elem)}} | |
| if "text" in batch and isinstance(batch["text"], dict): | |
| return {"text": batch["text"]} | |
| s = batch.get(short_col, batch.get("short", "")) | |
| l = batch.get(long_col, batch.get("long", "")) | |
| return {"text": {short_col: str(s), long_col: str(l)}} | |
| return {"text": {short_col: str(batch), long_col: ""}} | |
| # ---------------- LoRA Training ---------------- | |
| from tempfile import TemporaryDirectory | |
| from accelerate import Accelerator | |
| def train_lora_stream(base_model, dataset_src, csv_name, text_cols, | |
| epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, | |
| num_workers=0, max_train_records=None, hf_repo_id=None): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device=="cuda" else torch.float32 | |
| accelerator = Accelerator() | |
| pipe = load_pipeline_auto(base_model, dtype=dtype) | |
| model_obj = pipe["model"] | |
| tokenizer = pipe["tokenizer"] | |
| model_obj.train() | |
| target_modules = find_target_modules(model_obj) | |
| lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0) | |
| lora_module = get_peft_model(model_obj, lora_config) | |
| dataset = MediaTextDataset(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
| optimizer = torch.optim.AdamW(lora_module.parameters(), lr=lr) | |
| lora_module, optimizer, loader = accelerator.prepare(lora_module, optimizer, loader) | |
| max_steps = 150 | |
| step_counter = 0 | |
| logs = [] | |
| yield f"[INFO] Starting LoRA training on {device.upper()} (max {max_steps} steps)...\n", 0.0 | |
| for ep in range(epochs): | |
| yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / max_steps | |
| for batch in loader: | |
| if step_counter >= max_steps: | |
| break | |
| ex = unwrap_batch(batch, text_cols[0], text_cols[1]) | |
| texts = ex.get("text", {}) | |
| short_text = str(texts.get(text_cols[0], "") or "") | |
| long_text = str(texts.get(text_cols[1], "") or "") | |
| enc = tokenizer(short_text, text_pair=long_text, return_tensors="pt", | |
| padding="max_length", truncation=True, max_length=512) | |
| enc = {k: v.to(accelerator.device) for k,v in enc.items()} | |
| enc["labels"] = enc["input_ids"].clone() | |
| outputs = lora_module(**enc) | |
| loss = getattr(outputs, "loss", None) | |
| if loss is None: | |
| logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] | |
| loss = torch.nn.functional.cross_entropy( | |
| logits.view(-1, logits.size(-1)), | |
| enc["labels"].view(-1), | |
| ignore_index=tokenizer.pad_token_id | |
| ) | |
| optimizer.zero_grad() | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| logs.append(f"[DEBUG] Step {step_counter}, Loss: {loss.item():.6f}") | |
| step_counter += 1 | |
| yield "\n".join(logs[-10:]), step_counter / max_steps | |
| if step_counter >= max_steps: | |
| break | |
| # ---------------- Upload to HF ---------------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not hf_repo_id: | |
| raise ValueError("❌ HF repo ID required for upload.") | |
| if not HF_TOKEN: | |
| raise ValueError("❌ HF_TOKEN missing.") | |
| hf_repo_id = hf_repo_id.strip() | |
| logs.append(f"[INFO] 🚀 Uploading LoRA to Hugging Face repo: {hf_repo_id}") | |
| create_repo(hf_repo_id, repo_type="model", exist_ok=True, token=HF_TOKEN) | |
| with TemporaryDirectory() as tmp_dir: | |
| lora_module.save_pretrained(tmp_dir) | |
| upload_folder(folder_path=tmp_dir, repo_id=hf_repo_id, repo_type="model", token=HF_TOKEN) | |
| link = f"https://huggingface.co/{hf_repo_id}" | |
| logs.append(f"[INFO] ✅ Uploaded successfully: {link}") | |
| yield "\n".join(logs), link | |
| # ---------------- CPU Inference ---------------- | |
| from peft import PeftModel | |
| from peft import PeftModel | |
| import torch | |
| def generate_long_prompt_cpu(base_model, lora_repo, short_prompt, max_length=200): | |
| device = torch.device("cpu") | |
| # Load base model in float32 | |
| pipe = load_pipeline_auto(base_model, dtype=torch.float32) | |
| base_model_obj = pipe["model"].to(device) | |
| tokenizer = pipe["tokenizer"] | |
| base_model_obj.eval() | |
| # Load LoRA adapter on CPU | |
| lora_model = PeftModel.from_pretrained( | |
| base_model_obj, | |
| lora_repo, | |
| torch_dtype=torch.float32, | |
| device_map={"": device} | |
| ) | |
| lora_model.eval() | |
| # OPTIONAL: merge LoRA into base model to avoid PEFT runtime issues | |
| merged_model = lora_model.merge_and_unload() | |
| merged_model.eval() | |
| # Tokenize input | |
| input_ids = tokenizer(short_prompt, return_tensors="pt").input_ids.to(device) | |
| # Generate safely | |
| with torch.no_grad(): | |
| outputs = merged_model.generate( | |
| input_ids, | |
| max_length=max_length, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50 | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # ---------------- Gradio UI ---------------- | |
| # ---------------- Gradio UI ---------------- | |
| import gradio as gr | |
| def run_ui(): | |
| import gradio as gr | |
| with gr.Blocks(title="Prompt Enhancer Trainer + Inference UI") as demo: | |
| gr.Markdown("# ✨ Prompt Enhancer Trainer + Inference Playground") | |
| gr.Markdown("Train, test, and debug your LoRA-enhanced Gemma model easily.Use ZerpGPU to Train else CPU will work for other stuff") | |
| gr.Markdown(""" | |
| 🔗 **Quick Links:** | |
| - [📂 View DataSet (rahul7star/prompt-enhancer-dataset-01)](https://huggingface.co/datasets/rahul7star/prompt-enhancer-dataset-01) | |
| - [🤖 View Trained Model (rahul7star/gemma-3-270m-ccebc0)](https://huggingface.co/rahul7star/gemma-3-270m-ccebc0) | |
| """) | |
| with gr.Tabs(): | |
| # ========================================================= | |
| # 1️⃣ TRAIN LORA TAB | |
| # ========================================================= | |
| with gr.Tab("Train LoRA"): | |
| with gr.Row(): | |
| base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it") | |
| dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01") | |
| csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv") | |
| short_col = gr.Textbox(label="Short prompt column", value="short_prompt") | |
| long_col = gr.Textbox(label="Long prompt column", value="long_prompt") | |
| repo = gr.Textbox(label="HF repo to upload LoRA", value="rahul7star/gemma-3-270m-ccebc0") | |
| with gr.Row(): | |
| batch_size = gr.Number(value=1, label="Batch size") | |
| num_workers = gr.Number(value=0, label="DataLoader num_workers") | |
| r = gr.Number(value=8, label="LoRA rank") | |
| a = gr.Number(value=16, label="LoRA alpha") | |
| ep = gr.Number(value=1, label="Epochs") | |
| lr = gr.Number(value=1e-4, label="Learning rate") | |
| max_records = gr.Number(value=1000, label="Max training records") | |
| logs = gr.Textbox(label="Logs (streaming)", lines=25) | |
| def launch_train(bm, ds, csv, sc, lc, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_): | |
| gen = train_lora_stream( | |
| bm, ds, csv, [sc, lc], | |
| epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_), | |
| batch_size=int(batch), num_workers=int(num_w), | |
| max_train_records=int(max_rec), hf_repo_id=repo_ | |
| ) | |
| for item in gen: | |
| yield item | |
| btn = gr.Button("🚀 Start Training") | |
| btn.click( | |
| fn=launch_train, | |
| inputs=[ | |
| base_model, dataset, csvname, short_col, long_col, | |
| batch_size, num_workers, r, a, ep, lr, max_records, repo | |
| ], | |
| outputs=[logs], | |
| queue=True | |
| ) | |
| # ========================================================= | |
| # 2️⃣ INFERENCE (CPU) TAB | |
| # ========================================================= | |
| with gr.Tab("Inference (CPU)"): | |
| inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it") | |
| inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0") | |
| short_prompt = gr.Textbox(label="Short prompt") | |
| long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5) | |
| inf_btn = gr.Button("📝 Generate Long Prompt") | |
| inf_btn.click( | |
| fn=generate_long_prompt_cpu, | |
| inputs=[inf_base_model, inf_lora_repo, short_prompt], | |
| outputs=[long_prompt_out] | |
| ) | |
| # ========================================================= | |
| # 3️⃣ SHOW TRAINABLE PARAMS TAB | |
| # ========================================================= | |
| with gr.Tab("Show Trainable Params"): | |
| gr.Markdown("### 🧩 View Trainable Parameters in Your LoRA-Enhanced Model") | |
| with gr.Row(): | |
| base_model_name = gr.Textbox(label="Base Model", value="google/gemma-2b-it") | |
| check_btn = gr.Button("🔍 Show Trainable Layers") | |
| param_output = gr.Textbox(label="Trainable Parameters Info", lines=30) | |
| def show_trainable_layers(base_model_name): | |
| import torch | |
| from peft import get_peft_model, LoraConfig | |
| from transformers import AutoModelForCausalLM | |
| import io | |
| import contextlib | |
| buf = io.StringIO() | |
| print(f"[INFO] Loading base model: {base_model_name}", file=buf) | |
| model = AutoModelForCausalLM.from_pretrained(base_model_name) | |
| print("[INFO] Initializing LoRA configuration...", file=buf) | |
| config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", | |
| "o_proj", "gate_proj", "up_proj", "down_proj" | |
| ] | |
| ) | |
| print("[INFO] Applying LoRA adapters...", file=buf) | |
| model = get_peft_model(model, config) | |
| print("[INFO] Counting trainable parameters...", file=buf) | |
| with contextlib.redirect_stdout(buf): | |
| model.print_trainable_parameters() | |
| print("\n[INFO] Listing all LoRA-injected layers...", file=buf) | |
| lora_layers = [name for name, _ in model.named_modules() if "lora" in name.lower()] | |
| if not lora_layers: | |
| print("⚠️ No LoRA layers detected. Check target_modules configuration.", file=buf) | |
| else: | |
| print(f"✅ Found {len(lora_layers)} LoRA-injected submodules:\n", file=buf) | |
| for i, layer_name in enumerate(lora_layers[:200]): | |
| print(f" {i+1:03d}. {layer_name}", file=buf) | |
| if len(lora_layers) > 200: | |
| print(f"...and {len(lora_layers)-200} more layers (truncated)", file=buf) | |
| explanation = """ | |
| ──────────────────────────── | |
| ### 🔍 What “Adapter (90)” Means | |
| When you initialize LoRA on a large model like **Gemma**, the code scans the model | |
| to find all modules that can receive LoRA layers — typically: | |
| - **q_proj, k_proj, v_proj** → Query, Key, Value projections | |
| - **o_proj / out_proj** → Output of attention | |
| - **gate_proj, up_proj, down_proj** → Feed-forward MLPs | |
| Each matching layer gets two small trainable matrices **(A, B)** injected. | |
| So if you see: | |
| > Adapter (90) | |
| That means **90 total submodules** were wrapped with LoRA adapters. | |
| You can view them above 👆, or print them programmatically with: | |
| ```python | |
| for name, module in model.named_modules(): | |
| if "lora" in name.lower(): | |
| print(name) | |
| """ | |
| print(explanation, file=buf) | |
| return buf.getvalue() | |
| check_btn.click(show_trainable_layers, inputs=[base_model_name], outputs=[param_output]) | |
| # ========================================================= | |
| # 4️⃣ CODE DEBUG TAB | |
| # ========================================================= | |
| with gr.Tab("Code Debug"): | |
| gr.Markdown("### 🧩 Code Debug — Understand What's Happening Line by Line") | |
| gr.Markdown(""" | |
| #### 🧰 Step-by-Step Breakdown | |
| **1️⃣ `f"[INFO] Loading base model: {base_model}"`** | |
| → Logs which model is being loaded (e.g., `google/gemma-2b-it`) | |
| **2️⃣ `AutoModelForCausalLM.from_pretrained(base_model)`** | |
| → Downloads the base Gemma model weights and tokenizer. | |
| **3️⃣ `get_peft_model(model, config)`** | |
| → Wraps the model with LoRA and injects adapters into `q_proj`, `k_proj`, `v_proj`, etc. | |
| **4️⃣ Expected console output:** | |
| [INFO] Loading base model: google/gemma-2b-it | |
| [INFO] Preparing dataset... | |
| [INFO] Injecting LoRA adapters... | |
| trainable params: 3.5M || all params: 270M || trainable%: 1.3% | |
| **5️⃣ `trainer.train()`** | |
| → Starts training loop and shows live progress. | |
| **6️⃣ `upload_file(...)`** | |
| → Uploads all model files to your chosen HF repo. | |
| --- | |
| ### 🔍 What “Adapter (90)” Means | |
| When you initialize LoRA on Gemma, it finds **90 target layers** such as: | |
| - `q_proj`, `k_proj`, `v_proj` | |
| - `o_proj` | |
| - `gate_proj`, `up_proj`, `down_proj` | |
| Each layer gets small trainable matrices (A, B). | |
| So: | |
| > **Adapter (90)** → 90 modules modified by LoRA. | |
| To list them: | |
| ```python | |
| for name, module in model.named_modules(): | |
| if "lora" in name.lower(): | |
| print(name) | |
| """) | |
| # ========================================================= | |
| # 5️⃣ CODE EXPLAIN TAB | |
| # ========================================================= | |
| with gr.Tab("Code Explain"): | |
| explain_md = gr.Markdown(""" | |
| ### 🧩 Universal Dynamic LoRA Trainer & Inference — Code Explanation | |
| This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**. | |
| It supports both **training new LoRAs** and **generating text** with existing ones — all in a single interface. | |
| --- | |
| #### **1️⃣ Imports Overview** | |
| - **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas` | |
| - **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`) | |
| - **Modeling:** `transformers` (for Gemma base model) | |
| - **Hub integration:** `huggingface_hub` (for uploading adapters) | |
| - **Spaces:** `spaces` — for execution within Hugging Face Spaces | |
| --- | |
| #### **2️⃣ Dataset Loading** | |
| - Uses a lightweight **MediaTextDataset** class to load: | |
| - CSV / Parquet files | |
| - or directly from a Hugging Face dataset repo | |
| - Expects two columns: | |
| `short_prompt` → Input text | |
| `long_prompt` → Target expanded text | |
| - Supports batching, missing-column checks, and configurable max record limits. | |
| --- | |
| #### **3️⃣ Model Loading & Preparation** | |
| - Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`. | |
| - Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection. | |
| - Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage. | |
| --- | |
| #### **4️⃣ LoRA Training Logic** | |
| - Core formula: | |
| \[ | |
| W_{eff} = W + \alpha \times (B @ A) | |
| \] | |
| - Only **A** and **B** matrices are trainable; base model weights remain frozen. | |
| - Configurable parameters: | |
| `r` (rank), `alpha` (scaling), `epochs`, `lr`, `batch_size` | |
| - Training logs stream live in the UI, showing step-by-step loss values. | |
| - After training, the adapter is **saved locally** and **uploaded to Hugging Face Hub**. | |
| --- | |
| #### **5️⃣ CPU Inference Mode** | |
| - Runs entirely on **CPU**, no GPU required. | |
| - Loads base Gemma model + trained LoRA weights (`PeftModel.from_pretrained`). | |
| - Optionally merges LoRA with base model. | |
| - Expands the short prompt → long descriptive text using standard generation parameters (e.g., top-p / top-k sampling). | |
| --- | |
| #### **6️⃣ LoRA Internals Explained** | |
| - LoRA injects low-rank matrices (A, B) into **attention Linear layers**. | |
| - Example: | |
| \[ | |
| Q_{new} = Q + \alpha \times (B @ A) | |
| \] | |
| - Significantly reduces training cost: | |
| - Memory: ~1–2% of full model | |
| - Compute: trains faster with minimal GPU load | |
| - Scalable to large models like Gemma 3B / 4B with rank ≤ 16. | |
| --- | |
| #### **7️⃣ Gradio UI Structure** | |
| - **Train LoRA Tab:** | |
| Configure model, dataset, LoRA parameters, and upload target. | |
| Press **🚀 Start Training** to stream training logs live. | |
| - **Inference (CPU) Tab:** | |
| Type a short prompt → Generates expanded long-form version via trained LoRA. | |
| - **Code Explain Tab:** | |
| Detailed breakdown of logic + simulated console output below. | |
| --- | |
| ### 🧾 Example Log Simulation | |
| ```python | |
| print(f"[INFO] Loading base model: {base_model}") | |
| # -> Loads Gemma base model (fp16) on CUDA | |
| # [INFO] Base model google/gemma-3-4b-it loaded successfully | |
| print(f"[INFO] Preparing dataset from: {dataset_path}") | |
| # -> Loads dataset or CSV file | |
| # [DATA] 980 samples loaded, columns: short_prompt, long_prompt | |
| print("[INFO] Initializing LoRA configuration...") | |
| # -> Creates LoraConfig(r=8, alpha=16, target_modules=['q_proj', 'v_proj']) | |
| # [CONFIG] LoRA applied to 96 attention layers | |
| print("[INFO] Starting training loop...") | |
| # [TRAIN] Step 1 | Loss: 2.31 | |
| # [TRAIN] Step 50 | Loss: 1.42 | |
| # [TRAIN] Step 100 | Loss: 0.91 | |
| # [TRAIN] Epoch 1 complete (avg loss: 1.21) | |
| print("[INFO] Saving LoRA adapter...") | |
| # -> Saves safetensors and config locally | |
| print(f"[UPLOAD] Pushing adapter to {hf_repo_id}") | |
| # -> Uploads model to Hugging Face Hub | |
| # [UPLOAD] adapter_model.safetensors (67.7 MB) | |
| # [SUCCESS] LoRA uploaded successfully 🚀 | |
| ``` | |
| ### 🧩 Universal Dynamic LoRA Trainer & Inference — Code Explanation | |
| This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**. | |
| It supports both **training new LoRAs** and **generating text** with existing ones — all in a single interface. | |
| --- | |
| #### **1️⃣ Imports Overview** | |
| - **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas` | |
| - **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`) | |
| - **Modeling:** `transformers` (for Gemma base model) | |
| - **Hub integration:** `huggingface_hub` (for uploading adapters) | |
| - **Spaces:** `spaces` — for execution within Hugging Face Spaces | |
| --- | |
| #### **2️⃣ Dataset Loading** | |
| - Uses a lightweight **MediaTextDataset** class to load: | |
| - CSV / Parquet files | |
| - or directly from a Hugging Face dataset repo | |
| - Expects two columns: | |
| `short_prompt` → Input text | |
| `long_prompt` → Target expanded text | |
| - Supports batching, missing-column checks, and configurable max record limits. | |
| --- | |
| #### **3️⃣ Model Loading & Preparation** | |
| - Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`. | |
| - Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection. | |
| - Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage. | |
| --- | |
| #### **4️⃣ LoRA Training Logic** | |
| - Core formula: | |
| \[ | |
| W_{eff} = W + \alpha \times (B @ A) | |
| \] | |
| - Only **A** and **B** matrices are trainable; base model weights remain frozen. | |
| - Configurable parameters: | |
| `r` (rank), `alpha` (scaling), `epochs`, `lr`, `batch_size` | |
| - Training logs stream live in the UI, showing step-by-step loss values. | |
| - After training, the adapter is **saved locally** and **uploaded to Hugging Face Hub**. | |
| --- | |
| #### **5️⃣ CPU Inference Mode** | |
| - Runs entirely on **CPU**, no GPU required. | |
| - Loads base Gemma model + trained LoRA weights (`PeftModel.from_pretrained`). | |
| - Optionally merges LoRA with base model. | |
| - Expands the short prompt → long descriptive text using standard generation parameters (e.g., top-p / top-k sampling). | |
| --- | |
| #### **6️⃣ 🧠 What LoRA Does (A & B Injection Explained)** | |
| When you fine-tune a large model (like Gemma or Llama), you’re adjusting **billions** of parameters in large weight matrices. | |
| LoRA avoids this by **injecting two small low-rank matrices (A and B)** into selected layers instead of modifying the full weight. | |
| --- | |
| ##### **Step 1: Regular Linear Layer** | |
| \[ | |
| y = W x | |
| \] | |
| Here, **W** is a huge matrix (e.g., 4096×4096). | |
| --- | |
| ##### **Step 2: LoRA Layer Modification** | |
| Instead of updating W directly, LoRA adds a lightweight update: | |
| \[ | |
| W' = W + \Delta W | |
| \] | |
| \[ | |
| \Delta W = B A | |
| \] | |
| Where: | |
| - **A** ∈ ℝ^(r × d) | |
| - **B** ∈ ℝ^(d × r) | |
| - and **r ≪ d** (e.g., r=8 instead of 4096) | |
| So you’re training only a *tiny fraction* of parameters. | |
| --- | |
| ##### **Step 3: Where LoRA Gets Injected** | |
| It targets critical sub-layers such as: | |
| - **q_proj, k_proj, v_proj** → Query, Key, Value projections in attention | |
| - **o_proj / out_proj** → Output projection | |
| - **gate_proj, up_proj, down_proj** → Feed-forward layers | |
| When you see: | |
| > `Adapter (90)` | |
| That means 90 total layers (from these modules) were wrapped with LoRA adapters. | |
| --- | |
| ##### **Step 4: Training Efficiency** | |
| - Base weights (`W`) stay **frozen** | |
| - Only `(A, B)` are **trainable** | |
| - Compute and memory are drastically reduced | |
| | Metric | Full Fine-Tune | LoRA Fine-Tune | | |
| |---------|----------------|----------------| | |
| | Trainable Params | 2B+ | ~3M | | |
| | GPU Memory | 40GB+ | <6GB | | |
| | Time | 10–20 hrs | <1 hr | | |
| --- | |
| ##### **Step 5: Inference Equation** | |
| At inference time: | |
| \[ | |
| y = (W + \alpha \times B A) x | |
| \] | |
| Where **α** controls the strength of the adapter’s influence. | |
| --- | |
| ##### **Step 6: Visualization** | |
| Base Layer: | |
| y = W * x | |
| LoRA Layer: | |
| y = (W + B@A) * x | |
| ↑ ↑ | |
| | └── Small rank-A adapter (trainable) | |
| └──── Small rank-B adapter (trainable) | |
| --- | |
| ##### **Step 7: Example in Code** | |
| ```python | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") | |
| config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], | |
| lora_dropout=0.05 | |
| ) | |
| model = get_peft_model(model, config) | |
| model.print_trainable_parameters() | |
| Expected output: | |
| trainable params: 3,278,848 || all params: 2,040,000,000 || trainable%: 0.16% | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True) | |