File size: 1,057 Bytes
d065796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import List, Tuple, Optional
import torch

def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor:
    if torch.is_complex(v):
        return torch.cat([v.real, v.imag], dim=-1)
    return v

@torch.no_grad()
def build_raw_embeddings(
    datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
    n_per_dataset: int,
    label_aware: bool
):
    embs = []
    labels = [] if label_aware else None
    for ch, y, _ in datasets:
        # ch: [N, ...]
        N = ch.shape[0]
        n = min(int(n_per_dataset), int(N))
        idx = torch.randperm(N)[:n]
        Xi = ch[idx]
        # flatten each sample
        Xi = Xi.reshape(n, -1)
        Xi = _flatten_complex_to_real(Xi)
        Xi = Xi.to(torch.float32)
        embs.append(Xi)  # [n, d]
        if label_aware:
            if y is not None and len(y) >= n:
                labels.append(y[idx].to(torch.long))
            else:
                labels.append(torch.empty((0,), dtype=torch.long))
    embs = torch.stack(embs, dim=0)  # [D, n, d]
    return embs, labels