Spaces:
Sleeping
Sleeping
| from typing import List, Tuple, Optional | |
| import torch | |
| def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor: | |
| if torch.is_complex(v): | |
| return torch.cat([v.real, v.imag], dim=-1) | |
| return v | |
| def build_raw_embeddings( | |
| datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], | |
| n_per_dataset: int, | |
| label_aware: bool | |
| ): | |
| embs = [] | |
| labels = [] if label_aware else None | |
| for ch, y, _ in datasets: | |
| # ch: [N, ...] | |
| N = ch.shape[0] | |
| n = min(int(n_per_dataset), int(N)) | |
| idx = torch.randperm(N)[:n] | |
| Xi = ch[idx] | |
| # flatten each sample | |
| Xi = Xi.reshape(n, -1) | |
| Xi = _flatten_complex_to_real(Xi) | |
| Xi = Xi.to(torch.float32) | |
| embs.append(Xi) # [n, d] | |
| if label_aware: | |
| if y is not None and len(y) >= n: | |
| labels.append(y[idx].to(torch.long)) | |
| else: | |
| labels.append(torch.empty((0,), dtype=torch.long)) | |
| embs = torch.stack(embs, dim=0) # [D, n, d] | |
| return embs, labels |