File size: 9,564 Bytes
0877c51
c4c7a5a
a8678a6
0877c51
 
 
7c8fb46
0877c51
 
 
 
 
c4c7a5a
a4c1e96
d56c08f
a4c1e96
0877c51
 
d56c08f
dc6d34d
d071e42
d56c08f
9b6142b
0877c51
503e5c1
0877c51
 
 
2e79f16
 
0877c51
503e5c1
0877c51
 
 
2e79f16
 
0877c51
c4c7a5a
0877c51
d56c08f
0877c51
2e79f16
0877c51
 
 
 
 
 
 
 
 
d56c08f
 
2e79f16
 
 
d56c08f
0877c51
 
7c8fb46
 
d56c08f
7c8fb46
0877c51
 
 
d56c08f
0877c51
 
 
 
 
 
 
d56c08f
0877c51
7c8fb46
 
0877c51
7c8fb46
0877c51
2e79f16
530abb4
2e79f16
 
 
0877c51
2e79f16
0877c51
 
 
 
d56c08f
0877c51
2e79f16
0877c51
2e79f16
0877c51
 
 
 
 
 
 
 
 
 
d56c08f
 
 
2e79f16
0877c51
2e79f16
0877c51
2e79f16
0877c51
530abb4
0877c51
2e79f16
0877c51
 
 
 
 
 
 
 
 
 
 
 
 
d56c08f
0877c51
 
 
 
 
 
 
 
 
2e79f16
 
 
 
 
 
0877c51
ba4b2f5
0877c51
d56c08f
0877c51
 
7c8fb46
0877c51
 
 
d56c08f
0877c51
7c8fb46
0877c51
 
 
 
 
c4c7a5a
 
0877c51
c4c7a5a
d56c08f
0877c51
 
 
 
 
 
 
d56c08f
 
 
 
 
0877c51
 
 
 
 
c4c7a5a
0877c51
 
 
 
 
 
 
 
 
 
 
ce929d0
d191426
0877c51
d191426
0877c51
 
 
 
 
 
 
 
ce929d0
 
0877c51
ce929d0
0877c51
 
 
 
 
 
 
 
 
 
 
 
222699e
ce929d0
503e5c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import gc
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository
from typing import Tuple

# ============================================================
# 🖥 Device setup (CPU-only)
# ============================================================
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"🔧 Using device: {device} (CPU-only)")

# ============================================================
# 1️⃣ FlashPack model with better hidden layers
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
    def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# ============================================================
# 2️⃣ Encoder using mean+max pooling (for richer embeddings)
# ============================================================
def build_encoder(model_name="gpt2", max_length: int = 128):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    embed_model = AutoModel.from_pretrained(model_name).to(device)
    embed_model.eval()

    @torch.no_grad()
    def encode(prompt: str) -> torch.Tensor:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                           padding="max_length", max_length=max_length).to(device)
        last_hidden = embed_model(**inputs).last_hidden_state
        mean_pool = last_hidden.mean(dim=1)
        max_pool, _ = last_hidden.max(dim=1)
        return torch.cat([mean_pool, max_pool], dim=1).cpu()  # doubled embedding

    return tokenizer, embed_model, encode

# ============================================================
# 3️⃣ Push FlashPack model to Hugging Face
# ============================================================
def push_flashpack_model_to_hf(model, hf_repo: str):
    logs = []
    with tempfile.TemporaryDirectory() as tmp_dir:
        logs.append(f"📂 Temporary directory: {tmp_dir}")
        repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
        pack_path = os.path.join(tmp_dir, "model.flashpack")
        model.save_flashpack(pack_path, target_dtype=torch.float32)
        readme_path = os.path.join(tmp_dir, "README.md")
        with open(readme_path, "w") as f:
            f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
        repo.push_to_hub()
        logs.append(f"✅ Model pushed to HF: {hf_repo}")
    return logs

# ============================================================
# 4️⃣ Train FlashPack model
# ============================================================
def train_flashpack_model(
    dataset_name: str = "rahul7star/prompt-enhancer-dataset",
    max_encode: int = 1000,
    hidden_dim: int = 1024,
    push_to_hub: bool = True,
    hf_repo: str = "rahul7star/FlashPack"
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:

    print("📦 Loading dataset...")
    dataset = load_dataset(dataset_name, split="train")
    limit = min(max_encode, len(dataset))
    dataset = dataset.select(range(limit))
    print(f"⚡ Using {len(dataset)} prompts for training")

    tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)

    # Encode prompts
    short_list, long_list = [], []
    for i, item in enumerate(dataset):
        short_list.append(encode_fn(item["short_prompt"]))
        long_list.append(encode_fn(item["long_prompt"]))
        if (i+1) % 50 == 0 or (i+1) == len(dataset):
            print(f"  → Encoded {i+1}/{limit} prompts")
            gc.collect()

    short_embeddings = torch.vstack(short_list)
    long_embeddings = torch.vstack(long_list)
    print(f"✅ Encoded embeddings shape: short {short_embeddings.shape}, long {long_embeddings.shape}")

    input_dim = short_embeddings.shape[1]   # should match concatenated mean+max
    output_dim = long_embeddings.shape[1]

    model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)

    criterion = nn.CosineSimilarity(dim=1)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    max_epochs = 50
    batch_size = 32
    n = short_embeddings.shape[0]

    print("🚀 Training model...")
    for epoch in range(max_epochs):
        model.train()
        epoch_loss = 0.0
        perm = torch.randperm(n)
        for start in range(0, n, batch_size):
            idx = perm[start:start+batch_size]
            inputs = short_embeddings[idx].to(device)
            targets = long_embeddings[idx].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = 1 - criterion(outputs, targets).mean()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * inputs.size(0)

        epoch_loss /= n
        if epoch % 5 == 0 or epoch == max_epochs-1:
            print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")

    print("✅ Training finished!")

    if push_to_hub:
        logs = push_flashpack_model_to_hf(model, hf_repo)
        for log in logs:
            print(log)

    return model, dataset, embed_model, tokenizer, long_embeddings

# ============================================================
# 5️⃣ Load or train model
# ============================================================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
    try:
        print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
        model = GemmaTrainer.from_flashpack(hf_repo)
        model.eval()
        tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
        return model, tokenizer, embed_model
    except Exception as e:
        print(f"⚠️ Load failed: {e}")
        print("⏬ Training a new FlashPack model locally...")
        model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
        push_flashpack_model_to_hf(model, hf_repo)
        return model, tokenizer, embed_model, dataset, long_embeddings

# ============================================================
# 6️⃣ Load or train
# ============================================================
model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()

# ============================================================
# 7️⃣ Inference helpers
# ============================================================
@torch.no_grad()
def encode_for_inference(prompt: str) -> torch.Tensor:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                       padding="max_length", max_length=128).to(device)
    last_hidden = embed_model(**inputs).last_hidden_state
    mean_pool = last_hidden.mean(dim=1)
    max_pool, _ = last_hidden.max(dim=1)
    return torch.cat([mean_pool, max_pool], dim=1).cpu()

def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
    chat_history = chat_history or []
    short_emb = encode_for_inference(user_prompt)
    mapped = model(short_emb.to(device)).cpu()

    sims = (long_embeddings @ mapped.t()).squeeze(1)
    long_norms = long_embeddings.norm(dim=1)
    mapped_norm = mapped.norm()
    sims = sims / (long_norms * (mapped_norm + 1e-12))

    best_idx = int(sims.argmax().item())
    enhanced_prompt = dataset[best_idx]["long_prompt"]

    chat_history.append({"role": "user", "content": user_prompt})
    chat_history.append({"role": "assistant", "content": enhanced_prompt})
    return chat_history

# ============================================================
# 8️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ✨ Prompt Enhancer (FlashPack mapper)
        Enter a short prompt, and the model will **expand it with details and creative context**.
        (CPU-only mode.)
        """
    )

    with gr.Row():
        chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
        with gr.Column(scale=1):
            user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
            temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
            max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
            send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
            clear_btn = gr.Button("🧹 Clear Chat")

    send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

# ============================================================
# 9️⃣ Launch
# ============================================================
if __name__ == "__main__":
    demo.launch(show_error=True)