wi-lab commited on
Commit
c6177bf
·
1 Parent(s): 417cb75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -18,6 +18,7 @@ from distances_common import (
18
  sliced_wasserstein_distance_matrix,
19
  )
20
  from embed_raw import build_raw_embeddings
 
21
  from embed_umap import build_umap_embeddings
22
  from embed_lwm import get_lwm_encoder, build_lwm_embeddings
23
 
@@ -30,8 +31,10 @@ def _log(msg: str):
30
  print(msg, flush=True)
31
 
32
 
33
- def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None) -> gr.Update:
34
- """Return a safe gr.Dataframe update (headers match col_count)."""
 
 
35
  df = pd.DataFrame(np_mat)
36
  if labels is not None and len(labels) == df.shape[0]:
37
  df.index = labels
@@ -66,11 +69,12 @@ def _plot_heatmap(D: np.ndarray, labels: Optional[List[str]] = None) -> np.ndarr
66
  return img
67
 
68
 
69
- def _load_uploaded_files(file_objs: List[gr.File]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
70
  """Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional)."""
71
  out = []
72
  for f in (file_objs or []):
73
- path = f.name if hasattr(f, "name") else f
 
74
  try:
75
  obj = torch.load(path, map_location="cpu")
76
  ch = obj["channels"]
@@ -83,10 +87,11 @@ def _load_uploaded_files(file_objs: List[gr.File]) -> List[Tuple[torch.Tensor, O
83
 
84
  def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
85
  out = []
86
- for p in paths:
87
  try:
88
  ch, y = load_pt_dataset(p)
89
- out.append((ch, y, os.path.basename(os.path.dirname(p)))) # scenario folder as name
 
90
  except Exception as e:
91
  _log(f"[WARN] Failed to load demo dataset {p}: {e}")
92
  return out
@@ -179,10 +184,11 @@ def refresh_demo_tasks():
179
 
180
  def refresh_demo_scenarios(task: str):
181
  if not task:
182
- return gr.update(choices=[], value=None)
183
  files = list_demo_dataset_files(task)
184
- # multi-select
185
- return gr.update(choices=files, value=files[:3] if len(files) >= 3 else files)
 
186
 
187
 
188
  def run_compute(
@@ -195,7 +201,7 @@ def run_compute(
195
  n_eval_per_dataset: int,
196
  demo_task: str,
197
  demo_files: List[str],
198
- uploaded_files: List[gr.File],
199
  umap_mode: str,
200
  umap_n_components: int,
201
  umap_n_neighbors: int,
@@ -235,20 +241,20 @@ def run_compute(
235
 
236
  # UMAP cfg
237
  umap_kwargs = dict(
238
- n_components=umap_n_components,
239
- n_neighbors=umap_n_neighbors,
240
- min_dist=umap_min_dist,
241
  metric=umap_metric,
242
- spread=umap_spread,
243
- learning_rate=umap_learning_rate,
244
  n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs),
245
- negative_sample_rate=umap_negative_sample_rate,
246
  init=umap_init,
247
- densmap=umap_densmap,
248
- set_op_mix_ratio=umap_set_op_mix_ratio,
249
- local_connectivity=umap_local_connectivity,
250
- repulsion_strength=umap_repulsion_strength,
251
- random_state=umap_random_state,
252
  target_metric="categorical",
253
  target_weight=0.5,
254
  )
 
18
  sliced_wasserstein_distance_matrix,
19
  )
20
  from embed_raw import build_raw_embeddings
21
+ # noqa
22
  from embed_umap import build_umap_embeddings
23
  from embed_lwm import get_lwm_encoder, build_lwm_embeddings
24
 
 
31
  print(msg, flush=True)
32
 
33
 
34
+ def _matrix_payload(np_mat: np.ndarray, labels: Optional[List[str]] = None):
35
+ """
36
+ Return a safe gr.update payload for a Dataframe (headers must match col_count).
37
+ """
38
  df = pd.DataFrame(np_mat)
39
  if labels is not None and len(labels) == df.shape[0]:
40
  df.index = labels
 
69
  return img
70
 
71
 
72
+ def _load_uploaded_files(file_objs: List) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
73
  """Load uploaded datasets. Expect torch files with keys: 'channels', 'labels' (optional)."""
74
  out = []
75
  for f in (file_objs or []):
76
+ # gr.Files can give strings (paths) or objects with .name
77
+ path = getattr(f, "name", f)
78
  try:
79
  obj = torch.load(path, map_location="cpu")
80
  ch = obj["channels"]
 
87
 
88
  def _load_demo_files(paths: List[str]) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], str]]:
89
  out = []
90
+ for p in paths or []:
91
  try:
92
  ch, y = load_pt_dataset(p)
93
+ # scenario folder name as label
94
+ out.append((ch, y, os.path.basename(os.path.dirname(p))))
95
  except Exception as e:
96
  _log(f"[WARN] Failed to load demo dataset {p}: {e}")
97
  return out
 
184
 
185
  def refresh_demo_scenarios(task: str):
186
  if not task:
187
+ return gr.update(choices=[], value=[])
188
  files = list_demo_dataset_files(task)
189
+ # multi-select defaults
190
+ default = files[:3] if len(files) >= 3 else files
191
+ return gr.update(choices=files, value=default)
192
 
193
 
194
  def run_compute(
 
201
  n_eval_per_dataset: int,
202
  demo_task: str,
203
  demo_files: List[str],
204
+ uploaded_files: List,
205
  umap_mode: str,
206
  umap_n_components: int,
207
  umap_n_neighbors: int,
 
241
 
242
  # UMAP cfg
243
  umap_kwargs = dict(
244
+ n_components=int(umap_n_components),
245
+ n_neighbors=int(umap_n_neighbors),
246
+ min_dist=float(umap_min_dist),
247
  metric=umap_metric,
248
+ spread=float(umap_spread),
249
+ learning_rate=float(umap_learning_rate),
250
  n_epochs=None if umap_n_epochs in [None, 0] else int(umap_n_epochs),
251
+ negative_sample_rate=int(umap_negative_sample_rate),
252
  init=umap_init,
253
+ densmap=bool(umap_densmap),
254
+ set_op_mix_ratio=float(umap_set_op_mix_ratio),
255
+ local_connectivity=float(umap_local_connectivity),
256
+ repulsion_strength=float(umap_repulsion_strength),
257
+ random_state=int(umap_random_state),
258
  target_metric="categorical",
259
  target_weight=0.5,
260
  )