|
|
import spaces |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from flashpack import FlashPackMixin |
|
|
from datasets import load_dataset |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🔧 Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GemmaTrainer(nn.Module, FlashPackMixin): |
|
|
def __init__(self, input_dim=768, hidden_dim=1024, output_dim=768): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
|
self.relu = nn.ReLU() |
|
|
self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.relu(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_flashpack_model(): |
|
|
|
|
|
print("📦 Loading dataset...") |
|
|
dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
embed_model = AutoModel.from_pretrained("gpt2").to(device) |
|
|
embed_model.eval() |
|
|
|
|
|
def encode_prompt(prompt): |
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=32 |
|
|
).to(device) |
|
|
with torch.no_grad(): |
|
|
return embed_model(**inputs).last_hidden_state.mean(dim=1) |
|
|
|
|
|
|
|
|
print("🔢 Encoding dataset into embeddings...") |
|
|
short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device) |
|
|
long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device) |
|
|
print(f"✅ Encoded {len(dataset)} pairs") |
|
|
|
|
|
|
|
|
model = GemmaTrainer( |
|
|
input_dim=short_embeddings.shape[1], |
|
|
output_dim=long_embeddings.shape[1] |
|
|
).to(device) |
|
|
|
|
|
criterion = nn.MSELoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-3) |
|
|
max_epochs = 500 |
|
|
tolerance = 1e-4 |
|
|
|
|
|
for epoch in range(max_epochs): |
|
|
optimizer.zero_grad() |
|
|
outputs = model(short_embeddings) |
|
|
loss = criterion(outputs, long_embeddings) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
if loss.item() < tolerance: |
|
|
print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}") |
|
|
break |
|
|
if (epoch + 1) % 50 == 0: |
|
|
print(f"Epoch {epoch+1}, Loss={loss.item():.6f}") |
|
|
|
|
|
|
|
|
FLASHPACK_REPO = "rahul7star/FlashPack" |
|
|
model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True) |
|
|
print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}") |
|
|
|
|
|
return model, dataset, embed_model, tokenizer, long_embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model() |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_prompt(prompt): |
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=32 |
|
|
).to(device) |
|
|
with torch.no_grad(): |
|
|
return embed_model(**inputs).last_hidden_state.mean(dim=1) |
|
|
|
|
|
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): |
|
|
chat_history = chat_history or [] |
|
|
short_emb = encode_prompt(user_prompt) |
|
|
|
|
|
with torch.no_grad(): |
|
|
long_emb = model(short_emb) |
|
|
|
|
|
|
|
|
cos = nn.CosineSimilarity(dim=1) |
|
|
sims = cos(long_emb.repeat(len(long_embeddings), 1), long_embeddings) |
|
|
best_idx = sims.argmax() |
|
|
enhanced_prompt = dataset[best_idx]["long_prompt"] |
|
|
|
|
|
chat_history.append({"role": "user", "content": user_prompt}) |
|
|
chat_history.append({"role": "assistant", "content": enhanced_prompt}) |
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ✨ Prompt Enhancer (Gemma 3 270M) |
|
|
Enter a short prompt, and the model will **expand it with details and creative context** |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") |
|
|
with gr.Column(scale=1): |
|
|
user_prompt = gr.Textbox( |
|
|
placeholder="Enter a short prompt...", |
|
|
label="Your Prompt", |
|
|
lines=3, |
|
|
) |
|
|
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") |
|
|
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") |
|
|
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") |
|
|
clear_btn = gr.Button("🧹 Clear Chat") |
|
|
|
|
|
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
💡 **Tips:** |
|
|
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair") |
|
|
- Increase *Temperature* for more creative output. |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|