|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "Qwen/Qwen2.5-1.8B-Instruct" |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
|
|
print(f"🔹 Loading model: {MODEL_ID}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
print("⚙️ Using GPU (CUDA).") |
|
|
else: |
|
|
device = "cpu" |
|
|
dtype = torch.float32 |
|
|
print("⚙️ Using CPU with memory-efficient loading.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=dtype, |
|
|
device_map="auto" if device == "cuda" else {"": "cpu"}, |
|
|
low_cpu_mem_usage=True, |
|
|
offload_folder="./offload" if device == "cpu" else None, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): |
|
|
"""Generate an enhanced version of the user prompt.""" |
|
|
if not user_prompt.strip(): |
|
|
return chat_history + [["", "⚠️ Please enter a prompt."]] |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "Enhance and expand the following prompt with more detail, vivid context, and style."}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt", |
|
|
).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_tokens), |
|
|
temperature=float(temperature), |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
repetition_penalty=1.05, |
|
|
) |
|
|
|
|
|
result = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
|
chat_history = chat_history + [[user_prompt, result.strip()]] |
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Prompt Enhancer – Qwen 1.8B", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ✨ Prompt Enhancer (Qwen2.5-1.8B-Instruct) |
|
|
Enhance and enrich your creative prompts using **Qwen 2.5 1.8B**, |
|
|
a lightweight model optimized for reasoning and descriptive text generation. |
|
|
--- |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot(height=400, label="Prompt Enhancer Chat") |
|
|
with gr.Column(scale=1): |
|
|
user_prompt = gr.Textbox( |
|
|
placeholder="Enter a prompt to enhance (e.g., 'A cat sitting on a chair').", |
|
|
label="Your Prompt", |
|
|
lines=3, |
|
|
) |
|
|
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") |
|
|
max_tokens = gr.Slider(32, 512, value=128, step=16, label="Max Tokens") |
|
|
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") |
|
|
clear_btn = gr.Button("🧹 Clear Chat") |
|
|
|
|
|
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
💡 **Tips:** |
|
|
- Use short base prompts (e.g., *“a futuristic city skyline at sunset”*). |
|
|
- The model will expand and enhance them with extra creative context. |
|
|
- Works fully on CPU and is Space-friendly (<5 GB memory footprint). |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True, share=True) |
|
|
|