File size: 6,179 Bytes
5bac7a5
a8678a6
 
 
 
 
ce929d0
a4c1e96
a8678a6
c4c4bdc
 
 
a8678a6
c4c4bdc
ce929d0
 
a8678a6
ce929d0
a8678a6
 
 
 
 
 
 
 
 
 
 
 
ce929d0
a8678a6
a4c1e96
c4c4bdc
a4c1e96
1f97d4a
c4c4bdc
 
 
 
a4c1e96
c4c4bdc
 
 
a8678a6
c4c4bdc
 
 
a8678a6
c4c4bdc
 
 
 
 
 
 
 
 
 
30bd2c9
c4c4bdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff12e01
ce929d0
d071e42
c4c4bdc
d071e42
c4c4bdc
 
ce929d0
 
c4c4bdc
ce929d0
c4c4bdc
 
 
 
 
 
 
 
 
 
 
ce929d0
 
a8678a6
c4c4bdc
a8678a6
c4c4bdc
ce929d0
c4c4bdc
a8678a6
c4c4bdc
a8678a6
 
ce929d0
 
a8678a6
ce929d0
 
c4c4bdc
ce929d0
c4c4bdc
ce929d0
 
 
 
 
a4c1e96
ce929d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import spaces
import torch
import torch.nn as nn
import torch.optim as optim
from flashpack import FlashPackMixin
from datasets import load_dataset
import gradio as gr
from transformers import AutoTokenizer, AutoModel

# ============================================================
# 🧠 Device setup
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Using device: {device}")

# ============================================================
# 1️⃣ Define FlashPack model
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
    def __init__(self, input_dim=768, hidden_dim=1024, output_dim=768):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


# ============================================================
# 2️⃣ Encode and train using GPU
# ============================================================

def train_flashpack_model():
    # Load dataset
    print("📦 Loading dataset...")
    dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")

    # Tokenizer setup
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token  # ✅ Fix padding issue

    # Base embedding model
    embed_model = AutoModel.from_pretrained("gpt2").to(device)
    embed_model.eval()

    def encode_prompt(prompt):
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=32
        ).to(device)
        with torch.no_grad():
            return embed_model(**inputs).last_hidden_state.mean(dim=1)

    # Encode dataset prompts
    print("🔢 Encoding dataset into embeddings...")
    short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
    long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
    print(f"✅ Encoded {len(dataset)} pairs")

    # Train FlashPack model
    model = GemmaTrainer(
        input_dim=short_embeddings.shape[1],
        output_dim=long_embeddings.shape[1]
    ).to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    max_epochs = 500
    tolerance = 1e-4

    for epoch in range(max_epochs):
        optimizer.zero_grad()
        outputs = model(short_embeddings)
        loss = criterion(outputs, long_embeddings)
        loss.backward()
        optimizer.step()
        if loss.item() < tolerance:
            print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
            break
        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

    # Save to Hugging Face Hub
    FLASHPACK_REPO = "rahul7star/FlashPack"
    model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
    print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")

    return model, dataset, embed_model, tokenizer, long_embeddings


# ============================================================
# 3️⃣ Run training once and load for inference
# ============================================================
model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
model.eval()

# ============================================================
# 4️⃣ Inference function for Gradio
# ============================================================
def encode_prompt(prompt):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=32
    ).to(device)
    with torch.no_grad():
        return embed_model(**inputs).last_hidden_state.mean(dim=1)

def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
    chat_history = chat_history or []
    short_emb = encode_prompt(user_prompt)

    with torch.no_grad():
        long_emb = model(short_emb)

    # Nearest match search
    cos = nn.CosineSimilarity(dim=1)
    sims = cos(long_emb.repeat(len(long_embeddings), 1), long_embeddings)
    best_idx = sims.argmax()
    enhanced_prompt = dataset[best_idx]["long_prompt"]

    chat_history.append({"role": "user", "content": user_prompt})
    chat_history.append({"role": "assistant", "content": enhanced_prompt})
    return chat_history


# ============================================================
# 5️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ✨ Prompt Enhancer (Gemma 3 270M)
        Enter a short prompt, and the model will **expand it with details and creative context**
        """
    )

    with gr.Row():
        chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
        with gr.Column(scale=1):
            user_prompt = gr.Textbox(
                placeholder="Enter a short prompt...",
                label="Your Prompt",
                lines=3,
            )
            temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
            max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
            send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
            clear_btn = gr.Button("🧹 Clear Chat")

    send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

    gr.Markdown(
        """
        ---
        💡 **Tips:**
        - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
        - Increase *Temperature* for more creative output.
        """
    )

if __name__ == "__main__":
    demo.launch(show_error=True)