# embed_lwm.py import os import sys from typing import List, Tuple, Optional import torch from huggingface_hub import snapshot_download _LWM_MODEL = None _LWM_DIR = None def _add_repo_to_path(path: str): if path and os.path.isdir(path) and path not in sys.path: sys.path.insert(0, path) def _load_state_dict_flex(model: torch.nn.Module, state): """ Load a variety of saved formats into `model`: - plain state_dict - {"model": state_dict} - with or without "module." prefixes """ def _try(sd, strict=False): try: model.load_state_dict(sd, strict=strict) return True except Exception: return False # direct state dict? if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and any( torch.is_tensor(v) for v in state.values() ): sd = state elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict): sd = state["model"] else: raise ValueError("Unrecognized checkpoint format.") # Try as-is if _try(sd, strict=False): return # Try to add "module." prefix if not any(k.startswith("module.") for k in sd.keys()): sd_mod = {f"module.{k}": v for k, v in sd.items()} if _try(sd_mod, strict=False): return # Try to strip "module." prefix sd_strip = {k.replace("module.", "", 1): v for k, v in sd.items()} if _try(sd_strip, strict=False): return # last resort strict=False on original again model.load_state_dict(sd, strict=False) def get_lwm_encoder(): """ Download & load wi-lab/lwm-v1.1 and create the encoder from lwm_model.py. Returns a torch.nn.Module or None on failure. """ global _LWM_MODEL, _LWM_DIR if _LWM_MODEL is not None: return _LWM_MODEL try: _LWM_DIR = snapshot_download( repo_id="wi-lab/lwm-v1.1", local_dir="./LWM-v1.1", local_dir_use_symlinks=False, ) _add_repo_to_path(_LWM_DIR) # Import builder from the HF repo (it's named lwm_model.py) from lwm_model import lwm # type: ignore model = lwm() # Load checkpoint from models/model.pth ckpt_path = os.path.join(_LWM_DIR, "models", "model.pth") if os.path.isfile(ckpt_path): state = torch.load(ckpt_path, map_location="cpu") _load_state_dict_flex(model, state) model.eval() _LWM_MODEL = model return _LWM_MODEL except Exception as e: print(f"[WARN] Failed to load LWM encoder: {e}", flush=True) return None @torch.no_grad() def build_lwm_embeddings( model: torch.nn.Module, datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], n_per_dataset: int, label_aware: bool ): """ Build embeddings with the LWM encoder. Strategy: 1) Try repo's tokenizer if available (utils.tokenizer), feed to model. 2) Else try feeding flattened real vectors to the model. 3) If forward fails, fall back to using flattened vectors as embeddings. Returns: embs: [D, n, d] labels_per_ds: Optional[List[Tensor]] """ # Try optional tokenizer try: from utils import tokenizer as lwm_tokenizer # type: ignore except Exception: lwm_tokenizer = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device).eval() all_embs = [] labels_per_ds = [] if label_aware else None for ch, y, _name in datasets: N = int(ch.shape[0]) n = min(int(n_per_dataset), N) idx = torch.randperm(N)[:n] Xi = ch[idx] feats = [] for x in Xi: x2 = x if x2.ndim > 2: x2 = x2.squeeze(0) # 1) tokenizer path if lwm_tokenizer is not None: try: tok = lwm_tokenizer(x2) tok = tok.to(device) out = model(tok) out = torch.as_tensor(out).reshape(1, -1).detach().cpu() feats.append(out) continue except Exception: pass # 2) flattened forward path try: vec = x2.reshape(-1) if torch.is_complex(vec): vec = torch.cat([vec.real, vec.imag], dim=0) vec = vec.to(torch.float32).unsqueeze(0).to(device) out = model(vec) out = torch.as_tensor(out).reshape(1, -1).detach().cpu() feats.append(out) continue except Exception: pass # 3) fallback: use flattened vector directly vec = x2.reshape(-1) if torch.is_complex(vec): vec = torch.cat([vec.real, vec.imag], dim=0) vec = vec.to(torch.float32).unsqueeze(0).cpu() feats.append(vec) Zi = torch.cat(feats, dim=0) # [n, d] all_embs.append(Zi) if label_aware: if y is not None and len(y) >= n: labels_per_ds.append(y[idx].clone()) else: labels_per_ds.append(torch.empty((0,), dtype=torch.long)) # Pad to common dim max_d = max(t.shape[1] for t in all_embs) padded = [] for t in all_embs: if t.shape[1] < max_d: pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype) t = torch.cat([t, pad], dim=1) padded.append(t) embs = torch.stack(padded, dim=0) # [D, n, d] return embs, labels_per_ds