calender commited on
Commit
8b50960
Β·
verified Β·
1 Parent(s): 1ba8d62

Upload 4 files

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +37 -16
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🫁
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: "4.0.0"
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: "5.49.1"
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -54,10 +54,11 @@ try:
54
  ENSEMBLE_CONFIG = json.load(f)
55
  print("βœ… Loaded ensemble_config.json successfully.")
56
  except FileNotFoundError:
57
- print("⚠️ Warning: ensemble_config.json not found. Defaulting to equal weights for ensemble.")
58
  ENSEMBLE_CONFIG = {"weights": {}}
59
  except json.JSONDecodeError:
60
- print("❌ Error: Could not decode ensemble_config.json. Defaulting to equal weights.")
 
61
 
62
  class CBAM(nn.Module):
63
  """Convolutional Block Attention Module - matches training implementation"""
@@ -146,25 +147,39 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
146
  print("πŸ’» Running locally - checking local files first")
147
 
148
  # Map config filenames to local filenames
 
149
  config_to_local_map = {
150
- "iteration4_best.pth": "Model4.pth",
151
- "iteration5_best.pth": "Model5.pth",
152
- "iteration6_best.pth": "Model6.pth"
 
 
 
 
 
 
 
 
 
153
  }
154
 
155
  # Use filenames from config if available, otherwise default
156
  config_model_files = list(ENSEMBLE_CONFIG.get("weights", {}).keys())
157
  if not config_model_files:
158
  print("⚠️ No model weights in config, using default model files and equal weights.")
159
- config_model_files = ["iteration4_best.pth", "iteration5_best.pth", "iteration6_best.pth"]
 
 
 
 
160
 
161
- for hf_filename in config_model_files:
162
- weight = ENSEMBLE_CONFIG.get("weights", {}).get(hf_filename, 1.0)
163
 
164
  if load_locally:
165
- local_filename = config_to_local_map.get(hf_filename)
166
  if not local_filename:
167
- print(f"❌ Error: No local mapping for '{hf_filename}'. Skipping.")
168
  continue
169
  model_path = local_filename
170
  if not os.path.exists(model_path):
@@ -172,14 +187,16 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
172
  return [], device, []
173
  print(f"Found local model: {model_path}")
174
  else:
175
- print(f"Downloading {hf_filename} from repo {model_repo}...")
 
 
176
  try:
177
  # Set HF token for private repo access
178
  hf_token = os.environ.get("HF_TOKEN")
179
  if hf_token:
180
  print(f"βœ… Using HF_TOKEN for private repo access")
181
- model_path = hf_hub_download(repo_id=model_repo, filename=hf_filename, token=hf_token)
182
- print(f"βœ… Downloaded {hf_filename} successfully.")
183
 
184
  # Basic validation of downloaded file
185
  if not model_path.endswith('.pth') and not model_path.endswith('.pt'):
@@ -189,7 +206,8 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
189
  print(f"❌ Downloaded file not found at {model_path}")
190
  continue
191
  except Exception as e:
192
- print(f"❌ Failed to download {hf_filename}: {e}")
 
193
  continue
194
 
195
  # Load state dict with proper error handling
@@ -260,9 +278,9 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
260
 
261
  try:
262
  model.load_state_dict(state_dict, strict=False)
263
- print(f"βœ… Successfully loaded model for '{hf_filename}' with weight {weight:.2f}")
264
  except Exception as e:
265
- print(f"❌ Failed to load state dict for '{hf_filename}': {e}")
266
  print("This might indicate architecture mismatch. Consider retraining or using correct model architecture.")
267
  continue
268
 
@@ -278,6 +296,9 @@ def load_ensemble_model(model_repo="calender/Ensemble_C"):
278
  total_weight = sum(model_weights)
279
  normalized_weights = [w / total_weight for w in model_weights]
280
  print(f"Ensemble loaded with {len(models)} models. Normalized weights: {[f'{w:.2f}' for w in normalized_weights]}")
 
 
 
281
 
282
  return models, device, normalized_weights
283
 
 
54
  ENSEMBLE_CONFIG = json.load(f)
55
  print("βœ… Loaded ensemble_config.json successfully.")
56
  except FileNotFoundError:
57
+ print("⚠️ Warning: ensemble_config.json not found. Using default configuration.")
58
  ENSEMBLE_CONFIG = {"weights": {}}
59
  except json.JSONDecodeError:
60
+ print("❌ Error: Could not decode ensemble_config.json. Using default configuration.")
61
+ ENSEMBLE_CONFIG = {"weights": {}}
62
 
63
  class CBAM(nn.Module):
64
  """Convolutional Block Attention Module - matches training implementation"""
 
147
  print("πŸ’» Running locally - checking local files first")
148
 
149
  # Map config filenames to local filenames
150
+ # Local clean files: Model4.pth, Model5.pth, Model6.pth
151
  config_to_local_map = {
152
+ "Model4.pth": "Model4.pth",
153
+ "Model5.pth": "Model5.pth",
154
+ "Model6.pth": "Model6.pth"
155
+ }
156
+
157
+ # Map config filenames to actual repository filenames
158
+ # Repository contains clean files: Model4.pth, Model5.pth, Model6.pth
159
+ # (NOT the checkpoint files Iteration4_BEST.pth which have extra training data)
160
+ config_to_repo_map = {
161
+ "Model4.pth": "Model4.pth",
162
+ "Model5.pth": "Model5.pth",
163
+ "Model6.pth": "Model6.pth"
164
  }
165
 
166
  # Use filenames from config if available, otherwise default
167
  config_model_files = list(ENSEMBLE_CONFIG.get("weights", {}).keys())
168
  if not config_model_files:
169
  print("⚠️ No model weights in config, using default model files and equal weights.")
170
+ config_model_files = ["Model4.pth", "Model5.pth", "Model6.pth"]
171
+ print(f" Expected repository files: {config_model_files}")
172
+ else:
173
+ print(f"πŸ“‹ Config specifies models: {config_model_files}")
174
+ print(f"πŸ” Will download from repository: {[config_to_repo_map.get(f, f) for f in config_model_files]}")
175
 
176
+ for config_filename in config_model_files:
177
+ weight = ENSEMBLE_CONFIG.get("weights", {}).get(config_filename, 1.0)
178
 
179
  if load_locally:
180
+ local_filename = config_to_local_map.get(config_filename)
181
  if not local_filename:
182
+ print(f"❌ Error: No local mapping for '{config_filename}'. Skipping.")
183
  continue
184
  model_path = local_filename
185
  if not os.path.exists(model_path):
 
187
  return [], device, []
188
  print(f"Found local model: {model_path}")
189
  else:
190
+ # Use the actual filename in the repository
191
+ repo_filename = config_to_repo_map.get(config_filename, config_filename)
192
+ print(f"Downloading {repo_filename} from repo {model_repo}...")
193
  try:
194
  # Set HF token for private repo access
195
  hf_token = os.environ.get("HF_TOKEN")
196
  if hf_token:
197
  print(f"βœ… Using HF_TOKEN for private repo access")
198
+ model_path = hf_hub_download(repo_id=model_repo, filename=repo_filename, token=hf_token)
199
+ print(f"βœ… Downloaded {repo_filename} successfully.")
200
 
201
  # Basic validation of downloaded file
202
  if not model_path.endswith('.pth') and not model_path.endswith('.pt'):
 
206
  print(f"❌ Downloaded file not found at {model_path}")
207
  continue
208
  except Exception as e:
209
+ print(f"❌ Failed to download {repo_filename}: {e}")
210
+ print(f" Make sure '{repo_filename}' exists in the '{model_repo}' repository.")
211
  continue
212
 
213
  # Load state dict with proper error handling
 
278
 
279
  try:
280
  model.load_state_dict(state_dict, strict=False)
281
+ print(f"βœ… Successfully loaded model for '{config_filename}' with weight {weight:.2f}")
282
  except Exception as e:
283
+ print(f"❌ Failed to load state dict for '{config_filename}': {e}")
284
  print("This might indicate architecture mismatch. Consider retraining or using correct model architecture.")
285
  continue
286
 
 
296
  total_weight = sum(model_weights)
297
  normalized_weights = [w / total_weight for w in model_weights]
298
  print(f"Ensemble loaded with {len(models)} models. Normalized weights: {[f'{w:.2f}' for w in normalized_weights]}")
299
+ print(f"πŸ“Š Models loaded: {config_model_files}")
300
+ print(f"πŸ”— Repository: {model_repo}")
301
+ print(f"🎯 Ready for Gradio interface!")
302
 
303
  return models, device, normalized_weights
304