File size: 3,479 Bytes
35c9e22 2c51b55 ac3979b 775d1cf 2c51b55 401afad 2c51b55 401afad 35c9e22 ac3979b 401afad ac3979b 401afad 77384a1 2c51b55 ac3979b 2c51b55 401afad fd97af7 35c9e22 ac3979b 401afad ac3979b 401afad ac3979b 401afad ac3979b fd97af7 ac3979b fd97af7 8c48eed 401afad 35c9e22 ac3979b 2c51b55 401afad 2c51b55 401afad ac3979b 401afad 8c48eed 401afad 35c9e22 ac3979b 401afad 35c9e22 401afad ac3979b 401afad 35c9e22 2c51b55 401afad 2c51b55 401afad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# ============================================================
# 1️⃣ Load model and tokenizer
# ============================================================
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
# Use GPU if available
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device, # 0 for GPU, -1 for CPU
)
# ============================================================
# 2️⃣ Define the generation function (chat-template style)
# ============================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
chat_history = chat_history or []
# Build messages using proper roles
messages = [
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
{"role": "user", "content": user_prompt}
]
# Use tokenizer chat template to build the input
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Generate output
output = pipe(
prompt,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
)[0]["generated_text"].strip()
# Append conversation to history
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": output})
return chat_history
# ============================================================
# 3️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Gemma 3 270M)
Enter a short prompt, and the model will **expand it with details and creative context**
using the Gemma chat-template interface.
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a short prompt...",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
# Bind UI actions
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
- Increase *Temperature* for more creative output.
"""
)
# ============================================================
# 4️⃣ Launch
# ============================================================
if __name__ == "__main__":
demo.launch(show_error=True)
|