Update app_flash.py
Browse files- app_flash.py +22 -6
app_flash.py
CHANGED
|
@@ -106,12 +106,28 @@ def train_flashpack_model(hf_repo=HF_REPO):
|
|
| 106 |
# ============================================================
|
| 107 |
# π¦ Load FlashPack from Hub
|
| 108 |
# ============================================================
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
print("
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
# ============================================================
|
|
|
|
| 106 |
# ============================================================
|
| 107 |
# π¦ Load FlashPack from Hub
|
| 108 |
# ============================================================
|
| 109 |
+
from huggingface_hub import snapshot_download
|
| 110 |
+
import os
|
| 111 |
+
|
| 112 |
+
def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
| 113 |
+
print(f"π Loading FlashPack model from: {hf_repo}")
|
| 114 |
+
|
| 115 |
+
# Try local first, then Hugging Face Hub
|
| 116 |
+
if os.path.isdir(hf_repo):
|
| 117 |
+
local_dir = hf_repo
|
| 118 |
+
print(f"π Using local FlashPack model at: {local_dir}")
|
| 119 |
+
else:
|
| 120 |
+
print("βοΈ Downloading FlashPack model from Hugging Face Hub...")
|
| 121 |
+
local_dir = snapshot_download(repo_id=hf_repo)
|
| 122 |
+
print(f"π₯ Model snapshot downloaded to: {local_dir}")
|
| 123 |
+
|
| 124 |
+
# Load from local directory
|
| 125 |
+
model = GemmaTrainer.from_flashpack(local_dir)
|
| 126 |
+
model.eval()
|
| 127 |
+
print("β
FlashPack model loaded successfully.")
|
| 128 |
+
return model
|
| 129 |
+
|
| 130 |
+
|
| 131 |
|
| 132 |
|
| 133 |
# ============================================================
|