wi-lab commited on
Commit
7589a7e
·
1 Parent(s): c5f5103

Update embed_lwm.py

Browse files
Files changed (1) hide show
  1. embed_lwm.py +151 -59
embed_lwm.py CHANGED
@@ -1,48 +1,114 @@
1
  import os
2
  import sys
3
- from typing import List, Tuple, Optional
4
 
5
  import torch
6
- from huggingface_hub import snapshot_download
7
 
 
 
8
 
9
- _LWM_MODEL = None
10
- _LWM_DIR = None
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def get_lwm_encoder():
14
  """
15
- Try to download & load wi-lab/lwm-v1.1 and create the encoder.
16
- Returns a torch.nn.Module or None on failure.
17
  """
18
- global _LWM_MODEL, _LWM_DIR
19
- if _LWM_MODEL is not None:
20
- return _LWM_MODEL
 
 
 
 
 
 
21
  try:
22
- _LWM_DIR = snapshot_download(repo_id="wi-lab/lwm-v1.1", local_dir="./LWM-v1.1", local_dir_use_symlinks=False)
23
- if _LWM_DIR not in sys.path:
24
- sys.path.append(_LWM_DIR)
25
- from pretrained_model import lwm # type: ignore
26
- model = lwm()
27
- # Try common ckpt filenames
28
- cand = None
29
- for fn in ["model_checkpoint.pth", "checkpoint.pth", "lwm_v1.1.pth"]:
30
- p = os.path.join(_LWM_DIR, fn)
31
- if os.path.exists(p):
32
- cand = p
33
- break
34
- if cand:
35
- state = torch.load(cand, map_location="cpu")
36
- # handle optional "module." prefix
37
- if any(k.startswith("module.") for k in state.keys()):
38
- model.load_state_dict(state)
39
  else:
40
- model.load_state_dict({f"module.{k}": v for k, v in state.items()}, strict=False)
 
41
  model.eval()
42
- _LWM_MODEL = model
43
- return _LWM_MODEL
44
  except Exception as e:
45
- print(f"[WARN] Failed to load LWM encoder: {e}", flush=True)
46
  return None
47
 
48
 
@@ -52,36 +118,62 @@ def build_lwm_embeddings(
52
  datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
53
  n_per_dataset: int,
54
  label_aware: bool
55
- ):
56
  """
57
- Minimal: flatten inputs and pass through model if it accepts tensors directly.
58
- If the repo expects a tokenizer or different forward, adapt here.
 
 
 
59
  """
60
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
- model = model.to(device)
62
-
63
- embs = []
64
  labels_per_ds = [] if label_aware else None
65
 
66
- for ch, y, _ in datasets:
67
- N = ch.shape[0]
68
- n = min(int(n_per_dataset), int(N))
69
- idx = torch.randperm(N)[:n]
70
- Xi = ch[idx]
71
- Xi = Xi.reshape(n, -1)
72
- if torch.is_complex(Xi):
73
- Xi = torch.cat([Xi.real, Xi.imag], dim=-1)
74
- Xi = Xi.to(torch.float32).to(device)
75
- # naive forward: assume model(X)->[n, d]
76
- try:
77
- Zi = model(Xi)
78
- except Exception:
79
- # fallback: identity if forward signature differs
80
- Zi = Xi
81
- Zi = Zi.detach().to("cpu")
82
- embs.append(Zi)
83
- if label_aware:
84
- labels_per_ds.append(y[idx].to(torch.long) if (y is not None and len(y) >= n) else torch.empty((0,), dtype=torch.long))
85
-
86
- embs = torch.stack(embs, dim=0) # [D, n, d]
87
- return embs, labels_per_ds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
+ from typing import List, Optional, Tuple
4
 
5
  import torch
 
6
 
7
+ def _log(msg: str):
8
+ print(msg, flush=True)
9
 
10
+ def _candidate_repo_dirs():
11
+ return [
12
+ os.getenv("LWM_REPO_DIR", "").strip(),
13
+ "./LWM-v1.1",
14
+ "/home/user/app/LWM-v1.1",
15
+ ]
16
 
17
+ def _ensure_repo_on_path() -> Optional[str]:
18
+ for d in _candidate_repo_dirs():
19
+ if d and os.path.isdir(d):
20
+ if d not in sys.path:
21
+ sys.path.insert(0, d)
22
+ return d
23
+ return None
24
+
25
+ def _ensure_pretrained_model_shim(repo_dir: str) -> None:
26
+ """
27
+ Some LWM examples import: `from pretrained_model import lwm`
28
+ If the repo doesn't ship `pretrained_model.py`, but has `lwm_model.py` with class `LWM`,
29
+ we create a tiny shim so imports succeed.
30
+ """
31
+ shim_path = os.path.join(repo_dir, "pretrained_model.py")
32
+ lwm_path = os.path.join(repo_dir, "lwm_model.py")
33
+ if os.path.isfile(shim_path):
34
+ return
35
+ if not os.path.isfile(lwm_path):
36
+ return # nothing we can do
37
+
38
+ # Create a simple factory around LWM
39
+ shim_code = """# Auto-generated shim to satisfy `from pretrained_model import lwm`
40
+ import torch
41
+ try:
42
+ from lwm_model import LWM
43
+ except Exception as e:
44
+ raise ImportError(f"Shim could not import LWM from lwm_model.py: {e}")
45
+
46
+ def lwm():
47
+ # Build a default LWM encoder (adjust constructor args if your repo requires them)
48
+ return LWM()
49
+ """
50
+ try:
51
+ with open(shim_path, "w", encoding="utf-8") as f:
52
+ f.write(shim_code)
53
+ _log(f"[INFO] Created shim: {shim_path}")
54
+ except Exception as e:
55
+ _log(f"[WARN] Could not create pretrained_model shim: {e}")
56
+
57
+ def _maybe_load_weights(model, repo_dir: str):
58
+ # Try common weight locations
59
+ candidates = [
60
+ os.path.join(repo_dir, "models", "model.pth"),
61
+ os.path.join(repo_dir, "model.pth"),
62
+ ]
63
+ for w in candidates:
64
+ if os.path.isfile(w):
65
+ try:
66
+ sd = torch.load(w, map_location="cpu")
67
+ # Sometimes saved as {'model': state_dict}
68
+ if isinstance(sd, dict) and "state_dict" in sd:
69
+ sd = sd["state_dict"]
70
+ elif isinstance(sd, dict) and "model" in sd:
71
+ sd = sd["model"]
72
+ model.load_state_dict(sd, strict=False)
73
+ _log(f"[INFO] Loaded LWM weights from {w}")
74
+ return
75
+ except Exception as e:
76
+ _log(f"[WARN] Failed to load weights from {w}: {e}")
77
+ _log("[WARN] No weights file found; using randomly-initialized LWM.")
78
 
79
  def get_lwm_encoder():
80
  """
81
+ Try to build an LWM encoder using the cloned repo.
82
+ Returns a torch.nn.Module or None.
83
  """
84
+ repo_dir = _ensure_repo_on_path()
85
+ if not repo_dir:
86
+ _log("[WARN] LWM repo not found; set LWM_REPO_DIR or clone to ./LWM-v1.1")
87
+ return None
88
+
89
+ # If the repo's modules expect `pretrained_model`, make sure it exists
90
+ _ensure_pretrained_model_shim(repo_dir)
91
+
92
+ # Try the most common entry point used in examples
93
  try:
94
+ # Import order: prefer pretrained_model.lwm() if available
95
+ import pretrained_model # type: ignore
96
+ if hasattr(pretrained_model, "lwm"):
97
+ model = pretrained_model.lwm()
98
+ else:
99
+ # Fallback: try lwm_model directly
100
+ import lwm_model # type: ignore
101
+ if hasattr(lwm_model, "LWM"):
102
+ model = lwm_model.LWM()
103
+ elif hasattr(lwm_model, "build_model"):
104
+ model = lwm_model.build_model()
 
 
 
 
 
 
105
  else:
106
+ raise ImportError("No LWM builder found in lwm_model or pretrained_model")
107
+ _maybe_load_weights(model, repo_dir)
108
  model.eval()
109
+ return model
 
110
  except Exception as e:
111
+ _log(f"[WARN] Failed to load LWM encoder: {e}")
112
  return None
113
 
114
 
 
118
  datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
119
  n_per_dataset: int,
120
  label_aware: bool
121
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
122
  """
123
+ Generic embedding builder:
124
+ - Flattens each complex channel (concat real/imag),
125
+ - Forwards through the model if it accepts a flat vector,
126
+ - Pads to a common embedding dim.
127
+ If forward fails, falls back to the raw flattened vector.
128
  """
129
+ all_feats = []
 
 
 
130
  labels_per_ds = [] if label_aware else None
131
 
132
+ try:
133
+ device = next(model.parameters()).device
134
+ except StopIteration:
135
+ device = torch.device("cpu")
136
+ model = model.to(device).eval()
137
+
138
+ for chs, y, _name in datasets:
139
+ n = min(int(n_per_dataset), int(chs.shape[0]))
140
+ idx = torch.randperm(chs.shape[0])[:n]
141
+ sub = chs[idx]
142
+ feats_this = []
143
+
144
+ for x in sub:
145
+ if x.ndim > 2:
146
+ x = x.squeeze(0)
147
+ vec = x.reshape(-1)
148
+ if torch.is_complex(vec):
149
+ vec = torch.cat([vec.real, vec.imag], dim=0)
150
+ vec = vec.to(torch.float32).unsqueeze(0).to(device) # [1, d]
151
+
152
+ try:
153
+ out = model(vec) # adapt here if your model expects another shape
154
+ out = out.reshape(1, -1).detach().cpu()
155
+ except Exception:
156
+ # If the model forward signature mismatches, use the raw vector
157
+ out = vec.detach().cpu()
158
+
159
+ feats_this.append(out)
160
+
161
+ embs_this = torch.cat(feats_this, dim=0) # [n, d’]
162
+ all_feats.append(embs_this)
163
+
164
+ if label_aware and y is not None and y.numel() > 0:
165
+ labels_per_ds.append(y[idx].clone())
166
+
167
+ # Pad to common dim
168
+ max_d = max(t.shape[1] for t in all_feats)
169
+ padded = []
170
+ for t in all_feats:
171
+ if t.shape[1] < max_d:
172
+ pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype)
173
+ t = torch.cat([t, pad], dim=1)
174
+ padded.append(t)
175
+
176
+ embs = torch.stack(padded, dim=0) # [D, n, d]
177
+ if label_aware:
178
+ return embs, labels_per_ds if labels_per_ds is not None else []
179
+ return embs, None