rahul7star's picture
Update app_low.py
77384a1 verified
raw
history blame
4.41 kB
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# =========================================================
# 1️⃣ Model Configuration (optimized for HF Spaces)
# =========================================================
MODEL_ID = "Qwen/Qwen2.5-1.8B-Instruct"
# Hugging Face Space-friendly environment tweaks
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
print(f"🔹 Loading model: {MODEL_ID}")
# Smart device selection
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print("⚙️ Using GPU (CUDA).")
else:
device = "cpu"
dtype = torch.float32
print("⚙️ Using CPU with memory-efficient loading.")
# =========================================================
# 2️⃣ Load Model and Tokenizer
# =========================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else {"": "cpu"},
low_cpu_mem_usage=True,
offload_folder="./offload" if device == "cpu" else None,
)
model.eval()
# =========================================================
# 3️⃣ Inference Function
# =========================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
"""Generate an enhanced version of the user prompt."""
if not user_prompt.strip():
return chat_history + [["", "⚠️ Please enter a prompt."]]
messages = [
{"role": "system", "content": "Enhance and expand the following prompt with more detail, vivid context, and style."},
{"role": "user", "content": user_prompt},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_p=0.9,
do_sample=True,
repetition_penalty=1.05,
)
result = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
chat_history = chat_history + [[user_prompt, result.strip()]]
return chat_history
# =========================================================
# 4️⃣ Gradio Interface
# =========================================================
with gr.Blocks(title="Prompt Enhancer – Qwen 1.8B", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Qwen2.5-1.8B-Instruct)
Enhance and enrich your creative prompts using **Qwen 2.5 1.8B**,
a lightweight model optimized for reasoning and descriptive text generation.
---
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Prompt Enhancer Chat")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a prompt to enhance (e.g., 'A cat sitting on a chair').",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 512, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Use short base prompts (e.g., *“a futuristic city skyline at sunset”*).
- The model will expand and enhance them with extra creative context.
- Works fully on CPU and is Space-friendly (<5 GB memory footprint).
"""
)
# =========================================================
# 5️⃣ Launch
# =========================================================
if __name__ == "__main__":
demo.launch(show_error=True, share=True)