dataset-distancing-lab / embed_raw.py
wi-lab's picture
Create embed_raw.py
d065796
from typing import List, Tuple, Optional
import torch
def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor:
if torch.is_complex(v):
return torch.cat([v.real, v.imag], dim=-1)
return v
@torch.no_grad()
def build_raw_embeddings(
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
n_per_dataset: int,
label_aware: bool
):
embs = []
labels = [] if label_aware else None
for ch, y, _ in datasets:
# ch: [N, ...]
N = ch.shape[0]
n = min(int(n_per_dataset), int(N))
idx = torch.randperm(N)[:n]
Xi = ch[idx]
# flatten each sample
Xi = Xi.reshape(n, -1)
Xi = _flatten_complex_to_real(Xi)
Xi = Xi.to(torch.float32)
embs.append(Xi) # [n, d]
if label_aware:
if y is not None and len(y) >= n:
labels.append(y[idx].to(torch.long))
else:
labels.append(torch.empty((0,), dtype=torch.long))
embs = torch.stack(embs, dim=0) # [D, n, d]
return embs, labels