from typing import List, Tuple, Optional import torch def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor: if torch.is_complex(v): return torch.cat([v.real, v.imag], dim=-1) return v @torch.no_grad() def build_raw_embeddings( datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], n_per_dataset: int, label_aware: bool ): embs = [] labels = [] if label_aware else None for ch, y, _ in datasets: # ch: [N, ...] N = ch.shape[0] n = min(int(n_per_dataset), int(N)) idx = torch.randperm(N)[:n] Xi = ch[idx] # flatten each sample Xi = Xi.reshape(n, -1) Xi = _flatten_complex_to_real(Xi) Xi = Xi.to(torch.float32) embs.append(Xi) # [n, d] if label_aware: if y is not None and len(y) >= n: labels.append(y[idx].to(torch.long)) else: labels.append(torch.empty((0,), dtype=torch.long)) embs = torch.stack(embs, dim=0) # [D, n, d] return embs, labels