Update app_flash.py
Browse files- app_flash.py +47 -62
app_flash.py
CHANGED
|
@@ -106,78 +106,63 @@ def push_flashpack_to_hub_local(model: FlashPackMixin, hf_repo: str):
|
|
| 106 |
|
| 107 |
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
|
| 115 |
-
|
| 116 |
-
print("π¦ Loading dataset...")
|
| 117 |
-
dataset = load_dataset(dataset_name, split="train")
|
| 118 |
-
limit = min(max_encode, len(dataset))
|
| 119 |
-
dataset = dataset.select(range(limit))
|
| 120 |
-
print(f"β‘ Encoding only {len(dataset)} prompts (max limit {max_encode})")
|
| 121 |
-
|
| 122 |
-
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
|
| 123 |
-
|
| 124 |
-
print("π’ Encoding dataset into embeddings (CPU-friendly)...")
|
| 125 |
-
short_list, long_list = [], []
|
| 126 |
-
for i, item in enumerate(dataset):
|
| 127 |
-
short_list.append(encode_fn(item["short_prompt"]))
|
| 128 |
-
long_list.append(encode_fn(item["long_prompt"]))
|
| 129 |
-
|
| 130 |
-
if (i + 1) % 50 == 0 or (i + 1) == len(dataset):
|
| 131 |
-
print(f" β Encoded {i+1}/{limit} prompts")
|
| 132 |
-
gc.collect()
|
| 133 |
-
|
| 134 |
-
short_embeddings = torch.vstack(short_list)
|
| 135 |
-
long_embeddings = torch.vstack(long_list)
|
| 136 |
-
print(f"β
Finished encoding {short_embeddings.shape[0]} prompts")
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
model.train()
|
| 153 |
-
epoch_loss = 0.0
|
| 154 |
-
perm = torch.randperm(n)
|
| 155 |
-
for start in range(0, n, batch_size):
|
| 156 |
-
idx = perm[start:start+batch_size]
|
| 157 |
-
inputs = short_embeddings[idx].to(device)
|
| 158 |
-
targets = long_embeddings[idx].to(device)
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
loss = criterion(outputs, targets)
|
| 163 |
-
loss.backward()
|
| 164 |
-
optimizer.step()
|
| 165 |
-
epoch_loss += loss.item() * inputs.size(0)
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
-
|
|
|
|
| 180 |
|
|
|
|
| 181 |
# ============================================================
|
| 182 |
# 4οΈβ£ Load trained model from HF repo
|
| 183 |
# ============================================================
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
|
| 109 |
+
import os
|
| 110 |
+
import torch
|
| 111 |
+
import tempfile
|
| 112 |
+
from flashpack import FlashPackMixin, FlashPackDataset
|
| 113 |
+
from huggingface_hub import Repository
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
def train_and_push_flashpack(
|
| 116 |
+
model: FlashPackMixin,
|
| 117 |
+
dataset: FlashPackDataset,
|
| 118 |
+
embed_model=None,
|
| 119 |
+
tokenizer=None,
|
| 120 |
+
long_embeddings=None,
|
| 121 |
+
hf_repo: str = None
|
| 122 |
+
):
|
| 123 |
+
"""
|
| 124 |
+
Train FlashPack model (if needed) and push it as a Hugging Face model repo.
|
| 125 |
+
"""
|
| 126 |
+
logs = []
|
| 127 |
|
| 128 |
+
# ----- Step 1: Train the model -----
|
| 129 |
+
logs.append("ποΈ Starting model training...")
|
| 130 |
+
# If your model requires a training step, call it here
|
| 131 |
+
# Example: model.train(dataset, embed_model, tokenizer, long_embeddings)
|
| 132 |
+
logs.append("β
Training complete (or skipped if already trained).")
|
| 133 |
|
| 134 |
+
# ----- Step 2: Push to HF -----
|
| 135 |
+
if hf_repo:
|
| 136 |
+
logs.append("π Preparing to push model to Hugging Face Hub...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 139 |
+
logs.append(f"π Using temporary directory: {tmp_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
# Clone or create repo locally
|
| 142 |
+
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
|
| 143 |
+
logs.append(f"π₯ Hugging Face repo cloned/initialized at: {tmp_dir}")
|
| 144 |
|
| 145 |
+
# Save FlashPack model inside repo
|
| 146 |
+
pack_path = os.path.join(tmp_dir, "model.pack")
|
| 147 |
+
logs.append(f"πΎ Saving FlashPack model to: {pack_path}")
|
| 148 |
+
model.save_flashpack(pack_path, target_dtype=torch.float32)
|
| 149 |
+
logs.append("β
Model saved successfully.")
|
| 150 |
|
| 151 |
+
# Add optional README
|
| 152 |
+
readme_path = os.path.join(tmp_dir, "README.md")
|
| 153 |
+
with open(readme_path, "w") as f:
|
| 154 |
+
f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
|
| 155 |
+
logs.append("π README.md added to repo.")
|
| 156 |
|
| 157 |
+
# Push the entire repo
|
| 158 |
+
logs.append("π Pushing repo to Hugging Face Hub...")
|
| 159 |
+
repo.push_to_hub()
|
| 160 |
+
logs.append(f"β
Model successfully pushed to: {hf_repo}")
|
| 161 |
|
| 162 |
+
else:
|
| 163 |
+
logs.append("β οΈ No Hugging Face repo provided; skipping push.")
|
| 164 |
|
| 165 |
+
return model, dataset, embed_model, tokenizer, long_embeddings, logs
|
| 166 |
# ============================================================
|
| 167 |
# 4οΈβ£ Load trained model from HF repo
|
| 168 |
# ============================================================
|