Spaces:
Running
Running
| from typing import List, Tuple, Optional, Dict | |
| import torch | |
| import umap | |
| def _prep_for_umap(chs: torch.Tensor, representation: str, angle_delay_bins: int) -> torch.Tensor: | |
| """ | |
| Turn a batch of channels into 2D features for UMAP; concatenate real+imag if complex. | |
| """ | |
| X = chs | |
| if X.ndim > 2: | |
| X = X.reshape(X.shape[0], -1) | |
| if torch.is_complex(X): | |
| X = torch.cat([X.real, X.imag], dim=-1) | |
| X = X.to(torch.float32) | |
| return X | |
| def build_umap_embeddings( | |
| datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], | |
| n_per_dataset: int, | |
| label_aware: bool, | |
| umap_mode: str, | |
| umap_kwargs: Dict, | |
| channel_representation: str = "raw", | |
| angle_delay_bins: int = 16, | |
| ): | |
| all_X = [] | |
| all_y = [] | |
| sizes = [] | |
| for ch, y, _ in datasets: | |
| N = ch.shape[0] | |
| n = min(int(n_per_dataset), int(N)) | |
| idx = torch.randperm(N)[:n] | |
| Xi = ch[idx] | |
| Xi = _prep_for_umap(Xi, channel_representation, angle_delay_bins) # [n, d] | |
| all_X.append(Xi) | |
| if y is not None: | |
| all_y.append(y[idx].to(torch.long)) | |
| sizes.append(n) | |
| X = torch.cat(all_X, dim=0) # [sum n, d] | |
| y_all = torch.cat(all_y, dim=0) if (label_aware and all_y) else None | |
| reducer = umap.UMAP(**umap_kwargs) | |
| if umap_mode == "supervised" and y_all is not None and y_all.numel() > 0: | |
| U = reducer.fit_transform(X.numpy(), y=y_all.numpy()) | |
| else: | |
| U = reducer.fit_transform(X.numpy()) | |
| start = 0 | |
| embs = [] | |
| labels_per_ds = [] if label_aware else None | |
| for n in sizes: | |
| Ui = torch.from_numpy(U[start:start+n]).to(torch.float32) # [n, d_umap] | |
| embs.append(Ui) | |
| if label_aware: | |
| if y_all is not None and y_all.numel() >= start + n: | |
| labels_per_ds.append(y_all[start:start+n]) | |
| else: | |
| labels_per_ds.append(torch.empty((0,), dtype=torch.long)) | |
| start += n | |
| embs = torch.stack(embs, dim=0) # [D, n, d_umap] | |
| return embs, labels_per_ds |