File size: 3,479 Bytes
35c9e22
2c51b55
ac3979b
775d1cf
2c51b55
401afad
2c51b55
401afad
35c9e22
ac3979b
401afad
 
 
 
 
 
 
 
 
ac3979b
401afad
77384a1
2c51b55
ac3979b
2c51b55
401afad
fd97af7
35c9e22
ac3979b
 
 
 
 
401afad
ac3979b
 
 
 
401afad
ac3979b
401afad
 
 
ac3979b
fd97af7
ac3979b
fd97af7
 
8c48eed
401afad
35c9e22
ac3979b
2c51b55
401afad
2c51b55
401afad
 
 
 
ac3979b
 
401afad
8c48eed
 
401afad
 
 
 
 
 
 
 
 
 
 
 
35c9e22
ac3979b
401afad
 
 
35c9e22
401afad
 
 
ac3979b
 
 
401afad
 
35c9e22
2c51b55
401afad
2c51b55
401afad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# ============================================================
# 1️⃣ Load model and tokenizer
# ============================================================
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"

# Use GPU if available
device = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,  # 0 for GPU, -1 for CPU
)

# ============================================================
# 2️⃣ Define the generation function (chat-template style)
# ============================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
    chat_history = chat_history or []

    # Build messages using proper roles
    messages = [
        {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
        {"role": "user", "content": user_prompt}
    ]

    # Use tokenizer chat template to build the input
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Generate output
    output = pipe(
        prompt,
        max_new_tokens=int(max_tokens),
        temperature=float(temperature),
        do_sample=True,
    )[0]["generated_text"].strip()

    # Append conversation to history
    chat_history.append({"role": "user", "content": user_prompt})
    chat_history.append({"role": "assistant", "content": output})

    return chat_history


# ============================================================
# 3️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ✨ Prompt Enhancer (Gemma 3 270M)
        Enter a short prompt, and the model will **expand it with details and creative context**  
        using the Gemma chat-template interface.
        """
    )

    with gr.Row():
        chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
        with gr.Column(scale=1):
            user_prompt = gr.Textbox(
                placeholder="Enter a short prompt...",
                label="Your Prompt",
                lines=3,
            )
            temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
            max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
            send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
            clear_btn = gr.Button("🧹 Clear Chat")

    # Bind UI actions
    send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

    gr.Markdown(
        """
        ---
        💡 **Tips:**
        - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
        - Increase *Temperature* for more creative output.
        """
    )

# ============================================================
# 4️⃣ Launch
# ============================================================
if __name__ == "__main__":
    demo.launch(show_error=True)