Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,6 +18,7 @@ from distances_common import (
|
|
| 18 |
sliced_wasserstein_distance_matrix,
|
| 19 |
)
|
| 20 |
from embed_raw import build_raw_embeddings
|
|
|
|
| 21 |
from embed_umap import build_umap_embeddings
|
| 22 |
from embed_lwm import get_lwm_encoder, build_lwm_embeddings
|
| 23 |
|
|
@@ -30,8 +31,10 @@ def _log(msg: str):
|
|
| 30 |
print(msg, flush=True)
|
| 31 |
|
| 32 |
|
| 33 |
-
def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None)
|
| 34 |
-
"""
|
|
|
|
|
|
|
| 35 |
df = pd.DataFrame(np_mat)
|
| 36 |
if labels is not None and len(labels) == df.shape[0]:
|
| 37 |
df.index = labels
|
|
@@ -66,11 +69,12 @@ def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarr
|
|
| 66 |
return img
|
| 67 |
|
| 68 |
|
| 69 |
-
def _load_uploaded_files(file_objs: List
|
| 70 |
"""Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional)."""
|
| 71 |
out = []
|
| 72 |
for f in (file_objs or []):
|
| 73 |
-
|
|
|
|
| 74 |
try:
|
| 75 |
obj = torch.load(path, map_location="cpu")
|
| 76 |
ch = obj["channels"]
|
|
@@ -83,10 +87,11 @@ def _load_uploaded_files(file_objs: List[gr.File]) -> List[Tuple[torch.Tensor, O
|
|
| 83 |
|
| 84 |
def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
|
| 85 |
out = []
|
| 86 |
-
for p in paths:
|
| 87 |
try:
|
| 88 |
ch, y = load_pt_dataset(p)
|
| 89 |
-
|
|
|
|
| 90 |
except Exception as e:
|
| 91 |
_log(f"[WARN] Failed to load demo dataset {p}: {e}")
|
| 92 |
return out
|
|
@@ -179,10 +184,11 @@ def refresh_demo_tasks():
|
|
| 179 |
|
| 180 |
def refresh_demo_scenarios(task: str):
|
| 181 |
if not task:
|
| 182 |
-
return gr.update(choices=[], value=
|
| 183 |
files = list_demo_dataset_files(task)
|
| 184 |
-
# multi-select
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
def run_compute(
|
|
@@ -195,7 +201,7 @@ def run_compute(
|
|
| 195 |
n_eval_per_dataset: int,
|
| 196 |
demo_task: str,
|
| 197 |
demo_files: List[str],
|
| 198 |
-
uploaded_files: List
|
| 199 |
umap_mode: str,
|
| 200 |
umap_n_components: int,
|
| 201 |
umap_n_neighbors: int,
|
|
@@ -235,20 +241,20 @@ def run_compute(
|
|
| 235 |
|
| 236 |
# UMAP cfg
|
| 237 |
umap_kwargs = dict(
|
| 238 |
-
n_components=umap_n_components,
|
| 239 |
-
n_neighbors=umap_n_neighbors,
|
| 240 |
-
min_dist=umap_min_dist,
|
| 241 |
metric=umap_metric,
|
| 242 |
-
spread=umap_spread,
|
| 243 |
-
learning_rate=umap_learning_rate,
|
| 244 |
n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs),
|
| 245 |
-
negative_sample_rate=umap_negative_sample_rate,
|
| 246 |
init=umap_init,
|
| 247 |
-
densmap=umap_densmap,
|
| 248 |
-
set_op_mix_ratio=umap_set_op_mix_ratio,
|
| 249 |
-
local_connectivity=umap_local_connectivity,
|
| 250 |
-
repulsion_strength=umap_repulsion_strength,
|
| 251 |
-
random_state=umap_random_state,
|
| 252 |
target_metric="categorical",
|
| 253 |
target_weight=0.5,
|
| 254 |
)
|
|
|
|
| 18 |
sliced_wasserstein_distance_matrix,
|
| 19 |
)
|
| 20 |
from embed_raw import build_raw_embeddings
|
| 21 |
+
# noqa
|
| 22 |
from embed_umap import build_umap_embeddings
|
| 23 |
from embed_lwm import get_lwm_encoder, build_lwm_embeddings
|
| 24 |
|
|
|
|
| 31 |
print(msg, flush=True)
|
| 32 |
|
| 33 |
|
| 34 |
+
def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None):
|
| 35 |
+
"""
|
| 36 |
+
Return a safe gr.update payload for a Dataframe (headers must match col_count).
|
| 37 |
+
"""
|
| 38 |
df = pd.DataFrame(np_mat)
|
| 39 |
if labels is not None and len(labels) == df.shape[0]:
|
| 40 |
df.index = labels
|
|
|
|
| 69 |
return img
|
| 70 |
|
| 71 |
|
| 72 |
+
def _load_uploaded_files(file_objs: List) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
|
| 73 |
"""Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional)."""
|
| 74 |
out = []
|
| 75 |
for f in (file_objs or []):
|
| 76 |
+
# gr.Files can give strings (paths) or objects with .name
|
| 77 |
+
path = getattr(f, "name", f)
|
| 78 |
try:
|
| 79 |
obj = torch.load(path, map_location="cpu")
|
| 80 |
ch = obj["channels"]
|
|
|
|
| 87 |
|
| 88 |
def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
|
| 89 |
out = []
|
| 90 |
+
for p in paths or []:
|
| 91 |
try:
|
| 92 |
ch, y = load_pt_dataset(p)
|
| 93 |
+
# scenario folder name as label
|
| 94 |
+
out.append((ch, y, os.path.basename(os.path.dirname(p))))
|
| 95 |
except Exception as e:
|
| 96 |
_log(f"[WARN] Failed to load demo dataset {p}: {e}")
|
| 97 |
return out
|
|
|
|
| 184 |
|
| 185 |
def refresh_demo_scenarios(task: str):
|
| 186 |
if not task:
|
| 187 |
+
return gr.update(choices=[], value=[])
|
| 188 |
files = list_demo_dataset_files(task)
|
| 189 |
+
# multi-select defaults
|
| 190 |
+
default = files[:3] if len(files) >= 3 else files
|
| 191 |
+
return gr.update(choices=files, value=default)
|
| 192 |
|
| 193 |
|
| 194 |
def run_compute(
|
|
|
|
| 201 |
n_eval_per_dataset: int,
|
| 202 |
demo_task: str,
|
| 203 |
demo_files: List[str],
|
| 204 |
+
uploaded_files: List,
|
| 205 |
umap_mode: str,
|
| 206 |
umap_n_components: int,
|
| 207 |
umap_n_neighbors: int,
|
|
|
|
| 241 |
|
| 242 |
# UMAP cfg
|
| 243 |
umap_kwargs = dict(
|
| 244 |
+
n_components=int(umap_n_components),
|
| 245 |
+
n_neighbors=int(umap_n_neighbors),
|
| 246 |
+
min_dist=float(umap_min_dist),
|
| 247 |
metric=umap_metric,
|
| 248 |
+
spread=float(umap_spread),
|
| 249 |
+
learning_rate=float(umap_learning_rate),
|
| 250 |
n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs),
|
| 251 |
+
negative_sample_rate=int(umap_negative_sample_rate),
|
| 252 |
init=umap_init,
|
| 253 |
+
densmap=bool(umap_densmap),
|
| 254 |
+
set_op_mix_ratio=float(umap_set_op_mix_ratio),
|
| 255 |
+
local_connectivity=float(umap_local_connectivity),
|
| 256 |
+
repulsion_strength=float(umap_repulsion_strength),
|
| 257 |
+
random_state=int(umap_random_state),
|
| 258 |
target_metric="categorical",
|
| 259 |
target_weight=0.5,
|
| 260 |
)
|