rahul7star commited on
Commit
222699e
·
verified ·
1 Parent(s): 1f97d4a

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +171 -69
app_flash.py CHANGED
@@ -1,146 +1,245 @@
1
- import spaces
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
- from flashpack import FlashPackMixin
6
  from datasets import load_dataset
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
 
 
9
 
10
  # ============================================================
11
- # 🧠 Device setup
12
  # ============================================================
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"🔧 Using device: {device}")
 
15
 
16
  # ============================================================
17
  # 1️⃣ Define FlashPack model
18
  # ============================================================
19
  class GemmaTrainer(nn.Module, FlashPackMixin):
20
- def __init__(self, input_dim=768, hidden_dim=1024, output_dim=768):
21
  super().__init__()
22
  self.fc1 = nn.Linear(input_dim, hidden_dim)
23
  self.relu = nn.ReLU()
24
  self.fc2 = nn.Linear(hidden_dim, output_dim)
25
 
26
- def forward(self, x):
27
  x = self.fc1(x)
28
  x = self.relu(x)
29
  x = self.fc2(x)
30
  return x
31
 
32
-
33
  # ============================================================
34
- # 2️⃣ Encode and train using GPU
35
  # ============================================================
 
 
 
 
 
36
 
37
- def train_flashpack_model():
38
- # Load dataset
39
- print("📦 Loading dataset...")
40
- dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
41
-
42
- # Tokenizer setup
43
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
44
- tokenizer.pad_token = tokenizer.eos_token # ✅ Fix padding issue
45
-
46
- # Base embedding model
47
- embed_model = AutoModel.from_pretrained("gpt2").to(device)
48
  embed_model.eval()
49
 
50
- def encode_prompt(prompt):
 
 
 
 
 
51
  inputs = tokenizer(
52
  prompt,
53
  return_tensors="pt",
54
  truncation=True,
55
  padding="max_length",
56
- max_length=32
57
  ).to(device)
58
- with torch.no_grad():
59
- return embed_model(**inputs).last_hidden_state.mean(dim=1)
60
 
61
- # Encode dataset prompts
62
- print("🔢 Encoding dataset into embeddings...")
63
- short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset]).to(device)
64
- long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset]).to(device)
65
- print(f"✅ Encoded {len(dataset)} pairs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Train FlashPack model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  model = GemmaTrainer(
69
  input_dim=short_embeddings.shape[1],
70
- output_dim=long_embeddings.shape[1]
71
- ).to(device)
 
72
 
 
73
  criterion = nn.MSELoss()
74
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
 
75
  max_epochs = 500
76
  tolerance = 1e-4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- for epoch in range(max_epochs):
79
- optimizer.zero_grad()
80
- outputs = model(short_embeddings)
81
- loss = criterion(outputs, long_embeddings)
82
- loss.backward()
83
- optimizer.step()
84
- if loss.item() < tolerance:
85
- print(f"✅ Converged at epoch {epoch+1}, Loss={loss.item():.6f}")
86
  break
87
- if (epoch + 1) % 50 == 0:
88
- print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")
89
 
90
- # Save to Hugging Face Hub
91
- FLASHPACK_REPO = "rahul7star/FlashPack"
92
- model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
93
- print(f"✅ Model saved to FlashPack Hub: {FLASHPACK_REPO}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
95
  return model, dataset, embed_model, tokenizer, long_embeddings
96
 
97
-
98
  # ============================================================
99
- # 3️⃣ Run training once and load for inference
100
  # ============================================================
101
- model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
 
 
 
 
 
 
102
  model.eval()
103
 
104
- # ============================================================
105
- # 4️⃣ Inference function for Gradio
106
- # ============================================================
107
- def encode_prompt(prompt):
108
  inputs = tokenizer(
109
  prompt,
110
  return_tensors="pt",
111
  truncation=True,
112
  padding="max_length",
113
- max_length=32
114
  ).to(device)
115
- with torch.no_grad():
116
- return embed_model(**inputs).last_hidden_state.mean(dim=1)
117
 
118
- def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
 
 
 
119
  chat_history = chat_history or []
120
- short_emb = encode_prompt(user_prompt)
121
 
 
 
122
  with torch.no_grad():
123
- long_emb = model(short_emb)
124
 
125
- # Nearest match search
126
  cos = nn.CosineSimilarity(dim=1)
127
- sims = cos(long_emb.repeat(len(long_embeddings), 1), long_embeddings)
128
- best_idx = sims.argmax()
 
 
 
 
 
 
129
  enhanced_prompt = dataset[best_idx]["long_prompt"]
130
 
131
  chat_history.append({"role": "user", "content": user_prompt})
132
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
133
  return chat_history
134
 
135
-
136
  # ============================================================
137
- # 5️⃣ Gradio UI
138
  # ============================================================
139
- with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
140
  gr.Markdown(
141
  """
142
- # ✨ Prompt Enhancer (Gemma 3 270M)
143
- Enter a short prompt, and the model will **expand it with details and creative context**
 
144
  """
145
  )
146
 
@@ -165,10 +264,13 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
165
  """
166
  ---
167
  💡 **Tips:**
168
- - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
169
- - Increase *Temperature* for more creative output.
170
  """
171
  )
172
 
 
 
 
173
  if __name__ == "__main__":
174
  demo.launch(show_error=True)
 
1
+ # prompt_enhancer_flashpack_cpu.py
2
+ import gc
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
 
6
  from datasets import load_dataset
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
+ from flashpack import FlashPackMixin # keep if your mixin provides save_flashpack
10
+ from typing import Tuple
11
 
12
  # ============================================================
13
+ # 🖥 Force CPU mode (safe for HF Spaces / Kaggle)
14
  # ============================================================
15
+ device = torch.device("cpu")
16
+ torch.set_num_threads(4) # reduce CPU contention in shared environments
17
+ print(f"🔧 Forcing device: {device} (CPU-only mode)")
18
 
19
  # ============================================================
20
  # 1️⃣ Define FlashPack model
21
  # ============================================================
22
  class GemmaTrainer(nn.Module, FlashPackMixin):
23
+ def __init__(self, input_dim: int = 768, hidden_dim: int = 1024, output_dim: int = 768):
24
  super().__init__()
25
  self.fc1 = nn.Linear(input_dim, hidden_dim)
26
  self.relu = nn.ReLU()
27
  self.fc2 = nn.Linear(hidden_dim, output_dim)
28
 
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
  x = self.fc1(x)
31
  x = self.relu(x)
32
  x = self.fc2(x)
33
  return x
34
 
 
35
  # ============================================================
36
+ # 2️⃣ Utility: encode prompts (CPU-friendly)
37
  # ============================================================
38
+ def build_encoder(model_name="gpt2", max_length: int = 32):
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+ # Some GPT2 tokenizers have no pad token — set eos as pad
41
+ if tokenizer.pad_token is None:
42
+ tokenizer.pad_token = tokenizer.eos_token
43
 
44
+ embed_model = AutoModel.from_pretrained(model_name).to(device)
 
 
 
 
 
 
 
 
 
 
45
  embed_model.eval()
46
 
47
+ @torch.no_grad()
48
+ def encode(prompt: str) -> torch.Tensor:
49
+ """
50
+ Encodes a single prompt and returns a CPU tensor of shape (1, hidden_size).
51
+ Always returns a CPU tensor to avoid device juggling in downstream code.
52
+ """
53
  inputs = tokenizer(
54
  prompt,
55
  return_tensors="pt",
56
  truncation=True,
57
  padding="max_length",
58
+ max_length=max_length,
59
  ).to(device)
 
 
60
 
61
+ outputs = embed_model(**inputs).last_hidden_state.mean(dim=1) # (1, hidden)
62
+ return outputs.cpu()
63
+
64
+ return tokenizer, embed_model, encode
65
+
66
+ # ============================================================
67
+ # 3️⃣ Train FlashPack mapping (CPU-optimized)
68
+ # ============================================================
69
+ def train_flashpack_model(
70
+ dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
71
+ model_name: str = "gpt2",
72
+ max_length: int = 32,
73
+ subset_limit: int | None = None, # set to int to train on subset for quick runs
74
+ push_to_hub: bool = True,
75
+ hf_repo: str = "rahul7star/FlashPack",
76
+ ) -> Tuple[GemmaTrainer, object, AutoModel, AutoTokenizer, torch.Tensor]:
77
+ """
78
+ Returns: (trained_model, dataset, embed_model, tokenizer, long_embeddings)
79
+ All tensors remain on CPU to be safe in CPU-only environments.
80
+ """
81
+
82
+ # 1) Load dataset
83
+ print("📦 Loading dataset...")
84
+ dataset = load_dataset(dataset_name, split="train")
85
+
86
+ if subset_limit is not None and subset_limit > 0:
87
+ print(f"⚠️ Using subset of dataset: first {subset_limit} examples for fast iteration")
88
+ dataset = dataset.select(range(min(subset_limit, len(dataset))))
89
+
90
+ # 2) Build tokenizer + encoder
91
+ print("🔧 Setting up tokenizer & encoder...")
92
+ tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length)
93
 
94
+ # 3) Encode dataset in a memory-friendly loop (returns CPU tensors)
95
+ print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
96
+ short_list = []
97
+ long_list = []
98
+ for i, item in enumerate(dataset):
99
+ short_list.append(encode_fn(item["short_prompt"]))
100
+ long_list.append(encode_fn(item["long_prompt"]))
101
+
102
+ # logging & GC every 100 items
103
+ if (i + 1) % 100 == 0 or (i + 1) == len(dataset):
104
+ print(f" → Encoded {i+1}/{len(dataset)} prompts")
105
+ gc.collect()
106
+
107
+ # Stack to single tensors on CPU
108
+ short_embeddings = torch.vstack(short_list) # shape (N, hidden)
109
+ long_embeddings = torch.vstack(long_list)
110
+ print(f"✅ Finished encoding: {short_embeddings.shape[0]} pairs, dim={short_embeddings.shape[1]}")
111
+
112
+ # 4) Initialize GemmaTrainer (on CPU)
113
  model = GemmaTrainer(
114
  input_dim=short_embeddings.shape[1],
115
+ hidden_dim=min(2048, int(short_embeddings.shape[1] * 2)),
116
+ output_dim=long_embeddings.shape[1],
117
+ ).to(device) # device is cpu
118
 
119
+ # 5) Training loop (small-batch style to reduce memory pressure)
120
  criterion = nn.MSELoss()
121
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
122
+
123
  max_epochs = 500
124
  tolerance = 1e-4
125
+ batch_size = 64 # small batches on CPU
126
+
127
+ n = short_embeddings.shape[0]
128
+ print("🚀 Training FlashPack mapper model (CPU). This may take some time...")
129
+
130
+ for epoch in range(1, max_epochs + 1):
131
+ model.train()
132
+ epoch_loss = 0.0
133
+
134
+ # Shuffle indices each epoch
135
+ perm = torch.randperm(n)
136
+ for start in range(0, n, batch_size):
137
+ idx = perm[start : start + batch_size]
138
+ inputs = short_embeddings[idx].to(device)
139
+ targets = long_embeddings[idx].to(device)
140
+
141
+ optimizer.zero_grad()
142
+ outputs = model(inputs)
143
+ loss = criterion(outputs, targets)
144
+ loss.backward()
145
+ optimizer.step()
146
 
147
+ epoch_loss += loss.item() * inputs.size(0)
148
+
149
+ epoch_loss /= n
150
+ if epoch % 10 == 0 or epoch == 1:
151
+ print(f"Epoch {epoch:03d}/{max_epochs}, Loss={epoch_loss:.6f}")
152
+
153
+ if epoch_loss < tolerance:
154
+ print(f"✅ Converged at epoch {epoch}, Loss={epoch_loss:.6f}")
155
  break
 
 
156
 
157
+ # 6) Save model locally and optionally push to HF hub (robust)
158
+ try:
159
+ # If FlashPackMixin provides save_flashpack, use it:
160
+ if hasattr(model, "save_flashpack"):
161
+ print("💾 Saving model with FlashPackMixin.save_flashpack()")
162
+ model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=push_to_hub)
163
+ else:
164
+ # Fallback: simple torch.save
165
+ path = "flashpack_model.pt"
166
+ torch.save(model.state_dict(), path)
167
+ print(f"💾 Saved locally to {path}")
168
+ if push_to_hub:
169
+ try:
170
+ from huggingface_hub import HfApi, HfFolder
171
+ api = HfApi()
172
+ token = HfFolder.get_token()
173
+ api.upload_file(path_or_fileobj=path, path_in_repo=path, repo_id=hf_repo, token=token)
174
+ print(f"🚀 Uploaded model file to HF: {hf_repo}")
175
+ except Exception as e:
176
+ print("⚠️ Could not push to HF Hub:", e)
177
+ except Exception as e:
178
+ print("⚠️ Error while saving/pushing model:", e)
179
 
180
+ print("✅ Training done — returning model and artifacts.")
181
  return model, dataset, embed_model, tokenizer, long_embeddings
182
 
 
183
  # ============================================================
184
+ # 4️⃣ Build everything and prepare for inference
185
  # ============================================================
186
+ # For demo speed in CPU mode, you might want a subset_limit (e.g., 1000).
187
+ # Set subset_limit=None to use full dataset.
188
+ model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model(
189
+ subset_limit=None, # change to a small int for faster testing
190
+ push_to_hub=False, # toggle when you want to actually push
191
+ )
192
+
193
  model.eval()
194
 
195
+ # Reusable encode function for inference (returns CPU tensor)
196
+ @torch.no_grad()
197
+ def encode_for_inference(prompt: str) -> torch.Tensor:
 
198
  inputs = tokenizer(
199
  prompt,
200
  return_tensors="pt",
201
  truncation=True,
202
  padding="max_length",
203
+ max_length=32,
204
  ).to(device)
205
+ return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
 
206
 
207
+ # ============================================================
208
+ # 5️⃣ Enhance prompt function (nearest neighbor via cosine)
209
+ # ============================================================
210
+ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
211
  chat_history = chat_history or []
 
212
 
213
+ # encode user prompt (CPU tensor)
214
+ short_emb = encode_for_inference(user_prompt) # (1, dim)
215
  with torch.no_grad():
216
+ mapped = model(short_emb.to(device)).cpu() # (1, dim)
217
 
218
+ # cosine similarity against dataset long embeddings
219
  cos = nn.CosineSimilarity(dim=1)
220
+ # mapped.repeat(len(long_embeddings), 1) is heavy; do efficient matmul similarity:
221
+ sims = (long_embeddings @ mapped.t()).squeeze(1)
222
+ # normalize: sims / (||long|| * ||mapped||)
223
+ long_norms = long_embeddings.norm(dim=1)
224
+ mapped_norm = mapped.norm()
225
+ sims = sims / (long_norms * (mapped_norm + 1e-12))
226
+
227
+ best_idx = int(sims.argmax().item())
228
  enhanced_prompt = dataset[best_idx]["long_prompt"]
229
 
230
  chat_history.append({"role": "user", "content": user_prompt})
231
  chat_history.append({"role": "assistant", "content": enhanced_prompt})
232
  return chat_history
233
 
 
234
  # ============================================================
235
+ # 6️⃣ Gradio UI
236
  # ============================================================
237
+ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
238
  gr.Markdown(
239
  """
240
+ # ✨ Prompt Enhancer (FlashPack mapper)
241
+ Enter a short prompt, and the model will **expand it with details and creative context**.
242
+ (This demo runs on CPU — expect slower inference/training than GPU.)
243
  """
244
  )
245
 
 
264
  """
265
  ---
266
  💡 **Tips:**
267
+ - CPU mode: training and large-batch encodes can take a while. Use `subset_limit` in the training call for quick tests.
268
+ - Increase *Temperature* for more creative outputs (not used in the nearest-neighbour mapper but kept for UI parity).
269
  """
270
  )
271
 
272
+ # ============================================================
273
+ # 7️⃣ Launch
274
+ # ============================================================
275
  if __name__ == "__main__":
276
  demo.launch(show_error=True)