dataset-distancing-lab / embed_lwm.py
wi-lab's picture
Update embed_lwm.py
e00eaca
# embed_lwm.py
import os
import sys
from typing import List, Tuple, Optional
import torch
from huggingface_hub import snapshot_download
_LWM_MODEL = None
_LWM_DIR = None
def _add_repo_to_path(path: str):
if path and os.path.isdir(path) and path not in sys.path:
sys.path.insert(0, path)
def _load_state_dict_flex(model: torch.nn.Module, state):
"""
Load a variety of saved formats into `model`:
- plain state_dict
- {"model": state_dict}
- with or without "module." prefixes
"""
def _try(sd, strict=False):
try:
model.load_state_dict(sd, strict=strict)
return True
except Exception:
return False
# direct state dict?
if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and any(
torch.is_tensor(v) for v in state.values()
):
sd = state
elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
sd = state["model"]
else:
raise ValueError("Unrecognized checkpoint format.")
# Try as-is
if _try(sd, strict=False):
return
# Try to add "module." prefix
if not any(k.startswith("module.") for k in sd.keys()):
sd_mod = {f"module.{k}": v for k, v in sd.items()}
if _try(sd_mod, strict=False):
return
# Try to strip "module." prefix
sd_strip = {k.replace("module.", "", 1): v for k, v in sd.items()}
if _try(sd_strip, strict=False):
return
# last resort strict=False on original again
model.load_state_dict(sd, strict=False)
def get_lwm_encoder():
"""
Download & load wi-lab/lwm-v1.1 and create the encoder from lwm_model.py.
Returns a torch.nn.Module or None on failure.
"""
global _LWM_MODEL, _LWM_DIR
if _LWM_MODEL is not None:
return _LWM_MODEL
try:
_LWM_DIR = snapshot_download(
repo_id="wi-lab/lwm-v1.1",
local_dir="./LWM-v1.1",
local_dir_use_symlinks=False,
)
_add_repo_to_path(_LWM_DIR)
# Import builder from the HF repo (it's named lwm_model.py)
from lwm_model import lwm # type: ignore
model = lwm()
# Load checkpoint from models/model.pth
ckpt_path = os.path.join(_LWM_DIR, "models", "model.pth")
if os.path.isfile(ckpt_path):
state = torch.load(ckpt_path, map_location="cpu")
_load_state_dict_flex(model, state)
model.eval()
_LWM_MODEL = model
return _LWM_MODEL
except Exception as e:
print(f"[WARN] Failed to load LWM encoder: {e}", flush=True)
return None
@torch.no_grad()
def build_lwm_embeddings(
model: torch.nn.Module,
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
n_per_dataset: int,
label_aware: bool
):
"""
Build embeddings with the LWM encoder.
Strategy:
1) Try repo's tokenizer if available (utils.tokenizer), feed to model.
2) Else try feeding flattened real vectors to the model.
3) If forward fails, fall back to using flattened vectors as embeddings.
Returns:
embs: [D, n, d]
labels_per_ds: Optional[List[Tensor]]
"""
# Try optional tokenizer
try:
from utils import tokenizer as lwm_tokenizer # type: ignore
except Exception:
lwm_tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
all_embs = []
labels_per_ds = [] if label_aware else None
for ch, y, _name in datasets:
N = int(ch.shape[0])
n = min(int(n_per_dataset), N)
idx = torch.randperm(N)[:n]
Xi = ch[idx]
feats = []
for x in Xi:
x2 = x
if x2.ndim > 2:
x2 = x2.squeeze(0)
# 1) tokenizer path
if lwm_tokenizer is not None:
try:
tok = lwm_tokenizer(x2)
tok = tok.to(device)
out = model(tok)
out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
feats.append(out)
continue
except Exception:
pass
# 2) flattened forward path
try:
vec = x2.reshape(-1)
if torch.is_complex(vec):
vec = torch.cat([vec.real, vec.imag], dim=0)
vec = vec.to(torch.float32).unsqueeze(0).to(device)
out = model(vec)
out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
feats.append(out)
continue
except Exception:
pass
# 3) fallback: use flattened vector directly
vec = x2.reshape(-1)
if torch.is_complex(vec):
vec = torch.cat([vec.real, vec.imag], dim=0)
vec = vec.to(torch.float32).unsqueeze(0).cpu()
feats.append(vec)
Zi = torch.cat(feats, dim=0) # [n, d]
all_embs.append(Zi)
if label_aware:
if y is not None and len(y) >= n:
labels_per_ds.append(y[idx].clone())
else:
labels_per_ds.append(torch.empty((0,), dtype=torch.long))
# Pad to common dim
max_d = max(t.shape[1] for t in all_embs)
padded = []
for t in all_embs:
if t.shape[1] < max_d:
pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype)
t = torch.cat([t, pad], dim=1)
padded.append(t)
embs = torch.stack(padded, dim=0) # [D, n, d]
return embs, labels_per_ds