wi-lab commited on
Commit
f0949fa
·
1 Parent(s): b6e8a86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ from typing import List, Dict, Tuple, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import gradio as gr
9
+ from huggingface_hub import snapshot_download, HfFolder
10
+
11
+ # =========================
12
+ # Helpers: logging + tokens
13
+ # =========================
14
+ DEFAULT_MODEL_REPO = "wi-lab/lwm-v1.1"
15
+ DEFAULT_MODEL_DIR = "./LWM-v1.1"
16
+
17
+ def ensure_hf_token():
18
+ tok = os.getenv("HF_TOKEN", None)
19
+ if tok:
20
+ HfFolder.save_token(tok)
21
+ return tok
22
+
23
+ def log_md(msg: str) -> str:
24
+ return f"{msg}"
25
+
26
+ # =========================
27
+ # Dataset loading utilities
28
+ # =========================
29
+ def _load_pt_bytes(b: bytes) -> Dict[str, torch.Tensor]:
30
+ # Expect a dict with at least "channels". Optional: "labels"
31
+ buf = io.BytesIO(b)
32
+ obj = torch.load(buf, map_location="cpu")
33
+ if isinstance(obj, dict) and "channels" in obj:
34
+ return {
35
+ "channels": obj["channels"],
36
+ "labels": obj.get("labels", None)
37
+ }
38
+ # Fallback: if it’s a tensor
39
+ if torch.is_tensor(obj):
40
+ return {"channels": obj, "labels": None}
41
+ raise ValueError("PT file must contain a dict with 'channels' (and optional 'labels'), or a tensor.")
42
+
43
+ def _load_npy_bytes(b: bytes) -> Dict[str, torch.Tensor]:
44
+ buf = io.BytesIO(b)
45
+ arr = np.load(buf, allow_pickle=True)
46
+ # If it's an array directly
47
+ if isinstance(arr, np.ndarray):
48
+ t = torch.from_numpy(arr)
49
+ return {"channels": t, "labels": None}
50
+ # If it's a dict-like (rare for npy)
51
+ raise ValueError("NPY must contain a single ndarray (channels). For dict-like, use NPZ.")
52
+
53
+ def _load_npz_bytes(b: bytes) -> Dict[str, torch.Tensor]:
54
+ buf = io.BytesIO(b)
55
+ npz = np.load(buf, allow_pickle=True)
56
+ # Expect either keys "channels" and optional "labels", or fallback to first array
57
+ if "channels" in npz:
58
+ ch = npz["channels"]
59
+ labs = npz["labels"] if "labels" in npz else None
60
+ return {
61
+ "channels": torch.from_numpy(ch),
62
+ "labels": (torch.from_numpy(labs) if labs is not None else None)
63
+ }
64
+ # Fallback: take the first array in the file
65
+ keys = list(npz.keys())
66
+ if not keys:
67
+ raise ValueError("Empty NPZ.")
68
+ ch = npz[keys[0]]
69
+ return {"channels": torch.from_numpy(ch), "labels": None}
70
+
71
+ def parse_uploaded_datasets(files: List[gr.File]) -> Dict[int, Dict[str, torch.Tensor]]:
72
+ """
73
+ Accepts multiple files. Each becomes one dataset.
74
+ Supported:
75
+ - .pt / .pth (torch.save)
76
+ - .npy (single array)
77
+ - .npz (expects 'channels' and optional 'labels', else uses first array)
78
+ Output: {0: {'channels': Tensor[N, ...], 'labels': Optional[Tensor[N]]}, 1: {...}, ...}
79
+ """
80
+ datasets = {}
81
+ idx = 0
82
+ for f in files or []:
83
+ name = f.name or ""
84
+ data = f.read()
85
+ try:
86
+ if name.endswith((".pt", ".pth")):
87
+ ds = _load_pt_bytes(data)
88
+ elif name.endswith(".npy"):
89
+ ds = _load_npy_bytes(data)
90
+ elif name.endswith(".npz"):
91
+ ds = _load_npz_bytes(data)
92
+ else:
93
+ raise ValueError(f"Unsupported file type: {name}")
94
+ # Ensure tensors are float and shaped as [N, ...]
95
+ ch = ds["channels"]
96
+ if ch.ndim == 1:
97
+ ch = ch.unsqueeze(0) # [1, D]
98
+ ds["channels"] = ch
99
+ datasets[idx] = ds
100
+ idx += 1
101
+ except Exception as e:
102
+ raise ValueError(f"Failed to load '{name}': {e}")
103
+ return datasets
104
+
105
+ # =========================
106
+ # Distance backends (stubs)
107
+ # =========================
108
+ def _to_feature_matrix(chs: torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Flatten per-sample, split complex into [real, imag], return [N, D] float32
111
+ """
112
+ if chs.ndim >= 3:
113
+ chs = chs.reshape(chs.shape[0], -1) # [N, ...] -> [N, D]
114
+ elif chs.ndim == 2:
115
+ pass # already [N, D]
116
+ else:
117
+ chs = chs.view(chs.shape[0], -1)
118
+
119
+ if torch.is_complex(chs):
120
+ chs = torch.cat([chs.real, chs.imag], dim=1)
121
+ return chs.to(torch.float32)
122
+
123
+ def _pad_to_same_dim(mats: List[torch.Tensor]) -> List[torch.Tensor]:
124
+ max_d = max(m.shape[1] for m in mats)
125
+ out = []
126
+ for m in mats:
127
+ if m.shape[1] < max_d:
128
+ pad = torch.zeros((m.shape[0], max_d - m.shape[1]), dtype=m.dtype)
129
+ m = torch.cat([m, pad], dim=1)
130
+ out.append(m)
131
+ return out
132
+
133
+ def compute_distance_matrix_raw(
134
+ datasets: Dict[int, Dict[str, torch.Tensor]],
135
+ n_per_dataset: int,
136
+ distance_mode: str,
137
+ sw_num_projections: int,
138
+ label_aware: bool,
139
+ label_weighting: str,
140
+ label_max_per_class: int
141
+ ) -> torch.Tensor:
142
+ """
143
+ Minimal RAW baseline: centroid L2 or cosine. SW is not implemented here (stub).
144
+ """
145
+ mats = []
146
+ for i in sorted(datasets.keys()):
147
+ ch = datasets[i]["channels"]
148
+ n = min(n_per_dataset, ch.shape[0]) if n_per_dataset else ch.shape[0]
149
+ idxs = torch.randperm(ch.shape[0])[:n]
150
+ X = _to_feature_matrix(ch[idxs])
151
+ mats.append(X)
152
+ mats = _pad_to_same_dim(mats)
153
+ cents = [M.mean(dim=0, keepdim=True) for M in mats]
154
+ C = torch.cat(cents, dim=0) # [D, Df]
155
+
156
+ if distance_mode == "cosine_similarity":
157
+ Cn = torch.nn.functional.normalize(C, dim=1)
158
+ D = 1.0 - (Cn @ Cn.T)
159
+ else:
160
+ # "euclidean_centroid" and default fallback
161
+ D = torch.cdist(C, C, p=2)
162
+ return D
163
+
164
+ def compute_distance_matrix_umap(
165
+ datasets: Dict[int, Dict[str, torch.Tensor]],
166
+ umap_kwargs: dict,
167
+ channel_representation: str,
168
+ angle_delay_bins: int,
169
+ n_per_dataset: int,
170
+ distance_mode: str,
171
+ sw_num_projections: int,
172
+ label_aware: bool,
173
+ label_weighting: str,
174
+ label_max_per_class: int
175
+ ) -> torch.Tensor:
176
+ """
177
+ Placeholder: for now, reuse RAW. Swap in your UMAP pipeline later.
178
+ """
179
+ return compute_distance_matrix_raw(
180
+ datasets, n_per_dataset, distance_mode, sw_num_projections, label_aware, label_weighting, label_max_per_class
181
+ )
182
+
183
+ def compute_distance_matrix_lwm(
184
+ datasets: Dict[int, Dict[str, torch.Tensor]],
185
+ model_dir: str,
186
+ n_per_dataset: int,
187
+ distance_mode: str,
188
+ sw_num_projections: int,
189
+ label_aware: bool,
190
+ label_weighting: str,
191
+ label_max_per_class: int
192
+ ) -> torch.Tensor:
193
+ """
194
+ Placeholder: for now, reuse RAW. Replace with your LWM-embedding code that loads
195
+ the backbone from model_dir and computes pairwise distances from embeddings.
196
+ """
197
+ return compute_distance_matrix_raw(
198
+ datasets, n_per_dataset, distance_mode, sw_num_projections, label_aware, label_weighting, label_max_per_class
199
+ )
200
+
201
+ # =========================
202
+ # HF Model fetch (ONLY LWM)
203
+ # =========================
204
+ def fetch_lwm_model(model_repo: str, local_dir: str) -> str:
205
+ os.makedirs(local_dir, exist_ok=True)
206
+ ensure_hf_token()
207
+ snapshot_download(
208
+ repo_id=model_repo,
209
+ local_dir=local_dir,
210
+ local_dir_use_symlinks=False,
211
+ )
212
+ return f"Downloaded model repo: **{model_repo}** → `{local_dir}`"
213
+
214
+ # =========================
215
+ # UI callbacks
216
+ # =========================
217
+ def on_fetch_model(model_repo: str, model_dir: str):
218
+ try:
219
+ model_repo = model_repo.strip() or DEFAULT_MODEL_REPO
220
+ model_dir = model_dir.strip() or DEFAULT_MODEL_DIR
221
+ msg = fetch_lwm_model(model_repo, model_dir)
222
+ return gr.update(value=model_dir), log_md(msg)
223
+ except Exception as e:
224
+ return gr.update(value=model_dir), log_md(f"**Error**: {e}")
225
+
226
+ def on_compute(
227
+ files: List[gr.File],
228
+ framework: str,
229
+ distance_mode: str,
230
+ n_per_dataset: int,
231
+ sw_num_projections: int,
232
+ label_aware: bool,
233
+ label_weighting: str,
234
+ label_max_per_class: int,
235
+ model_dir: str,
236
+ umap_mode: str,
237
+ umap_n_components: int,
238
+ umap_n_neighbors: int,
239
+ umap_min_dist: float,
240
+ channel_representation: str,
241
+ angle_delay_bins: int
242
+ ):
243
+ try:
244
+ datasets = parse_uploaded_datasets(files)
245
+ if len(datasets) < 2:
246
+ return None, log_md("Please upload **≥ 2** datasets.")
247
+ if framework == "RAW":
248
+ D = compute_distance_matrix_raw(
249
+ datasets, int(n_per_dataset), distance_mode, int(sw_num_projections),
250
+ label_aware, label_weighting, int(label_max_per_class)
251
+ )
252
+ elif framework == "UMAP":
253
+ umap_kwargs = dict(
254
+ n_components=int(umap_n_components),
255
+ n_neighbors=int(umap_n_neighbors),
256
+ min_dist=float(umap_min_dist),
257
+ metric="euclidean",
258
+ random_state=42,
259
+ )
260
+ D = compute_distance_matrix_umap(
261
+ datasets, umap_kwargs, channel_representation, int(angle_delay_bins),
262
+ int(n_per_dataset), distance_mode, int(sw_num_projections),
263
+ label_aware, label_weighting, int(label_max_per_class)
264
+ )
265
+ else: # LWM
266
+ if not model_dir or not os.path.isdir(model_dir):
267
+ return None, log_md("LWM selected but **model dir** not found. Click *Fetch LWM model* first.")
268
+ D = compute_distance_matrix_lwm(
269
+ datasets, model_dir, int(n_per_dataset), distance_mode, int(sw_num_projections),
270
+ label_aware, label_weighting, int(label_max_per_class)
271
+ )
272
+
273
+ Dnp = D.detach().cpu().numpy().astype(float)
274
+ headers = [f"D{i}" for i in range(Dnp.shape[0])]
275
+ table = [[round(x, 6) for x in row] for row in Dnp]
276
+ return gr.update(value=table, headers=headers, row_count=(len(table), "fixed")), log_md("Done.")
277
+ except Exception as e:
278
+ return None, log_md(f"**Error**: {e}")
279
+
280
+ # =========================
281
+ # Gradio App
282
+ # =========================
283
+ with gr.Blocks(title="Dataset Distancing Lab") as demo:
284
+ gr.Markdown("# **Dataset Distancing Lab** \nUpload multiple datasets and compute similarity via **LWM / UMAP / RAW**.")
285
+
286
+ with gr.Row():
287
+ with gr.Column(scale=1):
288
+ gr.Markdown("### 1) Upload datasets (≥ 2)")
289
+ files_in = gr.File(file_count="multiple", label="Upload .pt/.pth/.npy/.npz", type="binary")
290
+
291
+ gr.Markdown("### 2) Choose framework & options")
292
+ framework_dd = gr.Radio(choices=["RAW", "UMAP", "LWM"], value="RAW", label="Framework")
293
+
294
+ distance_mode_dd = gr.Radio(
295
+ choices=["sliced_wasserstein", "euclidean_centroid", "cosine_similarity"],
296
+ value="euclidean_centroid", label="Distance mode"
297
+ )
298
+ n_per_ds_in = gr.Number(value=1024, precision=0, label="n_per_dataset (sampling)")
299
+ sw_proj_in = gr.Number(value=64, precision=0, label="SW num projections")
300
+ label_aware_cb = gr.Checkbox(value=True, label="Label-aware")
301
+ label_weighting_dd = gr.Radio(choices=["uniform", "support"], value="uniform", label="Label weighting")
302
+ label_max_in = gr.Number(value=1e10, precision=0, label="Label max per class")
303
+
304
+ with gr.Accordion("UMAP options", open=False):
305
+ umap_mode_dd = gr.Radio(choices=["unsupervised", "supervised"], value="supervised", label="UMAP mode")
306
+ umap_dim = gr.Slider(2, 256, value=128, step=1, label="UMAP n_components")
307
+ umap_knn = gr.Slider(2, 100, value=32, step=1, label="UMAP n_neighbors")
308
+ umap_min = gr.Slider(0.0, 0.99, value=0.1, step=0.01, label="UMAP min_dist")
309
+ chan_repr = gr.Radio(choices=["raw", "angle_delay"], value="angle_delay", label="Channel representation")
310
+ ad_bins = gr.Slider(4, 64, value=16, step=1, label="Angle-delay bins")
311
+
312
+ compute_btn = gr.Button("Compute distance matrix")
313
+
314
+ gr.Markdown("---")
315
+ gr.Markdown("### (Optional) Fetch LWM-v1.1 model")
316
+ model_repo_in = gr.Textbox(label="Model repo (HF)", value=DEFAULT_MODEL_REPO)
317
+ model_dir_in = gr.Textbox(label="Local model dir", value=DEFAULT_MODEL_DIR)
318
+ fetch_btn = gr.Button("Fetch LWM model")
319
+ fetch_status = gr.Markdown()
320
+
321
+ with gr.Column(scale=1):
322
+ gr.Markdown("### Distance Matrix")
323
+ matrix_out = gr.Dataframe(headers=[], value=None, interactive=False, wrap=True, row_count=(0, "dynamic"))
324
+ run_status = gr.Markdown()
325
+
326
+ fetch_btn.click(on_fetch_model, inputs=[model_repo_in, model_dir_in], outputs=[model_dir_in, fetch_status])
327
+
328
+ compute_btn.click(
329
+ on_compute,
330
+ inputs=[
331
+ files_in, framework_dd, distance_mode_dd, n_per_ds_in, sw_proj_in,
332
+ label_aware_cb, label_weighting_dd, label_max_in, model_dir_in,
333
+ umap_mode_dd, umap_dim, umap_knn, umap_min, chan_repr, ad_bins
334
+ ],
335
+ outputs=[matrix_out, run_status]
336
+ )
337
+
338
+ if __name__ == "__main__":
339
+ demo.launch()