Spaces:
Sleeping
Sleeping
| # embed_lwm.py | |
| import os | |
| import sys | |
| from typing import List, Tuple, Optional | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| _LWM_MODEL = None | |
| _LWM_DIR = None | |
| def _add_repo_to_path(path: str): | |
| if path and os.path.isdir(path) and path not in sys.path: | |
| sys.path.insert(0, path) | |
| def _load_state_dict_flex(model: torch.nn.Module, state): | |
| """ | |
| Load a variety of saved formats into `model`: | |
| - plain state_dict | |
| - {"model": state_dict} | |
| - with or without "module." prefixes | |
| """ | |
| def _try(sd, strict=False): | |
| try: | |
| model.load_state_dict(sd, strict=strict) | |
| return True | |
| except Exception: | |
| return False | |
| # direct state dict? | |
| if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and any( | |
| torch.is_tensor(v) for v in state.values() | |
| ): | |
| sd = state | |
| elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict): | |
| sd = state["model"] | |
| else: | |
| raise ValueError("Unrecognized checkpoint format.") | |
| # Try as-is | |
| if _try(sd, strict=False): | |
| return | |
| # Try to add "module." prefix | |
| if not any(k.startswith("module.") for k in sd.keys()): | |
| sd_mod = {f"module.{k}": v for k, v in sd.items()} | |
| if _try(sd_mod, strict=False): | |
| return | |
| # Try to strip "module." prefix | |
| sd_strip = {k.replace("module.", "", 1): v for k, v in sd.items()} | |
| if _try(sd_strip, strict=False): | |
| return | |
| # last resort strict=False on original again | |
| model.load_state_dict(sd, strict=False) | |
| def get_lwm_encoder(): | |
| """ | |
| Download & load wi-lab/lwm-v1.1 and create the encoder from lwm_model.py. | |
| Returns a torch.nn.Module or None on failure. | |
| """ | |
| global _LWM_MODEL, _LWM_DIR | |
| if _LWM_MODEL is not None: | |
| return _LWM_MODEL | |
| try: | |
| _LWM_DIR = snapshot_download( | |
| repo_id="wi-lab/lwm-v1.1", | |
| local_dir="./LWM-v1.1", | |
| local_dir_use_symlinks=False, | |
| ) | |
| _add_repo_to_path(_LWM_DIR) | |
| # Import builder from the HF repo (it's named lwm_model.py) | |
| from lwm_model import lwm # type: ignore | |
| model = lwm() | |
| # Load checkpoint from models/model.pth | |
| ckpt_path = os.path.join(_LWM_DIR, "models", "model.pth") | |
| if os.path.isfile(ckpt_path): | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| _load_state_dict_flex(model, state) | |
| model.eval() | |
| _LWM_MODEL = model | |
| return _LWM_MODEL | |
| except Exception as e: | |
| print(f"[WARN] Failed to load LWM encoder: {e}", flush=True) | |
| return None | |
| def build_lwm_embeddings( | |
| model: torch.nn.Module, | |
| datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]], | |
| n_per_dataset: int, | |
| label_aware: bool | |
| ): | |
| """ | |
| Build embeddings with the LWM encoder. | |
| Strategy: | |
| 1) Try repo's tokenizer if available (utils.tokenizer), feed to model. | |
| 2) Else try feeding flattened real vectors to the model. | |
| 3) If forward fails, fall back to using flattened vectors as embeddings. | |
| Returns: | |
| embs: [D, n, d] | |
| labels_per_ds: Optional[List[Tensor]] | |
| """ | |
| # Try optional tokenizer | |
| try: | |
| from utils import tokenizer as lwm_tokenizer # type: ignore | |
| except Exception: | |
| lwm_tokenizer = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device).eval() | |
| all_embs = [] | |
| labels_per_ds = [] if label_aware else None | |
| for ch, y, _name in datasets: | |
| N = int(ch.shape[0]) | |
| n = min(int(n_per_dataset), N) | |
| idx = torch.randperm(N)[:n] | |
| Xi = ch[idx] | |
| feats = [] | |
| for x in Xi: | |
| x2 = x | |
| if x2.ndim > 2: | |
| x2 = x2.squeeze(0) | |
| # 1) tokenizer path | |
| if lwm_tokenizer is not None: | |
| try: | |
| tok = lwm_tokenizer(x2) | |
| tok = tok.to(device) | |
| out = model(tok) | |
| out = torch.as_tensor(out).reshape(1, -1).detach().cpu() | |
| feats.append(out) | |
| continue | |
| except Exception: | |
| pass | |
| # 2) flattened forward path | |
| try: | |
| vec = x2.reshape(-1) | |
| if torch.is_complex(vec): | |
| vec = torch.cat([vec.real, vec.imag], dim=0) | |
| vec = vec.to(torch.float32).unsqueeze(0).to(device) | |
| out = model(vec) | |
| out = torch.as_tensor(out).reshape(1, -1).detach().cpu() | |
| feats.append(out) | |
| continue | |
| except Exception: | |
| pass | |
| # 3) fallback: use flattened vector directly | |
| vec = x2.reshape(-1) | |
| if torch.is_complex(vec): | |
| vec = torch.cat([vec.real, vec.imag], dim=0) | |
| vec = vec.to(torch.float32).unsqueeze(0).cpu() | |
| feats.append(vec) | |
| Zi = torch.cat(feats, dim=0) # [n, d] | |
| all_embs.append(Zi) | |
| if label_aware: | |
| if y is not None and len(y) >= n: | |
| labels_per_ds.append(y[idx].clone()) | |
| else: | |
| labels_per_ds.append(torch.empty((0,), dtype=torch.long)) | |
| # Pad to common dim | |
| max_d = max(t.shape[1] for t in all_embs) | |
| padded = [] | |
| for t in all_embs: | |
| if t.shape[1] < max_d: | |
| pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype) | |
| t = torch.cat([t, pad], dim=1) | |
| padded.append(t) | |
| embs = torch.stack(padded, dim=0) # [D, n, d] | |
| return embs, labels_per_ds | |