rahul7star commited on
Commit
c4c7a5a
·
verified ·
1 Parent(s): 0b52d90

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +86 -163
app_flash.py CHANGED
@@ -1,11 +1,11 @@
1
- # prompt_enhancer_flashpack_cpu_publish.py
2
  import gc
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
- from datasets import load_dataset
7
  import gradio as gr
8
- from transformers import AutoTokenizer, AutoModel
 
9
  from flashpack import FlashPackMixin
10
  from typing import Tuple
11
 
@@ -32,8 +32,9 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
32
  x = self.fc2(x)
33
  return x
34
 
 
35
  # ============================================================
36
- # 2️⃣ Utility: encode prompts (CPU-friendly)
37
  # ============================================================
38
  def build_encoder(model_name="gpt2", max_length: int = 32):
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -57,163 +58,60 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
57
 
58
  return tokenizer, embed_model, encode
59
 
60
- # ============================================================
61
- # 3️⃣ Train and push FlashPack model
62
- # ============================================================
63
- import os
64
- import tempfile
65
- from huggingface_hub import hf_hub_download, HfApi
66
-
67
- # ------------------------------------------------------------
68
- # Utility to push FlashPack model to HF using upload_file
69
- # ------------------------------------------------------------
70
- import os
71
- import gc
72
- import torch
73
- import torch.nn as nn
74
- import torch.optim as optim
75
- import tempfile
76
- from huggingface_hub import Repository
77
- from datasets import load_dataset
78
- from typing import Tuple
79
-
80
- # -------------------------------
81
- # Helper: Push FlashPack model
82
- # -------------------------------
83
- def push_flashpack_model_to_hf(model, hf_repo: str):
84
- """
85
- Save FlashPack model locally and push as Hugging Face model repo.
86
- """
87
- logs = []
88
-
89
- with tempfile.TemporaryDirectory() as tmp_dir:
90
- logs.append(f"📂 Using temporary directory: {tmp_dir}")
91
-
92
- # Clone or initialize HF repo locally
93
- repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
94
- logs.append(f"🌐 Hugging Face repo cloned to: {tmp_dir}")
95
-
96
- # Save model inside repo
97
- pack_path = os.path.join(tmp_dir, "model.flashpack")
98
- logs.append(f"💾 Saving model to: {pack_path}")
99
- model.save_flashpack(pack_path, target_dtype=torch.float32)
100
- logs.append("✅ Model saved successfully.")
101
-
102
- # Add README
103
- readme_path = os.path.join(tmp_dir, "README.md")
104
- with open(readme_path, "w") as f:
105
- f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
106
- logs.append("📄 README.md added.")
107
-
108
- # Push repo to HF
109
- logs.append("🚀 Pushing repo to Hugging Face Hub...")
110
- repo.push_to_hub()
111
- logs.append(f"✅ Model successfully pushed to: {hf_repo}")
112
-
113
- return logs
114
-
115
- # -------------------------------
116
- # Main training and push function
117
- # -------------------------------
118
- def train_and_push_flashpack(
119
- dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
120
- hf_repo: str = "rahul7star/FlashPack",
121
- max_encode: int = 1000,
122
- push_to_hub: bool = True,
123
- device: str = "cpu"
124
- ) -> Tuple[object, object, object, object, torch.Tensor]:
125
-
126
- print("📦 Loading dataset...")
127
- dataset = load_dataset(dataset_name, split="train")
128
- limit = min(max_encode, len(dataset))
129
- dataset = dataset.select(range(limit))
130
- print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
131
-
132
- # Placeholder: build your encoder here
133
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
134
-
135
- print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
136
- short_list, long_list = [], []
137
- for i, item in enumerate(dataset):
138
- short_list.append(encode_fn(item["short_prompt"]))
139
- long_list.append(encode_fn(item["long_prompt"]))
140
-
141
- if (i + 1) % 50 == 0 or (i + 1) == len(dataset):
142
- print(f" → Encoded {i+1}/{limit} prompts")
143
- gc.collect()
144
-
145
- short_embeddings = torch.vstack(short_list)
146
- long_embeddings = torch.vstack(long_list)
147
- print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
148
-
149
- # Build your FlashPack model (GemmaTrainer placeholder)
150
- model = GemmaTrainer(
151
- input_dim=short_embeddings.shape[1],
152
- hidden_dim=min(512, short_embeddings.shape[1]),
153
- output_dim=long_embeddings.shape[1],
154
- ).to(device)
155
-
156
- criterion = nn.MSELoss()
157
- optimizer = optim.Adam(model.parameters(), lr=1e-3)
158
- max_epochs = 20
159
- batch_size = 32
160
-
161
- print("🚀 Training model...")
162
- n = short_embeddings.shape[0]
163
- for epoch in range(max_epochs):
164
- model.train()
165
- epoch_loss = 0.0
166
- perm = torch.randperm(n)
167
- for start in range(0, n, batch_size):
168
- idx = perm[start:start+batch_size]
169
- inputs = short_embeddings[idx].to(device)
170
- targets = long_embeddings[idx].to(device)
171
-
172
- optimizer.zero_grad()
173
- outputs = model(inputs)
174
- loss = criterion(outputs, targets)
175
- loss.backward()
176
- optimizer.step()
177
- epoch_loss += loss.item() * inputs.size(0)
178
-
179
- epoch_loss /= n
180
- if epoch % 5 == 0 or epoch == max_epochs - 1:
181
- print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
182
-
183
- print("✅ Training finished!")
184
-
185
- logs = []
186
- if push_to_hub:
187
- print("📤 Pushing model to Hugging Face repo...")
188
- logs = push_flashpack_model_to_hf(model, hf_repo)
189
- for log in logs:
190
- print(log)
191
-
192
- return model, dataset, embed_model, tokenizer, long_embeddings
193
-
194
 
195
  # ============================================================
196
- # 4️⃣ Load trained model from HF repo
197
  # ============================================================
198
  def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
 
199
  model = GemmaTrainer.load_flashpack(hf_repo)
200
  model.eval()
201
  tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
202
  return model, tokenizer, embed_model
203
 
 
204
  # ============================================================
205
- # 5️⃣ Run training + push, then reload
206
  # ============================================================
207
- model, dataset, embed_model, tokenizer, long_embeddings = train_and_push_flashpack(
208
- max_encode=1000, # CPU-safe
209
- push_to_hub=True
 
 
 
 
 
 
 
210
  )
211
 
212
- # reload to ensure FlashPack workflow works
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  model, tokenizer, embed_model = load_flashpack_model("rahul7star/FlashPack")
 
 
 
 
 
 
 
214
 
215
  # ============================================================
216
- # 6️⃣ Inference helpers
217
  # ============================================================
218
  @torch.no_grad()
219
  def encode_for_inference(prompt: str) -> torch.Tensor:
@@ -226,12 +124,13 @@ def encode_for_inference(prompt: str) -> torch.Tensor:
226
  ).to(device)
227
  return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
228
 
229
- def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
 
 
230
  chat_history = chat_history or []
231
  short_emb = encode_for_inference(user_prompt)
232
  mapped = model(short_emb.to(device)).cpu()
233
 
234
- cos = nn.CosineSimilarity(dim=1)
235
  sims = (long_embeddings @ mapped.t()).squeeze(1)
236
  long_norms = long_embeddings.norm(dim=1)
237
  mapped_norm = mapped.norm()
@@ -244,33 +143,57 @@ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_h
244
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
245
  return chat_history
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  # ============================================================
248
- # 7️⃣ Gradio UI
249
  # ============================================================
250
- with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
251
- gr.Markdown(
252
- """
253
- # Prompt Enhancer (FlashPack mapper)
254
- Enter a short prompt, and the model will **expand it with details and creative context**.
255
- (CPU-only mode.)
256
- """
257
- )
258
 
259
  with gr.Row():
260
- chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
261
  with gr.Column(scale=1):
262
  user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
263
- temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
264
- max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
265
- send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
 
266
  clear_btn = gr.Button("🧹 Clear Chat")
267
 
268
- send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
269
- user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
 
270
  clear_btn.click(lambda: [], None, chatbot)
271
 
272
  # ============================================================
273
- # 8️⃣ Launch
274
  # ============================================================
275
  if __name__ == "__main__":
276
  demo.launch(show_error=True)
 
 
1
  import gc
2
+ import os
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
 
6
  import gradio as gr
7
+ from datasets import load_dataset
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
9
  from flashpack import FlashPackMixin
10
  from typing import Tuple
11
 
 
32
  x = self.fc2(x)
33
  return x
34
 
35
+
36
  # ============================================================
37
+ # 2️⃣ Build encoder (for embedding)
38
  # ============================================================
39
  def build_encoder(model_name="gpt2", max_length: int = 32):
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
58
 
59
  return tokenizer, embed_model, encode
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # ============================================================
63
+ # 3️⃣ Load pretrained FlashPack model (skip training)
64
  # ============================================================
65
  def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
66
+ print(f"🔁 Loading FlashPack model from: {hf_repo}")
67
  model = GemmaTrainer.load_flashpack(hf_repo)
68
  model.eval()
69
  tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
70
  return model, tokenizer, embed_model
71
 
72
+
73
  # ============================================================
74
+ # 4️⃣ Load Gemma text model for prompt enhancement
75
  # ============================================================
76
+ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
77
+
78
+ tokenizer_gemma = AutoTokenizer.from_pretrained(MODEL_ID)
79
+ model_gemma = AutoModelForCausalLM.from_pretrained(MODEL_ID)
80
+
81
+ pipe_gemma = pipeline(
82
+ "text-generation",
83
+ model=model_gemma,
84
+ tokenizer=tokenizer_gemma,
85
+ device=-1, # CPU
86
  )
87
 
88
+ import re
89
+
90
+ def extract_later_part(user_prompt, generated_text):
91
+ """Cleans the model output and extracts only the enhanced (later) portion."""
92
+ cleaned = re.sub(r"<.*?>", "", generated_text).strip()
93
+ cleaned = re.sub(r"\s+", " ", cleaned)
94
+ user_prompt_clean = user_prompt.strip().lower()
95
+ cleaned_lower = cleaned.lower()
96
+ if cleaned_lower.startswith(user_prompt_clean):
97
+ cleaned = cleaned[len(user_prompt):].strip(",. ").strip()
98
+ return cleaned
99
+
100
+
101
+ # ============================================================
102
+ # 5️⃣ Load FlashPack + Dataset + Encoder
103
+ # ============================================================
104
  model, tokenizer, embed_model = load_flashpack_model("rahul7star/FlashPack")
105
+ dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
106
+ long_embeddings = torch.vstack(
107
+ [embed_model(**tokenizer(p["long_prompt"], return_tensors="pt", truncation=True, padding="max_length", max_length=32)).last_hidden_state.mean(dim=1).cpu()
108
+ for p in dataset.select(range(min(500, len(dataset))))]
109
+ )
110
+ print("✅ Loaded FlashPack and Gemma models.")
111
+
112
 
113
  # ============================================================
114
+ # 6️⃣ FlashPack inference helper
115
  # ============================================================
116
  @torch.no_grad()
117
  def encode_for_inference(prompt: str) -> torch.Tensor:
 
124
  ).to(device)
125
  return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
126
 
127
+
128
+ @torch.no_grad()
129
+ def enhance_prompt_flashpack(user_prompt: str, temperature: float, max_tokens: int, chat_history):
130
  chat_history = chat_history or []
131
  short_emb = encode_for_inference(user_prompt)
132
  mapped = model(short_emb.to(device)).cpu()
133
 
 
134
  sims = (long_embeddings @ mapped.t()).squeeze(1)
135
  long_norms = long_embeddings.norm(dim=1)
136
  mapped_norm = mapped.norm()
 
143
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
144
  return chat_history
145
 
146
+
147
+ # ============================================================
148
+ # 7️⃣ Gemma prompt enhancer
149
+ # ============================================================
150
+ def enhance_prompt_gemma(user_prompt, temperature, max_tokens, chat_history):
151
+ chat_history = chat_history or []
152
+ messages = [
153
+ {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
154
+ {"role": "user", "content": user_prompt}
155
+ ]
156
+ prompt = tokenizer_gemma.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
157
+ output = pipe_gemma(
158
+ prompt,
159
+ max_new_tokens=int(max_tokens),
160
+ temperature=float(temperature),
161
+ do_sample=True,
162
+ )[0]["generated_text"]
163
+ enhanced_text = extract_later_part(user_prompt, output)
164
+ chat_history.append({"role": "user", "content": user_prompt})
165
+ chat_history.append({"role": "assistant", "content": enhanced_text})
166
+ return chat_history
167
+
168
+
169
  # ============================================================
170
+ # 8️⃣ Gradio UI
171
  # ============================================================
172
+ with gr.Blocks(title="Prompt Enhancer – FlashPack + Gemma (CPU)", theme=gr.themes.Soft()) as demo:
173
+ gr.Markdown("""
174
+ # ✨ Prompt Enhancer (FlashPack + Gemma)
175
+ - **Gemma model**: Enhances prompts with natural language.
176
+ - **FlashPack model**: Finds similar expanded prompts from dataset.
177
+ - CPU-only, for reproducibility.
178
+ """)
 
179
 
180
  with gr.Row():
181
+ chatbot = gr.Chatbot(height=420, label="Enhanced Prompts", type="messages")
182
  with gr.Column(scale=1):
183
  user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
184
+ temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
185
+ max_tokens = gr.Slider(32, 512, value=256, label="Max Tokens")
186
+ send_gemma = gr.Button("💬 Enhance (Gemma)")
187
+ send_flashpack = gr.Button("🔗 Enhance (FlashPack)")
188
  clear_btn = gr.Button("🧹 Clear Chat")
189
 
190
+ send_gemma.click(enhance_prompt_gemma, [user_prompt, temperature, max_tokens, chatbot], chatbot)
191
+ send_flashpack.click(enhance_prompt_flashpack, [user_prompt, temperature, max_tokens, chatbot], chatbot)
192
+ user_prompt.submit(enhance_prompt_gemma, [user_prompt, temperature, max_tokens, chatbot], chatbot)
193
  clear_btn.click(lambda: [], None, chatbot)
194
 
195
  # ============================================================
196
+ # 9️⃣ Launch
197
  # ============================================================
198
  if __name__ == "__main__":
199
  demo.launch(show_error=True)