wi-lab commited on
Commit
d065796
·
1 Parent(s): cea1e88

Create embed_raw.py

Browse files
Files changed (1) hide show
  1. embed_raw.py +34 -0
embed_raw.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional
2
+ import torch
3
+
4
+ def _flatten_complex_to_real(v: torch.Tensor) -> torch.Tensor:
5
+ if torch.is_complex(v):
6
+ return torch.cat([v.real, v.imag], dim=-1)
7
+ return v
8
+
9
+ @torch.no_grad()
10
+ def build_raw_embeddings(
11
+ datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
12
+ n_per_dataset: int,
13
+ label_aware: bool
14
+ ):
15
+ embs = []
16
+ labels = [] if label_aware else None
17
+ for ch, y, _ in datasets:
18
+ # ch: [N, ...]
19
+ N = ch.shape[0]
20
+ n = min(int(n_per_dataset), int(N))
21
+ idx = torch.randperm(N)[:n]
22
+ Xi = ch[idx]
23
+ # flatten each sample
24
+ Xi = Xi.reshape(n, -1)
25
+ Xi = _flatten_complex_to_real(Xi)
26
+ Xi = Xi.to(torch.float32)
27
+ embs.append(Xi) # [n, d]
28
+ if label_aware:
29
+ if y is not None and len(y) >= n:
30
+ labels.append(y[idx].to(torch.long))
31
+ else:
32
+ labels.append(torch.empty((0,), dtype=torch.long))
33
+ embs = torch.stack(embs, dim=0) # [D, n, d]
34
+ return embs, labels