Spaces:
Running
on
Zero
Running
on
Zero
2025-07-31 21:23 π
Browse filesFixed a bug in app.py
app.py
CHANGED
|
@@ -36,7 +36,7 @@ pretrained_models = [
|
|
| 36 |
# -----------------------------
|
| 37 |
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
|
| 38 |
""" Load the model weights from the Hugging Face Hub."""
|
| 39 |
-
global loaded_model
|
| 40 |
# Build model
|
| 41 |
|
| 42 |
model_info_path = hf_hub_download(
|
|
@@ -241,31 +241,30 @@ def predict(image: Image.Image, variant_dataset: str, metric: str):
|
|
| 241 |
Given an input image, preprocess it, run the model to obtain a density map,
|
| 242 |
compute the total crowd count, and prepare the density map for display.
|
| 243 |
"""
|
| 244 |
-
global loaded_model
|
| 245 |
variant, dataset = variant_dataset.split(" @ ")
|
| 246 |
|
| 247 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
|
| 250 |
-
dataset_name = "sha"
|
| 251 |
-
elif dataset == "ShanghaiTech B":
|
| 252 |
-
dataset_name = "shb"
|
| 253 |
-
elif dataset == "UCF-QNRF":
|
| 254 |
-
dataset_name = "qnrf"
|
| 255 |
-
elif dataset == "NWPU-Crowd":
|
| 256 |
-
dataset_name = "nwpu"
|
| 257 |
-
|
| 258 |
load_model(variant=variant, dataset=dataset_name, metric=metric)
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
|
| 270 |
loaded_model.to(device)
|
| 271 |
|
|
|
|
| 36 |
# -----------------------------
|
| 37 |
def load_model(variant: str, dataset: str = "ShanghaiTech B", metric: str = "mae"):
|
| 38 |
""" Load the model weights from the Hugging Face Hub."""
|
| 39 |
+
# global loaded_model
|
| 40 |
# Build model
|
| 41 |
|
| 42 |
model_info_path = hf_hub_download(
|
|
|
|
| 241 |
Given an input image, preprocess it, run the model to obtain a density map,
|
| 242 |
compute the total crowd count, and prepare the density map for display.
|
| 243 |
"""
|
| 244 |
+
# global loaded_model
|
| 245 |
variant, dataset = variant_dataset.split(" @ ")
|
| 246 |
|
| 247 |
+
if dataset == "ShanghaiTech A":
|
| 248 |
+
dataset_name = "sha"
|
| 249 |
+
elif dataset == "ShanghaiTech B":
|
| 250 |
+
dataset_name = "shb"
|
| 251 |
+
elif dataset == "UCF-QNRF":
|
| 252 |
+
dataset_name = "qnrf"
|
| 253 |
+
elif dataset == "NWPU-Crowd":
|
| 254 |
+
dataset_name = "nwpu"
|
| 255 |
|
| 256 |
+
if loaded_model is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
load_model(variant=variant, dataset=dataset_name, metric=metric)
|
| 258 |
|
| 259 |
+
if not hasattr(loaded_model, "input_size"):
|
| 260 |
+
if dataset_name == "sha":
|
| 261 |
+
loaded_model.input_size = 224
|
| 262 |
+
elif dataset_name == "shb":
|
| 263 |
+
loaded_model.input_size = 448
|
| 264 |
+
elif dataset_name == "qnrf":
|
| 265 |
+
loaded_model.input_size = 672
|
| 266 |
+
elif dataset_name == "nwpu":
|
| 267 |
+
loaded_model.input_size = 672
|
| 268 |
|
| 269 |
loaded_model.to(device)
|
| 270 |
|