File size: 5,707 Bytes
f439c65
9e6b74f
 
e00eaca
9e6b74f
 
e00eaca
9e6b74f
e00eaca
 
9e6b74f
e00eaca
 
 
 
 
 
 
7589a7e
e00eaca
 
 
 
7589a7e
e00eaca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6b74f
 
 
e00eaca
 
9e6b74f
e00eaca
 
 
9e6b74f
e00eaca
 
 
 
 
 
 
 
 
f439c65
9e6b74f
e00eaca
 
 
 
 
f439c65
e00eaca
 
 
 
 
 
f439c65
9e6b74f
 
 
 
 
 
 
e00eaca
9e6b74f
e00eaca
f439c65
e00eaca
 
 
f439c65
 
 
e00eaca
9e6b74f
e00eaca
f439c65
 
 
e00eaca
f439c65
e00eaca
 
9e6b74f
e00eaca
 
7589a7e
e00eaca
 
 
 
 
f439c65
e00eaca
 
 
 
 
7589a7e
e00eaca
 
 
 
 
 
 
 
 
 
 
7589a7e
e00eaca
7589a7e
e00eaca
 
 
 
 
 
 
 
7589a7e
f439c65
 
e00eaca
 
 
 
 
 
7589a7e
e00eaca
 
7589a7e
e00eaca
 
 
 
 
7589a7e
e00eaca
 
7589a7e
e00eaca
7589a7e
 
 
 
 
e00eaca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# embed_lwm.py
import os
import sys
from typing import List, Tuple, Optional

import torch
from huggingface_hub import snapshot_download

_LWM_MODEL = None
_LWM_DIR = None


def _add_repo_to_path(path: str):
    if path and os.path.isdir(path) and path not in sys.path:
        sys.path.insert(0, path)


def _load_state_dict_flex(model: torch.nn.Module, state):
    """
    Load a variety of saved formats into `model`:
      - plain state_dict
      - {"model": state_dict}
      - with or without "module." prefixes
    """
    def _try(sd, strict=False):
        try:
            model.load_state_dict(sd, strict=strict)
            return True
        except Exception:
            return False

    # direct state dict?
    if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and any(
        torch.is_tensor(v) for v in state.values()
    ):
        sd = state
    elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
        sd = state["model"]
    else:
        raise ValueError("Unrecognized checkpoint format.")

    # Try as-is
    if _try(sd, strict=False):
        return

    # Try to add "module." prefix
    if not any(k.startswith("module.") for k in sd.keys()):
        sd_mod = {f"module.{k}": v for k, v in sd.items()}
        if _try(sd_mod, strict=False):
            return

    # Try to strip "module." prefix
    sd_strip = {k.replace("module.", "", 1): v for k, v in sd.items()}
    if _try(sd_strip, strict=False):
        return

    # last resort strict=False on original again
    model.load_state_dict(sd, strict=False)


def get_lwm_encoder():
    """
    Download & load wi-lab/lwm-v1.1 and create the encoder from lwm_model.py.
    Returns a torch.nn.Module or None on failure.
    """
    global _LWM_MODEL, _LWM_DIR
    if _LWM_MODEL is not None:
        return _LWM_MODEL
    try:
        _LWM_DIR = snapshot_download(
            repo_id="wi-lab/lwm-v1.1",
            local_dir="./LWM-v1.1",
            local_dir_use_symlinks=False,
        )
        _add_repo_to_path(_LWM_DIR)

        # Import builder from the HF repo (it's named lwm_model.py)
        from lwm_model import lwm  # type: ignore
        model = lwm()

        # Load checkpoint from models/model.pth
        ckpt_path = os.path.join(_LWM_DIR, "models", "model.pth")
        if os.path.isfile(ckpt_path):
            state = torch.load(ckpt_path, map_location="cpu")
            _load_state_dict_flex(model, state)

        model.eval()
        _LWM_MODEL = model
        return _LWM_MODEL
    except Exception as e:
        print(f"[WARN] Failed to load LWM encoder: {e}", flush=True)
        return None


@torch.no_grad()
def build_lwm_embeddings(
    model: torch.nn.Module,
    datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
    n_per_dataset: int,
    label_aware: bool
):
    """
    Build embeddings with the LWM encoder.
    Strategy:
      1) Try repo's tokenizer if available (utils.tokenizer), feed to model.
      2) Else try feeding flattened real vectors to the model.
      3) If forward fails, fall back to using flattened vectors as embeddings.

    Returns:
      embs: [D, n, d]
      labels_per_ds: Optional[List[Tensor]]
    """
    # Try optional tokenizer
    try:
        from utils import tokenizer as lwm_tokenizer  # type: ignore
    except Exception:
        lwm_tokenizer = None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()

    all_embs = []
    labels_per_ds = [] if label_aware else None

    for ch, y, _name in datasets:
        N = int(ch.shape[0])
        n = min(int(n_per_dataset), N)
        idx = torch.randperm(N)[:n]
        Xi = ch[idx]

        feats = []
        for x in Xi:
            x2 = x
            if x2.ndim > 2:
                x2 = x2.squeeze(0)

            # 1) tokenizer path
            if lwm_tokenizer is not None:
                try:
                    tok = lwm_tokenizer(x2)
                    tok = tok.to(device)
                    out = model(tok)
                    out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
                    feats.append(out)
                    continue
                except Exception:
                    pass

            # 2) flattened forward path
            try:
                vec = x2.reshape(-1)
                if torch.is_complex(vec):
                    vec = torch.cat([vec.real, vec.imag], dim=0)
                vec = vec.to(torch.float32).unsqueeze(0).to(device)
                out = model(vec)
                out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
                feats.append(out)
                continue
            except Exception:
                pass

            # 3) fallback: use flattened vector directly
            vec = x2.reshape(-1)
            if torch.is_complex(vec):
                vec = torch.cat([vec.real, vec.imag], dim=0)
            vec = vec.to(torch.float32).unsqueeze(0).cpu()
            feats.append(vec)

        Zi = torch.cat(feats, dim=0)  # [n, d]
        all_embs.append(Zi)

        if label_aware:
            if y is not None and len(y) >= n:
                labels_per_ds.append(y[idx].clone())
            else:
                labels_per_ds.append(torch.empty((0,), dtype=torch.long))

    # Pad to common dim
    max_d = max(t.shape[1] for t in all_embs)
    padded = []
    for t in all_embs:
        if t.shape[1] < max_d:
            pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype)
            t = torch.cat([t, pad], dim=1)
        padded.append(t)

    embs = torch.stack(padded, dim=0)  # [D, n, d]
    return embs, labels_per_ds