File size: 4,408 Bytes
35c9e22 775d1cf 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 35c9e22 77384a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# =========================================================
# 1️⃣ Model Configuration (optimized for HF Spaces)
# =========================================================
MODEL_ID = "Qwen/Qwen2.5-1.8B-Instruct"
# Hugging Face Space-friendly environment tweaks
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
print(f"🔹 Loading model: {MODEL_ID}")
# Smart device selection
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print("⚙️ Using GPU (CUDA).")
else:
device = "cpu"
dtype = torch.float32
print("⚙️ Using CPU with memory-efficient loading.")
# =========================================================
# 2️⃣ Load Model and Tokenizer
# =========================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else {"": "cpu"},
low_cpu_mem_usage=True,
offload_folder="./offload" if device == "cpu" else None,
)
model.eval()
# =========================================================
# 3️⃣ Inference Function
# =========================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
"""Generate an enhanced version of the user prompt."""
if not user_prompt.strip():
return chat_history + [["", "⚠️ Please enter a prompt."]]
messages = [
{"role": "system", "content": "Enhance and expand the following prompt with more detail, vivid context, and style."},
{"role": "user", "content": user_prompt},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_p=0.9,
do_sample=True,
repetition_penalty=1.05,
)
result = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
chat_history = chat_history + [[user_prompt, result.strip()]]
return chat_history
# =========================================================
# 4️⃣ Gradio Interface
# =========================================================
with gr.Blocks(title="Prompt Enhancer – Qwen 1.8B", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Qwen2.5-1.8B-Instruct)
Enhance and enrich your creative prompts using **Qwen 2.5 1.8B**,
a lightweight model optimized for reasoning and descriptive text generation.
---
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Prompt Enhancer Chat")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a prompt to enhance (e.g., 'A cat sitting on a chair').",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 512, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Use short base prompts (e.g., *“a futuristic city skyline at sunset”*).
- The model will expand and enhance them with extra creative context.
- Works fully on CPU and is Space-friendly (<5 GB memory footprint).
"""
)
# =========================================================
# 5️⃣ Launch
# =========================================================
if __name__ == "__main__":
demo.launch(show_error=True, share=True)
|