from typing import List, Tuple, Optional, Dict import torch import umap def _prep_for_umap(chs: torch.Tensor, representation: str, angle_delay_bins: int) -> torch.Tensor: """ Turn a batch of channels into 2D features for UMAP; concatenate real+imag if complex. """ X = chs if X.ndim > 2: X = X.reshape(X.shape[0], -1) if torch.is_complex(X): X = torch.cat([X.real, X.imag], dim=-1) X = X.to(torch.float32) return X @torch.no_grad() def build_umap_embeddings( datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], n_per_dataset: int, label_aware: bool, umap_mode: str, umap_kwargs: Dict, channel_representation: str = "raw", angle_delay_bins: int = 16, ): all_X = [] all_y = [] sizes = [] for ch, y, _ in datasets: N = ch.shape[0] n = min(int(n_per_dataset), int(N)) idx = torch.randperm(N)[:n] Xi = ch[idx] Xi = _prep_for_umap(Xi, channel_representation, angle_delay_bins) # [n, d] all_X.append(Xi) if y is not None: all_y.append(y[idx].to(torch.long)) sizes.append(n) X = torch.cat(all_X, dim=0) # [sum n, d] y_all = torch.cat(all_y, dim=0) if (label_aware and all_y) else None reducer = umap.UMAP(**umap_kwargs) if umap_mode == "supervised" and y_all is not None and y_all.numel() > 0: U = reducer.fit_transform(X.numpy(), y=y_all.numpy()) else: U = reducer.fit_transform(X.numpy()) start = 0 embs = [] labels_per_ds = [] if label_aware else None for n in sizes: Ui = torch.from_numpy(U[start:start+n]).to(torch.float32) # [n, d_umap] embs.append(Ui) if label_aware: if y_all is not None and y_all.numel() >= start + n: labels_per_ds.append(y_all[start:start+n]) else: labels_per_ds.append(torch.empty((0,), dtype=torch.long)) start += n embs = torch.stack(embs, dim=0) # [D, n, d_umap] return embs, labels_per_ds