File size: 6,179 Bytes
5bac7a5 a8678a6 ce929d0 a4c1e96 a8678a6 c4c4bdc a8678a6 c4c4bdc ce929d0 a8678a6 ce929d0 a8678a6 ce929d0 a8678a6 a4c1e96 c4c4bdc a4c1e96 1f97d4a c4c4bdc a4c1e96 c4c4bdc a8678a6 c4c4bdc a8678a6 c4c4bdc 30bd2c9 c4c4bdc ff12e01 ce929d0 d071e42 c4c4bdc d071e42 c4c4bdc ce929d0 c4c4bdc ce929d0 c4c4bdc ce929d0 a8678a6 c4c4bdc a8678a6 c4c4bdc ce929d0 c4c4bdc a8678a6 c4c4bdc a8678a6 ce929d0 a8678a6 ce929d0 c4c4bdc ce929d0 c4c4bdc ce929d0 a4c1e96 ce929d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import spaces
import torch
import torch.nn as nn
import torch.optim as optim
from flashpack import FlashPackMixin
from datasets import load_dataset
import gradio as gr
from transformers import AutoTokenizer, AutoModel
# ============================================================
# 🧠 Device setup
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Using device: {device}")
# ============================================================
# 1️⃣ Define FlashPack model
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self, input_dim=768, hidden_dim=1024, output_dim=768):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# ============================================================
# 2️⃣ Encode and train using GPU
# ============================================================
def train_flashpack_model():
# Load dataset
print("📦 Loading dataset...")
dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
# Tokenizer setup
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # ✅ Fix padding issue
# Base embedding model
embed_model = AutoModel.from_pretrained("gpt2").to(device)
embed_model.eval()
def encode_prompt(prompt):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=32
).to(device)
with torch.no_grad():
return embed_model(**inputs).last_hidden_state.mean(dim=1)
# Encode dataset prompts
print("🔢 Encoding dataset into embeddings...")
short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
print(f"✅ Encoded {len(dataset)} pairs")
# Train FlashPack model
model = GemmaTrainer(
input_dim=short_embeddings.shape[1],
output_dim=long_embeddings.shape[1]
).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
max_epochs = 500
tolerance = 1e-4
for epoch in range(max_epochs):
optimizer.zero_grad()
outputs = model(short_embeddings)
loss = criterion(outputs, long_embeddings)
loss.backward()
optimizer.step()
if loss.item() < tolerance:
print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
break
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")
# Save to Hugging Face Hub
FLASHPACK_REPO = "rahul7star/FlashPack"
model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")
return model, dataset, embed_model, tokenizer, long_embeddings
# ============================================================
# 3️⃣ Run training once and load for inference
# ============================================================
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
model.eval()
# ============================================================
# 4️⃣ Inference function for Gradio
# ============================================================
def encode_prompt(prompt):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=32
).to(device)
with torch.no_grad():
return embed_model(**inputs).last_hidden_state.mean(dim=1)
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
chat_history = chat_history or []
short_emb = encode_prompt(user_prompt)
with torch.no_grad():
long_emb = model(short_emb)
# Nearest match search
cos = nn.CosineSimilarity(dim=1)
sims = cos(long_emb.repeat(len(long_embeddings), 1), long_embeddings)
best_idx = sims.argmax()
enhanced_prompt = dataset[best_idx]["long_prompt"]
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": enhanced_prompt})
return chat_history
# ============================================================
# 5️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Gemma 3 270M)
Enter a short prompt, and the model will **expand it with details and creative context**
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a short prompt...",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
- Increase *Temperature* for more creative output.
"""
)
if __name__ == "__main__":
demo.launch(show_error=True)
|