Spaces:
Running
Running
File size: 2,065 Bytes
5064e79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from typing import List, Tuple, Optional, Dict
import torch
import umap
def _prep_for_umap(chs: torch.Tensor, representation: str, angle_delay_bins: int) -> torch.Tensor:
"""
Turn a batch of channels into 2D features for UMAP; concatenate real+imag if complex.
"""
X = chs
if X.ndim > 2:
X = X.reshape(X.shape[0], -1)
if torch.is_complex(X):
X = torch.cat([X.real, X.imag], dim=-1)
X = X.to(torch.float32)
return X
@torch.no_grad()
def build_umap_embeddings(
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
n_per_dataset: int,
label_aware: bool,
umap_mode: str,
umap_kwargs: Dict,
channel_representation: str = "raw",
angle_delay_bins: int = 16,
):
all_X = []
all_y = []
sizes = []
for ch, y, _ in datasets:
N = ch.shape[0]
n = min(int(n_per_dataset), int(N))
idx = torch.randperm(N)[:n]
Xi = ch[idx]
Xi = _prep_for_umap(Xi, channel_representation, angle_delay_bins) # [n, d]
all_X.append(Xi)
if y is not None:
all_y.append(y[idx].to(torch.long))
sizes.append(n)
X = torch.cat(all_X, dim=0) # [sum n, d]
y_all = torch.cat(all_y, dim=0) if (label_aware and all_y) else None
reducer = umap.UMAP(**umap_kwargs)
if umap_mode == "supervised" and y_all is not None and y_all.numel() > 0:
U = reducer.fit_transform(X.numpy(), y=y_all.numpy())
else:
U = reducer.fit_transform(X.numpy())
start = 0
embs = []
labels_per_ds = [] if label_aware else None
for n in sizes:
Ui = torch.from_numpy(U[start:start+n]).to(torch.float32) # [n, d_umap]
embs.append(Ui)
if label_aware:
if y_all is not None and y_all.numel() >= start + n:
labels_per_ds.append(y_all[start:start+n])
else:
labels_per_ds.append(torch.empty((0,), dtype=torch.long))
start += n
embs = torch.stack(embs, dim=0) # [D, n, d_umap]
return embs, labels_per_ds |