rahul7star's picture
Update app_flash.py
503e5c1 verified
import gc
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository
from typing import Tuple
# ============================================================
# 🖥 Device setup (CPU-only)
# ============================================================
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"🔧 Using device: {device} (CPU-only)")
# ============================================================
# 1️⃣ FlashPack model with better hidden layers
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ============================================================
# 2️⃣ Encoder using mean+max pooling (for richer embeddings)
# ============================================================
def build_encoder(model_name="gpt2", max_length: int = 128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
return tokenizer, embed_model, encode
# ============================================================
# 3️⃣ Push FlashPack model to Hugging Face
# ============================================================
def push_flashpack_model_to_hf(model, hf_repo: str):
logs = []
with tempfile.TemporaryDirectory() as tmp_dir:
logs.append(f"📂 Temporary directory: {tmp_dir}")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
pack_path = os.path.join(tmp_dir, "model.flashpack")
model.save_flashpack(pack_path, target_dtype=torch.float32)
readme_path = os.path.join(tmp_dir, "README.md")
with open(readme_path, "w") as f:
f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
repo.push_to_hub()
logs.append(f"✅ Model pushed to HF: {hf_repo}")
return logs
# ============================================================
# 4️⃣ Train FlashPack model
# ============================================================
def train_flashpack_model(
dataset_name: str = "rahul7star/prompt-enhancer-dataset",
max_encode: int = 1000,
hidden_dim: int = 1024,
push_to_hub: bool = True,
hf_repo: str = "rahul7star/FlashPack"
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
print("📦 Loading dataset...")
dataset = load_dataset(dataset_name, split="train")
limit = min(max_encode, len(dataset))
dataset = dataset.select(range(limit))
print(f"⚡ Using {len(dataset)} prompts for training")
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
# Encode prompts
short_list, long_list = [], []
for i, item in enumerate(dataset):
short_list.append(encode_fn(item["short_prompt"]))
long_list.append(encode_fn(item["long_prompt"]))
if (i+1) % 50 == 0 or (i+1) == len(dataset):
print(f" → Encoded {i+1}/{limit} prompts")
gc.collect()
short_embeddings = torch.vstack(short_list)
long_embeddings = torch.vstack(long_list)
print(f"✅ Encoded embeddings shape: short {short_embeddings.shape}, long {long_embeddings.shape}")
input_dim = short_embeddings.shape[1] # should match concatenated mean+max
output_dim = long_embeddings.shape[1]
model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
criterion = nn.CosineSimilarity(dim=1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
max_epochs = 50
batch_size = 32
n = short_embeddings.shape[0]
print("🚀 Training model...")
for epoch in range(max_epochs):
model.train()
epoch_loss = 0.0
perm = torch.randperm(n)
for start in range(0, n, batch_size):
idx = perm[start:start+batch_size]
inputs = short_embeddings[idx].to(device)
targets = long_embeddings[idx].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = 1 - criterion(outputs, targets).mean()
loss.backward()
optimizer.step()
epoch_loss += loss.item() * inputs.size(0)
epoch_loss /= n
if epoch % 5 == 0 or epoch == max_epochs-1:
print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
print("✅ Training finished!")
if push_to_hub:
logs = push_flashpack_model_to_hf(model, hf_repo)
for log in logs:
print(log)
return model, dataset, embed_model, tokenizer, long_embeddings
# ============================================================
# 5️⃣ Load or train model
# ============================================================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
try:
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
model = GemmaTrainer.from_flashpack(hf_repo)
model.eval()
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
return model, tokenizer, embed_model
except Exception as e:
print(f"⚠️ Load failed: {e}")
print("⏬ Training a new FlashPack model locally...")
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
push_flashpack_model_to_hf(model, hf_repo)
return model, tokenizer, embed_model, dataset, long_embeddings
# ============================================================
# 6️⃣ Load or train
# ============================================================
model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
# ============================================================
# 7️⃣ Inference helpers
# ============================================================
@torch.no_grad()
def encode_for_inference(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=128).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu()
def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
chat_history = chat_history or []
short_emb = encode_for_inference(user_prompt)
mapped = model(short_emb.to(device)).cpu()
sims = (long_embeddings @ mapped.t()).squeeze(1)
long_norms = long_embeddings.norm(dim=1)
mapped_norm = mapped.norm()
sims = sims / (long_norms * (mapped_norm + 1e-12))
best_idx = int(sims.argmax().item())
enhanced_prompt = dataset[best_idx]["long_prompt"]
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": enhanced_prompt})
return chat_history
# ============================================================
# 8️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (FlashPack mapper)
Enter a short prompt, and the model will **expand it with details and creative context**.
(CPU-only mode.)
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
# ============================================================
# 9️⃣ Launch
# ============================================================
if __name__ == "__main__":
demo.launch(show_error=True)