Spaces:
Sleeping
Sleeping
Create embed_raw.py
Browse files- embed_raw.py +34 -0
embed_raw.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Optional
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor:
|
| 5 |
+
if torch.is_complex(v):
|
| 6 |
+
return torch.cat([v.real, v.imag], dim=-1)
|
| 7 |
+
return v
|
| 8 |
+
|
| 9 |
+
@torch.no_grad()
|
| 10 |
+
def build_raw_embeddings(
|
| 11 |
+
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
|
| 12 |
+
n_per_dataset: int,
|
| 13 |
+
label_aware: bool
|
| 14 |
+
):
|
| 15 |
+
embs = []
|
| 16 |
+
labels = [] if label_aware else None
|
| 17 |
+
for ch, y, _ in datasets:
|
| 18 |
+
# ch: [N, ...]
|
| 19 |
+
N = ch.shape[0]
|
| 20 |
+
n = min(int(n_per_dataset), int(N))
|
| 21 |
+
idx = torch.randperm(N)[:n]
|
| 22 |
+
Xi = ch[idx]
|
| 23 |
+
# flatten each sample
|
| 24 |
+
Xi = Xi.reshape(n, -1)
|
| 25 |
+
Xi = _flatten_complex_to_real(Xi)
|
| 26 |
+
Xi = Xi.to(torch.float32)
|
| 27 |
+
embs.append(Xi) # [n, d]
|
| 28 |
+
if label_aware:
|
| 29 |
+
if y is not None and len(y) >= n:
|
| 30 |
+
labels.append(y[idx].to(torch.long))
|
| 31 |
+
else:
|
| 32 |
+
labels.append(torch.empty((0,), dtype=torch.long))
|
| 33 |
+
embs = torch.stack(embs, dim=0) # [D, n, d]
|
| 34 |
+
return embs, labels
|