Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
from typing import List, Dict, Tuple, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from huggingface_hub import snapshot_download, HfFolder
|
| 10 |
+
|
| 11 |
+
# =========================
|
| 12 |
+
# Helpers: logging + tokens
|
| 13 |
+
# =========================
|
| 14 |
+
DEFAULT_MODEL_REPO = "wi-lab/lwm-v1.1"
|
| 15 |
+
DEFAULT_MODEL_DIR = "./LWM-v1.1"
|
| 16 |
+
|
| 17 |
+
def ensure_hf_token():
|
| 18 |
+
tok = os.getenv("HF_TOKEN", None)
|
| 19 |
+
if tok:
|
| 20 |
+
HfFolder.save_token(tok)
|
| 21 |
+
return tok
|
| 22 |
+
|
| 23 |
+
def log_md(msg: str) -> str:
|
| 24 |
+
return f"{msg}"
|
| 25 |
+
|
| 26 |
+
# =========================
|
| 27 |
+
# Dataset loading utilities
|
| 28 |
+
# =========================
|
| 29 |
+
def _load_pt_bytes(b: bytes) -> Dict[str, torch.Tensor]:
|
| 30 |
+
# Expect a dict with at least "channels". Optional: "labels"
|
| 31 |
+
buf = io.BytesIO(b)
|
| 32 |
+
obj = torch.load(buf, map_location="cpu")
|
| 33 |
+
if isinstance(obj, dict) and "channels" in obj:
|
| 34 |
+
return {
|
| 35 |
+
"channels": obj["channels"],
|
| 36 |
+
"labels": obj.get("labels", None)
|
| 37 |
+
}
|
| 38 |
+
# Fallback: if it’s a tensor
|
| 39 |
+
if torch.is_tensor(obj):
|
| 40 |
+
return {"channels": obj, "labels": None}
|
| 41 |
+
raise ValueError("PT file must contain a dict with 'channels' (and optional 'labels'), or a tensor.")
|
| 42 |
+
|
| 43 |
+
def _load_npy_bytes(b: bytes) -> Dict[str, torch.Tensor]:
|
| 44 |
+
buf = io.BytesIO(b)
|
| 45 |
+
arr = np.load(buf, allow_pickle=True)
|
| 46 |
+
# If it's an array directly
|
| 47 |
+
if isinstance(arr, np.ndarray):
|
| 48 |
+
t = torch.from_numpy(arr)
|
| 49 |
+
return {"channels": t, "labels": None}
|
| 50 |
+
# If it's a dict-like (rare for npy)
|
| 51 |
+
raise ValueError("NPY must contain a single ndarray (channels). For dict-like, use NPZ.")
|
| 52 |
+
|
| 53 |
+
def _load_npz_bytes(b: bytes) -> Dict[str, torch.Tensor]:
|
| 54 |
+
buf = io.BytesIO(b)
|
| 55 |
+
npz = np.load(buf, allow_pickle=True)
|
| 56 |
+
# Expect either keys "channels" and optional "labels", or fallback to first array
|
| 57 |
+
if "channels" in npz:
|
| 58 |
+
ch = npz["channels"]
|
| 59 |
+
labs = npz["labels"] if "labels" in npz else None
|
| 60 |
+
return {
|
| 61 |
+
"channels": torch.from_numpy(ch),
|
| 62 |
+
"labels": (torch.from_numpy(labs) if labs is not None else None)
|
| 63 |
+
}
|
| 64 |
+
# Fallback: take the first array in the file
|
| 65 |
+
keys = list(npz.keys())
|
| 66 |
+
if not keys:
|
| 67 |
+
raise ValueError("Empty NPZ.")
|
| 68 |
+
ch = npz[keys[0]]
|
| 69 |
+
return {"channels": torch.from_numpy(ch), "labels": None}
|
| 70 |
+
|
| 71 |
+
def parse_uploaded_datasets(files: List[gr.File]) -> Dict[int, Dict[str, torch.Tensor]]:
|
| 72 |
+
"""
|
| 73 |
+
Accepts multiple files. Each becomes one dataset.
|
| 74 |
+
Supported:
|
| 75 |
+
- .pt / .pth (torch.save)
|
| 76 |
+
- .npy (single array)
|
| 77 |
+
- .npz (expects 'channels' and optional 'labels', else uses first array)
|
| 78 |
+
Output: {0: {'channels': Tensor[N, ...], 'labels': Optional[Tensor[N]]}, 1: {...}, ...}
|
| 79 |
+
"""
|
| 80 |
+
datasets = {}
|
| 81 |
+
idx = 0
|
| 82 |
+
for f in files or []:
|
| 83 |
+
name = f.name or ""
|
| 84 |
+
data = f.read()
|
| 85 |
+
try:
|
| 86 |
+
if name.endswith((".pt", ".pth")):
|
| 87 |
+
ds = _load_pt_bytes(data)
|
| 88 |
+
elif name.endswith(".npy"):
|
| 89 |
+
ds = _load_npy_bytes(data)
|
| 90 |
+
elif name.endswith(".npz"):
|
| 91 |
+
ds = _load_npz_bytes(data)
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Unsupported file type: {name}")
|
| 94 |
+
# Ensure tensors are float and shaped as [N, ...]
|
| 95 |
+
ch = ds["channels"]
|
| 96 |
+
if ch.ndim == 1:
|
| 97 |
+
ch = ch.unsqueeze(0) # [1, D]
|
| 98 |
+
ds["channels"] = ch
|
| 99 |
+
datasets[idx] = ds
|
| 100 |
+
idx += 1
|
| 101 |
+
except Exception as e:
|
| 102 |
+
raise ValueError(f"Failed to load '{name}': {e}")
|
| 103 |
+
return datasets
|
| 104 |
+
|
| 105 |
+
# =========================
|
| 106 |
+
# Distance backends (stubs)
|
| 107 |
+
# =========================
|
| 108 |
+
def _to_feature_matrix(chs: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Flatten per-sample, split complex into [real, imag], return [N, D] float32
|
| 111 |
+
"""
|
| 112 |
+
if chs.ndim >= 3:
|
| 113 |
+
chs = chs.reshape(chs.shape[0], -1) # [N, ...] -> [N, D]
|
| 114 |
+
elif chs.ndim == 2:
|
| 115 |
+
pass # already [N, D]
|
| 116 |
+
else:
|
| 117 |
+
chs = chs.view(chs.shape[0], -1)
|
| 118 |
+
|
| 119 |
+
if torch.is_complex(chs):
|
| 120 |
+
chs = torch.cat([chs.real, chs.imag], dim=1)
|
| 121 |
+
return chs.to(torch.float32)
|
| 122 |
+
|
| 123 |
+
def _pad_to_same_dim(mats: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 124 |
+
max_d = max(m.shape[1] for m in mats)
|
| 125 |
+
out = []
|
| 126 |
+
for m in mats:
|
| 127 |
+
if m.shape[1] < max_d:
|
| 128 |
+
pad = torch.zeros((m.shape[0], max_d - m.shape[1]), dtype=m.dtype)
|
| 129 |
+
m = torch.cat([m, pad], dim=1)
|
| 130 |
+
out.append(m)
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
def compute_distance_matrix_raw(
|
| 134 |
+
datasets: Dict[int, Dict[str, torch.Tensor]],
|
| 135 |
+
n_per_dataset: int,
|
| 136 |
+
distance_mode: str,
|
| 137 |
+
sw_num_projections: int,
|
| 138 |
+
label_aware: bool,
|
| 139 |
+
label_weighting: str,
|
| 140 |
+
label_max_per_class: int
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
"""
|
| 143 |
+
Minimal RAW baseline: centroid L2 or cosine. SW is not implemented here (stub).
|
| 144 |
+
"""
|
| 145 |
+
mats = []
|
| 146 |
+
for i in sorted(datasets.keys()):
|
| 147 |
+
ch = datasets[i]["channels"]
|
| 148 |
+
n = min(n_per_dataset, ch.shape[0]) if n_per_dataset else ch.shape[0]
|
| 149 |
+
idxs = torch.randperm(ch.shape[0])[:n]
|
| 150 |
+
X = _to_feature_matrix(ch[idxs])
|
| 151 |
+
mats.append(X)
|
| 152 |
+
mats = _pad_to_same_dim(mats)
|
| 153 |
+
cents = [M.mean(dim=0, keepdim=True) for M in mats]
|
| 154 |
+
C = torch.cat(cents, dim=0) # [D, Df]
|
| 155 |
+
|
| 156 |
+
if distance_mode == "cosine_similarity":
|
| 157 |
+
Cn = torch.nn.functional.normalize(C, dim=1)
|
| 158 |
+
D = 1.0 - (Cn @ Cn.T)
|
| 159 |
+
else:
|
| 160 |
+
# "euclidean_centroid" and default fallback
|
| 161 |
+
D = torch.cdist(C, C, p=2)
|
| 162 |
+
return D
|
| 163 |
+
|
| 164 |
+
def compute_distance_matrix_umap(
|
| 165 |
+
datasets: Dict[int, Dict[str, torch.Tensor]],
|
| 166 |
+
umap_kwargs: dict,
|
| 167 |
+
channel_representation: str,
|
| 168 |
+
angle_delay_bins: int,
|
| 169 |
+
n_per_dataset: int,
|
| 170 |
+
distance_mode: str,
|
| 171 |
+
sw_num_projections: int,
|
| 172 |
+
label_aware: bool,
|
| 173 |
+
label_weighting: str,
|
| 174 |
+
label_max_per_class: int
|
| 175 |
+
) -> torch.Tensor:
|
| 176 |
+
"""
|
| 177 |
+
Placeholder: for now, reuse RAW. Swap in your UMAP pipeline later.
|
| 178 |
+
"""
|
| 179 |
+
return compute_distance_matrix_raw(
|
| 180 |
+
datasets, n_per_dataset, distance_mode, sw_num_projections, label_aware, label_weighting, label_max_per_class
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def compute_distance_matrix_lwm(
|
| 184 |
+
datasets: Dict[int, Dict[str, torch.Tensor]],
|
| 185 |
+
model_dir: str,
|
| 186 |
+
n_per_dataset: int,
|
| 187 |
+
distance_mode: str,
|
| 188 |
+
sw_num_projections: int,
|
| 189 |
+
label_aware: bool,
|
| 190 |
+
label_weighting: str,
|
| 191 |
+
label_max_per_class: int
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
"""
|
| 194 |
+
Placeholder: for now, reuse RAW. Replace with your LWM-embedding code that loads
|
| 195 |
+
the backbone from model_dir and computes pairwise distances from embeddings.
|
| 196 |
+
"""
|
| 197 |
+
return compute_distance_matrix_raw(
|
| 198 |
+
datasets, n_per_dataset, distance_mode, sw_num_projections, label_aware, label_weighting, label_max_per_class
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# =========================
|
| 202 |
+
# HF Model fetch (ONLY LWM)
|
| 203 |
+
# =========================
|
| 204 |
+
def fetch_lwm_model(model_repo: str, local_dir: str) -> str:
|
| 205 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 206 |
+
ensure_hf_token()
|
| 207 |
+
snapshot_download(
|
| 208 |
+
repo_id=model_repo,
|
| 209 |
+
local_dir=local_dir,
|
| 210 |
+
local_dir_use_symlinks=False,
|
| 211 |
+
)
|
| 212 |
+
return f"Downloaded model repo: **{model_repo}** → `{local_dir}`"
|
| 213 |
+
|
| 214 |
+
# =========================
|
| 215 |
+
# UI callbacks
|
| 216 |
+
# =========================
|
| 217 |
+
def on_fetch_model(model_repo: str, model_dir: str):
|
| 218 |
+
try:
|
| 219 |
+
model_repo = model_repo.strip() or DEFAULT_MODEL_REPO
|
| 220 |
+
model_dir = model_dir.strip() or DEFAULT_MODEL_DIR
|
| 221 |
+
msg = fetch_lwm_model(model_repo, model_dir)
|
| 222 |
+
return gr.update(value=model_dir), log_md(msg)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
return gr.update(value=model_dir), log_md(f"**Error**: {e}")
|
| 225 |
+
|
| 226 |
+
def on_compute(
|
| 227 |
+
files: List[gr.File],
|
| 228 |
+
framework: str,
|
| 229 |
+
distance_mode: str,
|
| 230 |
+
n_per_dataset: int,
|
| 231 |
+
sw_num_projections: int,
|
| 232 |
+
label_aware: bool,
|
| 233 |
+
label_weighting: str,
|
| 234 |
+
label_max_per_class: int,
|
| 235 |
+
model_dir: str,
|
| 236 |
+
umap_mode: str,
|
| 237 |
+
umap_n_components: int,
|
| 238 |
+
umap_n_neighbors: int,
|
| 239 |
+
umap_min_dist: float,
|
| 240 |
+
channel_representation: str,
|
| 241 |
+
angle_delay_bins: int
|
| 242 |
+
):
|
| 243 |
+
try:
|
| 244 |
+
datasets = parse_uploaded_datasets(files)
|
| 245 |
+
if len(datasets) < 2:
|
| 246 |
+
return None, log_md("Please upload **≥ 2** datasets.")
|
| 247 |
+
if framework == "RAW":
|
| 248 |
+
D = compute_distance_matrix_raw(
|
| 249 |
+
datasets, int(n_per_dataset), distance_mode, int(sw_num_projections),
|
| 250 |
+
label_aware, label_weighting, int(label_max_per_class)
|
| 251 |
+
)
|
| 252 |
+
elif framework == "UMAP":
|
| 253 |
+
umap_kwargs = dict(
|
| 254 |
+
n_components=int(umap_n_components),
|
| 255 |
+
n_neighbors=int(umap_n_neighbors),
|
| 256 |
+
min_dist=float(umap_min_dist),
|
| 257 |
+
metric="euclidean",
|
| 258 |
+
random_state=42,
|
| 259 |
+
)
|
| 260 |
+
D = compute_distance_matrix_umap(
|
| 261 |
+
datasets, umap_kwargs, channel_representation, int(angle_delay_bins),
|
| 262 |
+
int(n_per_dataset), distance_mode, int(sw_num_projections),
|
| 263 |
+
label_aware, label_weighting, int(label_max_per_class)
|
| 264 |
+
)
|
| 265 |
+
else: # LWM
|
| 266 |
+
if not model_dir or not os.path.isdir(model_dir):
|
| 267 |
+
return None, log_md("LWM selected but **model dir** not found. Click *Fetch LWM model* first.")
|
| 268 |
+
D = compute_distance_matrix_lwm(
|
| 269 |
+
datasets, model_dir, int(n_per_dataset), distance_mode, int(sw_num_projections),
|
| 270 |
+
label_aware, label_weighting, int(label_max_per_class)
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
Dnp = D.detach().cpu().numpy().astype(float)
|
| 274 |
+
headers = [f"D{i}" for i in range(Dnp.shape[0])]
|
| 275 |
+
table = [[round(x, 6) for x in row] for row in Dnp]
|
| 276 |
+
return gr.update(value=table, headers=headers, row_count=(len(table), "fixed")), log_md("Done.")
|
| 277 |
+
except Exception as e:
|
| 278 |
+
return None, log_md(f"**Error**: {e}")
|
| 279 |
+
|
| 280 |
+
# =========================
|
| 281 |
+
# Gradio App
|
| 282 |
+
# =========================
|
| 283 |
+
with gr.Blocks(title="Dataset Distancing Lab") as demo:
|
| 284 |
+
gr.Markdown("# **Dataset Distancing Lab** \nUpload multiple datasets and compute similarity via **LWM / UMAP / RAW**.")
|
| 285 |
+
|
| 286 |
+
with gr.Row():
|
| 287 |
+
with gr.Column(scale=1):
|
| 288 |
+
gr.Markdown("### 1) Upload datasets (≥ 2)")
|
| 289 |
+
files_in = gr.File(file_count="multiple", label="Upload .pt/.pth/.npy/.npz", type="binary")
|
| 290 |
+
|
| 291 |
+
gr.Markdown("### 2) Choose framework & options")
|
| 292 |
+
framework_dd = gr.Radio(choices=["RAW", "UMAP", "LWM"], value="RAW", label="Framework")
|
| 293 |
+
|
| 294 |
+
distance_mode_dd = gr.Radio(
|
| 295 |
+
choices=["sliced_wasserstein", "euclidean_centroid", "cosine_similarity"],
|
| 296 |
+
value="euclidean_centroid", label="Distance mode"
|
| 297 |
+
)
|
| 298 |
+
n_per_ds_in = gr.Number(value=1024, precision=0, label="n_per_dataset (sampling)")
|
| 299 |
+
sw_proj_in = gr.Number(value=64, precision=0, label="SW num projections")
|
| 300 |
+
label_aware_cb = gr.Checkbox(value=True, label="Label-aware")
|
| 301 |
+
label_weighting_dd = gr.Radio(choices=["uniform", "support"], value="uniform", label="Label weighting")
|
| 302 |
+
label_max_in = gr.Number(value=1e10, precision=0, label="Label max per class")
|
| 303 |
+
|
| 304 |
+
with gr.Accordion("UMAP options", open=False):
|
| 305 |
+
umap_mode_dd = gr.Radio(choices=["unsupervised", "supervised"], value="supervised", label="UMAP mode")
|
| 306 |
+
umap_dim = gr.Slider(2, 256, value=128, step=1, label="UMAP n_components")
|
| 307 |
+
umap_knn = gr.Slider(2, 100, value=32, step=1, label="UMAP n_neighbors")
|
| 308 |
+
umap_min = gr.Slider(0.0, 0.99, value=0.1, step=0.01, label="UMAP min_dist")
|
| 309 |
+
chan_repr = gr.Radio(choices=["raw", "angle_delay"], value="angle_delay", label="Channel representation")
|
| 310 |
+
ad_bins = gr.Slider(4, 64, value=16, step=1, label="Angle-delay bins")
|
| 311 |
+
|
| 312 |
+
compute_btn = gr.Button("Compute distance matrix")
|
| 313 |
+
|
| 314 |
+
gr.Markdown("---")
|
| 315 |
+
gr.Markdown("### (Optional) Fetch LWM-v1.1 model")
|
| 316 |
+
model_repo_in = gr.Textbox(label="Model repo (HF)", value=DEFAULT_MODEL_REPO)
|
| 317 |
+
model_dir_in = gr.Textbox(label="Local model dir", value=DEFAULT_MODEL_DIR)
|
| 318 |
+
fetch_btn = gr.Button("Fetch LWM model")
|
| 319 |
+
fetch_status = gr.Markdown()
|
| 320 |
+
|
| 321 |
+
with gr.Column(scale=1):
|
| 322 |
+
gr.Markdown("### Distance Matrix")
|
| 323 |
+
matrix_out = gr.Dataframe(headers=[], value=None, interactive=False, wrap=True, row_count=(0, "dynamic"))
|
| 324 |
+
run_status = gr.Markdown()
|
| 325 |
+
|
| 326 |
+
fetch_btn.click(on_fetch_model, inputs=[model_repo_in, model_dir_in], outputs=[model_dir_in, fetch_status])
|
| 327 |
+
|
| 328 |
+
compute_btn.click(
|
| 329 |
+
on_compute,
|
| 330 |
+
inputs=[
|
| 331 |
+
files_in, framework_dd, distance_mode_dd, n_per_ds_in, sw_proj_in,
|
| 332 |
+
label_aware_cb, label_weighting_dd, label_max_in, model_dir_in,
|
| 333 |
+
umap_mode_dd, umap_dim, umap_knn, umap_min, chan_repr, ad_bins
|
| 334 |
+
],
|
| 335 |
+
outputs=[matrix_out, run_status]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
demo.launch()
|