File size: 4,408 Bytes
35c9e22
 
 
775d1cf
 
35c9e22
77384a1
35c9e22
77384a1
35c9e22
77384a1
 
 
 
 
 
 
35c9e22
 
 
77384a1
35c9e22
 
 
77384a1
35c9e22
 
77384a1
35c9e22
77384a1
35c9e22
 
77384a1
35c9e22
 
 
 
 
77384a1
35c9e22
 
77384a1
35c9e22
77384a1
 
 
 
 
 
 
 
 
35c9e22
 
 
 
 
 
 
 
 
 
 
77384a1
 
35c9e22
 
77384a1
35c9e22
 
77384a1
 
35c9e22
 
 
77384a1
35c9e22
77384a1
35c9e22
 
77384a1
 
 
35c9e22
 
 
 
77384a1
 
 
 
 
 
 
 
 
 
 
 
35c9e22
77384a1
 
 
35c9e22
 
 
 
 
77384a1
 
 
35c9e22
 
 
 
77384a1
35c9e22
 
77384a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================================================
# 1️⃣ Model Configuration (optimized for HF Spaces)
# =========================================================
MODEL_ID = "Qwen/Qwen2.5-1.8B-Instruct"

# Hugging Face Space-friendly environment tweaks
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

print(f"🔹 Loading model: {MODEL_ID}")

# Smart device selection
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    print("⚙️ Using GPU (CUDA).")
else:
    device = "cpu"
    dtype = torch.float32
    print("⚙️ Using CPU with memory-efficient loading.")

# =========================================================
# 2️⃣ Load Model and Tokenizer
# =========================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    device_map="auto" if device == "cuda" else {"": "cpu"},
    low_cpu_mem_usage=True,
    offload_folder="./offload" if device == "cpu" else None,
)
model.eval()

# =========================================================
# 3️⃣ Inference Function
# =========================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
    """Generate an enhanced version of the user prompt."""
    if not user_prompt.strip():
        return chat_history + [["", "⚠️ Please enter a prompt."]]

    messages = [
        {"role": "system", "content": "Enhance and expand the following prompt with more detail, vivid context, and style."},
        {"role": "user", "content": user_prompt},
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=int(max_tokens),
            temperature=float(temperature),
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.05,
        )

    result = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    chat_history = chat_history + [[user_prompt, result.strip()]]
    return chat_history

# =========================================================
# 4️⃣ Gradio Interface
# =========================================================
with gr.Blocks(title="Prompt Enhancer – Qwen 1.8B", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ✨ Prompt Enhancer (Qwen2.5-1.8B-Instruct)
        Enhance and enrich your creative prompts using **Qwen 2.5 1.8B**,  
        a lightweight model optimized for reasoning and descriptive text generation.
        ---
        """
    )

    with gr.Row():
        chatbot = gr.Chatbot(height=400, label="Prompt Enhancer Chat")
        with gr.Column(scale=1):
            user_prompt = gr.Textbox(
                placeholder="Enter a prompt to enhance (e.g., 'A cat sitting on a chair').",
                label="Your Prompt",
                lines=3,
            )
            temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
            max_tokens = gr.Slider(32, 512, value=128, step=16, label="Max Tokens")
            send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
            clear_btn = gr.Button("🧹 Clear Chat")

    send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

    gr.Markdown(
        """
        ---
        💡 **Tips:**  
        - Use short base prompts (e.g., *“a futuristic city skyline at sunset”*).  
        - The model will expand and enhance them with extra creative context.  
        - Works fully on CPU and is Space-friendly (<5 GB memory footprint).
        """
    )

# =========================================================
# 5️⃣ Launch
# =========================================================
if __name__ == "__main__":
    demo.launch(show_error=True, share=True)