rahul7star commited on
Commit
d56c08f
·
verified ·
1 Parent(s): 530abb4

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +24 -68
app_flash.py CHANGED
@@ -12,35 +12,17 @@ from huggingface_hub import Repository
12
  from typing import Tuple
13
 
14
  # ============================================================
15
- # 🖥 Device setup (CPU-only safe)
16
  # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
- print(f"🔧 Using device: {device} (CPU-only mode)")
20
- # prompt_enhancer_flashpack_cpu_publish_v2.py
21
- import gc
22
- import os
23
- import tempfile
24
- from typing import Tuple
25
-
26
- import torch
27
- import torch.nn as nn
28
- import torch.optim as optim
29
- from datasets import load_dataset
30
- from transformers import AutoTokenizer, AutoModel
31
- from flashpack import FlashPackMixin
32
- from huggingface_hub import Repository
33
-
34
- device = torch.device("cpu")
35
- torch.set_num_threads(4)
36
- print(f"🔧 Using device: {device} (CPU-only mode)")
37
-
38
 
39
  # ============================================================
40
- # 1️⃣ Define improved FlashPack model
41
  # ============================================================
42
  class GemmaTrainer(nn.Module, FlashPackMixin):
43
- def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 768):
44
  super().__init__()
45
  self.fc1 = nn.Linear(input_dim, hidden_dim)
46
  self.relu = nn.ReLU()
@@ -55,9 +37,8 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
55
  x = self.fc3(x)
56
  return x
57
 
58
-
59
  # ============================================================
60
- # 2️⃣ Encoder with mean+max pooling
61
  # ============================================================
62
  def build_encoder(model_name="gpt2", max_length: int = 128):
63
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -69,48 +50,32 @@ def build_encoder(model_name="gpt2", max_length: int = 128):
69
 
70
  @torch.no_grad()
71
  def encode(prompt: str) -> torch.Tensor:
72
- inputs = tokenizer(
73
- prompt,
74
- return_tensors="pt",
75
- truncation=True,
76
- padding="max_length",
77
- max_length=max_length
78
- ).to(device)
79
  last_hidden = embed_model(**inputs).last_hidden_state
80
  mean_pool = last_hidden.mean(dim=1)
81
  max_pool, _ = last_hidden.max(dim=1)
82
- return torch.cat([mean_pool, max_pool], dim=1).cpu()
83
 
84
  return tokenizer, embed_model, encode
85
 
86
-
87
  # ============================================================
88
- # 3️⃣ Push FlashPack model to HF
89
  # ============================================================
90
  def push_flashpack_model_to_hf(model, hf_repo: str):
91
  logs = []
92
  with tempfile.TemporaryDirectory() as tmp_dir:
93
- logs.append(f"📂 Using temporary directory: {tmp_dir}")
94
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
95
- logs.append(f"🌐 Hugging Face repo cloned to: {tmp_dir}")
96
-
97
  pack_path = os.path.join(tmp_dir, "model.flashpack")
98
- logs.append(f"💾 Saving model to: {pack_path}")
99
  model.save_flashpack(pack_path, target_dtype=torch.float32)
100
- logs.append("✅ Model saved successfully.")
101
-
102
  readme_path = os.path.join(tmp_dir, "README.md")
103
  with open(readme_path, "w") as f:
104
  f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
105
- logs.append("📄 README.md added.")
106
-
107
- logs.append("🚀 Pushing repo to Hugging Face Hub...")
108
  repo.push_to_hub()
109
- logs.append(f"✅ Model successfully pushed to: {hf_repo}")
110
-
111
  return logs
112
 
113
-
114
  # ============================================================
115
  # 4️⃣ Train FlashPack model
116
  # ============================================================
@@ -126,9 +91,8 @@ def train_flashpack_model(
126
  dataset = load_dataset(dataset_name, split="train")
127
  limit = min(max_encode, len(dataset))
128
  dataset = dataset.select(range(limit))
129
- print(f"⚡ Using {len(dataset)} prompts for training (max {max_encode})")
130
 
131
- # Build encoder
132
  tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
133
 
134
  # Encode prompts
@@ -142,14 +106,13 @@ def train_flashpack_model(
142
 
143
  short_embeddings = torch.vstack(short_list)
144
  long_embeddings = torch.vstack(long_list)
145
- print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
146
- input_dim = short_embeddings.shape[1]
 
147
  output_dim = long_embeddings.shape[1]
148
 
149
- # Build model
150
  model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
151
 
152
- # Loss & optimizer
153
  criterion = nn.CosineSimilarity(dim=1)
154
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
155
  max_epochs = 50
@@ -168,7 +131,7 @@ def train_flashpack_model(
168
 
169
  optimizer.zero_grad()
170
  outputs = model(inputs)
171
- loss = 1 - criterion(outputs, targets).mean() # Cosine similarity loss
172
  loss.backward()
173
  optimizer.step()
174
  epoch_loss += loss.item() * inputs.size(0)
@@ -179,8 +142,6 @@ def train_flashpack_model(
179
 
180
  print("✅ Training finished!")
181
 
182
- # Push to HF
183
- logs = []
184
  if push_to_hub:
185
  logs = push_flashpack_model_to_hf(model, hf_repo)
186
  for log in logs:
@@ -189,31 +150,26 @@ def train_flashpack_model(
189
  return model, dataset, embed_model, tokenizer, long_embeddings
190
 
191
  # ============================================================
192
- # 5️⃣ Load FlashPack model (train if missing)
193
  # ============================================================
194
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
195
  try:
196
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
197
  model = GemmaTrainer.from_flashpack(hf_repo)
198
  model.eval()
199
- print("✅ Loaded model successfully from HF")
200
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
201
  return model, tokenizer, embed_model
202
  except Exception as e:
203
  print(f"⚠️ Load failed: {e}")
204
  print("⏬ Training a new FlashPack model locally...")
205
  model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
206
- print("📤 Pushing trained model to HF...")
207
  push_flashpack_model_to_hf(model, hf_repo)
208
  return model, tokenizer, embed_model, dataset, long_embeddings
209
 
210
  # ============================================================
211
  # 6️⃣ Load or train
212
  # ============================================================
213
- try:
214
- model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
215
- except Exception as e:
216
- raise SystemExit(f"❌ Failed to load or train FlashPack model: {e}")
217
 
218
  # ============================================================
219
  # 7️⃣ Inference helpers
@@ -221,8 +177,11 @@ except Exception as e:
221
  @torch.no_grad()
222
  def encode_for_inference(prompt: str) -> torch.Tensor:
223
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
224
- padding="max_length", max_length=32).to(device)
225
- return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
 
 
 
226
 
227
  def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
228
  chat_history = chat_history or []
@@ -268,9 +227,6 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
268
 
269
  # ============================================================
270
  # 9️⃣ Launch
271
-
272
- # ============================================================
273
- # 🏁 Launch app
274
  # ============================================================
275
  if __name__ == "__main__":
276
  demo.launch(show_error=True)
 
12
  from typing import Tuple
13
 
14
  # ============================================================
15
+ # 🖥 Device setup (CPU-only)
16
  # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
+ print(f"🔧 Using device: {device} (CPU-only)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # ============================================================
22
+ # 1️⃣ FlashPack model with better hidden layers
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
26
  super().__init__()
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
 
37
  x = self.fc3(x)
38
  return x
39
 
 
40
  # ============================================================
41
+ # 2️⃣ Encoder using mean+max pooling (for richer embeddings)
42
  # ============================================================
43
  def build_encoder(model_name="gpt2", max_length: int = 128):
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
50
 
51
  @torch.no_grad()
52
  def encode(prompt: str) -> torch.Tensor:
53
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
54
+ padding="max_length", max_length=max_length).to(device)
 
 
 
 
 
55
  last_hidden = embed_model(**inputs).last_hidden_state
56
  mean_pool = last_hidden.mean(dim=1)
57
  max_pool, _ = last_hidden.max(dim=1)
58
+ return torch.cat([mean_pool, max_pool], dim=1).cpu() # doubled embedding
59
 
60
  return tokenizer, embed_model, encode
61
 
 
62
  # ============================================================
63
+ # 3️⃣ Push FlashPack model to Hugging Face
64
  # ============================================================
65
  def push_flashpack_model_to_hf(model, hf_repo: str):
66
  logs = []
67
  with tempfile.TemporaryDirectory() as tmp_dir:
68
+ logs.append(f"📂 Temporary directory: {tmp_dir}")
69
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
 
 
70
  pack_path = os.path.join(tmp_dir, "model.flashpack")
 
71
  model.save_flashpack(pack_path, target_dtype=torch.float32)
 
 
72
  readme_path = os.path.join(tmp_dir, "README.md")
73
  with open(readme_path, "w") as f:
74
  f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
 
 
 
75
  repo.push_to_hub()
76
+ logs.append(f"✅ Model pushed to HF: {hf_repo}")
 
77
  return logs
78
 
 
79
  # ============================================================
80
  # 4️⃣ Train FlashPack model
81
  # ============================================================
 
91
  dataset = load_dataset(dataset_name, split="train")
92
  limit = min(max_encode, len(dataset))
93
  dataset = dataset.select(range(limit))
94
+ print(f"⚡ Using {len(dataset)} prompts for training")
95
 
 
96
  tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
97
 
98
  # Encode prompts
 
106
 
107
  short_embeddings = torch.vstack(short_list)
108
  long_embeddings = torch.vstack(long_list)
109
+ print(f"✅ Encoded embeddings shape: short {short_embeddings.shape}, long {long_embeddings.shape}")
110
+
111
+ input_dim = short_embeddings.shape[1] # should match concatenated mean+max
112
  output_dim = long_embeddings.shape[1]
113
 
 
114
  model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
115
 
 
116
  criterion = nn.CosineSimilarity(dim=1)
117
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
118
  max_epochs = 50
 
131
 
132
  optimizer.zero_grad()
133
  outputs = model(inputs)
134
+ loss = 1 - criterion(outputs, targets).mean()
135
  loss.backward()
136
  optimizer.step()
137
  epoch_loss += loss.item() * inputs.size(0)
 
142
 
143
  print("✅ Training finished!")
144
 
 
 
145
  if push_to_hub:
146
  logs = push_flashpack_model_to_hf(model, hf_repo)
147
  for log in logs:
 
150
  return model, dataset, embed_model, tokenizer, long_embeddings
151
 
152
  # ============================================================
153
+ # 5️⃣ Load or train model
154
  # ============================================================
155
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
156
  try:
157
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
158
  model = GemmaTrainer.from_flashpack(hf_repo)
159
  model.eval()
160
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
 
161
  return model, tokenizer, embed_model
162
  except Exception as e:
163
  print(f"⚠️ Load failed: {e}")
164
  print("⏬ Training a new FlashPack model locally...")
165
  model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
 
166
  push_flashpack_model_to_hf(model, hf_repo)
167
  return model, tokenizer, embed_model, dataset, long_embeddings
168
 
169
  # ============================================================
170
  # 6️⃣ Load or train
171
  # ============================================================
172
+ model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
 
 
 
173
 
174
  # ============================================================
175
  # 7️⃣ Inference helpers
 
177
  @torch.no_grad()
178
  def encode_for_inference(prompt: str) -> torch.Tensor:
179
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
180
+ padding="max_length", max_length=128).to(device)
181
+ last_hidden = embed_model(**inputs).last_hidden_state
182
+ mean_pool = last_hidden.mean(dim=1)
183
+ max_pool, _ = last_hidden.max(dim=1)
184
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
185
 
186
  def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
187
  chat_history = chat_history or []
 
227
 
228
  # ============================================================
229
  # 9️⃣ Launch
 
 
 
230
  # ============================================================
231
  if __name__ == "__main__":
232
  demo.launch(show_error=True)