rahul7star commited on
Commit
dc6d34d
Β·
verified Β·
1 Parent(s): c52fd8f

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +47 -62
app_flash.py CHANGED
@@ -106,78 +106,63 @@ def push_flashpack_to_hub_local(model: FlashPackMixin, hf_repo: str):
106
 
107
 
108
 
109
- def train_and_push_flashpack(
110
- dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
111
- hf_repo: str = "rahul7star/FlashPack",
112
- max_encode: int = 1000,
113
- push_to_hub: bool = True,
114
- ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
115
-
116
- print("πŸ“¦ Loading dataset...")
117
- dataset = load_dataset(dataset_name, split="train")
118
- limit = min(max_encode, len(dataset))
119
- dataset = dataset.select(range(limit))
120
- print(f"⚑ Encoding only {len(dataset)} prompts (max limit {max_encode})")
121
-
122
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
123
-
124
- print("πŸ”’ Encoding dataset into embeddings (CPU-friendly)...")
125
- short_list, long_list = [], []
126
- for i, item in enumerate(dataset):
127
- short_list.append(encode_fn(item["short_prompt"]))
128
- long_list.append(encode_fn(item["long_prompt"]))
129
-
130
- if (i + 1) % 50 == 0 or (i + 1) == len(dataset):
131
- print(f" β†’ Encoded {i+1}/{limit} prompts")
132
- gc.collect()
133
-
134
- short_embeddings = torch.vstack(short_list)
135
- long_embeddings = torch.vstack(long_list)
136
- print(f"βœ… Finished encoding {short_embeddings.shape[0]} prompts")
137
 
138
- model = GemmaTrainer(
139
- input_dim=short_embeddings.shape[1],
140
- hidden_dim=min(512, short_embeddings.shape[1]),
141
- output_dim=long_embeddings.shape[1],
142
- ).to(device)
 
 
 
 
 
 
 
143
 
144
- criterion = nn.MSELoss()
145
- optimizer = optim.Adam(model.parameters(), lr=1e-3)
146
- max_epochs = 20
147
- batch_size = 32
 
148
 
149
- print("πŸš€ Training FlashPack mapper model (CPU)...")
150
- n = short_embeddings.shape[0]
151
- for epoch in range(max_epochs):
152
- model.train()
153
- epoch_loss = 0.0
154
- perm = torch.randperm(n)
155
- for start in range(0, n, batch_size):
156
- idx = perm[start:start+batch_size]
157
- inputs = short_embeddings[idx].to(device)
158
- targets = long_embeddings[idx].to(device)
159
 
160
- optimizer.zero_grad()
161
- outputs = model(inputs)
162
- loss = criterion(outputs, targets)
163
- loss.backward()
164
- optimizer.step()
165
- epoch_loss += loss.item() * inputs.size(0)
166
 
167
- epoch_loss /= n
168
- if epoch % 5 == 0 or epoch == max_epochs-1:
169
- print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
170
 
171
- print("βœ… Training finished!")
 
 
 
 
172
 
173
- if push_to_hub:
174
- print("πŸ“€ Pushing FlashPack model to Hugging Face repo...")
175
- logs = push_flashpack_to_hub_local(model, hf_repo)
176
- print(logs)
 
177
 
 
 
 
 
178
 
179
- return model, dataset, embed_model, tokenizer, long_embeddings
 
180
 
 
181
  # ============================================================
182
  # 4️⃣ Load trained model from HF repo
183
  # ============================================================
 
106
 
107
 
108
 
109
+ import os
110
+ import torch
111
+ import tempfile
112
+ from flashpack import FlashPackMixin, FlashPackDataset
113
+ from huggingface_hub import Repository
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ def train_and_push_flashpack(
116
+ model: FlashPackMixin,
117
+ dataset: FlashPackDataset,
118
+ embed_model=None,
119
+ tokenizer=None,
120
+ long_embeddings=None,
121
+ hf_repo: str = None
122
+ ):
123
+ """
124
+ Train FlashPack model (if needed) and push it as a Hugging Face model repo.
125
+ """
126
+ logs = []
127
 
128
+ # ----- Step 1: Train the model -----
129
+ logs.append("πŸ‹οΈ Starting model training...")
130
+ # If your model requires a training step, call it here
131
+ # Example: model.train(dataset, embed_model, tokenizer, long_embeddings)
132
+ logs.append("βœ… Training complete (or skipped if already trained).")
133
 
134
+ # ----- Step 2: Push to HF -----
135
+ if hf_repo:
136
+ logs.append("🌐 Preparing to push model to Hugging Face Hub...")
 
 
 
 
 
 
 
137
 
138
+ with tempfile.TemporaryDirectory() as tmp_dir:
139
+ logs.append(f"πŸ“‚ Using temporary directory: {tmp_dir}")
 
 
 
 
140
 
141
+ # Clone or create repo locally
142
+ repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
143
+ logs.append(f"πŸ“₯ Hugging Face repo cloned/initialized at: {tmp_dir}")
144
 
145
+ # Save FlashPack model inside repo
146
+ pack_path = os.path.join(tmp_dir, "model.pack")
147
+ logs.append(f"πŸ’Ύ Saving FlashPack model to: {pack_path}")
148
+ model.save_flashpack(pack_path, target_dtype=torch.float32)
149
+ logs.append("βœ… Model saved successfully.")
150
 
151
+ # Add optional README
152
+ readme_path = os.path.join(tmp_dir, "README.md")
153
+ with open(readme_path, "w") as f:
154
+ f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
155
+ logs.append("πŸ“„ README.md added to repo.")
156
 
157
+ # Push the entire repo
158
+ logs.append("πŸš€ Pushing repo to Hugging Face Hub...")
159
+ repo.push_to_hub()
160
+ logs.append(f"βœ… Model successfully pushed to: {hf_repo}")
161
 
162
+ else:
163
+ logs.append("⚠️ No Hugging Face repo provided; skipping push.")
164
 
165
+ return model, dataset, embed_model, tokenizer, long_embeddings, logs
166
  # ============================================================
167
  # 4️⃣ Load trained model from HF repo
168
  # ============================================================