Upload 4 files
Browse files
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: π«
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: "
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.49.1"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
app.py
CHANGED
|
@@ -54,10 +54,11 @@ try:
|
|
| 54 |
ENSEMBLE_CONFIG = json.load(f)
|
| 55 |
print("β
Loaded ensemble_config.json successfully.")
|
| 56 |
except FileNotFoundError:
|
| 57 |
-
print("β οΈ Warning: ensemble_config.json not found.
|
| 58 |
ENSEMBLE_CONFIG = {"weights": {}}
|
| 59 |
except json.JSONDecodeError:
|
| 60 |
-
print("β Error: Could not decode ensemble_config.json.
|
|
|
|
| 61 |
|
| 62 |
class CBAM(nn.Module):
|
| 63 |
"""Convolutional Block Attention Module - matches training implementation"""
|
|
@@ -146,25 +147,39 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
|
|
| 146 |
print("π» Running locally - checking local files first")
|
| 147 |
|
| 148 |
# Map config filenames to local filenames
|
|
|
|
| 149 |
config_to_local_map = {
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
}
|
| 154 |
|
| 155 |
# Use filenames from config if available, otherwise default
|
| 156 |
config_model_files = list(ENSEMBLE_CONFIG.get("weights", {}).keys())
|
| 157 |
if not config_model_files:
|
| 158 |
print("β οΈ No model weights in config, using default model files and equal weights.")
|
| 159 |
-
config_model_files = ["
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
for
|
| 162 |
-
weight = ENSEMBLE_CONFIG.get("weights", {}).get(
|
| 163 |
|
| 164 |
if load_locally:
|
| 165 |
-
local_filename = config_to_local_map.get(
|
| 166 |
if not local_filename:
|
| 167 |
-
print(f"β Error: No local mapping for '{
|
| 168 |
continue
|
| 169 |
model_path = local_filename
|
| 170 |
if not os.path.exists(model_path):
|
|
@@ -172,14 +187,16 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
|
|
| 172 |
return [], device, []
|
| 173 |
print(f"Found local model: {model_path}")
|
| 174 |
else:
|
| 175 |
-
|
|
|
|
|
|
|
| 176 |
try:
|
| 177 |
# Set HF token for private repo access
|
| 178 |
hf_token = os.environ.get("HF_TOKEN")
|
| 179 |
if hf_token:
|
| 180 |
print(f"β
Using HF_TOKEN for private repo access")
|
| 181 |
-
model_path = hf_hub_download(repo_id=model_repo, filename=
|
| 182 |
-
print(f"β
Downloaded {
|
| 183 |
|
| 184 |
# Basic validation of downloaded file
|
| 185 |
if not model_path.endswith('.pth') and not model_path.endswith('.pt'):
|
|
@@ -189,7 +206,8 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
|
|
| 189 |
print(f"β Downloaded file not found at {model_path}")
|
| 190 |
continue
|
| 191 |
except Exception as e:
|
| 192 |
-
print(f"β Failed to download {
|
|
|
|
| 193 |
continue
|
| 194 |
|
| 195 |
# Load state dict with proper error handling
|
|
@@ -260,9 +278,9 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
|
|
| 260 |
|
| 261 |
try:
|
| 262 |
model.load_state_dict(state_dict, strict=False)
|
| 263 |
-
print(f"β
Successfully loaded model for '{
|
| 264 |
except Exception as e:
|
| 265 |
-
print(f"β Failed to load state dict for '{
|
| 266 |
print("This might indicate architecture mismatch. Consider retraining or using correct model architecture.")
|
| 267 |
continue
|
| 268 |
|
|
@@ -278,6 +296,9 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
|
|
| 278 |
total_weight = sum(model_weights)
|
| 279 |
normalized_weights = [w / total_weight for w in model_weights]
|
| 280 |
print(f"Ensemble loaded with {len(models)} models. Normalized weights: {[f'{w:.2f}' for w in normalized_weights]}")
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
return models, device, normalized_weights
|
| 283 |
|
|
|
|
| 54 |
ENSEMBLE_CONFIG = json.load(f)
|
| 55 |
print("β
Loaded ensemble_config.json successfully.")
|
| 56 |
except FileNotFoundError:
|
| 57 |
+
print("β οΈ Warning: ensemble_config.json not found. Using default configuration.")
|
| 58 |
ENSEMBLE_CONFIG = {"weights": {}}
|
| 59 |
except json.JSONDecodeError:
|
| 60 |
+
print("β Error: Could not decode ensemble_config.json. Using default configuration.")
|
| 61 |
+
ENSEMBLE_CONFIG = {"weights": {}}
|
| 62 |
|
| 63 |
class CBAM(nn.Module):
|
| 64 |
"""Convolutional Block Attention Module - matches training implementation"""
|
|
|
|
| 147 |
print("π» Running locally - checking local files first")
|
| 148 |
|
| 149 |
# Map config filenames to local filenames
|
| 150 |
+
# Local clean files: Model4.pth, Model5.pth, Model6.pth
|
| 151 |
config_to_local_map = {
|
| 152 |
+
"Model4.pth": "Model4.pth",
|
| 153 |
+
"Model5.pth": "Model5.pth",
|
| 154 |
+
"Model6.pth": "Model6.pth"
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Map config filenames to actual repository filenames
|
| 158 |
+
# Repository contains clean files: Model4.pth, Model5.pth, Model6.pth
|
| 159 |
+
# (NOT the checkpoint files Iteration4_BEST.pth which have extra training data)
|
| 160 |
+
config_to_repo_map = {
|
| 161 |
+
"Model4.pth": "Model4.pth",
|
| 162 |
+
"Model5.pth": "Model5.pth",
|
| 163 |
+
"Model6.pth": "Model6.pth"
|
| 164 |
}
|
| 165 |
|
| 166 |
# Use filenames from config if available, otherwise default
|
| 167 |
config_model_files = list(ENSEMBLE_CONFIG.get("weights", {}).keys())
|
| 168 |
if not config_model_files:
|
| 169 |
print("β οΈ No model weights in config, using default model files and equal weights.")
|
| 170 |
+
config_model_files = ["Model4.pth", "Model5.pth", "Model6.pth"]
|
| 171 |
+
print(f" Expected repository files: {config_model_files}")
|
| 172 |
+
else:
|
| 173 |
+
print(f"π Config specifies models: {config_model_files}")
|
| 174 |
+
print(f"π Will download from repository: {[config_to_repo_map.get(f, f) for f in config_model_files]}")
|
| 175 |
|
| 176 |
+
for config_filename in config_model_files:
|
| 177 |
+
weight = ENSEMBLE_CONFIG.get("weights", {}).get(config_filename, 1.0)
|
| 178 |
|
| 179 |
if load_locally:
|
| 180 |
+
local_filename = config_to_local_map.get(config_filename)
|
| 181 |
if not local_filename:
|
| 182 |
+
print(f"β Error: No local mapping for '{config_filename}'. Skipping.")
|
| 183 |
continue
|
| 184 |
model_path = local_filename
|
| 185 |
if not os.path.exists(model_path):
|
|
|
|
| 187 |
return [], device, []
|
| 188 |
print(f"Found local model: {model_path}")
|
| 189 |
else:
|
| 190 |
+
# Use the actual filename in the repository
|
| 191 |
+
repo_filename = config_to_repo_map.get(config_filename, config_filename)
|
| 192 |
+
print(f"Downloading {repo_filename} from repo {model_repo}...")
|
| 193 |
try:
|
| 194 |
# Set HF token for private repo access
|
| 195 |
hf_token = os.environ.get("HF_TOKEN")
|
| 196 |
if hf_token:
|
| 197 |
print(f"β
Using HF_TOKEN for private repo access")
|
| 198 |
+
model_path = hf_hub_download(repo_id=model_repo, filename=repo_filename, token=hf_token)
|
| 199 |
+
print(f"β
Downloaded {repo_filename} successfully.")
|
| 200 |
|
| 201 |
# Basic validation of downloaded file
|
| 202 |
if not model_path.endswith('.pth') and not model_path.endswith('.pt'):
|
|
|
|
| 206 |
print(f"β Downloaded file not found at {model_path}")
|
| 207 |
continue
|
| 208 |
except Exception as e:
|
| 209 |
+
print(f"β Failed to download {repo_filename}: {e}")
|
| 210 |
+
print(f" Make sure '{repo_filename}' exists in the '{model_repo}' repository.")
|
| 211 |
continue
|
| 212 |
|
| 213 |
# Load state dict with proper error handling
|
|
|
|
| 278 |
|
| 279 |
try:
|
| 280 |
model.load_state_dict(state_dict, strict=False)
|
| 281 |
+
print(f"β
Successfully loaded model for '{config_filename}' with weight {weight:.2f}")
|
| 282 |
except Exception as e:
|
| 283 |
+
print(f"β Failed to load state dict for '{config_filename}': {e}")
|
| 284 |
print("This might indicate architecture mismatch. Consider retraining or using correct model architecture.")
|
| 285 |
continue
|
| 286 |
|
|
|
|
| 296 |
total_weight = sum(model_weights)
|
| 297 |
normalized_weights = [w / total_weight for w in model_weights]
|
| 298 |
print(f"Ensemble loaded with {len(models)} models. Normalized weights: {[f'{w:.2f}' for w in normalized_weights]}")
|
| 299 |
+
print(f"π Models loaded: {config_model_files}")
|
| 300 |
+
print(f"π Repository: {model_repo}")
|
| 301 |
+
print(f"π― Ready for Gradio interface!")
|
| 302 |
|
| 303 |
return models, device, normalized_weights
|
| 304 |
|