Spaces:
Running
Running
File size: 1,057 Bytes
d065796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from typing import List, Tuple, Optional
import torch
def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor:
if torch.is_complex(v):
return torch.cat([v.real, v.imag], dim=-1)
return v
@torch.no_grad()
def build_raw_embeddings(
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
n_per_dataset: int,
label_aware: bool
):
embs = []
labels = [] if label_aware else None
for ch, y, _ in datasets:
# ch: [N, ...]
N = ch.shape[0]
n = min(int(n_per_dataset), int(N))
idx = torch.randperm(N)[:n]
Xi = ch[idx]
# flatten each sample
Xi = Xi.reshape(n, -1)
Xi = _flatten_complex_to_real(Xi)
Xi = Xi.to(torch.float32)
embs.append(Xi) # [n, d]
if label_aware:
if y is not None and len(y) >= n:
labels.append(y[idx].to(torch.long))
else:
labels.append(torch.empty((0,), dtype=torch.long))
embs = torch.stack(embs, dim=0) # [D, n, d]
return embs, labels |