File size: 2,065 Bytes
5064e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import List, Tuple, Optional, Dict
import torch
import umap


def _prep_for_umap(chs: torch.Tensor, representation: str, angle_delay_bins: int) -> torch.Tensor:
    """
    Turn a batch of channels into 2D features for UMAP; concatenate real+imag if complex.
    """
    X = chs
    if X.ndim > 2:
        X = X.reshape(X.shape[0], -1)
    if torch.is_complex(X):
        X = torch.cat([X.real, X.imag], dim=-1)
    X = X.to(torch.float32)
    return X


@torch.no_grad()
def build_umap_embeddings(
    datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
    n_per_dataset: int,
    label_aware: bool,
    umap_mode: str,
    umap_kwargs: Dict,
    channel_representation: str = "raw",
    angle_delay_bins: int = 16,
):
    all_X = []
    all_y = []
    sizes = []
    for ch, y, _ in datasets:
        N = ch.shape[0]
        n = min(int(n_per_dataset), int(N))
        idx = torch.randperm(N)[:n]
        Xi = ch[idx]
        Xi = _prep_for_umap(Xi, channel_representation, angle_delay_bins)  # [n, d]
        all_X.append(Xi)
        if y is not None:
            all_y.append(y[idx].to(torch.long))
        sizes.append(n)

    X = torch.cat(all_X, dim=0)            # [sum n, d]
    y_all = torch.cat(all_y, dim=0) if (label_aware and all_y) else None

    reducer = umap.UMAP(**umap_kwargs)
    if umap_mode == "supervised" and y_all is not None and y_all.numel() > 0:
        U = reducer.fit_transform(X.numpy(), y=y_all.numpy())
    else:
        U = reducer.fit_transform(X.numpy())

    start = 0
    embs = []
    labels_per_ds = [] if label_aware else None
    for n in sizes:
        Ui = torch.from_numpy(U[start:start+n]).to(torch.float32)  # [n, d_umap]
        embs.append(Ui)
        if label_aware:
            if y_all is not None and y_all.numel() >= start + n:
                labels_per_ds.append(y_all[start:start+n])
            else:
                labels_per_ds.append(torch.empty((0,), dtype=torch.long))
        start += n

    embs = torch.stack(embs, dim=0)  # [D, n, d_umap]
    return embs, labels_per_ds