rahul7star's picture
Update app_flash.py
1f97d4a verified
raw
history blame
6.18 kB
import spaces
import torch
import torch.nn as nn
import torch.optim as optim
from flashpack import FlashPackMixin
from datasets import load_dataset
import gradio as gr
from transformers import AutoTokenizer, AutoModel
# ============================================================
# 🧠 Device setup
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Using device: {device}")
# ============================================================
# 1️⃣ Define FlashPack model
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self, input_dim=768, hidden_dim=1024, output_dim=768):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# ============================================================
# 2️⃣ Encode and train using GPU
# ============================================================
def train_flashpack_model():
# Load dataset
print("📦 Loading dataset...")
dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
# Tokenizer setup
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # ✅ Fix padding issue
# Base embedding model
embed_model = AutoModel.from_pretrained("gpt2").to(device)
embed_model.eval()
def encode_prompt(prompt):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=32
).to(device)
with torch.no_grad():
return embed_model(**inputs).last_hidden_state.mean(dim=1)
# Encode dataset prompts
print("🔢 Encoding dataset into embeddings...")
short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
print(f"✅ Encoded {len(dataset)} pairs")
# Train FlashPack model
model = GemmaTrainer(
input_dim=short_embeddings.shape[1],
output_dim=long_embeddings.shape[1]
).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
max_epochs = 500
tolerance = 1e-4
for epoch in range(max_epochs):
optimizer.zero_grad()
outputs = model(short_embeddings)
loss = criterion(outputs, long_embeddings)
loss.backward()
optimizer.step()
if loss.item() < tolerance:
print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
break
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")
# Save to Hugging Face Hub
FLASHPACK_REPO = "rahul7star/FlashPack"
model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")
return model, dataset, embed_model, tokenizer, long_embeddings
# ============================================================
# 3️⃣ Run training once and load for inference
# ============================================================
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
model.eval()
# ============================================================
# 4️⃣ Inference function for Gradio
# ============================================================
def encode_prompt(prompt):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=32
).to(device)
with torch.no_grad():
return embed_model(**inputs).last_hidden_state.mean(dim=1)
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
chat_history = chat_history or []
short_emb = encode_prompt(user_prompt)
with torch.no_grad():
long_emb = model(short_emb)
# Nearest match search
cos = nn.CosineSimilarity(dim=1)
sims = cos(long_emb.repeat(len(long_embeddings), 1), long_embeddings)
best_idx = sims.argmax()
enhanced_prompt = dataset[best_idx]["long_prompt"]
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": enhanced_prompt})
return chat_history
# ============================================================
# 5️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Gemma 3 270M)
Enter a short prompt, and the model will **expand it with details and creative context**
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a short prompt...",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
- Increase *Temperature* for more creative output.
"""
)
if __name__ == "__main__":
demo.launch(show_error=True)