Spaces:
Running
Running
File size: 5,707 Bytes
f439c65 9e6b74f e00eaca 9e6b74f e00eaca 9e6b74f e00eaca 9e6b74f e00eaca 7589a7e e00eaca 7589a7e e00eaca 9e6b74f e00eaca 9e6b74f e00eaca 9e6b74f e00eaca f439c65 9e6b74f e00eaca f439c65 e00eaca f439c65 9e6b74f e00eaca 9e6b74f e00eaca f439c65 e00eaca f439c65 e00eaca 9e6b74f e00eaca f439c65 e00eaca f439c65 e00eaca 9e6b74f e00eaca 7589a7e e00eaca f439c65 e00eaca 7589a7e e00eaca 7589a7e e00eaca 7589a7e e00eaca 7589a7e f439c65 e00eaca 7589a7e e00eaca 7589a7e e00eaca 7589a7e e00eaca 7589a7e e00eaca 7589a7e e00eaca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# embed_lwm.py
import os
import sys
from typing import List, Tuple, Optional
import torch
from huggingface_hub import snapshot_download
_LWM_MODEL = None
_LWM_DIR = None
def _add_repo_to_path(path: str):
if path and os.path.isdir(path) and path not in sys.path:
sys.path.insert(0, path)
def _load_state_dict_flex(model: torch.nn.Module, state):
"""
Load a variety of saved formats into `model`:
- plain state_dict
- {"model": state_dict}
- with or without "module." prefixes
"""
def _try(sd, strict=False):
try:
model.load_state_dict(sd, strict=strict)
return True
except Exception:
return False
# direct state dict?
if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and any(
torch.is_tensor(v) for v in state.values()
):
sd = state
elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
sd = state["model"]
else:
raise ValueError("Unrecognized checkpoint format.")
# Try as-is
if _try(sd, strict=False):
return
# Try to add "module." prefix
if not any(k.startswith("module.") for k in sd.keys()):
sd_mod = {f"module.{k}": v for k, v in sd.items()}
if _try(sd_mod, strict=False):
return
# Try to strip "module." prefix
sd_strip = {k.replace("module.", "", 1): v for k, v in sd.items()}
if _try(sd_strip, strict=False):
return
# last resort strict=False on original again
model.load_state_dict(sd, strict=False)
def get_lwm_encoder():
"""
Download & load wi-lab/lwm-v1.1 and create the encoder from lwm_model.py.
Returns a torch.nn.Module or None on failure.
"""
global _LWM_MODEL, _LWM_DIR
if _LWM_MODEL is not None:
return _LWM_MODEL
try:
_LWM_DIR = snapshot_download(
repo_id="wi-lab/lwm-v1.1",
local_dir="./LWM-v1.1",
local_dir_use_symlinks=False,
)
_add_repo_to_path(_LWM_DIR)
# Import builder from the HF repo (it's named lwm_model.py)
from lwm_model import lwm # type: ignore
model = lwm()
# Load checkpoint from models/model.pth
ckpt_path = os.path.join(_LWM_DIR, "models", "model.pth")
if os.path.isfile(ckpt_path):
state = torch.load(ckpt_path, map_location="cpu")
_load_state_dict_flex(model, state)
model.eval()
_LWM_MODEL = model
return _LWM_MODEL
except Exception as e:
print(f"[WARN] Failed to load LWM encoder: {e}", flush=True)
return None
@torch.no_grad()
def build_lwm_embeddings(
model: torch.nn.Module,
datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
n_per_dataset: int,
label_aware: bool
):
"""
Build embeddings with the LWM encoder.
Strategy:
1) Try repo's tokenizer if available (utils.tokenizer), feed to model.
2) Else try feeding flattened real vectors to the model.
3) If forward fails, fall back to using flattened vectors as embeddings.
Returns:
embs: [D, n, d]
labels_per_ds: Optional[List[Tensor]]
"""
# Try optional tokenizer
try:
from utils import tokenizer as lwm_tokenizer # type: ignore
except Exception:
lwm_tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
all_embs = []
labels_per_ds = [] if label_aware else None
for ch, y, _name in datasets:
N = int(ch.shape[0])
n = min(int(n_per_dataset), N)
idx = torch.randperm(N)[:n]
Xi = ch[idx]
feats = []
for x in Xi:
x2 = x
if x2.ndim > 2:
x2 = x2.squeeze(0)
# 1) tokenizer path
if lwm_tokenizer is not None:
try:
tok = lwm_tokenizer(x2)
tok = tok.to(device)
out = model(tok)
out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
feats.append(out)
continue
except Exception:
pass
# 2) flattened forward path
try:
vec = x2.reshape(-1)
if torch.is_complex(vec):
vec = torch.cat([vec.real, vec.imag], dim=0)
vec = vec.to(torch.float32).unsqueeze(0).to(device)
out = model(vec)
out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
feats.append(out)
continue
except Exception:
pass
# 3) fallback: use flattened vector directly
vec = x2.reshape(-1)
if torch.is_complex(vec):
vec = torch.cat([vec.real, vec.imag], dim=0)
vec = vec.to(torch.float32).unsqueeze(0).cpu()
feats.append(vec)
Zi = torch.cat(feats, dim=0) # [n, d]
all_embs.append(Zi)
if label_aware:
if y is not None and len(y) >= n:
labels_per_ds.append(y[idx].clone())
else:
labels_per_ds.append(torch.empty((0,), dtype=torch.long))
# Pad to common dim
max_d = max(t.shape[1] for t in all_embs)
padded = []
for t in all_embs:
if t.shape[1] < max_d:
pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype)
t = torch.cat([t, pad], dim=1)
padded.append(t)
embs = torch.stack(padded, dim=0) # [D, n, d]
return embs, labels_per_ds
|