herrscher0 commited on
Commit
ebc7f2e
·
verified ·
1 Parent(s): 3419384

Initial commit: FloodDiffusion text-to-motion generation model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - text-to-motion
5
+ - motion-generation
6
+ - diffusion-forcing
7
+ - humanml3d
8
+ - computer-animation
9
+ library_name: transformers
10
+ pipeline_tag: other
11
+ ---
12
+
13
+ # FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation
14
+
15
+ <div align="center">
16
+
17
+ **A state-of-the-art text-to-motion generation model based on Latent Diffusion Forcing**
18
+
19
+ [Paper]() | [Project Page]() | [Demo]()
20
+
21
+ </div>
22
+
23
+ ## Overview
24
+
25
+ We present FloodDiffusion, a new framework for text-driven, streaming human motion generation. Given time-varying text prompts, FloodDiffusion generates text-aligned, seamless motion sequences with real-time latency.
26
+
27
+ ## Model Architecture
28
+
29
+ The model consists of three main components:
30
+
31
+ 1. **Text Encoder**: UMT5-XXL encoder for text feature extraction
32
+ 2. **Latent Diffusion Model**: Transformer-based diffusion model operating in latent space
33
+ 3. **VAE Decoder**: 1D convolutional VAE for decoding latent features to motion sequences
34
+
35
+ **Technical Specifications:**
36
+ - Input: Natural language text
37
+ - Output: Motion sequences in two formats:
38
+ - 263-dimensional HumanML3D features (default)
39
+ - 22×3 joint coordinates (optional)
40
+ - Latent dimension: 4
41
+ - Upsampling factor: 4× (VAE decoder)
42
+ - Frame rate: 20 FPS
43
+
44
+ ## Installation
45
+
46
+ ### Prerequisites
47
+
48
+ - Python 3.8+
49
+ - CUDA-capable GPU with 16GB+ VRAM (recommended)
50
+ - 16GB+ system RAM
51
+
52
+ ### Dependencies
53
+
54
+ **Step 1: Install basic dependencies**
55
+
56
+ ```bash
57
+ pip install torch transformers huggingface_hub
58
+ pip install lightning diffusers omegaconf ftfy numpy
59
+ ```
60
+
61
+ **Step 2: Install Flash Attention (Required)**
62
+
63
+ Flash attention requires CUDA and may need compilation. Choose the appropriate method:
64
+
65
+ ```bash
66
+ pip install flash-attn --no-build-isolation
67
+ ```
68
+
69
+ **Note:** Flash attention is **required** for this model. If installation fails, please refer to the [official flash-attention installation guide](https://github.com/Dao-AILab/flash-attention#installation-and-features).
70
+
71
+ ## Quick Start
72
+
73
+ ### Basic Usage
74
+
75
+ ```python
76
+ from transformers import AutoModel
77
+
78
+ # Load model
79
+ model = AutoModel.from_pretrained(
80
+ "ShandaAI/FloodDiffusion",
81
+ trust_remote_code=True
82
+ )
83
+
84
+ # Generate motion from text (263-dim HumanML3D features)
85
+ motion = model("a person walking forward", length=60)
86
+ print(f"Generated motion: {motion.shape}") # (~240, 263)
87
+
88
+ # Generate motion as joint coordinates (22 joints × 3 coords)
89
+ motion_joints = model("a person walking forward", length=60, output_joints=True)
90
+ print(f"Generated joints: {motion_joints.shape}") # (~240, 22, 3)
91
+ ```
92
+
93
+ ### Batch Generation
94
+
95
+ ```python
96
+ # Generate multiple motions efficiently
97
+ texts = [
98
+ "a person walking forward",
99
+ "a person running quickly",
100
+ "a person jumping up and down"
101
+ ]
102
+ lengths = [60, 50, 40] # Different lengths for each motion
103
+
104
+ motions = model(texts, length=lengths)
105
+
106
+ for i, motion in enumerate(motions):
107
+ print(f"Motion {i}: {motion.shape}")
108
+ ```
109
+
110
+ ### Multi-Text Motion Transitions
111
+
112
+ ```python
113
+ # Generate a motion sequence with smooth transitions between actions
114
+ motion = model(
115
+ text=[["walk forward", "turn around", "run back"]],
116
+ length=[120],
117
+ text_end=[[40, 80, 120]] # Transition points in latent tokens
118
+ )
119
+
120
+ # Output: ~480 frames showing all three actions smoothly connected
121
+ print(f"Transition motion: {motion[0].shape}")
122
+ ```
123
+
124
+ ## API Reference
125
+
126
+ ### `model(text, length=60, text_end=None, num_denoise_steps=None, output_joints=False)`
127
+
128
+ Generate motion sequences from text descriptions.
129
+
130
+ **Parameters:**
131
+
132
+ - **text** (`str`, `List[str]`, or `List[List[str]]`): Text description(s)
133
+ - Single string: Generate one motion
134
+ - List of strings: Batch generation
135
+ - Nested list: Multiple text prompts per motion (for transitions)
136
+
137
+ - **length** (`int` or `List[int]`, default=60): Number of latent tokens to generate
138
+ - Output frames ≈ `length × 4` (due to VAE upsampling)
139
+ - Example: `length=60` → ~240 frames (~12 seconds at 20 FPS)
140
+
141
+ - **text_end** (`List[int]` or `List[List[int]]`, optional): Latent token positions for text transitions
142
+ - Only used when `text` is a nested list
143
+ - Specifies when to switch between different text descriptions
144
+ - **IMPORTANT**: Must have the same length as the corresponding text list
145
+ - Example: `text=[["walk", "turn", "sit"]]` requires `text_end=[[20, 40, 60]]` (3 endpoints for 3 texts)
146
+ - Must be in ascending order
147
+
148
+ - **num_denoise_steps** (`int`, optional): Number of denoising iterations
149
+ - Higher values produce better quality but slower generation
150
+ - Recommended range: 10-50
151
+
152
+ - **output_joints** (`bool`, default=False): Output format selector
153
+ - `False`: Returns 263-dimensional HumanML3D features
154
+ - `True`: Returns 22×3 joint coordinates for direct visualization
155
+
156
+ **Returns:**
157
+ - Single motion:
158
+ - `output_joints=False`: `numpy.ndarray` of shape `(frames, 263)`
159
+ - `output_joints=True`: `numpy.ndarray` of shape `(frames, 22, 3)`
160
+ - Batch: `List[numpy.ndarray]` with shapes as above
161
+
162
+ **Example:**
163
+ ```python
164
+ # Single generation (263-dim features)
165
+ motion = model("walk forward", length=60) # Returns (240, 263)
166
+
167
+ # Single generation (joint coordinates)
168
+ joints = model("walk forward", length=60, output_joints=True) # Returns (240, 22, 3)
169
+
170
+ # Batch generation
171
+ motions = model(["walk", "run"], length=[60, 50]) # Returns list of 2 arrays
172
+
173
+ # Multi-text transitions
174
+ motion = model(
175
+ [["walk", "turn"]],
176
+ length=[60],
177
+ text_end=[[30, 60]]
178
+ ) # Returns list with 1 array of shape (240, 263)
179
+ ```
180
+
181
+ ## Citation
182
+
183
+ If you use this model in your research, please cite:
184
+
185
+ ```bibtex
186
+ @article{flood2025,
187
+ title={FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation},
188
+ author={YIYI CAI, Yuhan Wu, Kunhang Li, YOU ZHOU, Bo Zheng, Haiyang Liu},
189
+ year={2025}
190
+ }
191
+ ```
192
+
193
+ ## Troubleshooting
194
+
195
+ ### Common Issues
196
+
197
+ **ImportError with trust_remote_code:**
198
+ ```python
199
+ # Solution: Add trust_remote_code=True
200
+ model = AutoModel.from_pretrained(
201
+ "ShandaAI/FloodDiffusion",
202
+ trust_remote_code=True # Required!
203
+ )
204
+ ```
205
+
206
+ **Out of Memory:**
207
+ ```python
208
+ # Solution: Generate shorter sequences
209
+ motion = model("walk", length=30) # Shorter = less memory
210
+ ```
211
+
212
+ **Slow first load:**
213
+ The first load downloads ~14GB of model files and may take 5-30 minutes depending on internet speed. Subsequent loads use cached files and are instant.
214
+
215
+ **Module import errors:**
216
+ Ensure all dependencies are installed:
217
+ ```bash
218
+ pip install lightning diffusers omegaconf ftfy numpy
219
+ ```
__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FloodDiffusion - Text-to-Motion Generation
3
+
4
+ Usage:
5
+ from transformers import AutoModel
6
+
7
+ model = AutoModel.from_pretrained("your-username/FloodDiffusion", trust_remote_code=True)
8
+ motion = model("a person walking forward", length=60)
9
+ """
10
+
11
+ __version__ = "1.0.0"
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["LDFModel"],
3
+ "model_type": "ldf_motion",
4
+ "auto_map": {
5
+ "AutoModel": "hf_pipeline.LDFModel",
6
+ "AutoConfig": "hf_pipeline.LDFConfig"
7
+ },
8
+ "torch_dtype": "float32",
9
+ "transformers_version": "4.30.0",
10
+ "license": "mit"
11
+ }
generate_ldf.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ from lightning import seed_everything
6
+ from safetensors.torch import load_file as load_safetensors
7
+
8
+ from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config
9
+
10
+ # Set tokenizers parallelism to false to avoid warnings in multiprocessing
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+
13
+
14
+ def load_model_from_config():
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ torch.set_float32_matmul_precision("high")
17
+ cfg = load_config()
18
+ seed_everything(cfg.seed)
19
+
20
+ # Get the directory containing the config file
21
+ # Try to find config directory from sys.argv or use current directory
22
+ if '--config' in sys.argv:
23
+ config_idx = sys.argv.index('--config') + 1
24
+ config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
25
+ else:
26
+ config_dir = os.getcwd()
27
+
28
+ vae = instantiate(
29
+ target=cfg.test_vae.target,
30
+ cfg=None,
31
+ hfstyle=False,
32
+ **cfg.test_vae.params,
33
+ )
34
+
35
+ # Handle relative paths
36
+ vae_path = cfg.test_vae_ckpt
37
+ if not os.path.isabs(vae_path):
38
+ vae_path = os.path.join(config_dir, vae_path)
39
+
40
+ # Load from safetensors (already contains EMA weights)
41
+ vae_state_dict = load_safetensors(vae_path)
42
+ vae.load_state_dict(vae_state_dict, strict=True)
43
+ print(f"Loaded VAE model from {vae_path}")
44
+
45
+ compare_statedict_and_parameters(
46
+ state_dict=vae.state_dict(),
47
+ named_parameters=vae.named_parameters(),
48
+ named_buffers=vae.named_buffers(),
49
+ )
50
+ vae.to(device)
51
+ vae.eval()
52
+
53
+ # Model - fix relative paths in model params
54
+ model_params = dict(cfg.model.params)
55
+ # Convert relative paths to absolute paths
56
+ if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
57
+ if not os.path.isabs(model_params['checkpoint_path']):
58
+ model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
59
+ if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
60
+ if not os.path.isabs(model_params['tokenizer_path']):
61
+ model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
62
+
63
+ model = instantiate(
64
+ target=cfg.model.target, cfg=None, hfstyle=False, **model_params
65
+ )
66
+
67
+ # Handle relative paths
68
+ model_path = cfg.test_ckpt
69
+ if not os.path.isabs(model_path):
70
+ model_path = os.path.join(config_dir, model_path)
71
+
72
+ # Load from safetensors (already contains EMA weights)
73
+ model_state_dict = load_safetensors(model_path)
74
+ model.load_state_dict(model_state_dict, strict=True)
75
+ print(f"Loaded model from {model_path}")
76
+
77
+ compare_statedict_and_parameters(
78
+ state_dict=model.state_dict(),
79
+ named_parameters=model.named_parameters(),
80
+ named_buffers=model.named_buffers(),
81
+ )
82
+ model.to(device)
83
+ model.eval()
84
+
85
+ return vae, model
86
+
87
+
88
+ @torch.inference_mode()
89
+ def generate_feature_stream(
90
+ model, feature_length, text, feature_text_end=None, num_denoise_steps=None
91
+ ):
92
+ """
93
+ Streaming interface for feature generation
94
+ Args:
95
+ model: Loaded model
96
+ feature_length: List[int], generation length for each sample
97
+ text: List[str] or List[List[str]], text prompts
98
+ feature_text_end: List[List[int]], time points where text ends (if text is list of list)
99
+ num_denoise_steps: Number of denoising steps
100
+ Yields:
101
+ dict: Contains "generated" (current generated feature segment)
102
+ """
103
+
104
+ # Construct input dict x
105
+ # stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
106
+ x = {"feature_length": torch.tensor(feature_length), "text": text}
107
+
108
+ if feature_text_end is not None:
109
+ x["feature_text_end"] = feature_text_end
110
+
111
+ # Call model's stream_generate
112
+ # Note: stream_generate is a generator
113
+ generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)
114
+
115
+ for step_output in generator:
116
+ # step_output is already a dict with "generated" key
117
+ yield step_output
118
+
119
+
120
+ if __name__ == "__main__":
121
+ import argparse
122
+
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--config", type=str, required=True, help="Path to config")
125
+ parser.add_argument(
126
+ "--text", type=str, default="a person walks forward", help="Text prompt"
127
+ )
128
+ parser.add_argument("--length", type=int, default=120, help="Motion length")
129
+ parser.add_argument(
130
+ "--output", type=str, default="output.mp4", help="Output video path"
131
+ )
132
+ parser.add_argument(
133
+ "--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
134
+ )
135
+ args = parser.parse_args()
136
+
137
+ print("Loading model...")
138
+ vae, model = load_model_from_config()
139
+
hf_pipeline.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LDF Model for Hugging Face Hub
3
+
4
+ Usage:
5
+ from transformers import AutoModel
6
+
7
+ model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
8
+ motion = model("a person walking forward", length=60)
9
+ """
10
+
11
+ import torch
12
+ from transformers import PretrainedConfig, PreTrainedModel
13
+ from typing import Union, List, Optional
14
+ import os
15
+ import sys
16
+
17
+
18
+ class LDFConfig(PretrainedConfig):
19
+ """Configuration for LDF Motion Generation Model"""
20
+ model_type = "ldf_motion"
21
+
22
+ def __init__(
23
+ self,
24
+ input_dim=4,
25
+ output_dim=263,
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+ self.input_dim = input_dim
30
+ self.output_dim = output_dim
31
+
32
+
33
+ class LDFModel(PreTrainedModel):
34
+ """
35
+ LDF Motion Generation Model
36
+
37
+ This model generates motion sequences from text descriptions using Latent Diffusion Forcing.
38
+
39
+ Example:
40
+ >>> from transformers import AutoModel
41
+ >>> model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
42
+ >>> motion = model("a person walking forward", length=60)
43
+ >>> print(motion.shape) # (~240, 263)
44
+ """
45
+
46
+ config_class = LDFConfig
47
+
48
+ def __init__(self, config):
49
+ super().__init__(config)
50
+ self.config = config
51
+
52
+ # Will be loaded in from_pretrained
53
+ self.ldf_model = None
54
+ self.vae = None
55
+ self.model_dir = None # Store model directory for later use
56
+
57
+ def _load_models(self):
58
+ """Load the actual LDF and VAE models"""
59
+ if self.ldf_model is not None:
60
+ return # Already loaded
61
+
62
+ # Get the model directory - should be set by from_pretrained
63
+ if hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path):
64
+ model_dir = self.name_or_path
65
+ else:
66
+ raise RuntimeError(
67
+ "Model directory not found. Please use from_pretrained() to load the model."
68
+ )
69
+
70
+ # Save model_dir for later use (e.g., in output_joints conversion)
71
+ self.model_dir = model_dir
72
+
73
+ # Add model_dir to sys.path for imports
74
+ if model_dir not in sys.path:
75
+ sys.path.insert(0, model_dir)
76
+
77
+ # Use dynamic import to avoid HF's static import checker
78
+ import importlib
79
+ generate_ldf = importlib.import_module('generate_ldf')
80
+ load_model_from_config = generate_ldf.load_model_from_config
81
+
82
+ config_path = os.path.join(model_dir, "ldf.yaml")
83
+ old_argv = sys.argv
84
+ sys.argv = ['model', '--config', config_path]
85
+
86
+ try:
87
+ self.vae, self.ldf_model = load_model_from_config()
88
+
89
+ # Move to correct device
90
+ device = next(self.parameters()).device if list(self.parameters()) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+ self.ldf_model = self.ldf_model.to(device)
92
+ self.vae = self.vae.to(device)
93
+ finally:
94
+ sys.argv = old_argv
95
+
96
+ @classmethod
97
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
98
+ """
99
+ Load pretrained model
100
+
101
+ Args:
102
+ pretrained_model_name_or_path: Model name or path
103
+ trust_remote_code: Must be True to load this custom model
104
+ **kwargs: Additional arguments
105
+
106
+ Returns:
107
+ LDFModel instance
108
+ """
109
+ # Check trust_remote_code
110
+ if not kwargs.get('trust_remote_code', False):
111
+ raise ValueError(
112
+ "Loading this model requires trust_remote_code=True. "
113
+ "Usage: AutoModel.from_pretrained(..., trust_remote_code=True)"
114
+ )
115
+
116
+ # Download if needed
117
+ if not os.path.exists(pretrained_model_name_or_path):
118
+ from huggingface_hub import snapshot_download
119
+ model_path = snapshot_download(repo_id=pretrained_model_name_or_path)
120
+ else:
121
+ model_path = pretrained_model_name_or_path
122
+
123
+ # Load config
124
+ config = LDFConfig.from_pretrained(model_path)
125
+
126
+ # Create model
127
+ model = cls(config)
128
+ model.name_or_path = model_path
129
+
130
+ # Load the actual models
131
+ model._load_models()
132
+
133
+ return model
134
+
135
+ def forward(
136
+ self,
137
+ text: Union[str, List[str], List[List[str]]],
138
+ length: Union[int, List[int]] = 60,
139
+ text_end: Optional[Union[List[int], List[List[int]]]] = None,
140
+ num_denoise_steps: Optional[int] = None,
141
+ **kwargs
142
+ ):
143
+ """
144
+ Generate motion from text
145
+
146
+ Args:
147
+ text: Text description(s)
148
+ length: Number of latent tokens (output frames ≈ length × 4)
149
+ text_end: Transition points for multi-text
150
+ num_denoise_steps: Number of denoising steps
151
+
152
+ Returns:
153
+ Generated motion sequence(s)
154
+ """
155
+ return self.__call__(text, length, text_end, num_denoise_steps)
156
+
157
+ @torch.no_grad()
158
+ def __call__(
159
+ self,
160
+ text: Union[str, List[str], List[List[str]]],
161
+ length: Union[int, List[int]] = 60,
162
+ text_end: Optional[Union[List[int], List[List[int]]]] = None,
163
+ num_denoise_steps: Optional[int] = None,
164
+ output_joints: bool = False
165
+ ):
166
+ """
167
+ Generate motion sequences
168
+
169
+ Args:
170
+ text: Text description
171
+ - Single string: "walk" -> single sample
172
+ - String list: ["walk", "run"] -> batch
173
+ - Nested list: [["walk", "turn"], ["run", "jump"]] -> multi-text per sample
174
+ length: Number of latent tokens (frames ≈ length × 4)
175
+ text_end: Token positions for text switching
176
+ num_denoise_steps: Number of denoising steps
177
+ output_joints: If True, output 22×3 joint coordinates; if False (default), output 263-dim HumanML3D features
178
+
179
+ Returns:
180
+ numpy.ndarray or list of arrays
181
+ - If output_joints=False: shape (frames, 263)
182
+ - If output_joints=True: shape (frames, 22, 3)
183
+ """
184
+ # Ensure models are loaded
185
+ self._load_models()
186
+
187
+ # Normalize inputs
188
+ is_single = not isinstance(length, list)
189
+ if is_single:
190
+ text_batch = [text]
191
+ length_batch = [length]
192
+ text_end_batch = [text_end] if text_end is not None else None
193
+ else:
194
+ text_batch = text
195
+ length_batch = length
196
+ text_end_batch = text_end
197
+
198
+ # Validate text_end alignment with text
199
+ if text_end_batch is not None:
200
+ for i, (txt, te) in enumerate(zip(text_batch, text_end_batch)):
201
+ if isinstance(txt, list) and te is not None:
202
+ if len(txt) != len(te):
203
+ raise ValueError(
204
+ f"Batch {i}: text has {len(txt)} segments but text_end has {len(te)} endpoints. "
205
+ f"They must match! text={txt}, text_end={te}"
206
+ )
207
+
208
+ batch_size = len(text_batch)
209
+
210
+ # Construct input dict for model
211
+ x = {"feature_length": torch.tensor(length_batch), "text": text_batch}
212
+ if text_end_batch is not None:
213
+ x["feature_text_end"] = text_end_batch
214
+
215
+ # Non-streaming generate (following generate_ldf.py 125-139)
216
+ output = self.ldf_model.generate(x, num_denoise_steps=num_denoise_steps)
217
+ generated_batch = output["generated"]
218
+
219
+ # Decode with VAE and optionally convert to joints
220
+ decoded_results = []
221
+ joints_results = [] if output_joints else None
222
+
223
+ for i, generated in enumerate(generated_batch):
224
+ if generated is not None and torch.is_tensor(generated):
225
+ # Decode with VAE (following generate_ldf.py line 130)
226
+ decoded_g = self.vae.decode(generated[None, :])[0]
227
+
228
+ if output_joints:
229
+ # Use the model_dir that was saved during _load_models
230
+ model_dir = self.model_dir
231
+
232
+ # Import convert_motion_to_joints from ldf_utils
233
+ import importlib.util
234
+ import numpy as np
235
+ utils_spec = importlib.util.spec_from_file_location(
236
+ "motion_process",
237
+ os.path.join(model_dir, "ldf_utils", "motion_process.py")
238
+ )
239
+ motion_process_module = importlib.util.module_from_spec(utils_spec)
240
+ utils_spec.loader.exec_module(motion_process_module)
241
+
242
+ # Convert to joints using convert_motion_to_joints
243
+ decoded_np = decoded_g.cpu().numpy()
244
+
245
+ joints = motion_process_module.convert_motion_to_joints(
246
+ decoded_np, dim=263
247
+ )
248
+ joints_results.append(joints)
249
+ else:
250
+ decoded_results.append(decoded_g.cpu().numpy())
251
+ else:
252
+ if output_joints:
253
+ joints_results.append(None)
254
+ else:
255
+ decoded_results.append(None)
256
+
257
+ # Return results
258
+ if output_joints:
259
+ return joints_results[0] if is_single else joints_results
260
+ else:
261
+ return decoded_results[0] if is_single else decoded_results
262
+
263
+ def generate(self, *args, **kwargs):
264
+ """Alias for __call__ to match transformers API"""
265
+ return self.__call__(*args, **kwargs)
266
+
267
+
268
+ # For backwards compatibility
269
+ LDFPipeline = LDFModel
270
+
271
+
272
+ # Register with AutoModel
273
+ try:
274
+ from transformers import AutoModel, AutoConfig
275
+ AutoConfig.register("ldf_motion", LDFConfig)
276
+ AutoModel.register(LDFConfig, LDFModel)
277
+ except:
278
+ pass
ldf.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: ldf
2
+ seed: 1234
3
+ debug: false
4
+ train: false
5
+
6
+ save_dir: ./outputs
7
+ resume_ckpt: null
8
+ test_ckpt: "model.safetensors"
9
+ test_vae_ckpt: "vae.safetensors"
10
+
11
+ test_vae:
12
+ target: ldf_models.vae_wan_1d.VAEWanModel
13
+ ema_decay: 0.99
14
+ params:
15
+ input_dim: 263
16
+ z_dim: 4
17
+
18
+ test_setting:
19
+ render: false
20
+ simple: true
21
+ recover_dim: 263
22
+
23
+ val_repeat: 1
24
+
25
+ model:
26
+ target: ldf_models.diffusion_forcing_wan.DiffForcingWanModel
27
+ ema_decay: 0.99
28
+ params:
29
+ checkpoint_path: "ldf_deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth"
30
+ tokenizer_path: "ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl"
31
+ input_dim: 4
32
+ noise_steps: 10
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/special_tokens_map.json ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>",
103
+ "<extra_id_100>",
104
+ "<extra_id_101>",
105
+ "<extra_id_102>",
106
+ "<extra_id_103>",
107
+ "<extra_id_104>",
108
+ "<extra_id_105>",
109
+ "<extra_id_106>",
110
+ "<extra_id_107>",
111
+ "<extra_id_108>",
112
+ "<extra_id_109>",
113
+ "<extra_id_110>",
114
+ "<extra_id_111>",
115
+ "<extra_id_112>",
116
+ "<extra_id_113>",
117
+ "<extra_id_114>",
118
+ "<extra_id_115>",
119
+ "<extra_id_116>",
120
+ "<extra_id_117>",
121
+ "<extra_id_118>",
122
+ "<extra_id_119>",
123
+ "<extra_id_120>",
124
+ "<extra_id_121>",
125
+ "<extra_id_122>",
126
+ "<extra_id_123>",
127
+ "<extra_id_124>",
128
+ "<extra_id_125>",
129
+ "<extra_id_126>",
130
+ "<extra_id_127>",
131
+ "<extra_id_128>",
132
+ "<extra_id_129>",
133
+ "<extra_id_130>",
134
+ "<extra_id_131>",
135
+ "<extra_id_132>",
136
+ "<extra_id_133>",
137
+ "<extra_id_134>",
138
+ "<extra_id_135>",
139
+ "<extra_id_136>",
140
+ "<extra_id_137>",
141
+ "<extra_id_138>",
142
+ "<extra_id_139>",
143
+ "<extra_id_140>",
144
+ "<extra_id_141>",
145
+ "<extra_id_142>",
146
+ "<extra_id_143>",
147
+ "<extra_id_144>",
148
+ "<extra_id_145>",
149
+ "<extra_id_146>",
150
+ "<extra_id_147>",
151
+ "<extra_id_148>",
152
+ "<extra_id_149>",
153
+ "<extra_id_150>",
154
+ "<extra_id_151>",
155
+ "<extra_id_152>",
156
+ "<extra_id_153>",
157
+ "<extra_id_154>",
158
+ "<extra_id_155>",
159
+ "<extra_id_156>",
160
+ "<extra_id_157>",
161
+ "<extra_id_158>",
162
+ "<extra_id_159>",
163
+ "<extra_id_160>",
164
+ "<extra_id_161>",
165
+ "<extra_id_162>",
166
+ "<extra_id_163>",
167
+ "<extra_id_164>",
168
+ "<extra_id_165>",
169
+ "<extra_id_166>",
170
+ "<extra_id_167>",
171
+ "<extra_id_168>",
172
+ "<extra_id_169>",
173
+ "<extra_id_170>",
174
+ "<extra_id_171>",
175
+ "<extra_id_172>",
176
+ "<extra_id_173>",
177
+ "<extra_id_174>",
178
+ "<extra_id_175>",
179
+ "<extra_id_176>",
180
+ "<extra_id_177>",
181
+ "<extra_id_178>",
182
+ "<extra_id_179>",
183
+ "<extra_id_180>",
184
+ "<extra_id_181>",
185
+ "<extra_id_182>",
186
+ "<extra_id_183>",
187
+ "<extra_id_184>",
188
+ "<extra_id_185>",
189
+ "<extra_id_186>",
190
+ "<extra_id_187>",
191
+ "<extra_id_188>",
192
+ "<extra_id_189>",
193
+ "<extra_id_190>",
194
+ "<extra_id_191>",
195
+ "<extra_id_192>",
196
+ "<extra_id_193>",
197
+ "<extra_id_194>",
198
+ "<extra_id_195>",
199
+ "<extra_id_196>",
200
+ "<extra_id_197>",
201
+ "<extra_id_198>",
202
+ "<extra_id_199>",
203
+ "<extra_id_200>",
204
+ "<extra_id_201>",
205
+ "<extra_id_202>",
206
+ "<extra_id_203>",
207
+ "<extra_id_204>",
208
+ "<extra_id_205>",
209
+ "<extra_id_206>",
210
+ "<extra_id_207>",
211
+ "<extra_id_208>",
212
+ "<extra_id_209>",
213
+ "<extra_id_210>",
214
+ "<extra_id_211>",
215
+ "<extra_id_212>",
216
+ "<extra_id_213>",
217
+ "<extra_id_214>",
218
+ "<extra_id_215>",
219
+ "<extra_id_216>",
220
+ "<extra_id_217>",
221
+ "<extra_id_218>",
222
+ "<extra_id_219>",
223
+ "<extra_id_220>",
224
+ "<extra_id_221>",
225
+ "<extra_id_222>",
226
+ "<extra_id_223>",
227
+ "<extra_id_224>",
228
+ "<extra_id_225>",
229
+ "<extra_id_226>",
230
+ "<extra_id_227>",
231
+ "<extra_id_228>",
232
+ "<extra_id_229>",
233
+ "<extra_id_230>",
234
+ "<extra_id_231>",
235
+ "<extra_id_232>",
236
+ "<extra_id_233>",
237
+ "<extra_id_234>",
238
+ "<extra_id_235>",
239
+ "<extra_id_236>",
240
+ "<extra_id_237>",
241
+ "<extra_id_238>",
242
+ "<extra_id_239>",
243
+ "<extra_id_240>",
244
+ "<extra_id_241>",
245
+ "<extra_id_242>",
246
+ "<extra_id_243>",
247
+ "<extra_id_244>",
248
+ "<extra_id_245>",
249
+ "<extra_id_246>",
250
+ "<extra_id_247>",
251
+ "<extra_id_248>",
252
+ "<extra_id_249>",
253
+ "<extra_id_250>",
254
+ "<extra_id_251>",
255
+ "<extra_id_252>",
256
+ "<extra_id_253>",
257
+ "<extra_id_254>",
258
+ "<extra_id_255>",
259
+ "<extra_id_256>",
260
+ "<extra_id_257>",
261
+ "<extra_id_258>",
262
+ "<extra_id_259>",
263
+ "<extra_id_260>",
264
+ "<extra_id_261>",
265
+ "<extra_id_262>",
266
+ "<extra_id_263>",
267
+ "<extra_id_264>",
268
+ "<extra_id_265>",
269
+ "<extra_id_266>",
270
+ "<extra_id_267>",
271
+ "<extra_id_268>",
272
+ "<extra_id_269>",
273
+ "<extra_id_270>",
274
+ "<extra_id_271>",
275
+ "<extra_id_272>",
276
+ "<extra_id_273>",
277
+ "<extra_id_274>",
278
+ "<extra_id_275>",
279
+ "<extra_id_276>",
280
+ "<extra_id_277>",
281
+ "<extra_id_278>",
282
+ "<extra_id_279>",
283
+ "<extra_id_280>",
284
+ "<extra_id_281>",
285
+ "<extra_id_282>",
286
+ "<extra_id_283>",
287
+ "<extra_id_284>",
288
+ "<extra_id_285>",
289
+ "<extra_id_286>",
290
+ "<extra_id_287>",
291
+ "<extra_id_288>",
292
+ "<extra_id_289>",
293
+ "<extra_id_290>",
294
+ "<extra_id_291>",
295
+ "<extra_id_292>",
296
+ "<extra_id_293>",
297
+ "<extra_id_294>",
298
+ "<extra_id_295>",
299
+ "<extra_id_296>",
300
+ "<extra_id_297>",
301
+ "<extra_id_298>",
302
+ "<extra_id_299>"
303
+ ],
304
+ "bos_token": "<s>",
305
+ "eos_token": "</s>",
306
+ "pad_token": "<pad>",
307
+ "unk_token": "<unk>"
308
+ }
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3909a67b780650b35cf529ac782ad2b6b26e6d1f849d3fbb6a872905f452458
3
+ size 4548313
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e197b4d3dbd71da14b4eb255f4fa91c9c1f2068b20a2de2472967ca3d22602b
3
+ size 16837417
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer_config.json ADDED
@@ -0,0 +1,2748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "</s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "256000": {
36
+ "content": "<extra_id_299>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "256001": {
44
+ "content": "<extra_id_298>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "256002": {
52
+ "content": "<extra_id_297>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "256003": {
60
+ "content": "<extra_id_296>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "256004": {
68
+ "content": "<extra_id_295>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "256005": {
76
+ "content": "<extra_id_294>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "256006": {
84
+ "content": "<extra_id_293>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "256007": {
92
+ "content": "<extra_id_292>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "256008": {
100
+ "content": "<extra_id_291>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "256009": {
108
+ "content": "<extra_id_290>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "256010": {
116
+ "content": "<extra_id_289>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "256011": {
124
+ "content": "<extra_id_288>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "256012": {
132
+ "content": "<extra_id_287>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "256013": {
140
+ "content": "<extra_id_286>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "256014": {
148
+ "content": "<extra_id_285>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "256015": {
156
+ "content": "<extra_id_284>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
+ "256016": {
164
+ "content": "<extra_id_283>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": true
170
+ },
171
+ "256017": {
172
+ "content": "<extra_id_282>",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "256018": {
180
+ "content": "<extra_id_281>",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ },
187
+ "256019": {
188
+ "content": "<extra_id_280>",
189
+ "lstrip": false,
190
+ "normalized": false,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": true
194
+ },
195
+ "256020": {
196
+ "content": "<extra_id_279>",
197
+ "lstrip": false,
198
+ "normalized": false,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": true
202
+ },
203
+ "256021": {
204
+ "content": "<extra_id_278>",
205
+ "lstrip": false,
206
+ "normalized": false,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": true
210
+ },
211
+ "256022": {
212
+ "content": "<extra_id_277>",
213
+ "lstrip": false,
214
+ "normalized": false,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": true
218
+ },
219
+ "256023": {
220
+ "content": "<extra_id_276>",
221
+ "lstrip": false,
222
+ "normalized": false,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": true
226
+ },
227
+ "256024": {
228
+ "content": "<extra_id_275>",
229
+ "lstrip": false,
230
+ "normalized": false,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": true
234
+ },
235
+ "256025": {
236
+ "content": "<extra_id_274>",
237
+ "lstrip": false,
238
+ "normalized": false,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": true
242
+ },
243
+ "256026": {
244
+ "content": "<extra_id_273>",
245
+ "lstrip": false,
246
+ "normalized": false,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": true
250
+ },
251
+ "256027": {
252
+ "content": "<extra_id_272>",
253
+ "lstrip": false,
254
+ "normalized": false,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": true
258
+ },
259
+ "256028": {
260
+ "content": "<extra_id_271>",
261
+ "lstrip": false,
262
+ "normalized": false,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": true
266
+ },
267
+ "256029": {
268
+ "content": "<extra_id_270>",
269
+ "lstrip": false,
270
+ "normalized": false,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": true
274
+ },
275
+ "256030": {
276
+ "content": "<extra_id_269>",
277
+ "lstrip": false,
278
+ "normalized": false,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": true
282
+ },
283
+ "256031": {
284
+ "content": "<extra_id_268>",
285
+ "lstrip": false,
286
+ "normalized": false,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": true
290
+ },
291
+ "256032": {
292
+ "content": "<extra_id_267>",
293
+ "lstrip": false,
294
+ "normalized": false,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": true
298
+ },
299
+ "256033": {
300
+ "content": "<extra_id_266>",
301
+ "lstrip": false,
302
+ "normalized": false,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": true
306
+ },
307
+ "256034": {
308
+ "content": "<extra_id_265>",
309
+ "lstrip": false,
310
+ "normalized": false,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": true
314
+ },
315
+ "256035": {
316
+ "content": "<extra_id_264>",
317
+ "lstrip": false,
318
+ "normalized": false,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": true
322
+ },
323
+ "256036": {
324
+ "content": "<extra_id_263>",
325
+ "lstrip": false,
326
+ "normalized": false,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": true
330
+ },
331
+ "256037": {
332
+ "content": "<extra_id_262>",
333
+ "lstrip": false,
334
+ "normalized": false,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": true
338
+ },
339
+ "256038": {
340
+ "content": "<extra_id_261>",
341
+ "lstrip": false,
342
+ "normalized": false,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": true
346
+ },
347
+ "256039": {
348
+ "content": "<extra_id_260>",
349
+ "lstrip": false,
350
+ "normalized": false,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": true
354
+ },
355
+ "256040": {
356
+ "content": "<extra_id_259>",
357
+ "lstrip": false,
358
+ "normalized": false,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": true
362
+ },
363
+ "256041": {
364
+ "content": "<extra_id_258>",
365
+ "lstrip": false,
366
+ "normalized": false,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": true
370
+ },
371
+ "256042": {
372
+ "content": "<extra_id_257>",
373
+ "lstrip": false,
374
+ "normalized": false,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": true
378
+ },
379
+ "256043": {
380
+ "content": "<extra_id_256>",
381
+ "lstrip": false,
382
+ "normalized": false,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": true
386
+ },
387
+ "256044": {
388
+ "content": "<extra_id_255>",
389
+ "lstrip": false,
390
+ "normalized": false,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": true
394
+ },
395
+ "256045": {
396
+ "content": "<extra_id_254>",
397
+ "lstrip": false,
398
+ "normalized": false,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": true
402
+ },
403
+ "256046": {
404
+ "content": "<extra_id_253>",
405
+ "lstrip": false,
406
+ "normalized": false,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": true
410
+ },
411
+ "256047": {
412
+ "content": "<extra_id_252>",
413
+ "lstrip": false,
414
+ "normalized": false,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": true
418
+ },
419
+ "256048": {
420
+ "content": "<extra_id_251>",
421
+ "lstrip": false,
422
+ "normalized": false,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": true
426
+ },
427
+ "256049": {
428
+ "content": "<extra_id_250>",
429
+ "lstrip": false,
430
+ "normalized": false,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": true
434
+ },
435
+ "256050": {
436
+ "content": "<extra_id_249>",
437
+ "lstrip": false,
438
+ "normalized": false,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": true
442
+ },
443
+ "256051": {
444
+ "content": "<extra_id_248>",
445
+ "lstrip": false,
446
+ "normalized": false,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": true
450
+ },
451
+ "256052": {
452
+ "content": "<extra_id_247>",
453
+ "lstrip": false,
454
+ "normalized": false,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": true
458
+ },
459
+ "256053": {
460
+ "content": "<extra_id_246>",
461
+ "lstrip": false,
462
+ "normalized": false,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": true
466
+ },
467
+ "256054": {
468
+ "content": "<extra_id_245>",
469
+ "lstrip": false,
470
+ "normalized": false,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": true
474
+ },
475
+ "256055": {
476
+ "content": "<extra_id_244>",
477
+ "lstrip": false,
478
+ "normalized": false,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": true
482
+ },
483
+ "256056": {
484
+ "content": "<extra_id_243>",
485
+ "lstrip": false,
486
+ "normalized": false,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": true
490
+ },
491
+ "256057": {
492
+ "content": "<extra_id_242>",
493
+ "lstrip": false,
494
+ "normalized": false,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": true
498
+ },
499
+ "256058": {
500
+ "content": "<extra_id_241>",
501
+ "lstrip": false,
502
+ "normalized": false,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": true
506
+ },
507
+ "256059": {
508
+ "content": "<extra_id_240>",
509
+ "lstrip": false,
510
+ "normalized": false,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": true
514
+ },
515
+ "256060": {
516
+ "content": "<extra_id_239>",
517
+ "lstrip": false,
518
+ "normalized": false,
519
+ "rstrip": false,
520
+ "single_word": false,
521
+ "special": true
522
+ },
523
+ "256061": {
524
+ "content": "<extra_id_238>",
525
+ "lstrip": false,
526
+ "normalized": false,
527
+ "rstrip": false,
528
+ "single_word": false,
529
+ "special": true
530
+ },
531
+ "256062": {
532
+ "content": "<extra_id_237>",
533
+ "lstrip": false,
534
+ "normalized": false,
535
+ "rstrip": false,
536
+ "single_word": false,
537
+ "special": true
538
+ },
539
+ "256063": {
540
+ "content": "<extra_id_236>",
541
+ "lstrip": false,
542
+ "normalized": false,
543
+ "rstrip": false,
544
+ "single_word": false,
545
+ "special": true
546
+ },
547
+ "256064": {
548
+ "content": "<extra_id_235>",
549
+ "lstrip": false,
550
+ "normalized": false,
551
+ "rstrip": false,
552
+ "single_word": false,
553
+ "special": true
554
+ },
555
+ "256065": {
556
+ "content": "<extra_id_234>",
557
+ "lstrip": false,
558
+ "normalized": false,
559
+ "rstrip": false,
560
+ "single_word": false,
561
+ "special": true
562
+ },
563
+ "256066": {
564
+ "content": "<extra_id_233>",
565
+ "lstrip": false,
566
+ "normalized": false,
567
+ "rstrip": false,
568
+ "single_word": false,
569
+ "special": true
570
+ },
571
+ "256067": {
572
+ "content": "<extra_id_232>",
573
+ "lstrip": false,
574
+ "normalized": false,
575
+ "rstrip": false,
576
+ "single_word": false,
577
+ "special": true
578
+ },
579
+ "256068": {
580
+ "content": "<extra_id_231>",
581
+ "lstrip": false,
582
+ "normalized": false,
583
+ "rstrip": false,
584
+ "single_word": false,
585
+ "special": true
586
+ },
587
+ "256069": {
588
+ "content": "<extra_id_230>",
589
+ "lstrip": false,
590
+ "normalized": false,
591
+ "rstrip": false,
592
+ "single_word": false,
593
+ "special": true
594
+ },
595
+ "256070": {
596
+ "content": "<extra_id_229>",
597
+ "lstrip": false,
598
+ "normalized": false,
599
+ "rstrip": false,
600
+ "single_word": false,
601
+ "special": true
602
+ },
603
+ "256071": {
604
+ "content": "<extra_id_228>",
605
+ "lstrip": false,
606
+ "normalized": false,
607
+ "rstrip": false,
608
+ "single_word": false,
609
+ "special": true
610
+ },
611
+ "256072": {
612
+ "content": "<extra_id_227>",
613
+ "lstrip": false,
614
+ "normalized": false,
615
+ "rstrip": false,
616
+ "single_word": false,
617
+ "special": true
618
+ },
619
+ "256073": {
620
+ "content": "<extra_id_226>",
621
+ "lstrip": false,
622
+ "normalized": false,
623
+ "rstrip": false,
624
+ "single_word": false,
625
+ "special": true
626
+ },
627
+ "256074": {
628
+ "content": "<extra_id_225>",
629
+ "lstrip": false,
630
+ "normalized": false,
631
+ "rstrip": false,
632
+ "single_word": false,
633
+ "special": true
634
+ },
635
+ "256075": {
636
+ "content": "<extra_id_224>",
637
+ "lstrip": false,
638
+ "normalized": false,
639
+ "rstrip": false,
640
+ "single_word": false,
641
+ "special": true
642
+ },
643
+ "256076": {
644
+ "content": "<extra_id_223>",
645
+ "lstrip": false,
646
+ "normalized": false,
647
+ "rstrip": false,
648
+ "single_word": false,
649
+ "special": true
650
+ },
651
+ "256077": {
652
+ "content": "<extra_id_222>",
653
+ "lstrip": false,
654
+ "normalized": false,
655
+ "rstrip": false,
656
+ "single_word": false,
657
+ "special": true
658
+ },
659
+ "256078": {
660
+ "content": "<extra_id_221>",
661
+ "lstrip": false,
662
+ "normalized": false,
663
+ "rstrip": false,
664
+ "single_word": false,
665
+ "special": true
666
+ },
667
+ "256079": {
668
+ "content": "<extra_id_220>",
669
+ "lstrip": false,
670
+ "normalized": false,
671
+ "rstrip": false,
672
+ "single_word": false,
673
+ "special": true
674
+ },
675
+ "256080": {
676
+ "content": "<extra_id_219>",
677
+ "lstrip": false,
678
+ "normalized": false,
679
+ "rstrip": false,
680
+ "single_word": false,
681
+ "special": true
682
+ },
683
+ "256081": {
684
+ "content": "<extra_id_218>",
685
+ "lstrip": false,
686
+ "normalized": false,
687
+ "rstrip": false,
688
+ "single_word": false,
689
+ "special": true
690
+ },
691
+ "256082": {
692
+ "content": "<extra_id_217>",
693
+ "lstrip": false,
694
+ "normalized": false,
695
+ "rstrip": false,
696
+ "single_word": false,
697
+ "special": true
698
+ },
699
+ "256083": {
700
+ "content": "<extra_id_216>",
701
+ "lstrip": false,
702
+ "normalized": false,
703
+ "rstrip": false,
704
+ "single_word": false,
705
+ "special": true
706
+ },
707
+ "256084": {
708
+ "content": "<extra_id_215>",
709
+ "lstrip": false,
710
+ "normalized": false,
711
+ "rstrip": false,
712
+ "single_word": false,
713
+ "special": true
714
+ },
715
+ "256085": {
716
+ "content": "<extra_id_214>",
717
+ "lstrip": false,
718
+ "normalized": false,
719
+ "rstrip": false,
720
+ "single_word": false,
721
+ "special": true
722
+ },
723
+ "256086": {
724
+ "content": "<extra_id_213>",
725
+ "lstrip": false,
726
+ "normalized": false,
727
+ "rstrip": false,
728
+ "single_word": false,
729
+ "special": true
730
+ },
731
+ "256087": {
732
+ "content": "<extra_id_212>",
733
+ "lstrip": false,
734
+ "normalized": false,
735
+ "rstrip": false,
736
+ "single_word": false,
737
+ "special": true
738
+ },
739
+ "256088": {
740
+ "content": "<extra_id_211>",
741
+ "lstrip": false,
742
+ "normalized": false,
743
+ "rstrip": false,
744
+ "single_word": false,
745
+ "special": true
746
+ },
747
+ "256089": {
748
+ "content": "<extra_id_210>",
749
+ "lstrip": false,
750
+ "normalized": false,
751
+ "rstrip": false,
752
+ "single_word": false,
753
+ "special": true
754
+ },
755
+ "256090": {
756
+ "content": "<extra_id_209>",
757
+ "lstrip": false,
758
+ "normalized": false,
759
+ "rstrip": false,
760
+ "single_word": false,
761
+ "special": true
762
+ },
763
+ "256091": {
764
+ "content": "<extra_id_208>",
765
+ "lstrip": false,
766
+ "normalized": false,
767
+ "rstrip": false,
768
+ "single_word": false,
769
+ "special": true
770
+ },
771
+ "256092": {
772
+ "content": "<extra_id_207>",
773
+ "lstrip": false,
774
+ "normalized": false,
775
+ "rstrip": false,
776
+ "single_word": false,
777
+ "special": true
778
+ },
779
+ "256093": {
780
+ "content": "<extra_id_206>",
781
+ "lstrip": false,
782
+ "normalized": false,
783
+ "rstrip": false,
784
+ "single_word": false,
785
+ "special": true
786
+ },
787
+ "256094": {
788
+ "content": "<extra_id_205>",
789
+ "lstrip": false,
790
+ "normalized": false,
791
+ "rstrip": false,
792
+ "single_word": false,
793
+ "special": true
794
+ },
795
+ "256095": {
796
+ "content": "<extra_id_204>",
797
+ "lstrip": false,
798
+ "normalized": false,
799
+ "rstrip": false,
800
+ "single_word": false,
801
+ "special": true
802
+ },
803
+ "256096": {
804
+ "content": "<extra_id_203>",
805
+ "lstrip": false,
806
+ "normalized": false,
807
+ "rstrip": false,
808
+ "single_word": false,
809
+ "special": true
810
+ },
811
+ "256097": {
812
+ "content": "<extra_id_202>",
813
+ "lstrip": false,
814
+ "normalized": false,
815
+ "rstrip": false,
816
+ "single_word": false,
817
+ "special": true
818
+ },
819
+ "256098": {
820
+ "content": "<extra_id_201>",
821
+ "lstrip": false,
822
+ "normalized": false,
823
+ "rstrip": false,
824
+ "single_word": false,
825
+ "special": true
826
+ },
827
+ "256099": {
828
+ "content": "<extra_id_200>",
829
+ "lstrip": false,
830
+ "normalized": false,
831
+ "rstrip": false,
832
+ "single_word": false,
833
+ "special": true
834
+ },
835
+ "256100": {
836
+ "content": "<extra_id_199>",
837
+ "lstrip": false,
838
+ "normalized": false,
839
+ "rstrip": false,
840
+ "single_word": false,
841
+ "special": true
842
+ },
843
+ "256101": {
844
+ "content": "<extra_id_198>",
845
+ "lstrip": false,
846
+ "normalized": false,
847
+ "rstrip": false,
848
+ "single_word": false,
849
+ "special": true
850
+ },
851
+ "256102": {
852
+ "content": "<extra_id_197>",
853
+ "lstrip": false,
854
+ "normalized": false,
855
+ "rstrip": false,
856
+ "single_word": false,
857
+ "special": true
858
+ },
859
+ "256103": {
860
+ "content": "<extra_id_196>",
861
+ "lstrip": false,
862
+ "normalized": false,
863
+ "rstrip": false,
864
+ "single_word": false,
865
+ "special": true
866
+ },
867
+ "256104": {
868
+ "content": "<extra_id_195>",
869
+ "lstrip": false,
870
+ "normalized": false,
871
+ "rstrip": false,
872
+ "single_word": false,
873
+ "special": true
874
+ },
875
+ "256105": {
876
+ "content": "<extra_id_194>",
877
+ "lstrip": false,
878
+ "normalized": false,
879
+ "rstrip": false,
880
+ "single_word": false,
881
+ "special": true
882
+ },
883
+ "256106": {
884
+ "content": "<extra_id_193>",
885
+ "lstrip": false,
886
+ "normalized": false,
887
+ "rstrip": false,
888
+ "single_word": false,
889
+ "special": true
890
+ },
891
+ "256107": {
892
+ "content": "<extra_id_192>",
893
+ "lstrip": false,
894
+ "normalized": false,
895
+ "rstrip": false,
896
+ "single_word": false,
897
+ "special": true
898
+ },
899
+ "256108": {
900
+ "content": "<extra_id_191>",
901
+ "lstrip": false,
902
+ "normalized": false,
903
+ "rstrip": false,
904
+ "single_word": false,
905
+ "special": true
906
+ },
907
+ "256109": {
908
+ "content": "<extra_id_190>",
909
+ "lstrip": false,
910
+ "normalized": false,
911
+ "rstrip": false,
912
+ "single_word": false,
913
+ "special": true
914
+ },
915
+ "256110": {
916
+ "content": "<extra_id_189>",
917
+ "lstrip": false,
918
+ "normalized": false,
919
+ "rstrip": false,
920
+ "single_word": false,
921
+ "special": true
922
+ },
923
+ "256111": {
924
+ "content": "<extra_id_188>",
925
+ "lstrip": false,
926
+ "normalized": false,
927
+ "rstrip": false,
928
+ "single_word": false,
929
+ "special": true
930
+ },
931
+ "256112": {
932
+ "content": "<extra_id_187>",
933
+ "lstrip": false,
934
+ "normalized": false,
935
+ "rstrip": false,
936
+ "single_word": false,
937
+ "special": true
938
+ },
939
+ "256113": {
940
+ "content": "<extra_id_186>",
941
+ "lstrip": false,
942
+ "normalized": false,
943
+ "rstrip": false,
944
+ "single_word": false,
945
+ "special": true
946
+ },
947
+ "256114": {
948
+ "content": "<extra_id_185>",
949
+ "lstrip": false,
950
+ "normalized": false,
951
+ "rstrip": false,
952
+ "single_word": false,
953
+ "special": true
954
+ },
955
+ "256115": {
956
+ "content": "<extra_id_184>",
957
+ "lstrip": false,
958
+ "normalized": false,
959
+ "rstrip": false,
960
+ "single_word": false,
961
+ "special": true
962
+ },
963
+ "256116": {
964
+ "content": "<extra_id_183>",
965
+ "lstrip": false,
966
+ "normalized": false,
967
+ "rstrip": false,
968
+ "single_word": false,
969
+ "special": true
970
+ },
971
+ "256117": {
972
+ "content": "<extra_id_182>",
973
+ "lstrip": false,
974
+ "normalized": false,
975
+ "rstrip": false,
976
+ "single_word": false,
977
+ "special": true
978
+ },
979
+ "256118": {
980
+ "content": "<extra_id_181>",
981
+ "lstrip": false,
982
+ "normalized": false,
983
+ "rstrip": false,
984
+ "single_word": false,
985
+ "special": true
986
+ },
987
+ "256119": {
988
+ "content": "<extra_id_180>",
989
+ "lstrip": false,
990
+ "normalized": false,
991
+ "rstrip": false,
992
+ "single_word": false,
993
+ "special": true
994
+ },
995
+ "256120": {
996
+ "content": "<extra_id_179>",
997
+ "lstrip": false,
998
+ "normalized": false,
999
+ "rstrip": false,
1000
+ "single_word": false,
1001
+ "special": true
1002
+ },
1003
+ "256121": {
1004
+ "content": "<extra_id_178>",
1005
+ "lstrip": false,
1006
+ "normalized": false,
1007
+ "rstrip": false,
1008
+ "single_word": false,
1009
+ "special": true
1010
+ },
1011
+ "256122": {
1012
+ "content": "<extra_id_177>",
1013
+ "lstrip": false,
1014
+ "normalized": false,
1015
+ "rstrip": false,
1016
+ "single_word": false,
1017
+ "special": true
1018
+ },
1019
+ "256123": {
1020
+ "content": "<extra_id_176>",
1021
+ "lstrip": false,
1022
+ "normalized": false,
1023
+ "rstrip": false,
1024
+ "single_word": false,
1025
+ "special": true
1026
+ },
1027
+ "256124": {
1028
+ "content": "<extra_id_175>",
1029
+ "lstrip": false,
1030
+ "normalized": false,
1031
+ "rstrip": false,
1032
+ "single_word": false,
1033
+ "special": true
1034
+ },
1035
+ "256125": {
1036
+ "content": "<extra_id_174>",
1037
+ "lstrip": false,
1038
+ "normalized": false,
1039
+ "rstrip": false,
1040
+ "single_word": false,
1041
+ "special": true
1042
+ },
1043
+ "256126": {
1044
+ "content": "<extra_id_173>",
1045
+ "lstrip": false,
1046
+ "normalized": false,
1047
+ "rstrip": false,
1048
+ "single_word": false,
1049
+ "special": true
1050
+ },
1051
+ "256127": {
1052
+ "content": "<extra_id_172>",
1053
+ "lstrip": false,
1054
+ "normalized": false,
1055
+ "rstrip": false,
1056
+ "single_word": false,
1057
+ "special": true
1058
+ },
1059
+ "256128": {
1060
+ "content": "<extra_id_171>",
1061
+ "lstrip": false,
1062
+ "normalized": false,
1063
+ "rstrip": false,
1064
+ "single_word": false,
1065
+ "special": true
1066
+ },
1067
+ "256129": {
1068
+ "content": "<extra_id_170>",
1069
+ "lstrip": false,
1070
+ "normalized": false,
1071
+ "rstrip": false,
1072
+ "single_word": false,
1073
+ "special": true
1074
+ },
1075
+ "256130": {
1076
+ "content": "<extra_id_169>",
1077
+ "lstrip": false,
1078
+ "normalized": false,
1079
+ "rstrip": false,
1080
+ "single_word": false,
1081
+ "special": true
1082
+ },
1083
+ "256131": {
1084
+ "content": "<extra_id_168>",
1085
+ "lstrip": false,
1086
+ "normalized": false,
1087
+ "rstrip": false,
1088
+ "single_word": false,
1089
+ "special": true
1090
+ },
1091
+ "256132": {
1092
+ "content": "<extra_id_167>",
1093
+ "lstrip": false,
1094
+ "normalized": false,
1095
+ "rstrip": false,
1096
+ "single_word": false,
1097
+ "special": true
1098
+ },
1099
+ "256133": {
1100
+ "content": "<extra_id_166>",
1101
+ "lstrip": false,
1102
+ "normalized": false,
1103
+ "rstrip": false,
1104
+ "single_word": false,
1105
+ "special": true
1106
+ },
1107
+ "256134": {
1108
+ "content": "<extra_id_165>",
1109
+ "lstrip": false,
1110
+ "normalized": false,
1111
+ "rstrip": false,
1112
+ "single_word": false,
1113
+ "special": true
1114
+ },
1115
+ "256135": {
1116
+ "content": "<extra_id_164>",
1117
+ "lstrip": false,
1118
+ "normalized": false,
1119
+ "rstrip": false,
1120
+ "single_word": false,
1121
+ "special": true
1122
+ },
1123
+ "256136": {
1124
+ "content": "<extra_id_163>",
1125
+ "lstrip": false,
1126
+ "normalized": false,
1127
+ "rstrip": false,
1128
+ "single_word": false,
1129
+ "special": true
1130
+ },
1131
+ "256137": {
1132
+ "content": "<extra_id_162>",
1133
+ "lstrip": false,
1134
+ "normalized": false,
1135
+ "rstrip": false,
1136
+ "single_word": false,
1137
+ "special": true
1138
+ },
1139
+ "256138": {
1140
+ "content": "<extra_id_161>",
1141
+ "lstrip": false,
1142
+ "normalized": false,
1143
+ "rstrip": false,
1144
+ "single_word": false,
1145
+ "special": true
1146
+ },
1147
+ "256139": {
1148
+ "content": "<extra_id_160>",
1149
+ "lstrip": false,
1150
+ "normalized": false,
1151
+ "rstrip": false,
1152
+ "single_word": false,
1153
+ "special": true
1154
+ },
1155
+ "256140": {
1156
+ "content": "<extra_id_159>",
1157
+ "lstrip": false,
1158
+ "normalized": false,
1159
+ "rstrip": false,
1160
+ "single_word": false,
1161
+ "special": true
1162
+ },
1163
+ "256141": {
1164
+ "content": "<extra_id_158>",
1165
+ "lstrip": false,
1166
+ "normalized": false,
1167
+ "rstrip": false,
1168
+ "single_word": false,
1169
+ "special": true
1170
+ },
1171
+ "256142": {
1172
+ "content": "<extra_id_157>",
1173
+ "lstrip": false,
1174
+ "normalized": false,
1175
+ "rstrip": false,
1176
+ "single_word": false,
1177
+ "special": true
1178
+ },
1179
+ "256143": {
1180
+ "content": "<extra_id_156>",
1181
+ "lstrip": false,
1182
+ "normalized": false,
1183
+ "rstrip": false,
1184
+ "single_word": false,
1185
+ "special": true
1186
+ },
1187
+ "256144": {
1188
+ "content": "<extra_id_155>",
1189
+ "lstrip": false,
1190
+ "normalized": false,
1191
+ "rstrip": false,
1192
+ "single_word": false,
1193
+ "special": true
1194
+ },
1195
+ "256145": {
1196
+ "content": "<extra_id_154>",
1197
+ "lstrip": false,
1198
+ "normalized": false,
1199
+ "rstrip": false,
1200
+ "single_word": false,
1201
+ "special": true
1202
+ },
1203
+ "256146": {
1204
+ "content": "<extra_id_153>",
1205
+ "lstrip": false,
1206
+ "normalized": false,
1207
+ "rstrip": false,
1208
+ "single_word": false,
1209
+ "special": true
1210
+ },
1211
+ "256147": {
1212
+ "content": "<extra_id_152>",
1213
+ "lstrip": false,
1214
+ "normalized": false,
1215
+ "rstrip": false,
1216
+ "single_word": false,
1217
+ "special": true
1218
+ },
1219
+ "256148": {
1220
+ "content": "<extra_id_151>",
1221
+ "lstrip": false,
1222
+ "normalized": false,
1223
+ "rstrip": false,
1224
+ "single_word": false,
1225
+ "special": true
1226
+ },
1227
+ "256149": {
1228
+ "content": "<extra_id_150>",
1229
+ "lstrip": false,
1230
+ "normalized": false,
1231
+ "rstrip": false,
1232
+ "single_word": false,
1233
+ "special": true
1234
+ },
1235
+ "256150": {
1236
+ "content": "<extra_id_149>",
1237
+ "lstrip": false,
1238
+ "normalized": false,
1239
+ "rstrip": false,
1240
+ "single_word": false,
1241
+ "special": true
1242
+ },
1243
+ "256151": {
1244
+ "content": "<extra_id_148>",
1245
+ "lstrip": false,
1246
+ "normalized": false,
1247
+ "rstrip": false,
1248
+ "single_word": false,
1249
+ "special": true
1250
+ },
1251
+ "256152": {
1252
+ "content": "<extra_id_147>",
1253
+ "lstrip": false,
1254
+ "normalized": false,
1255
+ "rstrip": false,
1256
+ "single_word": false,
1257
+ "special": true
1258
+ },
1259
+ "256153": {
1260
+ "content": "<extra_id_146>",
1261
+ "lstrip": false,
1262
+ "normalized": false,
1263
+ "rstrip": false,
1264
+ "single_word": false,
1265
+ "special": true
1266
+ },
1267
+ "256154": {
1268
+ "content": "<extra_id_145>",
1269
+ "lstrip": false,
1270
+ "normalized": false,
1271
+ "rstrip": false,
1272
+ "single_word": false,
1273
+ "special": true
1274
+ },
1275
+ "256155": {
1276
+ "content": "<extra_id_144>",
1277
+ "lstrip": false,
1278
+ "normalized": false,
1279
+ "rstrip": false,
1280
+ "single_word": false,
1281
+ "special": true
1282
+ },
1283
+ "256156": {
1284
+ "content": "<extra_id_143>",
1285
+ "lstrip": false,
1286
+ "normalized": false,
1287
+ "rstrip": false,
1288
+ "single_word": false,
1289
+ "special": true
1290
+ },
1291
+ "256157": {
1292
+ "content": "<extra_id_142>",
1293
+ "lstrip": false,
1294
+ "normalized": false,
1295
+ "rstrip": false,
1296
+ "single_word": false,
1297
+ "special": true
1298
+ },
1299
+ "256158": {
1300
+ "content": "<extra_id_141>",
1301
+ "lstrip": false,
1302
+ "normalized": false,
1303
+ "rstrip": false,
1304
+ "single_word": false,
1305
+ "special": true
1306
+ },
1307
+ "256159": {
1308
+ "content": "<extra_id_140>",
1309
+ "lstrip": false,
1310
+ "normalized": false,
1311
+ "rstrip": false,
1312
+ "single_word": false,
1313
+ "special": true
1314
+ },
1315
+ "256160": {
1316
+ "content": "<extra_id_139>",
1317
+ "lstrip": false,
1318
+ "normalized": false,
1319
+ "rstrip": false,
1320
+ "single_word": false,
1321
+ "special": true
1322
+ },
1323
+ "256161": {
1324
+ "content": "<extra_id_138>",
1325
+ "lstrip": false,
1326
+ "normalized": false,
1327
+ "rstrip": false,
1328
+ "single_word": false,
1329
+ "special": true
1330
+ },
1331
+ "256162": {
1332
+ "content": "<extra_id_137>",
1333
+ "lstrip": false,
1334
+ "normalized": false,
1335
+ "rstrip": false,
1336
+ "single_word": false,
1337
+ "special": true
1338
+ },
1339
+ "256163": {
1340
+ "content": "<extra_id_136>",
1341
+ "lstrip": false,
1342
+ "normalized": false,
1343
+ "rstrip": false,
1344
+ "single_word": false,
1345
+ "special": true
1346
+ },
1347
+ "256164": {
1348
+ "content": "<extra_id_135>",
1349
+ "lstrip": false,
1350
+ "normalized": false,
1351
+ "rstrip": false,
1352
+ "single_word": false,
1353
+ "special": true
1354
+ },
1355
+ "256165": {
1356
+ "content": "<extra_id_134>",
1357
+ "lstrip": false,
1358
+ "normalized": false,
1359
+ "rstrip": false,
1360
+ "single_word": false,
1361
+ "special": true
1362
+ },
1363
+ "256166": {
1364
+ "content": "<extra_id_133>",
1365
+ "lstrip": false,
1366
+ "normalized": false,
1367
+ "rstrip": false,
1368
+ "single_word": false,
1369
+ "special": true
1370
+ },
1371
+ "256167": {
1372
+ "content": "<extra_id_132>",
1373
+ "lstrip": false,
1374
+ "normalized": false,
1375
+ "rstrip": false,
1376
+ "single_word": false,
1377
+ "special": true
1378
+ },
1379
+ "256168": {
1380
+ "content": "<extra_id_131>",
1381
+ "lstrip": false,
1382
+ "normalized": false,
1383
+ "rstrip": false,
1384
+ "single_word": false,
1385
+ "special": true
1386
+ },
1387
+ "256169": {
1388
+ "content": "<extra_id_130>",
1389
+ "lstrip": false,
1390
+ "normalized": false,
1391
+ "rstrip": false,
1392
+ "single_word": false,
1393
+ "special": true
1394
+ },
1395
+ "256170": {
1396
+ "content": "<extra_id_129>",
1397
+ "lstrip": false,
1398
+ "normalized": false,
1399
+ "rstrip": false,
1400
+ "single_word": false,
1401
+ "special": true
1402
+ },
1403
+ "256171": {
1404
+ "content": "<extra_id_128>",
1405
+ "lstrip": false,
1406
+ "normalized": false,
1407
+ "rstrip": false,
1408
+ "single_word": false,
1409
+ "special": true
1410
+ },
1411
+ "256172": {
1412
+ "content": "<extra_id_127>",
1413
+ "lstrip": false,
1414
+ "normalized": false,
1415
+ "rstrip": false,
1416
+ "single_word": false,
1417
+ "special": true
1418
+ },
1419
+ "256173": {
1420
+ "content": "<extra_id_126>",
1421
+ "lstrip": false,
1422
+ "normalized": false,
1423
+ "rstrip": false,
1424
+ "single_word": false,
1425
+ "special": true
1426
+ },
1427
+ "256174": {
1428
+ "content": "<extra_id_125>",
1429
+ "lstrip": false,
1430
+ "normalized": false,
1431
+ "rstrip": false,
1432
+ "single_word": false,
1433
+ "special": true
1434
+ },
1435
+ "256175": {
1436
+ "content": "<extra_id_124>",
1437
+ "lstrip": false,
1438
+ "normalized": false,
1439
+ "rstrip": false,
1440
+ "single_word": false,
1441
+ "special": true
1442
+ },
1443
+ "256176": {
1444
+ "content": "<extra_id_123>",
1445
+ "lstrip": false,
1446
+ "normalized": false,
1447
+ "rstrip": false,
1448
+ "single_word": false,
1449
+ "special": true
1450
+ },
1451
+ "256177": {
1452
+ "content": "<extra_id_122>",
1453
+ "lstrip": false,
1454
+ "normalized": false,
1455
+ "rstrip": false,
1456
+ "single_word": false,
1457
+ "special": true
1458
+ },
1459
+ "256178": {
1460
+ "content": "<extra_id_121>",
1461
+ "lstrip": false,
1462
+ "normalized": false,
1463
+ "rstrip": false,
1464
+ "single_word": false,
1465
+ "special": true
1466
+ },
1467
+ "256179": {
1468
+ "content": "<extra_id_120>",
1469
+ "lstrip": false,
1470
+ "normalized": false,
1471
+ "rstrip": false,
1472
+ "single_word": false,
1473
+ "special": true
1474
+ },
1475
+ "256180": {
1476
+ "content": "<extra_id_119>",
1477
+ "lstrip": false,
1478
+ "normalized": false,
1479
+ "rstrip": false,
1480
+ "single_word": false,
1481
+ "special": true
1482
+ },
1483
+ "256181": {
1484
+ "content": "<extra_id_118>",
1485
+ "lstrip": false,
1486
+ "normalized": false,
1487
+ "rstrip": false,
1488
+ "single_word": false,
1489
+ "special": true
1490
+ },
1491
+ "256182": {
1492
+ "content": "<extra_id_117>",
1493
+ "lstrip": false,
1494
+ "normalized": false,
1495
+ "rstrip": false,
1496
+ "single_word": false,
1497
+ "special": true
1498
+ },
1499
+ "256183": {
1500
+ "content": "<extra_id_116>",
1501
+ "lstrip": false,
1502
+ "normalized": false,
1503
+ "rstrip": false,
1504
+ "single_word": false,
1505
+ "special": true
1506
+ },
1507
+ "256184": {
1508
+ "content": "<extra_id_115>",
1509
+ "lstrip": false,
1510
+ "normalized": false,
1511
+ "rstrip": false,
1512
+ "single_word": false,
1513
+ "special": true
1514
+ },
1515
+ "256185": {
1516
+ "content": "<extra_id_114>",
1517
+ "lstrip": false,
1518
+ "normalized": false,
1519
+ "rstrip": false,
1520
+ "single_word": false,
1521
+ "special": true
1522
+ },
1523
+ "256186": {
1524
+ "content": "<extra_id_113>",
1525
+ "lstrip": false,
1526
+ "normalized": false,
1527
+ "rstrip": false,
1528
+ "single_word": false,
1529
+ "special": true
1530
+ },
1531
+ "256187": {
1532
+ "content": "<extra_id_112>",
1533
+ "lstrip": false,
1534
+ "normalized": false,
1535
+ "rstrip": false,
1536
+ "single_word": false,
1537
+ "special": true
1538
+ },
1539
+ "256188": {
1540
+ "content": "<extra_id_111>",
1541
+ "lstrip": false,
1542
+ "normalized": false,
1543
+ "rstrip": false,
1544
+ "single_word": false,
1545
+ "special": true
1546
+ },
1547
+ "256189": {
1548
+ "content": "<extra_id_110>",
1549
+ "lstrip": false,
1550
+ "normalized": false,
1551
+ "rstrip": false,
1552
+ "single_word": false,
1553
+ "special": true
1554
+ },
1555
+ "256190": {
1556
+ "content": "<extra_id_109>",
1557
+ "lstrip": false,
1558
+ "normalized": false,
1559
+ "rstrip": false,
1560
+ "single_word": false,
1561
+ "special": true
1562
+ },
1563
+ "256191": {
1564
+ "content": "<extra_id_108>",
1565
+ "lstrip": false,
1566
+ "normalized": false,
1567
+ "rstrip": false,
1568
+ "single_word": false,
1569
+ "special": true
1570
+ },
1571
+ "256192": {
1572
+ "content": "<extra_id_107>",
1573
+ "lstrip": false,
1574
+ "normalized": false,
1575
+ "rstrip": false,
1576
+ "single_word": false,
1577
+ "special": true
1578
+ },
1579
+ "256193": {
1580
+ "content": "<extra_id_106>",
1581
+ "lstrip": false,
1582
+ "normalized": false,
1583
+ "rstrip": false,
1584
+ "single_word": false,
1585
+ "special": true
1586
+ },
1587
+ "256194": {
1588
+ "content": "<extra_id_105>",
1589
+ "lstrip": false,
1590
+ "normalized": false,
1591
+ "rstrip": false,
1592
+ "single_word": false,
1593
+ "special": true
1594
+ },
1595
+ "256195": {
1596
+ "content": "<extra_id_104>",
1597
+ "lstrip": false,
1598
+ "normalized": false,
1599
+ "rstrip": false,
1600
+ "single_word": false,
1601
+ "special": true
1602
+ },
1603
+ "256196": {
1604
+ "content": "<extra_id_103>",
1605
+ "lstrip": false,
1606
+ "normalized": false,
1607
+ "rstrip": false,
1608
+ "single_word": false,
1609
+ "special": true
1610
+ },
1611
+ "256197": {
1612
+ "content": "<extra_id_102>",
1613
+ "lstrip": false,
1614
+ "normalized": false,
1615
+ "rstrip": false,
1616
+ "single_word": false,
1617
+ "special": true
1618
+ },
1619
+ "256198": {
1620
+ "content": "<extra_id_101>",
1621
+ "lstrip": false,
1622
+ "normalized": false,
1623
+ "rstrip": false,
1624
+ "single_word": false,
1625
+ "special": true
1626
+ },
1627
+ "256199": {
1628
+ "content": "<extra_id_100>",
1629
+ "lstrip": false,
1630
+ "normalized": false,
1631
+ "rstrip": false,
1632
+ "single_word": false,
1633
+ "special": true
1634
+ },
1635
+ "256200": {
1636
+ "content": "<extra_id_99>",
1637
+ "lstrip": false,
1638
+ "normalized": false,
1639
+ "rstrip": false,
1640
+ "single_word": false,
1641
+ "special": true
1642
+ },
1643
+ "256201": {
1644
+ "content": "<extra_id_98>",
1645
+ "lstrip": false,
1646
+ "normalized": false,
1647
+ "rstrip": false,
1648
+ "single_word": false,
1649
+ "special": true
1650
+ },
1651
+ "256202": {
1652
+ "content": "<extra_id_97>",
1653
+ "lstrip": false,
1654
+ "normalized": false,
1655
+ "rstrip": false,
1656
+ "single_word": false,
1657
+ "special": true
1658
+ },
1659
+ "256203": {
1660
+ "content": "<extra_id_96>",
1661
+ "lstrip": false,
1662
+ "normalized": false,
1663
+ "rstrip": false,
1664
+ "single_word": false,
1665
+ "special": true
1666
+ },
1667
+ "256204": {
1668
+ "content": "<extra_id_95>",
1669
+ "lstrip": false,
1670
+ "normalized": false,
1671
+ "rstrip": false,
1672
+ "single_word": false,
1673
+ "special": true
1674
+ },
1675
+ "256205": {
1676
+ "content": "<extra_id_94>",
1677
+ "lstrip": false,
1678
+ "normalized": false,
1679
+ "rstrip": false,
1680
+ "single_word": false,
1681
+ "special": true
1682
+ },
1683
+ "256206": {
1684
+ "content": "<extra_id_93>",
1685
+ "lstrip": false,
1686
+ "normalized": false,
1687
+ "rstrip": false,
1688
+ "single_word": false,
1689
+ "special": true
1690
+ },
1691
+ "256207": {
1692
+ "content": "<extra_id_92>",
1693
+ "lstrip": false,
1694
+ "normalized": false,
1695
+ "rstrip": false,
1696
+ "single_word": false,
1697
+ "special": true
1698
+ },
1699
+ "256208": {
1700
+ "content": "<extra_id_91>",
1701
+ "lstrip": false,
1702
+ "normalized": false,
1703
+ "rstrip": false,
1704
+ "single_word": false,
1705
+ "special": true
1706
+ },
1707
+ "256209": {
1708
+ "content": "<extra_id_90>",
1709
+ "lstrip": false,
1710
+ "normalized": false,
1711
+ "rstrip": false,
1712
+ "single_word": false,
1713
+ "special": true
1714
+ },
1715
+ "256210": {
1716
+ "content": "<extra_id_89>",
1717
+ "lstrip": false,
1718
+ "normalized": false,
1719
+ "rstrip": false,
1720
+ "single_word": false,
1721
+ "special": true
1722
+ },
1723
+ "256211": {
1724
+ "content": "<extra_id_88>",
1725
+ "lstrip": false,
1726
+ "normalized": false,
1727
+ "rstrip": false,
1728
+ "single_word": false,
1729
+ "special": true
1730
+ },
1731
+ "256212": {
1732
+ "content": "<extra_id_87>",
1733
+ "lstrip": false,
1734
+ "normalized": false,
1735
+ "rstrip": false,
1736
+ "single_word": false,
1737
+ "special": true
1738
+ },
1739
+ "256213": {
1740
+ "content": "<extra_id_86>",
1741
+ "lstrip": false,
1742
+ "normalized": false,
1743
+ "rstrip": false,
1744
+ "single_word": false,
1745
+ "special": true
1746
+ },
1747
+ "256214": {
1748
+ "content": "<extra_id_85>",
1749
+ "lstrip": false,
1750
+ "normalized": false,
1751
+ "rstrip": false,
1752
+ "single_word": false,
1753
+ "special": true
1754
+ },
1755
+ "256215": {
1756
+ "content": "<extra_id_84>",
1757
+ "lstrip": false,
1758
+ "normalized": false,
1759
+ "rstrip": false,
1760
+ "single_word": false,
1761
+ "special": true
1762
+ },
1763
+ "256216": {
1764
+ "content": "<extra_id_83>",
1765
+ "lstrip": false,
1766
+ "normalized": false,
1767
+ "rstrip": false,
1768
+ "single_word": false,
1769
+ "special": true
1770
+ },
1771
+ "256217": {
1772
+ "content": "<extra_id_82>",
1773
+ "lstrip": false,
1774
+ "normalized": false,
1775
+ "rstrip": false,
1776
+ "single_word": false,
1777
+ "special": true
1778
+ },
1779
+ "256218": {
1780
+ "content": "<extra_id_81>",
1781
+ "lstrip": false,
1782
+ "normalized": false,
1783
+ "rstrip": false,
1784
+ "single_word": false,
1785
+ "special": true
1786
+ },
1787
+ "256219": {
1788
+ "content": "<extra_id_80>",
1789
+ "lstrip": false,
1790
+ "normalized": false,
1791
+ "rstrip": false,
1792
+ "single_word": false,
1793
+ "special": true
1794
+ },
1795
+ "256220": {
1796
+ "content": "<extra_id_79>",
1797
+ "lstrip": false,
1798
+ "normalized": false,
1799
+ "rstrip": false,
1800
+ "single_word": false,
1801
+ "special": true
1802
+ },
1803
+ "256221": {
1804
+ "content": "<extra_id_78>",
1805
+ "lstrip": false,
1806
+ "normalized": false,
1807
+ "rstrip": false,
1808
+ "single_word": false,
1809
+ "special": true
1810
+ },
1811
+ "256222": {
1812
+ "content": "<extra_id_77>",
1813
+ "lstrip": false,
1814
+ "normalized": false,
1815
+ "rstrip": false,
1816
+ "single_word": false,
1817
+ "special": true
1818
+ },
1819
+ "256223": {
1820
+ "content": "<extra_id_76>",
1821
+ "lstrip": false,
1822
+ "normalized": false,
1823
+ "rstrip": false,
1824
+ "single_word": false,
1825
+ "special": true
1826
+ },
1827
+ "256224": {
1828
+ "content": "<extra_id_75>",
1829
+ "lstrip": false,
1830
+ "normalized": false,
1831
+ "rstrip": false,
1832
+ "single_word": false,
1833
+ "special": true
1834
+ },
1835
+ "256225": {
1836
+ "content": "<extra_id_74>",
1837
+ "lstrip": false,
1838
+ "normalized": false,
1839
+ "rstrip": false,
1840
+ "single_word": false,
1841
+ "special": true
1842
+ },
1843
+ "256226": {
1844
+ "content": "<extra_id_73>",
1845
+ "lstrip": false,
1846
+ "normalized": false,
1847
+ "rstrip": false,
1848
+ "single_word": false,
1849
+ "special": true
1850
+ },
1851
+ "256227": {
1852
+ "content": "<extra_id_72>",
1853
+ "lstrip": false,
1854
+ "normalized": false,
1855
+ "rstrip": false,
1856
+ "single_word": false,
1857
+ "special": true
1858
+ },
1859
+ "256228": {
1860
+ "content": "<extra_id_71>",
1861
+ "lstrip": false,
1862
+ "normalized": false,
1863
+ "rstrip": false,
1864
+ "single_word": false,
1865
+ "special": true
1866
+ },
1867
+ "256229": {
1868
+ "content": "<extra_id_70>",
1869
+ "lstrip": false,
1870
+ "normalized": false,
1871
+ "rstrip": false,
1872
+ "single_word": false,
1873
+ "special": true
1874
+ },
1875
+ "256230": {
1876
+ "content": "<extra_id_69>",
1877
+ "lstrip": false,
1878
+ "normalized": false,
1879
+ "rstrip": false,
1880
+ "single_word": false,
1881
+ "special": true
1882
+ },
1883
+ "256231": {
1884
+ "content": "<extra_id_68>",
1885
+ "lstrip": false,
1886
+ "normalized": false,
1887
+ "rstrip": false,
1888
+ "single_word": false,
1889
+ "special": true
1890
+ },
1891
+ "256232": {
1892
+ "content": "<extra_id_67>",
1893
+ "lstrip": false,
1894
+ "normalized": false,
1895
+ "rstrip": false,
1896
+ "single_word": false,
1897
+ "special": true
1898
+ },
1899
+ "256233": {
1900
+ "content": "<extra_id_66>",
1901
+ "lstrip": false,
1902
+ "normalized": false,
1903
+ "rstrip": false,
1904
+ "single_word": false,
1905
+ "special": true
1906
+ },
1907
+ "256234": {
1908
+ "content": "<extra_id_65>",
1909
+ "lstrip": false,
1910
+ "normalized": false,
1911
+ "rstrip": false,
1912
+ "single_word": false,
1913
+ "special": true
1914
+ },
1915
+ "256235": {
1916
+ "content": "<extra_id_64>",
1917
+ "lstrip": false,
1918
+ "normalized": false,
1919
+ "rstrip": false,
1920
+ "single_word": false,
1921
+ "special": true
1922
+ },
1923
+ "256236": {
1924
+ "content": "<extra_id_63>",
1925
+ "lstrip": false,
1926
+ "normalized": false,
1927
+ "rstrip": false,
1928
+ "single_word": false,
1929
+ "special": true
1930
+ },
1931
+ "256237": {
1932
+ "content": "<extra_id_62>",
1933
+ "lstrip": false,
1934
+ "normalized": false,
1935
+ "rstrip": false,
1936
+ "single_word": false,
1937
+ "special": true
1938
+ },
1939
+ "256238": {
1940
+ "content": "<extra_id_61>",
1941
+ "lstrip": false,
1942
+ "normalized": false,
1943
+ "rstrip": false,
1944
+ "single_word": false,
1945
+ "special": true
1946
+ },
1947
+ "256239": {
1948
+ "content": "<extra_id_60>",
1949
+ "lstrip": false,
1950
+ "normalized": false,
1951
+ "rstrip": false,
1952
+ "single_word": false,
1953
+ "special": true
1954
+ },
1955
+ "256240": {
1956
+ "content": "<extra_id_59>",
1957
+ "lstrip": false,
1958
+ "normalized": false,
1959
+ "rstrip": false,
1960
+ "single_word": false,
1961
+ "special": true
1962
+ },
1963
+ "256241": {
1964
+ "content": "<extra_id_58>",
1965
+ "lstrip": false,
1966
+ "normalized": false,
1967
+ "rstrip": false,
1968
+ "single_word": false,
1969
+ "special": true
1970
+ },
1971
+ "256242": {
1972
+ "content": "<extra_id_57>",
1973
+ "lstrip": false,
1974
+ "normalized": false,
1975
+ "rstrip": false,
1976
+ "single_word": false,
1977
+ "special": true
1978
+ },
1979
+ "256243": {
1980
+ "content": "<extra_id_56>",
1981
+ "lstrip": false,
1982
+ "normalized": false,
1983
+ "rstrip": false,
1984
+ "single_word": false,
1985
+ "special": true
1986
+ },
1987
+ "256244": {
1988
+ "content": "<extra_id_55>",
1989
+ "lstrip": false,
1990
+ "normalized": false,
1991
+ "rstrip": false,
1992
+ "single_word": false,
1993
+ "special": true
1994
+ },
1995
+ "256245": {
1996
+ "content": "<extra_id_54>",
1997
+ "lstrip": false,
1998
+ "normalized": false,
1999
+ "rstrip": false,
2000
+ "single_word": false,
2001
+ "special": true
2002
+ },
2003
+ "256246": {
2004
+ "content": "<extra_id_53>",
2005
+ "lstrip": false,
2006
+ "normalized": false,
2007
+ "rstrip": false,
2008
+ "single_word": false,
2009
+ "special": true
2010
+ },
2011
+ "256247": {
2012
+ "content": "<extra_id_52>",
2013
+ "lstrip": false,
2014
+ "normalized": false,
2015
+ "rstrip": false,
2016
+ "single_word": false,
2017
+ "special": true
2018
+ },
2019
+ "256248": {
2020
+ "content": "<extra_id_51>",
2021
+ "lstrip": false,
2022
+ "normalized": false,
2023
+ "rstrip": false,
2024
+ "single_word": false,
2025
+ "special": true
2026
+ },
2027
+ "256249": {
2028
+ "content": "<extra_id_50>",
2029
+ "lstrip": false,
2030
+ "normalized": false,
2031
+ "rstrip": false,
2032
+ "single_word": false,
2033
+ "special": true
2034
+ },
2035
+ "256250": {
2036
+ "content": "<extra_id_49>",
2037
+ "lstrip": false,
2038
+ "normalized": false,
2039
+ "rstrip": false,
2040
+ "single_word": false,
2041
+ "special": true
2042
+ },
2043
+ "256251": {
2044
+ "content": "<extra_id_48>",
2045
+ "lstrip": false,
2046
+ "normalized": false,
2047
+ "rstrip": false,
2048
+ "single_word": false,
2049
+ "special": true
2050
+ },
2051
+ "256252": {
2052
+ "content": "<extra_id_47>",
2053
+ "lstrip": false,
2054
+ "normalized": false,
2055
+ "rstrip": false,
2056
+ "single_word": false,
2057
+ "special": true
2058
+ },
2059
+ "256253": {
2060
+ "content": "<extra_id_46>",
2061
+ "lstrip": false,
2062
+ "normalized": false,
2063
+ "rstrip": false,
2064
+ "single_word": false,
2065
+ "special": true
2066
+ },
2067
+ "256254": {
2068
+ "content": "<extra_id_45>",
2069
+ "lstrip": false,
2070
+ "normalized": false,
2071
+ "rstrip": false,
2072
+ "single_word": false,
2073
+ "special": true
2074
+ },
2075
+ "256255": {
2076
+ "content": "<extra_id_44>",
2077
+ "lstrip": false,
2078
+ "normalized": false,
2079
+ "rstrip": false,
2080
+ "single_word": false,
2081
+ "special": true
2082
+ },
2083
+ "256256": {
2084
+ "content": "<extra_id_43>",
2085
+ "lstrip": false,
2086
+ "normalized": false,
2087
+ "rstrip": false,
2088
+ "single_word": false,
2089
+ "special": true
2090
+ },
2091
+ "256257": {
2092
+ "content": "<extra_id_42>",
2093
+ "lstrip": false,
2094
+ "normalized": false,
2095
+ "rstrip": false,
2096
+ "single_word": false,
2097
+ "special": true
2098
+ },
2099
+ "256258": {
2100
+ "content": "<extra_id_41>",
2101
+ "lstrip": false,
2102
+ "normalized": false,
2103
+ "rstrip": false,
2104
+ "single_word": false,
2105
+ "special": true
2106
+ },
2107
+ "256259": {
2108
+ "content": "<extra_id_40>",
2109
+ "lstrip": false,
2110
+ "normalized": false,
2111
+ "rstrip": false,
2112
+ "single_word": false,
2113
+ "special": true
2114
+ },
2115
+ "256260": {
2116
+ "content": "<extra_id_39>",
2117
+ "lstrip": false,
2118
+ "normalized": false,
2119
+ "rstrip": false,
2120
+ "single_word": false,
2121
+ "special": true
2122
+ },
2123
+ "256261": {
2124
+ "content": "<extra_id_38>",
2125
+ "lstrip": false,
2126
+ "normalized": false,
2127
+ "rstrip": false,
2128
+ "single_word": false,
2129
+ "special": true
2130
+ },
2131
+ "256262": {
2132
+ "content": "<extra_id_37>",
2133
+ "lstrip": false,
2134
+ "normalized": false,
2135
+ "rstrip": false,
2136
+ "single_word": false,
2137
+ "special": true
2138
+ },
2139
+ "256263": {
2140
+ "content": "<extra_id_36>",
2141
+ "lstrip": false,
2142
+ "normalized": false,
2143
+ "rstrip": false,
2144
+ "single_word": false,
2145
+ "special": true
2146
+ },
2147
+ "256264": {
2148
+ "content": "<extra_id_35>",
2149
+ "lstrip": false,
2150
+ "normalized": false,
2151
+ "rstrip": false,
2152
+ "single_word": false,
2153
+ "special": true
2154
+ },
2155
+ "256265": {
2156
+ "content": "<extra_id_34>",
2157
+ "lstrip": false,
2158
+ "normalized": false,
2159
+ "rstrip": false,
2160
+ "single_word": false,
2161
+ "special": true
2162
+ },
2163
+ "256266": {
2164
+ "content": "<extra_id_33>",
2165
+ "lstrip": false,
2166
+ "normalized": false,
2167
+ "rstrip": false,
2168
+ "single_word": false,
2169
+ "special": true
2170
+ },
2171
+ "256267": {
2172
+ "content": "<extra_id_32>",
2173
+ "lstrip": false,
2174
+ "normalized": false,
2175
+ "rstrip": false,
2176
+ "single_word": false,
2177
+ "special": true
2178
+ },
2179
+ "256268": {
2180
+ "content": "<extra_id_31>",
2181
+ "lstrip": false,
2182
+ "normalized": false,
2183
+ "rstrip": false,
2184
+ "single_word": false,
2185
+ "special": true
2186
+ },
2187
+ "256269": {
2188
+ "content": "<extra_id_30>",
2189
+ "lstrip": false,
2190
+ "normalized": false,
2191
+ "rstrip": false,
2192
+ "single_word": false,
2193
+ "special": true
2194
+ },
2195
+ "256270": {
2196
+ "content": "<extra_id_29>",
2197
+ "lstrip": false,
2198
+ "normalized": false,
2199
+ "rstrip": false,
2200
+ "single_word": false,
2201
+ "special": true
2202
+ },
2203
+ "256271": {
2204
+ "content": "<extra_id_28>",
2205
+ "lstrip": false,
2206
+ "normalized": false,
2207
+ "rstrip": false,
2208
+ "single_word": false,
2209
+ "special": true
2210
+ },
2211
+ "256272": {
2212
+ "content": "<extra_id_27>",
2213
+ "lstrip": false,
2214
+ "normalized": false,
2215
+ "rstrip": false,
2216
+ "single_word": false,
2217
+ "special": true
2218
+ },
2219
+ "256273": {
2220
+ "content": "<extra_id_26>",
2221
+ "lstrip": false,
2222
+ "normalized": false,
2223
+ "rstrip": false,
2224
+ "single_word": false,
2225
+ "special": true
2226
+ },
2227
+ "256274": {
2228
+ "content": "<extra_id_25>",
2229
+ "lstrip": false,
2230
+ "normalized": false,
2231
+ "rstrip": false,
2232
+ "single_word": false,
2233
+ "special": true
2234
+ },
2235
+ "256275": {
2236
+ "content": "<extra_id_24>",
2237
+ "lstrip": false,
2238
+ "normalized": false,
2239
+ "rstrip": false,
2240
+ "single_word": false,
2241
+ "special": true
2242
+ },
2243
+ "256276": {
2244
+ "content": "<extra_id_23>",
2245
+ "lstrip": false,
2246
+ "normalized": false,
2247
+ "rstrip": false,
2248
+ "single_word": false,
2249
+ "special": true
2250
+ },
2251
+ "256277": {
2252
+ "content": "<extra_id_22>",
2253
+ "lstrip": false,
2254
+ "normalized": false,
2255
+ "rstrip": false,
2256
+ "single_word": false,
2257
+ "special": true
2258
+ },
2259
+ "256278": {
2260
+ "content": "<extra_id_21>",
2261
+ "lstrip": false,
2262
+ "normalized": false,
2263
+ "rstrip": false,
2264
+ "single_word": false,
2265
+ "special": true
2266
+ },
2267
+ "256279": {
2268
+ "content": "<extra_id_20>",
2269
+ "lstrip": false,
2270
+ "normalized": false,
2271
+ "rstrip": false,
2272
+ "single_word": false,
2273
+ "special": true
2274
+ },
2275
+ "256280": {
2276
+ "content": "<extra_id_19>",
2277
+ "lstrip": false,
2278
+ "normalized": false,
2279
+ "rstrip": false,
2280
+ "single_word": false,
2281
+ "special": true
2282
+ },
2283
+ "256281": {
2284
+ "content": "<extra_id_18>",
2285
+ "lstrip": false,
2286
+ "normalized": false,
2287
+ "rstrip": false,
2288
+ "single_word": false,
2289
+ "special": true
2290
+ },
2291
+ "256282": {
2292
+ "content": "<extra_id_17>",
2293
+ "lstrip": false,
2294
+ "normalized": false,
2295
+ "rstrip": false,
2296
+ "single_word": false,
2297
+ "special": true
2298
+ },
2299
+ "256283": {
2300
+ "content": "<extra_id_16>",
2301
+ "lstrip": false,
2302
+ "normalized": false,
2303
+ "rstrip": false,
2304
+ "single_word": false,
2305
+ "special": true
2306
+ },
2307
+ "256284": {
2308
+ "content": "<extra_id_15>",
2309
+ "lstrip": false,
2310
+ "normalized": false,
2311
+ "rstrip": false,
2312
+ "single_word": false,
2313
+ "special": true
2314
+ },
2315
+ "256285": {
2316
+ "content": "<extra_id_14>",
2317
+ "lstrip": false,
2318
+ "normalized": false,
2319
+ "rstrip": false,
2320
+ "single_word": false,
2321
+ "special": true
2322
+ },
2323
+ "256286": {
2324
+ "content": "<extra_id_13>",
2325
+ "lstrip": false,
2326
+ "normalized": false,
2327
+ "rstrip": false,
2328
+ "single_word": false,
2329
+ "special": true
2330
+ },
2331
+ "256287": {
2332
+ "content": "<extra_id_12>",
2333
+ "lstrip": false,
2334
+ "normalized": false,
2335
+ "rstrip": false,
2336
+ "single_word": false,
2337
+ "special": true
2338
+ },
2339
+ "256288": {
2340
+ "content": "<extra_id_11>",
2341
+ "lstrip": false,
2342
+ "normalized": false,
2343
+ "rstrip": false,
2344
+ "single_word": false,
2345
+ "special": true
2346
+ },
2347
+ "256289": {
2348
+ "content": "<extra_id_10>",
2349
+ "lstrip": false,
2350
+ "normalized": false,
2351
+ "rstrip": false,
2352
+ "single_word": false,
2353
+ "special": true
2354
+ },
2355
+ "256290": {
2356
+ "content": "<extra_id_9>",
2357
+ "lstrip": false,
2358
+ "normalized": false,
2359
+ "rstrip": false,
2360
+ "single_word": false,
2361
+ "special": true
2362
+ },
2363
+ "256291": {
2364
+ "content": "<extra_id_8>",
2365
+ "lstrip": false,
2366
+ "normalized": false,
2367
+ "rstrip": false,
2368
+ "single_word": false,
2369
+ "special": true
2370
+ },
2371
+ "256292": {
2372
+ "content": "<extra_id_7>",
2373
+ "lstrip": false,
2374
+ "normalized": false,
2375
+ "rstrip": false,
2376
+ "single_word": false,
2377
+ "special": true
2378
+ },
2379
+ "256293": {
2380
+ "content": "<extra_id_6>",
2381
+ "lstrip": false,
2382
+ "normalized": false,
2383
+ "rstrip": false,
2384
+ "single_word": false,
2385
+ "special": true
2386
+ },
2387
+ "256294": {
2388
+ "content": "<extra_id_5>",
2389
+ "lstrip": false,
2390
+ "normalized": false,
2391
+ "rstrip": false,
2392
+ "single_word": false,
2393
+ "special": true
2394
+ },
2395
+ "256295": {
2396
+ "content": "<extra_id_4>",
2397
+ "lstrip": false,
2398
+ "normalized": false,
2399
+ "rstrip": false,
2400
+ "single_word": false,
2401
+ "special": true
2402
+ },
2403
+ "256296": {
2404
+ "content": "<extra_id_3>",
2405
+ "lstrip": false,
2406
+ "normalized": false,
2407
+ "rstrip": false,
2408
+ "single_word": false,
2409
+ "special": true
2410
+ },
2411
+ "256297": {
2412
+ "content": "<extra_id_2>",
2413
+ "lstrip": false,
2414
+ "normalized": false,
2415
+ "rstrip": false,
2416
+ "single_word": false,
2417
+ "special": true
2418
+ },
2419
+ "256298": {
2420
+ "content": "<extra_id_1>",
2421
+ "lstrip": false,
2422
+ "normalized": false,
2423
+ "rstrip": false,
2424
+ "single_word": false,
2425
+ "special": true
2426
+ },
2427
+ "256299": {
2428
+ "content": "<extra_id_0>",
2429
+ "lstrip": false,
2430
+ "normalized": false,
2431
+ "rstrip": false,
2432
+ "single_word": false,
2433
+ "special": true
2434
+ }
2435
+ },
2436
+ "additional_special_tokens": [
2437
+ "<extra_id_0>",
2438
+ "<extra_id_1>",
2439
+ "<extra_id_2>",
2440
+ "<extra_id_3>",
2441
+ "<extra_id_4>",
2442
+ "<extra_id_5>",
2443
+ "<extra_id_6>",
2444
+ "<extra_id_7>",
2445
+ "<extra_id_8>",
2446
+ "<extra_id_9>",
2447
+ "<extra_id_10>",
2448
+ "<extra_id_11>",
2449
+ "<extra_id_12>",
2450
+ "<extra_id_13>",
2451
+ "<extra_id_14>",
2452
+ "<extra_id_15>",
2453
+ "<extra_id_16>",
2454
+ "<extra_id_17>",
2455
+ "<extra_id_18>",
2456
+ "<extra_id_19>",
2457
+ "<extra_id_20>",
2458
+ "<extra_id_21>",
2459
+ "<extra_id_22>",
2460
+ "<extra_id_23>",
2461
+ "<extra_id_24>",
2462
+ "<extra_id_25>",
2463
+ "<extra_id_26>",
2464
+ "<extra_id_27>",
2465
+ "<extra_id_28>",
2466
+ "<extra_id_29>",
2467
+ "<extra_id_30>",
2468
+ "<extra_id_31>",
2469
+ "<extra_id_32>",
2470
+ "<extra_id_33>",
2471
+ "<extra_id_34>",
2472
+ "<extra_id_35>",
2473
+ "<extra_id_36>",
2474
+ "<extra_id_37>",
2475
+ "<extra_id_38>",
2476
+ "<extra_id_39>",
2477
+ "<extra_id_40>",
2478
+ "<extra_id_41>",
2479
+ "<extra_id_42>",
2480
+ "<extra_id_43>",
2481
+ "<extra_id_44>",
2482
+ "<extra_id_45>",
2483
+ "<extra_id_46>",
2484
+ "<extra_id_47>",
2485
+ "<extra_id_48>",
2486
+ "<extra_id_49>",
2487
+ "<extra_id_50>",
2488
+ "<extra_id_51>",
2489
+ "<extra_id_52>",
2490
+ "<extra_id_53>",
2491
+ "<extra_id_54>",
2492
+ "<extra_id_55>",
2493
+ "<extra_id_56>",
2494
+ "<extra_id_57>",
2495
+ "<extra_id_58>",
2496
+ "<extra_id_59>",
2497
+ "<extra_id_60>",
2498
+ "<extra_id_61>",
2499
+ "<extra_id_62>",
2500
+ "<extra_id_63>",
2501
+ "<extra_id_64>",
2502
+ "<extra_id_65>",
2503
+ "<extra_id_66>",
2504
+ "<extra_id_67>",
2505
+ "<extra_id_68>",
2506
+ "<extra_id_69>",
2507
+ "<extra_id_70>",
2508
+ "<extra_id_71>",
2509
+ "<extra_id_72>",
2510
+ "<extra_id_73>",
2511
+ "<extra_id_74>",
2512
+ "<extra_id_75>",
2513
+ "<extra_id_76>",
2514
+ "<extra_id_77>",
2515
+ "<extra_id_78>",
2516
+ "<extra_id_79>",
2517
+ "<extra_id_80>",
2518
+ "<extra_id_81>",
2519
+ "<extra_id_82>",
2520
+ "<extra_id_83>",
2521
+ "<extra_id_84>",
2522
+ "<extra_id_85>",
2523
+ "<extra_id_86>",
2524
+ "<extra_id_87>",
2525
+ "<extra_id_88>",
2526
+ "<extra_id_89>",
2527
+ "<extra_id_90>",
2528
+ "<extra_id_91>",
2529
+ "<extra_id_92>",
2530
+ "<extra_id_93>",
2531
+ "<extra_id_94>",
2532
+ "<extra_id_95>",
2533
+ "<extra_id_96>",
2534
+ "<extra_id_97>",
2535
+ "<extra_id_98>",
2536
+ "<extra_id_99>",
2537
+ "<extra_id_100>",
2538
+ "<extra_id_101>",
2539
+ "<extra_id_102>",
2540
+ "<extra_id_103>",
2541
+ "<extra_id_104>",
2542
+ "<extra_id_105>",
2543
+ "<extra_id_106>",
2544
+ "<extra_id_107>",
2545
+ "<extra_id_108>",
2546
+ "<extra_id_109>",
2547
+ "<extra_id_110>",
2548
+ "<extra_id_111>",
2549
+ "<extra_id_112>",
2550
+ "<extra_id_113>",
2551
+ "<extra_id_114>",
2552
+ "<extra_id_115>",
2553
+ "<extra_id_116>",
2554
+ "<extra_id_117>",
2555
+ "<extra_id_118>",
2556
+ "<extra_id_119>",
2557
+ "<extra_id_120>",
2558
+ "<extra_id_121>",
2559
+ "<extra_id_122>",
2560
+ "<extra_id_123>",
2561
+ "<extra_id_124>",
2562
+ "<extra_id_125>",
2563
+ "<extra_id_126>",
2564
+ "<extra_id_127>",
2565
+ "<extra_id_128>",
2566
+ "<extra_id_129>",
2567
+ "<extra_id_130>",
2568
+ "<extra_id_131>",
2569
+ "<extra_id_132>",
2570
+ "<extra_id_133>",
2571
+ "<extra_id_134>",
2572
+ "<extra_id_135>",
2573
+ "<extra_id_136>",
2574
+ "<extra_id_137>",
2575
+ "<extra_id_138>",
2576
+ "<extra_id_139>",
2577
+ "<extra_id_140>",
2578
+ "<extra_id_141>",
2579
+ "<extra_id_142>",
2580
+ "<extra_id_143>",
2581
+ "<extra_id_144>",
2582
+ "<extra_id_145>",
2583
+ "<extra_id_146>",
2584
+ "<extra_id_147>",
2585
+ "<extra_id_148>",
2586
+ "<extra_id_149>",
2587
+ "<extra_id_150>",
2588
+ "<extra_id_151>",
2589
+ "<extra_id_152>",
2590
+ "<extra_id_153>",
2591
+ "<extra_id_154>",
2592
+ "<extra_id_155>",
2593
+ "<extra_id_156>",
2594
+ "<extra_id_157>",
2595
+ "<extra_id_158>",
2596
+ "<extra_id_159>",
2597
+ "<extra_id_160>",
2598
+ "<extra_id_161>",
2599
+ "<extra_id_162>",
2600
+ "<extra_id_163>",
2601
+ "<extra_id_164>",
2602
+ "<extra_id_165>",
2603
+ "<extra_id_166>",
2604
+ "<extra_id_167>",
2605
+ "<extra_id_168>",
2606
+ "<extra_id_169>",
2607
+ "<extra_id_170>",
2608
+ "<extra_id_171>",
2609
+ "<extra_id_172>",
2610
+ "<extra_id_173>",
2611
+ "<extra_id_174>",
2612
+ "<extra_id_175>",
2613
+ "<extra_id_176>",
2614
+ "<extra_id_177>",
2615
+ "<extra_id_178>",
2616
+ "<extra_id_179>",
2617
+ "<extra_id_180>",
2618
+ "<extra_id_181>",
2619
+ "<extra_id_182>",
2620
+ "<extra_id_183>",
2621
+ "<extra_id_184>",
2622
+ "<extra_id_185>",
2623
+ "<extra_id_186>",
2624
+ "<extra_id_187>",
2625
+ "<extra_id_188>",
2626
+ "<extra_id_189>",
2627
+ "<extra_id_190>",
2628
+ "<extra_id_191>",
2629
+ "<extra_id_192>",
2630
+ "<extra_id_193>",
2631
+ "<extra_id_194>",
2632
+ "<extra_id_195>",
2633
+ "<extra_id_196>",
2634
+ "<extra_id_197>",
2635
+ "<extra_id_198>",
2636
+ "<extra_id_199>",
2637
+ "<extra_id_200>",
2638
+ "<extra_id_201>",
2639
+ "<extra_id_202>",
2640
+ "<extra_id_203>",
2641
+ "<extra_id_204>",
2642
+ "<extra_id_205>",
2643
+ "<extra_id_206>",
2644
+ "<extra_id_207>",
2645
+ "<extra_id_208>",
2646
+ "<extra_id_209>",
2647
+ "<extra_id_210>",
2648
+ "<extra_id_211>",
2649
+ "<extra_id_212>",
2650
+ "<extra_id_213>",
2651
+ "<extra_id_214>",
2652
+ "<extra_id_215>",
2653
+ "<extra_id_216>",
2654
+ "<extra_id_217>",
2655
+ "<extra_id_218>",
2656
+ "<extra_id_219>",
2657
+ "<extra_id_220>",
2658
+ "<extra_id_221>",
2659
+ "<extra_id_222>",
2660
+ "<extra_id_223>",
2661
+ "<extra_id_224>",
2662
+ "<extra_id_225>",
2663
+ "<extra_id_226>",
2664
+ "<extra_id_227>",
2665
+ "<extra_id_228>",
2666
+ "<extra_id_229>",
2667
+ "<extra_id_230>",
2668
+ "<extra_id_231>",
2669
+ "<extra_id_232>",
2670
+ "<extra_id_233>",
2671
+ "<extra_id_234>",
2672
+ "<extra_id_235>",
2673
+ "<extra_id_236>",
2674
+ "<extra_id_237>",
2675
+ "<extra_id_238>",
2676
+ "<extra_id_239>",
2677
+ "<extra_id_240>",
2678
+ "<extra_id_241>",
2679
+ "<extra_id_242>",
2680
+ "<extra_id_243>",
2681
+ "<extra_id_244>",
2682
+ "<extra_id_245>",
2683
+ "<extra_id_246>",
2684
+ "<extra_id_247>",
2685
+ "<extra_id_248>",
2686
+ "<extra_id_249>",
2687
+ "<extra_id_250>",
2688
+ "<extra_id_251>",
2689
+ "<extra_id_252>",
2690
+ "<extra_id_253>",
2691
+ "<extra_id_254>",
2692
+ "<extra_id_255>",
2693
+ "<extra_id_256>",
2694
+ "<extra_id_257>",
2695
+ "<extra_id_258>",
2696
+ "<extra_id_259>",
2697
+ "<extra_id_260>",
2698
+ "<extra_id_261>",
2699
+ "<extra_id_262>",
2700
+ "<extra_id_263>",
2701
+ "<extra_id_264>",
2702
+ "<extra_id_265>",
2703
+ "<extra_id_266>",
2704
+ "<extra_id_267>",
2705
+ "<extra_id_268>",
2706
+ "<extra_id_269>",
2707
+ "<extra_id_270>",
2708
+ "<extra_id_271>",
2709
+ "<extra_id_272>",
2710
+ "<extra_id_273>",
2711
+ "<extra_id_274>",
2712
+ "<extra_id_275>",
2713
+ "<extra_id_276>",
2714
+ "<extra_id_277>",
2715
+ "<extra_id_278>",
2716
+ "<extra_id_279>",
2717
+ "<extra_id_280>",
2718
+ "<extra_id_281>",
2719
+ "<extra_id_282>",
2720
+ "<extra_id_283>",
2721
+ "<extra_id_284>",
2722
+ "<extra_id_285>",
2723
+ "<extra_id_286>",
2724
+ "<extra_id_287>",
2725
+ "<extra_id_288>",
2726
+ "<extra_id_289>",
2727
+ "<extra_id_290>",
2728
+ "<extra_id_291>",
2729
+ "<extra_id_292>",
2730
+ "<extra_id_293>",
2731
+ "<extra_id_294>",
2732
+ "<extra_id_295>",
2733
+ "<extra_id_296>",
2734
+ "<extra_id_297>",
2735
+ "<extra_id_298>",
2736
+ "<extra_id_299>"
2737
+ ],
2738
+ "bos_token": "<s>",
2739
+ "clean_up_tokenization_spaces": true,
2740
+ "eos_token": "</s>",
2741
+ "extra_ids": 300,
2742
+ "model_max_length": 1000000000000000019884624838656,
2743
+ "pad_token": "<pad>",
2744
+ "sp_model_kwargs": {},
2745
+ "spaces_between_special_tokens": false,
2746
+ "tokenizer_class": "T5Tokenizer",
2747
+ "unk_token": "<unk>"
2748
+ }
ldf_deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cace0da2b446bbbbc57d031ab6cf163a3d59b366da94e5afe36745b746fd81d
3
+ size 11361920418
ldf_models/__init__.py ADDED
File without changes
ldf_models/diffusion_forcing_wan.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .tools.t5 import T5EncoderModel
6
+ from .tools.wan_model import WanModel
7
+
8
+
9
+ class DiffForcingWanModel(nn.Module):
10
+ def __init__(
11
+ self,
12
+ checkpoint_path="deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth",
13
+ tokenizer_path="deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl",
14
+ input_dim=256,
15
+ hidden_dim=1024,
16
+ ffn_dim=2048,
17
+ freq_dim=256,
18
+ num_heads=8,
19
+ num_layers=8,
20
+ time_embedding_scale=1.0,
21
+ chunk_size=5,
22
+ noise_steps=10,
23
+ use_text_cond=True,
24
+ text_len=128,
25
+ drop_out=0.1,
26
+ cfg_scale=5.0,
27
+ prediction_type="vel", # "vel", "x0", "noise"
28
+ causal=False,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.input_dim = input_dim
33
+ self.hidden_dim = hidden_dim
34
+ self.ffn_dim = ffn_dim
35
+ self.freq_dim = freq_dim
36
+ self.num_heads = num_heads
37
+ self.num_layers = num_layers
38
+ self.time_embedding_scale = time_embedding_scale
39
+ self.chunk_size = chunk_size
40
+ self.noise_steps = noise_steps
41
+ self.use_text_cond = use_text_cond
42
+ self.drop_out = drop_out
43
+ self.cfg_scale = cfg_scale
44
+ self.prediction_type = prediction_type
45
+ self.causal = causal
46
+
47
+ self.text_dim = 4096
48
+ self.text_len = text_len
49
+ self.text_encoder = T5EncoderModel(
50
+ text_len=self.text_len,
51
+ dtype=torch.bfloat16,
52
+ device=torch.device("cpu"),
53
+ checkpoint_path=checkpoint_path,
54
+ tokenizer_path=tokenizer_path,
55
+ shard_fn=None,
56
+ )
57
+
58
+ # Text encoding cache
59
+ self.text_cache = {}
60
+ self.model = WanModel(
61
+ model_type="t2v",
62
+ patch_size=(1, 1, 1),
63
+ text_len=self.text_len,
64
+ in_dim=self.input_dim,
65
+ dim=self.hidden_dim,
66
+ ffn_dim=self.ffn_dim,
67
+ freq_dim=self.freq_dim,
68
+ text_dim=self.text_dim,
69
+ out_dim=self.input_dim,
70
+ num_heads=self.num_heads,
71
+ num_layers=self.num_layers,
72
+ window_size=(-1, -1),
73
+ qk_norm=True,
74
+ cross_attn_norm=True,
75
+ eps=1e-6,
76
+ causal=self.causal,
77
+ )
78
+ self.param_dtype = torch.float32
79
+
80
+ def encode_text_with_cache(self, text_list, device):
81
+ """Encode text using cache
82
+ Args:
83
+ text_list: List[str], list of texts
84
+ device: torch.device
85
+ Returns:
86
+ List[Tensor]: List of encoded text features
87
+ """
88
+ text_features = []
89
+ indices_to_encode = []
90
+ texts_to_encode = []
91
+
92
+ # Check cache
93
+ for i, text in enumerate(text_list):
94
+ if text in self.text_cache:
95
+ # Get from cache and move to correct device
96
+ cached_feature = self.text_cache[text].to(device)
97
+ text_features.append(cached_feature)
98
+ else:
99
+ # Need to encode
100
+ text_features.append(None)
101
+ indices_to_encode.append(i)
102
+ texts_to_encode.append(text)
103
+
104
+ # Batch encode uncached texts
105
+ if texts_to_encode:
106
+ self.text_encoder.model.to(device)
107
+ encoded = self.text_encoder(texts_to_encode, device)
108
+
109
+ # Store in cache and update results
110
+ for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded):
111
+ # Cache to CPU to save GPU memory
112
+ self.text_cache[text] = feature.cpu()
113
+ text_features[idx] = feature
114
+
115
+ return text_features
116
+
117
+ def preprocess(self, x):
118
+ # (bs, T, C) -> (bs, C, T, 1, 1)
119
+ x = x.permute(0, 2, 1)[:, :, :, None, None]
120
+ return x
121
+
122
+ def postprocess(self, x):
123
+ # (bs, C, T, 1, 1) -> (bs, T, C)
124
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1)
125
+ return x
126
+
127
+ def _get_noise_levels(self, device, seq_len, time_steps):
128
+ """Get noise levels"""
129
+ # noise_level[i] = clip(1 + i / chunk_size - time_steps, 0, 1)
130
+ noise_level = torch.clamp(
131
+ 1
132
+ + torch.arange(seq_len, device=device) / self.chunk_size
133
+ - time_steps.unsqueeze(1),
134
+ min=0.0,
135
+ max=1.0,
136
+ )
137
+ return noise_level
138
+
139
+ def add_noise(self, x, noise_level):
140
+ """Add noise
141
+ Args:
142
+ x: (B, T, D)
143
+ noise_level: (B, T)
144
+ """
145
+ noise = torch.randn_like(x)
146
+ # noise_level: (B, T) -> (B, T, 1)
147
+ noise_level = noise_level.unsqueeze(-1)
148
+ noisy_x = x * (1 - noise_level) + noise_level * noise
149
+ return noisy_x, noise
150
+
151
+ def forward(self, x):
152
+ feature = x["feature"] # (B, T, C)
153
+ feature_length = x["feature_length"] # (B,)
154
+ batch_size, seq_len, _ = feature.shape
155
+ device = feature.device
156
+
157
+ # Randomly use a time step
158
+ time_steps = []
159
+ for i in range(batch_size):
160
+ valid_len = feature_length[i].item()
161
+ # Random float from 0 to valid_len/chunk_size, not an integer
162
+ max_time = valid_len / self.chunk_size
163
+ # max_time = valid_len / self.chunk_size + 1
164
+ time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item())
165
+ time_steps = torch.tensor(time_steps, device=device) # (B,)
166
+ noise_level = self._get_noise_levels(device, seq_len, time_steps) # (B, T)
167
+
168
+ # # Debug: Print noise levels
169
+ # print("Time steps and corresponding noise levels:")
170
+ # for i in range(batch_size):
171
+ # t = time_steps[i].item()
172
+ # # Get noise level at each position
173
+ # start_idx = int(self.chunk_size * (t - 1))
174
+ # end_idx = int(self.chunk_size * t) + 2
175
+ # # Limit to valid range
176
+ # start_idx = max(0, start_idx)
177
+ # end_idx = min(seq_len, end_idx)
178
+ # print(time_steps[i])
179
+ # print(noise_level[i, start_idx:end_idx])
180
+
181
+ # Add noise to entire sequence
182
+ noisy_feature, noise = self.add_noise(feature, noise_level) # (B, T, D)
183
+
184
+ # Debug: Print noise addition information
185
+ # print("Added noise levels at chunk positions:")
186
+ # for i in range(batch_size):
187
+ # t = time_steps[i].item()
188
+ # start_idx = int(self.chunk_size * (t - 1))
189
+ # end_idx = int(self.chunk_size * t) + 2
190
+ # # Limit to valid range
191
+ # start_idx = max(0, start_idx)
192
+ # end_idx = min(seq_len, end_idx)
193
+ # test1 = (
194
+ # feature[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
195
+ # )
196
+ # test2 = (
197
+ # noise[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
198
+ # )
199
+ # # Compute length on last dimension
200
+ # print(test1.norm(dim=-1))
201
+ # print(test2.norm(dim=-1))
202
+
203
+ feature = self.preprocess(feature) # (B, C, T, 1, 1)
204
+ noisy_feature = self.preprocess(noisy_feature) # (B, C, T, 1, 1)
205
+ noise = self.preprocess(noise) # (B, C, T, 1, 1)
206
+
207
+ feature_ref = []
208
+ noise_ref = []
209
+ noisy_feature_input = []
210
+ for i in range(batch_size):
211
+ t = time_steps[i].item()
212
+ end_index = int(self.chunk_size * t) + 1
213
+ valid_len = feature_length[i].item()
214
+ end_index = min(valid_len, end_index)
215
+ feature_ref.append(feature[i, :, :end_index, ...])
216
+ noise_ref.append(noise[i, :, :end_index, ...])
217
+ noisy_feature_input.append(noisy_feature[i, :, :end_index, ...])
218
+
219
+ # Encode text condition (using cache)
220
+ if self.use_text_cond and "text" in x:
221
+ text_list = x["text"] # List[str] or List[List[str]]
222
+ if isinstance(text_list[0], list):
223
+ text_end_list = x["feature_text_end"]
224
+ all_text_context = []
225
+ for single_text_list, single_text_end_list in zip(
226
+ text_list, text_end_list
227
+ ):
228
+ if np.random.rand() > self.drop_out:
229
+ single_text_list = [""]
230
+ single_text_end_list = [0, seq_len]
231
+ else:
232
+ single_text_end_list = [0] + [
233
+ min(t, seq_len) for t in single_text_end_list
234
+ ]
235
+ single_text_length_list = [
236
+ t - b
237
+ for t, b in zip(
238
+ single_text_end_list[1:], single_text_end_list[:-1]
239
+ )
240
+ ]
241
+ single_text_context = self.encode_text_with_cache(
242
+ single_text_list, device
243
+ )
244
+ single_text_context = [
245
+ u.to(self.param_dtype) for u in single_text_context
246
+ ]
247
+ for u, duration in zip(
248
+ single_text_context, single_text_length_list
249
+ ):
250
+ all_text_context.extend([u for _ in range(duration)])
251
+ all_text_context.extend(
252
+ [
253
+ single_text_context[-1]
254
+ for _ in range(seq_len - single_text_end_list[-1])
255
+ ]
256
+ )
257
+ else:
258
+ all_text_context = [
259
+ (u if np.random.rand() > self.drop_out else "") for u in text_list
260
+ ]
261
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
262
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
263
+ else:
264
+ all_text_context = [""] * batch_size
265
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
266
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
267
+
268
+ # Through WanModel
269
+ predicted_result = self.model(
270
+ noisy_feature_input,
271
+ noise_level * self.time_embedding_scale,
272
+ all_text_context,
273
+ seq_len,
274
+ y=None,
275
+ ) # (B, C, T, 1, 1)
276
+
277
+ loss = 0.0
278
+ for b in range(batch_size):
279
+ if self.prediction_type == "vel":
280
+ vel = feature_ref[b] - noise_ref[b] # (C, input_length, 1, 1)
281
+ squared_error = (
282
+ predicted_result[b][:, -self.chunk_size :, ...]
283
+ - vel[:, -self.chunk_size :, ...]
284
+ ) ** 2
285
+ elif self.prediction_type == "x0":
286
+ squared_error = (
287
+ predicted_result[b][:, -self.chunk_size :, ...]
288
+ - feature_ref[b][:, -self.chunk_size :, ...]
289
+ ) ** 2
290
+ elif self.prediction_type == "noise":
291
+ squared_error = (
292
+ predicted_result[b][:, -self.chunk_size :, ...]
293
+ - noise_ref[b][:, -self.chunk_size :, ...]
294
+ ) ** 2
295
+ sample_loss = squared_error.sum().mean()
296
+ loss += sample_loss
297
+ loss = loss / batch_size
298
+
299
+ loss_dict = {"total": loss, "mse": loss}
300
+ return loss_dict
301
+
302
+ def generate(self, x, num_denoise_steps=None):
303
+ """
304
+ Generation - Diffusion Forcing inference
305
+ Uses triangular noise schedule, progressively generating from left to right
306
+
307
+ Generation process:
308
+ 1. Start from t=0, gradually increase t
309
+ 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
310
+ 3. After each denoising step, t increases slightly and continues
311
+ """
312
+ feature_length = x["feature_length"]
313
+ batch_size = len(feature_length)
314
+ seq_len = max(feature_length).item()
315
+
316
+ # # debug
317
+ # x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
318
+ # x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
319
+ # text = x["text"]
320
+ # text_end = x["feature_text_end"]
321
+ # print(text)
322
+ # print(text_end)
323
+ # print(batch_size, seq_len, self.chunk_size)
324
+
325
+ if num_denoise_steps is None:
326
+ num_denoise_steps = self.noise_steps
327
+ assert num_denoise_steps % self.chunk_size == 0
328
+
329
+ device = next(self.parameters()).device
330
+
331
+ # Initialize entire sequence as pure noise
332
+ generated = torch.randn(
333
+ batch_size, seq_len + self.chunk_size, self.input_dim, device=device
334
+ )
335
+ generated = self.preprocess(generated) # (B, C, T, 1, 1)
336
+
337
+ # Calculate total number of time steps needed
338
+ max_t = 1 + (seq_len - 1) / self.chunk_size
339
+
340
+ # Step size for each advancement
341
+ dt = 1 / num_denoise_steps
342
+ total_steps = int(max_t / dt)
343
+
344
+ # Encode text condition (using cache)
345
+ if self.use_text_cond and "text" in x:
346
+ text_list = x["text"] # List[str] or List[List[str]]
347
+ if isinstance(text_list[0], list):
348
+ generated_length = []
349
+ text_end_list = x["feature_text_end"]
350
+ full_text = []
351
+ all_text_context = []
352
+ for single_text_list, single_text_end_list in zip(
353
+ text_list, text_end_list
354
+ ):
355
+ single_text_end_list = [0] + [
356
+ min(t, seq_len) for t in single_text_end_list
357
+ ]
358
+ generated_length.append(single_text_end_list[-1])
359
+ single_text_length_list = [
360
+ t - b
361
+ for t, b in zip(
362
+ single_text_end_list[1:], single_text_end_list[:-1]
363
+ )
364
+ ]
365
+ full_text.append(
366
+ " ////////// ".join(
367
+ [
368
+ f"{u} //dur:{t}"
369
+ for u, t in zip(
370
+ single_text_list, single_text_length_list
371
+ )
372
+ ]
373
+ )
374
+ )
375
+ single_text_context = self.encode_text_with_cache(
376
+ single_text_list, device
377
+ )
378
+ single_text_context = [
379
+ u.to(self.param_dtype) for u in single_text_context
380
+ ]
381
+ for u, duration in zip(
382
+ single_text_context, single_text_length_list
383
+ ):
384
+ all_text_context.extend([u for _ in range(duration)])
385
+ all_text_context.extend(
386
+ [
387
+ single_text_context[-1]
388
+ for _ in range(
389
+ seq_len + self.chunk_size - single_text_end_list[-1]
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ generated_length = feature_length
395
+ full_text = text_list
396
+ all_text_context = self.encode_text_with_cache(text_list, device)
397
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
398
+ else:
399
+ generated_length = feature_length
400
+ full_text = [""] * batch_size
401
+ all_text_context = [""] * batch_size
402
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
403
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
404
+
405
+ # Get empty text condition encoding (for CFG)
406
+ text_null_list = [""] * batch_size
407
+ text_null_context = self.encode_text_with_cache(text_null_list, device)
408
+ text_null_context = [u.to(self.param_dtype) for u in text_null_context]
409
+
410
+ # print(len(all_text_context), len(text_null_context))
411
+
412
+ # Progressively advance from t=0 to t=max_t
413
+ for step in range(total_steps):
414
+ # Current time step
415
+ t = step * dt
416
+ start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
417
+ end_index = int(self.chunk_size * t) + 1
418
+ time_steps = torch.full((batch_size,), t, device=device)
419
+
420
+ # Calculate current noise schedule
421
+ noise_level = self._get_noise_levels(
422
+ device, seq_len + self.chunk_size, time_steps
423
+ ) # (B, T)
424
+
425
+ # Predict noise through WanModel
426
+ noisy_input = []
427
+ for i in range(batch_size):
428
+ noisy_input.append(generated[i, :, :end_index, ...])
429
+
430
+ predicted_result = self.model(
431
+ noisy_input,
432
+ noise_level * self.time_embedding_scale,
433
+ all_text_context,
434
+ seq_len + self.chunk_size,
435
+ y=None,
436
+ ) # (B, C, T, 1, 1)
437
+
438
+ # Adjust using CFG
439
+ if self.cfg_scale != 1.0:
440
+ predicted_result_null = self.model(
441
+ noisy_input,
442
+ noise_level * self.time_embedding_scale,
443
+ text_null_context,
444
+ seq_len + self.chunk_size,
445
+ y=None,
446
+ ) # (B, C, T, 1, 1)
447
+ predicted_result = [
448
+ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
449
+ for pv, pvn in zip(predicted_result, predicted_result_null)
450
+ ]
451
+
452
+ for i in range(batch_size):
453
+ predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
454
+ if self.prediction_type == "vel":
455
+ predicted_vel = predicted_result_i[:, start_index:end_index, ...]
456
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
457
+ elif self.prediction_type == "x0":
458
+ predicted_vel = (
459
+ predicted_result_i[:, start_index:end_index, ...]
460
+ - generated[i, :, start_index:end_index, ...]
461
+ ) / (
462
+ noise_level[i, start_index:end_index]
463
+ .unsqueeze(0)
464
+ .unsqueeze(-1)
465
+ .unsqueeze(-1)
466
+ )
467
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
468
+ elif self.prediction_type == "noise":
469
+ predicted_vel = (
470
+ generated[i, :, start_index:end_index, ...]
471
+ - predicted_result_i[:, start_index:end_index, ...]
472
+ ) / (
473
+ 1
474
+ + dt
475
+ - noise_level[i, start_index:end_index]
476
+ .unsqueeze(0)
477
+ .unsqueeze(-1)
478
+ .unsqueeze(-1)
479
+ )
480
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
481
+
482
+ generated = self.postprocess(generated) # (B, T, C)
483
+ y_hat_out = []
484
+ for i in range(batch_size):
485
+ # cut off the padding
486
+ single_generated = generated[i, : generated_length[i], :]
487
+ y_hat_out.append(single_generated)
488
+ out = {}
489
+ out["generated"] = y_hat_out
490
+ out["text"] = full_text
491
+
492
+ return out
493
+
494
+ @torch.no_grad()
495
+ def stream_generate(self, x, num_denoise_steps=None):
496
+ """
497
+ Streaming generation - Diffusion Forcing inference
498
+ Uses triangular noise schedule, progressively generating from left to right
499
+
500
+ Generation process:
501
+ 1. Start from t=0, gradually increase t
502
+ 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
503
+ 3. After each denoising step, t increases slightly and continues
504
+ """
505
+ feature_length = x["feature_length"]
506
+ batch_size = len(feature_length)
507
+ seq_len = max(feature_length).item()
508
+
509
+ # # debug
510
+ # x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
511
+ # x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
512
+ # text = x["text"]
513
+ # text_end = x["feature_text_end"]
514
+ # print(text)
515
+ # print(text_end)
516
+ # print(batch_size, seq_len, self.chunk_size)
517
+
518
+ if num_denoise_steps is None:
519
+ num_denoise_steps = self.noise_steps
520
+ assert num_denoise_steps % self.chunk_size == 0
521
+
522
+ device = next(self.parameters()).device
523
+
524
+ # Initialize entire sequence as pure noise
525
+ generated = torch.randn(
526
+ batch_size, seq_len + self.chunk_size, self.input_dim, device=device
527
+ )
528
+ generated = self.preprocess(generated) # (B, C, T, 1, 1)
529
+
530
+ # Calculate total number of time steps needed
531
+ max_t = 1 + (seq_len - 1) / self.chunk_size
532
+
533
+ # Step size for each advancement
534
+ dt = 1 / num_denoise_steps
535
+ total_steps = int(max_t / dt)
536
+
537
+ # Encode text condition (using cache)
538
+ if self.use_text_cond and "text" in x:
539
+ text_list = x["text"] # List[str] or List[List[str]]
540
+ if isinstance(text_list[0], list):
541
+ generated_length = []
542
+ text_end_list = x["feature_text_end"]
543
+ full_text = []
544
+ all_text_context = []
545
+ for single_text_list, single_text_end_list in zip(
546
+ text_list, text_end_list
547
+ ):
548
+ single_text_end_list = [0] + [
549
+ min(t, seq_len) for t in single_text_end_list
550
+ ]
551
+ generated_length.append(single_text_end_list[-1])
552
+ single_text_length_list = [
553
+ t - b
554
+ for t, b in zip(
555
+ single_text_end_list[1:], single_text_end_list[:-1]
556
+ )
557
+ ]
558
+ full_text.append(
559
+ " ////////// ".join(
560
+ [
561
+ f"{u} //dur:{t}"
562
+ for u, t in zip(
563
+ single_text_list, single_text_length_list
564
+ )
565
+ ]
566
+ )
567
+ )
568
+ single_text_context = self.encode_text_with_cache(
569
+ single_text_list, device
570
+ )
571
+ single_text_context = [
572
+ u.to(self.param_dtype) for u in single_text_context
573
+ ]
574
+ for u, duration in zip(
575
+ single_text_context, single_text_length_list
576
+ ):
577
+ all_text_context.extend([u for _ in range(duration)])
578
+ all_text_context.extend(
579
+ [
580
+ single_text_context[-1]
581
+ for _ in range(
582
+ seq_len + self.chunk_size - single_text_end_list[-1]
583
+ )
584
+ ]
585
+ )
586
+ else:
587
+ generated_length = feature_length
588
+ full_text = text_list
589
+ all_text_context = self.encode_text_with_cache(text_list, device)
590
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
591
+ else:
592
+ generated_length = feature_length
593
+ full_text = [""] * batch_size
594
+ all_text_context = [""] * batch_size
595
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
596
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
597
+
598
+ # Get empty text condition encoding (for CFG)
599
+ text_null_list = [""] * batch_size
600
+ text_null_context = self.encode_text_with_cache(text_null_list, device)
601
+ text_null_context = [u.to(self.param_dtype) for u in text_null_context]
602
+
603
+ # print(len(all_text_context), len(text_null_context))
604
+
605
+ commit_index = 0
606
+ # Progressively advance from t=0 to t=max_t
607
+ for step in range(total_steps):
608
+ # Current time step
609
+ t = step * dt
610
+ start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
611
+ end_index = int(self.chunk_size * t) + 1
612
+ time_steps = torch.full((batch_size,), t, device=device)
613
+
614
+ # Calculate current noise schedule
615
+ noise_level = self._get_noise_levels(
616
+ device, seq_len + self.chunk_size, time_steps
617
+ ) # (B, T)
618
+
619
+ # Predict noise through WanModel
620
+ noisy_input = []
621
+ for i in range(batch_size):
622
+ noisy_input.append(generated[i, :, :end_index, ...])
623
+
624
+ predicted_result = self.model(
625
+ noisy_input,
626
+ noise_level * self.time_embedding_scale,
627
+ all_text_context,
628
+ seq_len + self.chunk_size,
629
+ y=None,
630
+ ) # (B, C, T, 1, 1)
631
+
632
+ # Adjust using CFG
633
+ if self.cfg_scale != 1.0:
634
+ predicted_result_null = self.model(
635
+ noisy_input,
636
+ noise_level * self.time_embedding_scale,
637
+ text_null_context,
638
+ seq_len + self.chunk_size,
639
+ y=None,
640
+ ) # (B, C, T, 1, 1)
641
+ predicted_result = [
642
+ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
643
+ for pv, pvn in zip(predicted_result, predicted_result_null)
644
+ ]
645
+
646
+ for i in range(batch_size):
647
+ predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
648
+ if self.prediction_type == "vel":
649
+ predicted_vel = predicted_result_i[:, start_index:end_index, ...]
650
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
651
+ elif self.prediction_type == "x0":
652
+ predicted_vel = (
653
+ predicted_result_i[:, start_index:end_index, ...]
654
+ - generated[i, :, start_index:end_index, ...]
655
+ ) / (
656
+ noise_level[i, start_index:end_index]
657
+ .unsqueeze(0)
658
+ .unsqueeze(-1)
659
+ .unsqueeze(-1)
660
+ )
661
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
662
+ elif self.prediction_type == "noise":
663
+ predicted_vel = (
664
+ generated[i, :, start_index:end_index, ...]
665
+ - predicted_result_i[:, start_index:end_index, ...]
666
+ ) / (
667
+ 1
668
+ + dt
669
+ - noise_level[i, start_index:end_index]
670
+ .unsqueeze(0)
671
+ .unsqueeze(-1)
672
+ .unsqueeze(-1)
673
+ )
674
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
675
+
676
+ if commit_index < start_index:
677
+ output = generated[:, :, commit_index:start_index, ...]
678
+ output = self.postprocess(output) # (B, T, C)
679
+ y_hat_out = []
680
+ for i in range(batch_size):
681
+ if commit_index < generated_length[i]:
682
+ y_hat_out.append(
683
+ output[i, : generated_length[i] - commit_index, ...]
684
+ )
685
+ else:
686
+ y_hat_out.append(None)
687
+
688
+ out = {}
689
+ out["generated"] = y_hat_out
690
+ yield out
691
+ commit_index = start_index
692
+
693
+ output = generated[:, :, commit_index:, ...]
694
+ output = self.postprocess(output) # (B, T_remain, C)
695
+ y_hat_out = []
696
+ for i in range(batch_size):
697
+ if commit_index < generated_length[i]:
698
+ y_hat_out.append(output[i, : generated_length[i] - commit_index, ...])
699
+ else:
700
+ y_hat_out.append(None)
701
+ out = {}
702
+ out["generated"] = y_hat_out
703
+ yield out
704
+
705
+ def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None):
706
+ self.seq_len = seq_len
707
+ self.batch_size = batch_size
708
+ if num_denoise_steps is None:
709
+ self.num_denoise_steps = self.noise_steps
710
+ else:
711
+ self.num_denoise_steps = num_denoise_steps
712
+ assert self.num_denoise_steps % self.chunk_size == 0
713
+ self.dt = 1 / self.num_denoise_steps
714
+ self.current_step = 0
715
+ self.text_condition_list = [[] for _ in range(self.batch_size)]
716
+ self.generated = torch.randn(
717
+ self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim
718
+ )
719
+ self.generated = self.preprocess(self.generated) # (B, C, T, 1, 1)
720
+ self.commit_index = 0
721
+
722
+ @torch.no_grad()
723
+ def stream_generate_step(self, x, first_chunk=True):
724
+ """
725
+ Streaming generation step - Diffusion Forcing inference
726
+ Uses triangular noise schedule, progressively generating from left to right
727
+
728
+ Generation process:
729
+ 1. Start from t=0, gradually increase t
730
+ 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
731
+ 3. After each denoising step, t increases slightly and continues
732
+ """
733
+
734
+ device = next(self.parameters()).device
735
+ if first_chunk:
736
+ self.generated = self.generated.to(device)
737
+
738
+ # Encode text condition (using cache)
739
+ if self.use_text_cond and "text" in x:
740
+ text_list = x["text"] # List[str]
741
+ new_text_context = self.encode_text_with_cache(text_list, device)
742
+ new_text_context = [u.to(self.param_dtype) for u in new_text_context]
743
+ else:
744
+ new_text_context = [""] * self.batch_size
745
+ new_text_context = self.encode_text_with_cache(new_text_context, device)
746
+ new_text_context = [u.to(self.param_dtype) for u in new_text_context]
747
+
748
+ # Get empty text condition encoding (for CFG)
749
+ text_null_list = [""] * self.batch_size
750
+ text_null_context = self.encode_text_with_cache(text_null_list, device)
751
+ text_null_context = [u.to(self.param_dtype) for u in text_null_context]
752
+
753
+ for i in range(self.batch_size):
754
+ if first_chunk:
755
+ self.text_condition_list[i].extend(
756
+ [new_text_context[i]] * self.chunk_size
757
+ )
758
+ else:
759
+ self.text_condition_list[i].extend([new_text_context[i]])
760
+
761
+ end_step = (
762
+ (self.commit_index + self.chunk_size)
763
+ * self.num_denoise_steps
764
+ / self.chunk_size
765
+ )
766
+ while self.current_step < end_step:
767
+ current_time = self.current_step * self.dt
768
+ start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1)
769
+ end_index = int(self.chunk_size * current_time) + 1
770
+ time_steps = torch.full((self.batch_size,), current_time, device=device)
771
+
772
+ noise_level = self._get_noise_levels(device, end_index, time_steps)[
773
+ :, -self.seq_len :
774
+ ] # (B, T)
775
+
776
+ # Predict noise through WanModel
777
+ noisy_input = []
778
+ for i in range(self.batch_size):
779
+ noisy_input.append(
780
+ self.generated[i, :, :end_index, ...][:, -self.seq_len :]
781
+ ) # (C, T, 1, 1)
782
+
783
+ text_condition = []
784
+ for i in range(self.batch_size):
785
+ text_condition.extend(
786
+ self.text_condition_list[i][:end_index][-self.seq_len :]
787
+ ) # (T, D, 4096)
788
+
789
+ # print("////////////////////")
790
+ # print("current step: ", self.current_step)
791
+ # print("chunk size: ", self.chunk_size)
792
+ # print("start_index: ", start_index)
793
+ # print("end_index: ", end_index)
794
+ # print("noisy_input shape: ", noisy_input[0].shape)
795
+ # print("noise_level: ", noise_level[0, start_index:end_index])
796
+ # print("text_condition shape: ", len(text_condition))
797
+ # print("commit_index: ", self.commit_index)
798
+ # print("////////////////////")
799
+
800
+ predicted_result = self.model(
801
+ noisy_input,
802
+ noise_level * self.time_embedding_scale,
803
+ text_condition,
804
+ min(end_index, self.seq_len),
805
+ y=None,
806
+ ) # (B, C, T, 1, 1)
807
+
808
+ # Adjust using CFG
809
+ if self.cfg_scale != 1.0:
810
+ predicted_result_null = self.model(
811
+ noisy_input,
812
+ noise_level * self.time_embedding_scale,
813
+ text_null_context,
814
+ min(end_index, self.seq_len),
815
+ y=None,
816
+ ) # (B, C, T, 1, 1)
817
+ predicted_result = [
818
+ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
819
+ for pv, pvn in zip(predicted_result, predicted_result_null)
820
+ ]
821
+
822
+ for i in range(self.batch_size):
823
+ predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
824
+ if end_index > self.seq_len:
825
+ predicted_result_i = torch.cat(
826
+ [
827
+ torch.zeros(
828
+ predicted_result_i.shape[0],
829
+ end_index - self.seq_len,
830
+ predicted_result_i.shape[2],
831
+ predicted_result_i.shape[3],
832
+ device=device,
833
+ ),
834
+ predicted_result_i,
835
+ ],
836
+ dim=1,
837
+ )
838
+ if self.prediction_type == "vel":
839
+ predicted_vel = predicted_result_i[:, start_index:end_index, ...]
840
+ self.generated[i, :, start_index:end_index, ...] += (
841
+ predicted_vel * self.dt
842
+ )
843
+ elif self.prediction_type == "x0":
844
+ predicted_vel = (
845
+ predicted_result_i[:, start_index:end_index, ...]
846
+ - self.generated[i, :, start_index:end_index, ...]
847
+ ) / (
848
+ noise_level[i, start_index:end_index]
849
+ .unsqueeze(0)
850
+ .unsqueeze(-1)
851
+ .unsqueeze(-1)
852
+ )
853
+ self.generated[i, :, start_index:end_index, ...] += (
854
+ predicted_vel * self.dt
855
+ )
856
+ elif self.prediction_type == "noise":
857
+ predicted_vel = (
858
+ self.generated[i, :, start_index:end_index, ...]
859
+ - predicted_result_i[:, start_index:end_index, ...]
860
+ ) / (
861
+ 1
862
+ + self.dt
863
+ - noise_level[i, start_index:end_index]
864
+ .unsqueeze(0)
865
+ .unsqueeze(-1)
866
+ .unsqueeze(-1)
867
+ )
868
+ self.generated[i, :, start_index:end_index, ...] += (
869
+ predicted_vel * self.dt
870
+ )
871
+ self.current_step += 1
872
+ output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...]
873
+ output = self.postprocess(output) # (B, 1, C)
874
+ out = {}
875
+ out["generated"] = output
876
+ self.commit_index += 1
877
+
878
+ if self.commit_index == self.seq_len * 2:
879
+ self.generated = torch.cat(
880
+ [
881
+ self.generated[:, :, self.seq_len :, ...],
882
+ torch.randn(
883
+ self.batch_size,
884
+ self.input_dim,
885
+ self.seq_len,
886
+ 1,
887
+ 1,
888
+ device=device,
889
+ ),
890
+ ],
891
+ dim=2,
892
+ )
893
+ self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size
894
+ self.commit_index -= self.seq_len
895
+ for i in range(self.batch_size):
896
+ self.text_condition_list[i] = self.text_condition_list[i][
897
+ self.seq_len :
898
+ ]
899
+ return out
ldf_models/tools/attention.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+ import warnings
19
+
20
+ __all__ = [
21
+ "flash_attention",
22
+ "attention",
23
+ ]
24
+
25
+
26
+ def flash_attention(
27
+ q,
28
+ k,
29
+ v,
30
+ q_lens=None,
31
+ k_lens=None,
32
+ dropout_p=0.0,
33
+ softmax_scale=None,
34
+ q_scale=None,
35
+ causal=False,
36
+ window_size=(-1, -1),
37
+ deterministic=False,
38
+ dtype=torch.bfloat16,
39
+ version=None,
40
+ ):
41
+ """
42
+ q: [B, Lq, Nq, C1].
43
+ k: [B, Lk, Nk, C1].
44
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
45
+ q_lens: [B].
46
+ k_lens: [B].
47
+ dropout_p: float. Dropout probability.
48
+ softmax_scale: float. The scaling of QK^T before applying softmax.
49
+ causal: bool. Whether to apply causal attention mask.
50
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
51
+ deterministic: bool. If True, slightly slower and uses more memory.
52
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
53
+ """
54
+ half_dtypes = (torch.float16, torch.bfloat16)
55
+ assert dtype in half_dtypes
56
+ assert q.device.type == "cuda" and q.size(-1) <= 256
57
+
58
+ # params
59
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
60
+
61
+ def half(x):
62
+ return x if x.dtype in half_dtypes else x.to(dtype)
63
+
64
+ # preprocess query
65
+ if q_lens is None:
66
+ q = half(q.flatten(0, 1))
67
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
68
+ device=q.device, non_blocking=True
69
+ )
70
+ else:
71
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
72
+
73
+ # preprocess key, value
74
+ if k_lens is None:
75
+ k = half(k.flatten(0, 1))
76
+ v = half(v.flatten(0, 1))
77
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
78
+ device=k.device, non_blocking=True
79
+ )
80
+ else:
81
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
82
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
83
+
84
+ q = q.to(v.dtype)
85
+ k = k.to(v.dtype)
86
+
87
+ if q_scale is not None:
88
+ q = q * q_scale
89
+
90
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
91
+ warnings.warn(
92
+ "Flash attention 3 is not available, use flash attention 2 instead."
93
+ )
94
+
95
+ # apply attention
96
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
97
+ # Note: dropout_p, window_size are not supported in FA3 now.
98
+ x = flash_attn_interface.flash_attn_varlen_func(
99
+ q=q,
100
+ k=k,
101
+ v=v,
102
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
103
+ .cumsum(0, dtype=torch.int32)
104
+ .to(q.device, non_blocking=True),
105
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
106
+ .cumsum(0, dtype=torch.int32)
107
+ .to(q.device, non_blocking=True),
108
+ seqused_q=None,
109
+ seqused_k=None,
110
+ max_seqlen_q=lq,
111
+ max_seqlen_k=lk,
112
+ softmax_scale=softmax_scale,
113
+ causal=causal,
114
+ deterministic=deterministic,
115
+ )[0].unflatten(0, (b, lq))
116
+ else:
117
+ assert FLASH_ATTN_2_AVAILABLE
118
+ x = flash_attn.flash_attn_varlen_func(
119
+ q=q,
120
+ k=k,
121
+ v=v,
122
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
123
+ .cumsum(0, dtype=torch.int32)
124
+ .to(q.device, non_blocking=True),
125
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
126
+ .cumsum(0, dtype=torch.int32)
127
+ .to(q.device, non_blocking=True),
128
+ max_seqlen_q=lq,
129
+ max_seqlen_k=lk,
130
+ dropout_p=dropout_p,
131
+ softmax_scale=softmax_scale,
132
+ causal=causal,
133
+ window_size=window_size,
134
+ deterministic=deterministic,
135
+ ).unflatten(0, (b, lq))
136
+
137
+ # output
138
+ return x.type(out_dtype)
139
+
140
+
141
+ def attention(
142
+ q,
143
+ k,
144
+ v,
145
+ q_lens=None,
146
+ k_lens=None,
147
+ dropout_p=0.0,
148
+ softmax_scale=None,
149
+ q_scale=None,
150
+ causal=False,
151
+ window_size=(-1, -1),
152
+ deterministic=False,
153
+ dtype=torch.bfloat16,
154
+ fa_version=None,
155
+ ):
156
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
157
+ return flash_attention(
158
+ q=q,
159
+ k=k,
160
+ v=v,
161
+ q_lens=q_lens,
162
+ k_lens=k_lens,
163
+ dropout_p=dropout_p,
164
+ softmax_scale=softmax_scale,
165
+ q_scale=q_scale,
166
+ causal=causal,
167
+ window_size=window_size,
168
+ deterministic=deterministic,
169
+ dtype=dtype,
170
+ version=fa_version,
171
+ )
172
+ else:
173
+ if q_lens is not None or k_lens is not None:
174
+ warnings.warn(
175
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
176
+ )
177
+ attn_mask = None
178
+
179
+ q = q.transpose(1, 2).to(dtype)
180
+ k = k.transpose(1, 2).to(dtype)
181
+ v = v.transpose(1, 2).to(dtype)
182
+
183
+ out = torch.nn.functional.scaled_dot_product_attention(
184
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
185
+ )
186
+
187
+ out = out.transpose(1, 2).contiguous()
188
+ return out
ldf_models/tools/t5.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ "T5Model",
14
+ "T5Encoder",
15
+ "T5Decoder",
16
+ "T5EncoderModel",
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
44
+ )
45
+
46
+
47
+ class GELU(nn.Module):
48
+ def forward(self, x):
49
+ return (
50
+ 0.5
51
+ * x
52
+ * (
53
+ 1.0
54
+ + torch.tanh(
55
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
56
+ )
57
+ )
58
+ )
59
+
60
+
61
+ class T5LayerNorm(nn.Module):
62
+ def __init__(self, dim, eps=1e-6):
63
+ super(T5LayerNorm, self).__init__()
64
+ self.dim = dim
65
+ self.eps = eps
66
+ self.weight = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x):
69
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
70
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
71
+ x = x.type_as(self.weight)
72
+ return self.weight * x
73
+
74
+
75
+ class T5Attention(nn.Module):
76
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
77
+ assert dim_attn % num_heads == 0
78
+ super(T5Attention, self).__init__()
79
+ self.dim = dim
80
+ self.dim_attn = dim_attn
81
+ self.num_heads = num_heads
82
+ self.head_dim = dim_attn // num_heads
83
+
84
+ # layers
85
+ self.q = nn.Linear(dim, dim_attn, bias=False)
86
+ self.k = nn.Linear(dim, dim_attn, bias=False)
87
+ self.v = nn.Linear(dim, dim_attn, bias=False)
88
+ self.o = nn.Linear(dim_attn, dim, bias=False)
89
+ self.dropout = nn.Dropout(dropout)
90
+
91
+ def forward(self, x, context=None, mask=None, pos_bias=None):
92
+ """
93
+ x: [B, L1, C].
94
+ context: [B, L2, C] or None.
95
+ mask: [B, L2] or [B, L1, L2] or None.
96
+ """
97
+ # check inputs
98
+ context = x if context is None else context
99
+ b, n, c = x.size(0), self.num_heads, self.head_dim
100
+
101
+ # compute query, key, value
102
+ q = self.q(x).view(b, -1, n, c)
103
+ k = self.k(context).view(b, -1, n, c)
104
+ v = self.v(context).view(b, -1, n, c)
105
+
106
+ # attention bias
107
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
108
+ if pos_bias is not None:
109
+ attn_bias += pos_bias
110
+ if mask is not None:
111
+ assert mask.ndim in [2, 3]
112
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
113
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
114
+
115
+ # compute attention (T5 does not use scaling)
116
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
117
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
118
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
119
+
120
+ # output
121
+ x = x.reshape(b, -1, n * c)
122
+ x = self.o(x)
123
+ x = self.dropout(x)
124
+ return x
125
+
126
+
127
+ class T5FeedForward(nn.Module):
128
+ def __init__(self, dim, dim_ffn, dropout=0.1):
129
+ super(T5FeedForward, self).__init__()
130
+ self.dim = dim
131
+ self.dim_ffn = dim_ffn
132
+
133
+ # layers
134
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
135
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
136
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
137
+ self.dropout = nn.Dropout(dropout)
138
+
139
+ def forward(self, x):
140
+ x = self.fc1(x) * self.gate(x)
141
+ x = self.dropout(x)
142
+ x = self.fc2(x)
143
+ x = self.dropout(x)
144
+ return x
145
+
146
+
147
+ class T5SelfAttention(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim,
151
+ dim_attn,
152
+ dim_ffn,
153
+ num_heads,
154
+ num_buckets,
155
+ shared_pos=True,
156
+ dropout=0.1,
157
+ ):
158
+ super(T5SelfAttention, self).__init__()
159
+ self.dim = dim
160
+ self.dim_attn = dim_attn
161
+ self.dim_ffn = dim_ffn
162
+ self.num_heads = num_heads
163
+ self.num_buckets = num_buckets
164
+ self.shared_pos = shared_pos
165
+
166
+ # layers
167
+ self.norm1 = T5LayerNorm(dim)
168
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
169
+ self.norm2 = T5LayerNorm(dim)
170
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
171
+ self.pos_embedding = (
172
+ None
173
+ if shared_pos
174
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
175
+ )
176
+
177
+ def forward(self, x, mask=None, pos_bias=None):
178
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
179
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
180
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
181
+ return x
182
+
183
+
184
+ class T5CrossAttention(nn.Module):
185
+ def __init__(
186
+ self,
187
+ dim,
188
+ dim_attn,
189
+ dim_ffn,
190
+ num_heads,
191
+ num_buckets,
192
+ shared_pos=True,
193
+ dropout=0.1,
194
+ ):
195
+ super(T5CrossAttention, self).__init__()
196
+ self.dim = dim
197
+ self.dim_attn = dim_attn
198
+ self.dim_ffn = dim_ffn
199
+ self.num_heads = num_heads
200
+ self.num_buckets = num_buckets
201
+ self.shared_pos = shared_pos
202
+
203
+ # layers
204
+ self.norm1 = T5LayerNorm(dim)
205
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
206
+ self.norm2 = T5LayerNorm(dim)
207
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
208
+ self.norm3 = T5LayerNorm(dim)
209
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
210
+ self.pos_embedding = (
211
+ None
212
+ if shared_pos
213
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
214
+ )
215
+
216
+ def forward(
217
+ self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
218
+ ):
219
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
220
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
221
+ x = fp16_clamp(
222
+ x
223
+ + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
224
+ )
225
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
226
+ return x
227
+
228
+
229
+ class T5RelativeEmbedding(nn.Module):
230
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
231
+ super(T5RelativeEmbedding, self).__init__()
232
+ self.num_buckets = num_buckets
233
+ self.num_heads = num_heads
234
+ self.bidirectional = bidirectional
235
+ self.max_dist = max_dist
236
+
237
+ # layers
238
+ self.embedding = nn.Embedding(num_buckets, num_heads)
239
+
240
+ def forward(self, lq, lk):
241
+ device = self.embedding.weight.device
242
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
243
+ # torch.arange(lq).unsqueeze(1).to(device)
244
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
245
+ lq, device=device
246
+ ).unsqueeze(1)
247
+ rel_pos = self._relative_position_bucket(rel_pos)
248
+ rel_pos_embeds = self.embedding(rel_pos)
249
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
250
+ return rel_pos_embeds.contiguous()
251
+
252
+ def _relative_position_bucket(self, rel_pos):
253
+ # preprocess
254
+ if self.bidirectional:
255
+ num_buckets = self.num_buckets // 2
256
+ rel_buckets = (rel_pos > 0).long() * num_buckets
257
+ rel_pos = torch.abs(rel_pos)
258
+ else:
259
+ num_buckets = self.num_buckets
260
+ rel_buckets = 0
261
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
262
+
263
+ # embeddings for small and large positions
264
+ max_exact = num_buckets // 2
265
+ rel_pos_large = (
266
+ max_exact
267
+ + (
268
+ torch.log(rel_pos.float() / max_exact)
269
+ / math.log(self.max_dist / max_exact)
270
+ * (num_buckets - max_exact)
271
+ ).long()
272
+ )
273
+ rel_pos_large = torch.min(
274
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
275
+ )
276
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
277
+ return rel_buckets
278
+
279
+
280
+ class T5Encoder(nn.Module):
281
+ def __init__(
282
+ self,
283
+ vocab,
284
+ dim,
285
+ dim_attn,
286
+ dim_ffn,
287
+ num_heads,
288
+ num_layers,
289
+ num_buckets,
290
+ shared_pos=True,
291
+ dropout=0.1,
292
+ ):
293
+ super(T5Encoder, self).__init__()
294
+ self.dim = dim
295
+ self.dim_attn = dim_attn
296
+ self.dim_ffn = dim_ffn
297
+ self.num_heads = num_heads
298
+ self.num_layers = num_layers
299
+ self.num_buckets = num_buckets
300
+ self.shared_pos = shared_pos
301
+
302
+ # layers
303
+ self.token_embedding = (
304
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
305
+ )
306
+ self.pos_embedding = (
307
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
308
+ if shared_pos
309
+ else None
310
+ )
311
+ self.dropout = nn.Dropout(dropout)
312
+ self.blocks = nn.ModuleList(
313
+ [
314
+ T5SelfAttention(
315
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
316
+ )
317
+ for _ in range(num_layers)
318
+ ]
319
+ )
320
+ self.norm = T5LayerNorm(dim)
321
+
322
+ # initialize weights
323
+ self.apply(init_weights)
324
+
325
+ def forward(self, ids, mask=None):
326
+ x = self.token_embedding(ids)
327
+ x = self.dropout(x)
328
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
329
+ for block in self.blocks:
330
+ x = block(x, mask, pos_bias=e)
331
+ x = self.norm(x)
332
+ x = self.dropout(x)
333
+ return x
334
+
335
+
336
+ class T5Decoder(nn.Module):
337
+ def __init__(
338
+ self,
339
+ vocab,
340
+ dim,
341
+ dim_attn,
342
+ dim_ffn,
343
+ num_heads,
344
+ num_layers,
345
+ num_buckets,
346
+ shared_pos=True,
347
+ dropout=0.1,
348
+ ):
349
+ super(T5Decoder, self).__init__()
350
+ self.dim = dim
351
+ self.dim_attn = dim_attn
352
+ self.dim_ffn = dim_ffn
353
+ self.num_heads = num_heads
354
+ self.num_layers = num_layers
355
+ self.num_buckets = num_buckets
356
+ self.shared_pos = shared_pos
357
+
358
+ # layers
359
+ self.token_embedding = (
360
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
361
+ )
362
+ self.pos_embedding = (
363
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
364
+ if shared_pos
365
+ else None
366
+ )
367
+ self.dropout = nn.Dropout(dropout)
368
+ self.blocks = nn.ModuleList(
369
+ [
370
+ T5CrossAttention(
371
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
372
+ )
373
+ for _ in range(num_layers)
374
+ ]
375
+ )
376
+ self.norm = T5LayerNorm(dim)
377
+
378
+ # initialize weights
379
+ self.apply(init_weights)
380
+
381
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
382
+ b, s = ids.size()
383
+
384
+ # causal mask
385
+ if mask is None:
386
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
387
+ elif mask.ndim == 2:
388
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
389
+
390
+ # layers
391
+ x = self.token_embedding(ids)
392
+ x = self.dropout(x)
393
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
394
+ for block in self.blocks:
395
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
396
+ x = self.norm(x)
397
+ x = self.dropout(x)
398
+ return x
399
+
400
+
401
+ class T5Model(nn.Module):
402
+ def __init__(
403
+ self,
404
+ vocab_size,
405
+ dim,
406
+ dim_attn,
407
+ dim_ffn,
408
+ num_heads,
409
+ encoder_layers,
410
+ decoder_layers,
411
+ num_buckets,
412
+ shared_pos=True,
413
+ dropout=0.1,
414
+ ):
415
+ super(T5Model, self).__init__()
416
+ self.vocab_size = vocab_size
417
+ self.dim = dim
418
+ self.dim_attn = dim_attn
419
+ self.dim_ffn = dim_ffn
420
+ self.num_heads = num_heads
421
+ self.encoder_layers = encoder_layers
422
+ self.decoder_layers = decoder_layers
423
+ self.num_buckets = num_buckets
424
+
425
+ # layers
426
+ self.token_embedding = nn.Embedding(vocab_size, dim)
427
+ self.encoder = T5Encoder(
428
+ self.token_embedding,
429
+ dim,
430
+ dim_attn,
431
+ dim_ffn,
432
+ num_heads,
433
+ encoder_layers,
434
+ num_buckets,
435
+ shared_pos,
436
+ dropout,
437
+ )
438
+ self.decoder = T5Decoder(
439
+ self.token_embedding,
440
+ dim,
441
+ dim_attn,
442
+ dim_ffn,
443
+ num_heads,
444
+ decoder_layers,
445
+ num_buckets,
446
+ shared_pos,
447
+ dropout,
448
+ )
449
+ self.head = nn.Linear(dim, vocab_size, bias=False)
450
+
451
+ # initialize weights
452
+ self.apply(init_weights)
453
+
454
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
455
+ x = self.encoder(encoder_ids, encoder_mask)
456
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
457
+ x = self.head(x)
458
+ return x
459
+
460
+
461
+ def _t5(
462
+ name,
463
+ encoder_only=False,
464
+ decoder_only=False,
465
+ return_tokenizer=False,
466
+ tokenizer_kwargs={},
467
+ dtype=torch.float32,
468
+ device="cpu",
469
+ **kwargs,
470
+ ):
471
+ # sanity check
472
+ assert not (encoder_only and decoder_only)
473
+
474
+ # params
475
+ if encoder_only:
476
+ model_cls = T5Encoder
477
+ kwargs["vocab"] = kwargs.pop("vocab_size")
478
+ kwargs["num_layers"] = kwargs.pop("encoder_layers")
479
+ _ = kwargs.pop("decoder_layers")
480
+ elif decoder_only:
481
+ model_cls = T5Decoder
482
+ kwargs["vocab"] = kwargs.pop("vocab_size")
483
+ kwargs["num_layers"] = kwargs.pop("decoder_layers")
484
+ _ = kwargs.pop("encoder_layers")
485
+ else:
486
+ model_cls = T5Model
487
+
488
+ # init model
489
+ with torch.device(device):
490
+ model = model_cls(**kwargs)
491
+
492
+ # set device
493
+ model = model.to(dtype=dtype, device=device)
494
+
495
+ # init tokenizer
496
+ if return_tokenizer:
497
+ from .tokenizers import HuggingfaceTokenizer
498
+
499
+ tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
500
+ return model, tokenizer
501
+ else:
502
+ return model
503
+
504
+
505
+ def umt5_xxl(**kwargs):
506
+ cfg = dict(
507
+ vocab_size=256384,
508
+ dim=4096,
509
+ dim_attn=4096,
510
+ dim_ffn=10240,
511
+ num_heads=64,
512
+ encoder_layers=24,
513
+ decoder_layers=24,
514
+ num_buckets=32,
515
+ shared_pos=False,
516
+ dropout=0.1,
517
+ )
518
+ cfg.update(**kwargs)
519
+ return _t5("umt5-xxl", **cfg)
520
+
521
+
522
+ class T5EncoderModel:
523
+ def __init__(
524
+ self,
525
+ text_len,
526
+ dtype=torch.bfloat16,
527
+ device=torch.cuda.current_device(),
528
+ checkpoint_path=None,
529
+ tokenizer_path=None,
530
+ shard_fn=None,
531
+ ):
532
+ self.text_len = text_len
533
+ self.dtype = dtype
534
+ self.device = device
535
+ self.checkpoint_path = checkpoint_path
536
+ self.tokenizer_path = tokenizer_path
537
+
538
+ # init model
539
+ model = (
540
+ umt5_xxl(
541
+ encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
542
+ )
543
+ .eval()
544
+ .requires_grad_(False)
545
+ )
546
+ logging.info(f"loading {checkpoint_path}")
547
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
548
+ self.model = model
549
+ if shard_fn is not None:
550
+ self.model = shard_fn(self.model, sync_module_states=False)
551
+ else:
552
+ self.model.to(self.device)
553
+ # init tokenizer
554
+ self.tokenizer = HuggingfaceTokenizer(
555
+ name=tokenizer_path, seq_len=text_len, clean="whitespace"
556
+ )
557
+
558
+ def __call__(self, texts, device):
559
+ ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
560
+ ids = ids.to(device)
561
+ mask = mask.to(device)
562
+ seq_lens = mask.gt(0).sum(dim=1).long()
563
+ context = self.model(ids, mask)
564
+ return [u[:v] for u, v in zip(context, seq_lens)]
ldf_models/tools/tokenizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ["HuggingfaceTokenizer"]
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r"\s+", " ", text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace("_", " ")
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans("", "", string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string)
30
+ )
31
+ else:
32
+ text = text.translate(str.maketrans("", "", string.punctuation))
33
+ text = text.lower()
34
+ text = re.sub(r"\s+", " ", text)
35
+ return text.strip()
36
+
37
+
38
+ class HuggingfaceTokenizer:
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, "whitespace", "lower", "canonicalize")
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop("return_mask", False)
51
+
52
+ # arguments
53
+ _kwargs = {"return_tensors": "pt"}
54
+ if self.seq_len is not None:
55
+ _kwargs.update(
56
+ {
57
+ "padding": "max_length",
58
+ "truncation": True,
59
+ "max_length": self.seq_len,
60
+ }
61
+ )
62
+ _kwargs.update(**kwargs)
63
+
64
+ # tokenization
65
+ if isinstance(sequence, str):
66
+ sequence = [sequence]
67
+ if self.clean:
68
+ sequence = [self._clean(u) for u in sequence]
69
+ ids = self.tokenizer(sequence, **_kwargs)
70
+
71
+ # output
72
+ if return_mask:
73
+ return ids.input_ids, ids.attention_mask
74
+ else:
75
+ return ids.input_ids
76
+
77
+ def _clean(self, text):
78
+ if self.clean == "whitespace":
79
+ text = whitespace_clean(basic_clean(text))
80
+ elif self.clean == "lower":
81
+ text = whitespace_clean(basic_clean(text)).lower()
82
+ elif self.clean == "canonicalize":
83
+ text = canonicalize(basic_clean(text))
84
+ return text
ldf_models/tools/wan_model.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module uses modified code from Alibaba Wan Team
2
+ # Original source: https://github.com/Wan-Video/Wan2.2
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ # Modified to support stream mode for cross-attention.
5
+ # Added causal attention for self-attention (1d case)
6
+ # Added context length corrrection.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+
15
+ from .attention import flash_attention
16
+
17
+
18
+ def sinusoidal_embedding_1d(dim, position):
19
+ # preprocess
20
+ assert dim % 2 == 0
21
+ half = dim // 2
22
+ position = position.type(torch.float64)
23
+
24
+ # calculation
25
+ sinusoid = torch.outer(
26
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half))
27
+ )
28
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
29
+ return x
30
+
31
+
32
+ @torch.amp.autocast("cuda", enabled=False)
33
+ def rope_params(max_seq_len, dim, theta=10000):
34
+ assert dim % 2 == 0
35
+ freqs = torch.outer(
36
+ torch.arange(max_seq_len),
37
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
38
+ )
39
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
40
+ return freqs
41
+
42
+
43
+ @torch.amp.autocast("cuda", enabled=False)
44
+ def rope_apply(x, grid_sizes, freqs):
45
+ n, c = x.size(2), x.size(3) // 2
46
+
47
+ # split freqs
48
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
49
+
50
+ # loop over samples
51
+ output = []
52
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
53
+ seq_len = f * h * w
54
+
55
+ # precompute multipliers
56
+ x_i = torch.view_as_complex(
57
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
58
+ )
59
+ freqs_i = torch.cat(
60
+ [
61
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
62
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
63
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
64
+ ],
65
+ dim=-1,
66
+ ).reshape(seq_len, 1, -1)
67
+
68
+ # apply rotary embedding
69
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
70
+ x_i = torch.cat([x_i, x[i, seq_len:]])
71
+
72
+ # append to collection
73
+ output.append(x_i)
74
+ return torch.stack(output).float()
75
+
76
+
77
+ class WanRMSNorm(nn.Module):
78
+ def __init__(self, dim, eps=1e-5):
79
+ super().__init__()
80
+ self.dim = dim
81
+ self.eps = eps
82
+ self.weight = nn.Parameter(torch.ones(dim))
83
+
84
+ def forward(self, x):
85
+ r"""
86
+ Args:
87
+ x(Tensor): Shape [B, L, C]
88
+ """
89
+ return self._norm(x.float()).type_as(x) * self.weight
90
+
91
+ def _norm(self, x):
92
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
93
+
94
+
95
+ class WanLayerNorm(nn.LayerNorm):
96
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
97
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
98
+
99
+ def forward(self, x):
100
+ r"""
101
+ Args:
102
+ x(Tensor): Shape [B, L, C]
103
+ """
104
+ return super().forward(x.float()).type_as(x)
105
+
106
+
107
+ class WanSelfAttention(nn.Module):
108
+ def __init__(
109
+ self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, causal=False
110
+ ):
111
+ assert dim % num_heads == 0
112
+ super().__init__()
113
+ self.dim = dim
114
+ self.num_heads = num_heads
115
+ self.head_dim = dim // num_heads
116
+ self.window_size = window_size
117
+ self.qk_norm = qk_norm
118
+ self.eps = eps
119
+ self.causal = causal
120
+ # layers
121
+ self.q = nn.Linear(dim, dim)
122
+ self.k = nn.Linear(dim, dim)
123
+ self.v = nn.Linear(dim, dim)
124
+ self.o = nn.Linear(dim, dim)
125
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
127
+
128
+ def forward(self, x, seq_lens, grid_sizes, freqs):
129
+ r"""
130
+ Args:
131
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
132
+ seq_lens(Tensor): Shape [B]
133
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
134
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
135
+ """
136
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
137
+
138
+ # query, key, value function
139
+ def qkv_fn(x):
140
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
141
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
142
+ v = self.v(x).view(b, s, n, d)
143
+ return q, k, v
144
+
145
+ q, k, v = qkv_fn(x)
146
+
147
+ x = flash_attention(
148
+ q=rope_apply(q, grid_sizes, freqs),
149
+ k=rope_apply(k, grid_sizes, freqs),
150
+ v=v,
151
+ k_lens=seq_lens,
152
+ window_size=self.window_size,
153
+ causal=self.causal,
154
+ )
155
+
156
+ # output
157
+ x = x.flatten(2)
158
+ x = self.o(x)
159
+ return x
160
+
161
+
162
+ class WanCrossAttention(WanSelfAttention):
163
+ def forward(self, x, context, context_lens):
164
+ r"""
165
+ Args non-stream mode:
166
+ x(Tensor): Shape [B, L1, C]
167
+ context(Tensor): Shape [B, L2, C]
168
+ context_lens(Tensor): Shape [B]
169
+ Args stream mode:
170
+ x(Tensor): Shape [B, L1, C]
171
+ context(Tensor): Shape [BxL1, L2, C]
172
+ context_lens(Tensor): Shape [BxL1]
173
+ """
174
+ out_sizes = x.size()
175
+ b, n, d = context.size(0), self.num_heads, self.head_dim
176
+
177
+ # compute query, key, value
178
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
179
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
180
+ v = self.v(context).view(b, -1, n, d)
181
+
182
+ # compute attention
183
+ x = flash_attention(q, k, v, k_lens=context_lens)
184
+
185
+ # output
186
+ x = x.flatten(2).view(*out_sizes)
187
+ x = self.o(x)
188
+ return x
189
+
190
+
191
+ class WanAttentionBlock(nn.Module):
192
+ def __init__(
193
+ self,
194
+ dim,
195
+ ffn_dim,
196
+ num_heads,
197
+ window_size=(-1, -1),
198
+ qk_norm=True,
199
+ cross_attn_norm=False,
200
+ eps=1e-6,
201
+ causal=False,
202
+ ):
203
+ super().__init__()
204
+ self.dim = dim
205
+ self.ffn_dim = ffn_dim
206
+ self.num_heads = num_heads
207
+ self.window_size = window_size
208
+ self.qk_norm = qk_norm
209
+ self.cross_attn_norm = cross_attn_norm
210
+ self.eps = eps
211
+ self.causal = causal
212
+ # layers
213
+ self.norm1 = WanLayerNorm(dim, eps)
214
+ self.self_attn = WanSelfAttention(
215
+ dim, num_heads, window_size, qk_norm, eps, causal
216
+ )
217
+ self.norm3 = (
218
+ WanLayerNorm(dim, eps, elementwise_affine=True)
219
+ if cross_attn_norm
220
+ else nn.Identity()
221
+ )
222
+
223
+ self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
224
+ self.norm2 = WanLayerNorm(dim, eps)
225
+ self.ffn = nn.Sequential(
226
+ nn.Linear(dim, ffn_dim),
227
+ nn.GELU(approximate="tanh"),
228
+ nn.Linear(ffn_dim, dim),
229
+ )
230
+
231
+ # modulation
232
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
233
+
234
+ def forward(
235
+ self,
236
+ x,
237
+ e,
238
+ seq_lens,
239
+ grid_sizes,
240
+ freqs,
241
+ context,
242
+ context_lens,
243
+ ):
244
+ r"""
245
+ Args:
246
+ x(Tensor): Shape [B, L, C]
247
+ e(Tensor): Shape [B, L1, 6, C]
248
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
249
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
250
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
251
+ """
252
+ assert e.dtype == torch.float32
253
+ with torch.amp.autocast("cuda", dtype=torch.float32):
254
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
255
+ assert e[0].dtype == torch.float32
256
+
257
+ # self-attention
258
+ y = self.self_attn(
259
+ self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
260
+ seq_lens,
261
+ grid_sizes,
262
+ freqs,
263
+ )
264
+ with torch.amp.autocast("cuda", dtype=torch.float32):
265
+ x = x + y * e[2].squeeze(2)
266
+
267
+ # cross-attention & ffn function
268
+ def cross_attn_ffn(x, context, context_lens, e):
269
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
270
+ y = self.ffn(
271
+ self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)
272
+ )
273
+ with torch.amp.autocast("cuda", dtype=torch.float32):
274
+ x = x + y * e[5].squeeze(2)
275
+ return x
276
+
277
+ x = cross_attn_ffn(x, context, context_lens, e)
278
+ return x
279
+
280
+
281
+ class Head(nn.Module):
282
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
283
+ super().__init__()
284
+ self.dim = dim
285
+ self.out_dim = out_dim
286
+ self.patch_size = patch_size
287
+ self.eps = eps
288
+
289
+ # layers
290
+ out_dim = math.prod(patch_size) * out_dim
291
+ self.norm = WanLayerNorm(dim, eps)
292
+ self.head = nn.Linear(dim, out_dim)
293
+
294
+ # modulation
295
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
296
+
297
+ def forward(self, x, e):
298
+ r"""
299
+ Args:
300
+ x(Tensor): Shape [B, L1, C]
301
+ e(Tensor): Shape [B, L1, C]
302
+ """
303
+ assert e.dtype == torch.float32
304
+ with torch.amp.autocast("cuda", dtype=torch.float32):
305
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
306
+ x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
307
+ return x
308
+
309
+
310
+ class WanModel(ModelMixin, ConfigMixin):
311
+ r"""
312
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
313
+ """
314
+
315
+ ignore_for_config = [
316
+ "patch_size",
317
+ "cross_attn_norm",
318
+ "qk_norm",
319
+ "text_dim",
320
+ "window_size",
321
+ ]
322
+ _no_split_modules = ["WanAttentionBlock"]
323
+
324
+ @register_to_config
325
+ def __init__(
326
+ self,
327
+ model_type="t2v",
328
+ patch_size=(1, 2, 2),
329
+ text_len=512,
330
+ in_dim=16,
331
+ dim=2048,
332
+ ffn_dim=8192,
333
+ freq_dim=256,
334
+ text_dim=4096,
335
+ out_dim=16,
336
+ num_heads=16,
337
+ num_layers=32,
338
+ window_size=(-1, -1),
339
+ qk_norm=True,
340
+ cross_attn_norm=True,
341
+ eps=1e-6,
342
+ causal=False,
343
+ ):
344
+ r"""
345
+ Initialize the diffusion model backbone.
346
+
347
+ Args:
348
+ model_type (`str`, *optional*, defaults to 't2v'):
349
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
350
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
351
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
352
+ text_len (`int`, *optional*, defaults to 512):
353
+ Fixed length for text embeddings
354
+ in_dim (`int`, *optional*, defaults to 16):
355
+ Input video channels (C_in)
356
+ dim (`int`, *optional*, defaults to 2048):
357
+ Hidden dimension of the transformer
358
+ ffn_dim (`int`, *optional*, defaults to 8192):
359
+ Intermediate dimension in feed-forward network
360
+ freq_dim (`int`, *optional*, defaults to 256):
361
+ Dimension for sinusoidal time embeddings
362
+ text_dim (`int`, *optional*, defaults to 4096):
363
+ Input dimension for text embeddings
364
+ out_dim (`int`, *optional*, defaults to 16):
365
+ Output video channels (C_out)
366
+ num_heads (`int`, *optional*, defaults to 16):
367
+ Number of attention heads
368
+ num_layers (`int`, *optional*, defaults to 32):
369
+ Number of transformer blocks
370
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
371
+ Window size for local attention (-1 indicates global attention)
372
+ qk_norm (`bool`, *optional*, defaults to True):
373
+ Enable query/key normalization
374
+ cross_attn_norm (`bool`, *optional*, defaults to False):
375
+ Enable cross-attention normalization
376
+ eps (`float`, *optional*, defaults to 1e-6):
377
+ Epsilon value for normalization layers
378
+ """
379
+
380
+ super().__init__()
381
+
382
+ assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
383
+ self.model_type = model_type
384
+
385
+ self.patch_size = patch_size
386
+ self.text_len = text_len
387
+ self.in_dim = in_dim
388
+ self.dim = dim
389
+ self.ffn_dim = ffn_dim
390
+ self.freq_dim = freq_dim
391
+ self.text_dim = text_dim
392
+ self.out_dim = out_dim
393
+ self.num_heads = num_heads
394
+ self.num_layers = num_layers
395
+ self.window_size = window_size
396
+ self.qk_norm = qk_norm
397
+ self.cross_attn_norm = cross_attn_norm
398
+ self.eps = eps
399
+ self.causal = causal
400
+ # embeddings
401
+ self.patch_embedding = nn.Conv3d(
402
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
403
+ )
404
+ self.text_embedding = nn.Sequential(
405
+ nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
406
+ )
407
+
408
+ self.time_embedding = nn.Sequential(
409
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
410
+ )
411
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
412
+
413
+ # blocks
414
+ self.blocks = nn.ModuleList(
415
+ [
416
+ WanAttentionBlock(
417
+ dim,
418
+ ffn_dim,
419
+ num_heads,
420
+ window_size,
421
+ qk_norm,
422
+ cross_attn_norm,
423
+ eps,
424
+ causal,
425
+ )
426
+ for _ in range(num_layers)
427
+ ]
428
+ )
429
+
430
+ # head
431
+ self.head = Head(dim, out_dim, patch_size, eps)
432
+
433
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
434
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
435
+ d = dim // num_heads
436
+ self.freqs = torch.cat(
437
+ [
438
+ rope_params(1024, d - 4 * (d // 6)),
439
+ rope_params(1024, 2 * (d // 6)),
440
+ rope_params(1024, 2 * (d // 6)),
441
+ ],
442
+ dim=1,
443
+ )
444
+
445
+ # initialize weights
446
+ self.init_weights()
447
+
448
+ def forward(
449
+ self,
450
+ x,
451
+ t,
452
+ context,
453
+ seq_len,
454
+ y=None,
455
+ ):
456
+ r"""
457
+ Forward pass through the diffusion model
458
+
459
+ Args:
460
+ x (List[Tensor]):
461
+ List of input video tensors, each with shape [C_in, F, H, W]
462
+ t (Tensor):
463
+ Diffusion timesteps tensor of shape [B]
464
+ context (List[Tensor]):
465
+ List of text embeddings each with shape [L, C]
466
+ seq_len (`int`):
467
+ Maximum sequence length for positional encoding
468
+ y (List[Tensor], *optional*):
469
+ Conditional video inputs for image-to-video mode, same shape as x
470
+
471
+ Returns:
472
+ List[Tensor]:
473
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
474
+ """
475
+ if self.model_type == "i2v":
476
+ assert y is not None
477
+ # params
478
+ device = self.patch_embedding.weight.device
479
+ if self.freqs.device != device:
480
+ self.freqs = self.freqs.to(device)
481
+
482
+ if y is not None:
483
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
484
+
485
+ # embeddings
486
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
487
+ grid_sizes = torch.stack(
488
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
489
+ )
490
+ x = [u.flatten(2).transpose(1, 2) for u in x]
491
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
492
+ assert seq_lens.max() <= seq_len
493
+ x = torch.cat(
494
+ [
495
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
496
+ for u in x
497
+ ]
498
+ )
499
+
500
+ # time embeddings
501
+ if t.dim() == 1: # bs
502
+ t = t.expand(t.size(0), seq_len)
503
+ with torch.amp.autocast("cuda", dtype=torch.float32):
504
+ bt = t.size(0)
505
+ t = t.flatten()
506
+ e = self.time_embedding(
507
+ sinusoidal_embedding_1d(self.freq_dim, t)
508
+ .unflatten(0, (bt, seq_len))
509
+ .float()
510
+ )
511
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
512
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
513
+
514
+ # context
515
+ context_lens = torch.tensor([u.size(0) for u in context], dtype=torch.long)
516
+ context = self.text_embedding(
517
+ torch.stack(
518
+ [
519
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
520
+ for u in context
521
+ ]
522
+ )
523
+ )
524
+
525
+ # arguments
526
+ kwargs = dict(
527
+ e=e0,
528
+ seq_lens=seq_lens,
529
+ grid_sizes=grid_sizes,
530
+ freqs=self.freqs,
531
+ context=context,
532
+ context_lens=context_lens,
533
+ )
534
+
535
+ for block in self.blocks:
536
+ x = block(x, **kwargs)
537
+
538
+ # head
539
+ x = self.head(x, e)
540
+
541
+ # unpatchify
542
+ x = self.unpatchify(x, grid_sizes)
543
+ return [u.float() for u in x]
544
+
545
+ def unpatchify(self, x, grid_sizes):
546
+ r"""
547
+ Reconstruct video tensors from patch embeddings.
548
+
549
+ Args:
550
+ x (List[Tensor]):
551
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
552
+ grid_sizes (Tensor):
553
+ Original spatial-temporal grid dimensions before patching,
554
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
555
+
556
+ Returns:
557
+ List[Tensor]:
558
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
559
+ """
560
+
561
+ c = self.out_dim
562
+ out = []
563
+ for u, v in zip(x, grid_sizes.tolist()):
564
+ u = u[: math.prod(v)].view(*v, *self.patch_size, c)
565
+ u = torch.einsum("fhwpqrc->cfphqwr", u)
566
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
567
+ out.append(u)
568
+ return out
569
+
570
+ def init_weights(self):
571
+ r"""
572
+ Initialize model parameters using Xavier initialization.
573
+ """
574
+
575
+ # basic init
576
+ for m in self.modules():
577
+ if isinstance(m, nn.Linear):
578
+ nn.init.xavier_uniform_(m.weight)
579
+ if m.bias is not None:
580
+ nn.init.zeros_(m.bias)
581
+
582
+ # init embeddings
583
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
584
+ for m in self.text_embedding.modules():
585
+ if isinstance(m, nn.Linear):
586
+ nn.init.normal_(m.weight, std=0.02)
587
+ for m in self.time_embedding.modules():
588
+ if isinstance(m, nn.Linear):
589
+ nn.init.normal_(m.weight, std=0.02)
590
+
591
+ # init output layer
592
+ nn.init.zeros_(self.head.head.weight)
ldf_models/tools/wan_vae_1d.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module uses modified code from Alibaba Wan Team
2
+ # Original source: https://github.com/Wan-Video/Wan2.2
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ # Modified to support 1d features with (B, C, T)
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ CACHE_T = 2
11
+
12
+
13
+ class CausalConv1d(nn.Conv1d):
14
+ """
15
+ Causal 1d convolusion.
16
+ """
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self._padding = (
21
+ 2 * self.padding[0],
22
+ 0,
23
+ )
24
+ self.padding = (0,)
25
+
26
+ def forward(self, x, cache_x=None):
27
+ padding = list(self._padding)
28
+ if cache_x is not None and self._padding[0] > 0:
29
+ cache_x = cache_x.to(x.device)
30
+ x = torch.cat([cache_x, x], dim=2)
31
+ padding[0] -= cache_x.shape[2]
32
+ x = F.pad(x, padding)
33
+
34
+ return super().forward(x)
35
+
36
+
37
+ class RMS_norm(nn.Module):
38
+ def __init__(self, dim, channel_first=True, bias=False):
39
+ super().__init__()
40
+ broadcastable_dims = (1,)
41
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
42
+
43
+ self.channel_first = channel_first
44
+ self.scale = dim**0.5
45
+ self.gamma = nn.Parameter(torch.ones(shape))
46
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
47
+
48
+ def forward(self, x):
49
+ return (
50
+ F.normalize(x, dim=(1 if self.channel_first else -1))
51
+ * self.scale
52
+ * self.gamma
53
+ + self.bias
54
+ )
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+ def forward(self, x):
59
+ """
60
+ Fix bfloat16 support for nearest neighbor interpolation.
61
+ """
62
+ return super().forward(x.float()).type_as(x)
63
+
64
+
65
+ class Resample(nn.Module):
66
+ def __init__(self, dim, mode):
67
+ assert mode in (
68
+ "upsample1d",
69
+ "downsample1d",
70
+ )
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == "upsample1d":
77
+ self.time_conv = CausalConv1d(dim, dim * 2, (3,), padding=(1,))
78
+ elif mode == "downsample1d":
79
+ self.time_conv = CausalConv1d(dim, dim, (3,), stride=(2,), padding=(0,))
80
+
81
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
82
+ b, c, t = x.size()
83
+ if self.mode == "upsample1d":
84
+ if feat_cache is not None:
85
+ idx = feat_idx[0]
86
+ if feat_cache[idx] is None:
87
+ feat_cache[idx] = "Rep"
88
+ feat_idx[0] += 1
89
+ else:
90
+ cache_x = x[:, :, -CACHE_T:].clone()
91
+ if (
92
+ cache_x.shape[2] < 2
93
+ and feat_cache[idx] is not None
94
+ and feat_cache[idx] != "Rep"
95
+ ):
96
+ # cache last frame of last two chunk
97
+ cache_x = torch.cat(
98
+ [
99
+ feat_cache[idx][:, :, -1]
100
+ .unsqueeze(2)
101
+ .to(cache_x.device),
102
+ cache_x,
103
+ ],
104
+ dim=2,
105
+ )
106
+ if (
107
+ cache_x.shape[2] < 2
108
+ and feat_cache[idx] is not None
109
+ and feat_cache[idx] == "Rep"
110
+ ):
111
+ cache_x = torch.cat(
112
+ [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
113
+ dim=2,
114
+ )
115
+ if feat_cache[idx] == "Rep":
116
+ x = self.time_conv(x)
117
+ else:
118
+ x = self.time_conv(x, feat_cache[idx])
119
+ feat_cache[idx] = cache_x
120
+ feat_idx[0] += 1
121
+ x = x.reshape(b, 2, c, t)
122
+ x = torch.stack((x[:, 0, :, :], x[:, 1, :, :]), 3)
123
+ x = x.reshape(b, c, t * 2)
124
+
125
+ if self.mode == "downsample1d":
126
+ if feat_cache is not None:
127
+ idx = feat_idx[0]
128
+ if feat_cache[idx] is None:
129
+ feat_cache[idx] = x.clone()
130
+ feat_idx[0] += 1
131
+ else:
132
+ cache_x = x[:, :, -1:].clone()
133
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:], x], 2))
134
+ feat_cache[idx] = cache_x
135
+ feat_idx[0] += 1
136
+ return x
137
+
138
+
139
+ class ResidualBlock(nn.Module):
140
+ def __init__(self, in_dim, out_dim, dropout=0.0):
141
+ super().__init__()
142
+ self.in_dim = in_dim
143
+ self.out_dim = out_dim
144
+
145
+ # layers
146
+ self.residual = nn.Sequential(
147
+ RMS_norm(in_dim),
148
+ nn.SiLU(),
149
+ CausalConv1d(in_dim, out_dim, 3, padding=1),
150
+ RMS_norm(out_dim),
151
+ nn.SiLU(),
152
+ nn.Dropout(dropout),
153
+ CausalConv1d(out_dim, out_dim, 3, padding=1),
154
+ )
155
+ self.shortcut = (
156
+ CausalConv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
157
+ )
158
+
159
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
160
+ h = self.shortcut(x)
161
+ for layer in self.residual:
162
+ if isinstance(layer, CausalConv1d) and feat_cache is not None:
163
+ idx = feat_idx[0]
164
+ cache_x = x[:, :, -CACHE_T:].clone()
165
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
166
+ # cache last frame of last two chunk
167
+ cache_x = torch.cat(
168
+ [
169
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
170
+ cache_x,
171
+ ],
172
+ dim=2,
173
+ )
174
+ x = layer(x, feat_cache[idx])
175
+ feat_cache[idx] = cache_x
176
+ feat_idx[0] += 1
177
+ else:
178
+ x = layer(x)
179
+ return x + h
180
+
181
+
182
+ class AvgDown1D(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channels,
186
+ out_channels,
187
+ factor_t,
188
+ ):
189
+ super().__init__()
190
+ self.in_channels = in_channels
191
+ self.out_channels = out_channels
192
+ self.factor_t = factor_t
193
+ self.factor = self.factor_t
194
+
195
+ assert in_channels * self.factor % out_channels == 0
196
+ self.group_size = in_channels * self.factor // out_channels
197
+
198
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
199
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
200
+ pad = (pad_t, 0)
201
+ x = F.pad(x, pad)
202
+ B, C, T = x.shape
203
+ x = x.view(
204
+ B,
205
+ C,
206
+ T // self.factor_t,
207
+ self.factor_t,
208
+ )
209
+ x = x.permute(0, 1, 3, 2).contiguous()
210
+ x = x.view(
211
+ B,
212
+ C * self.factor,
213
+ T // self.factor_t,
214
+ )
215
+ x = x.view(
216
+ B,
217
+ self.out_channels,
218
+ self.group_size,
219
+ T // self.factor_t,
220
+ )
221
+ x = x.mean(dim=2)
222
+ return x
223
+
224
+
225
+ class DupUp1D(nn.Module):
226
+ def __init__(
227
+ self,
228
+ in_channels: int,
229
+ out_channels: int,
230
+ factor_t,
231
+ ):
232
+ super().__init__()
233
+ self.in_channels = in_channels
234
+ self.out_channels = out_channels
235
+
236
+ self.factor_t = factor_t
237
+ self.factor = self.factor_t
238
+
239
+ assert out_channels * self.factor % in_channels == 0
240
+ self.repeats = out_channels * self.factor // in_channels
241
+
242
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
243
+ x = x.repeat_interleave(self.repeats, dim=1)
244
+ x = x.view(
245
+ x.size(0),
246
+ self.out_channels,
247
+ self.factor_t,
248
+ x.size(2),
249
+ )
250
+ x = x.permute(0, 1, 3, 2).contiguous()
251
+ x = x.view(
252
+ x.size(0),
253
+ self.out_channels,
254
+ x.size(2) * self.factor_t,
255
+ )
256
+ if first_chunk:
257
+ x = x[
258
+ :,
259
+ :,
260
+ self.factor_t - 1 :,
261
+ ]
262
+ return x
263
+
264
+
265
+ class Down_ResidualBlock(nn.Module):
266
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False):
267
+ super().__init__()
268
+
269
+ # Shortcut path with downsample
270
+ if temperal_downsample:
271
+ self.avg_shortcut = AvgDown1D(
272
+ in_dim,
273
+ out_dim,
274
+ factor_t=2,
275
+ )
276
+ else:
277
+ self.avg_shortcut = None
278
+
279
+ # Main path with residual blocks and downsample
280
+ downsamples = []
281
+ for _ in range(mult):
282
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
283
+ in_dim = out_dim
284
+
285
+ # Add the final downsample block
286
+ if temperal_downsample:
287
+ downsamples.append(Resample(out_dim, mode="downsample1d"))
288
+
289
+ self.downsamples = nn.Sequential(*downsamples)
290
+
291
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
292
+ x_copy = x.clone()
293
+ for module in self.downsamples:
294
+ x = module(x, feat_cache, feat_idx)
295
+ if self.avg_shortcut is None:
296
+ return x
297
+ else:
298
+ return x + self.avg_shortcut(x_copy)
299
+
300
+
301
+ class Up_ResidualBlock(nn.Module):
302
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False):
303
+ super().__init__()
304
+ # Shortcut path with upsample
305
+ if temperal_upsample:
306
+ self.avg_shortcut = DupUp1D(
307
+ in_dim,
308
+ out_dim,
309
+ factor_t=2,
310
+ )
311
+ else:
312
+ self.avg_shortcut = None
313
+
314
+ # Main path with residual blocks and upsample
315
+ upsamples = []
316
+ for _ in range(mult):
317
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
318
+ in_dim = out_dim
319
+
320
+ # Add the final upsample block
321
+ if temperal_upsample:
322
+ upsamples.append(Resample(out_dim, mode="upsample1d"))
323
+
324
+ self.upsamples = nn.Sequential(*upsamples)
325
+
326
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
327
+ x_main = x.clone()
328
+ for module in self.upsamples:
329
+ x_main = module(x_main, feat_cache, feat_idx)
330
+ if self.avg_shortcut is not None:
331
+ x_shortcut = self.avg_shortcut(x, first_chunk)
332
+ return x_main + x_shortcut
333
+ else:
334
+ return x_main
335
+
336
+
337
+ class Encoder1d(nn.Module):
338
+ def __init__(
339
+ self,
340
+ input_dim,
341
+ dim=128,
342
+ z_dim=4,
343
+ dim_mult=[1, 2, 4, 4],
344
+ num_res_blocks=2,
345
+ temperal_downsample=[True, True, False],
346
+ dropout=0.0,
347
+ ):
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.z_dim = z_dim
351
+ self.dim_mult = dim_mult
352
+ self.num_res_blocks = num_res_blocks
353
+ self.temperal_downsample = temperal_downsample
354
+
355
+ # dimensions
356
+ dims = [dim * u for u in [1] + dim_mult]
357
+ scale = 1.0
358
+
359
+ # init block
360
+ self.conv1 = CausalConv1d(input_dim, dims[0], 3, padding=1)
361
+
362
+ # downsample blocks
363
+ downsamples = []
364
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
365
+ t_down_flag = (
366
+ temperal_downsample[i] if i < len(temperal_downsample) else False
367
+ )
368
+ downsamples.append(
369
+ Down_ResidualBlock(
370
+ in_dim=in_dim,
371
+ out_dim=out_dim,
372
+ dropout=dropout,
373
+ mult=num_res_blocks,
374
+ temperal_downsample=t_down_flag,
375
+ )
376
+ )
377
+ scale /= 2.0
378
+ self.downsamples = nn.Sequential(*downsamples)
379
+
380
+ # middle blocks
381
+ self.middle = nn.Sequential(
382
+ ResidualBlock(out_dim, out_dim, dropout),
383
+ RMS_norm(out_dim),
384
+ CausalConv1d(out_dim, out_dim, 1),
385
+ ResidualBlock(out_dim, out_dim, dropout),
386
+ )
387
+
388
+ # # output blocks
389
+ self.head = nn.Sequential(
390
+ RMS_norm(out_dim),
391
+ nn.SiLU(),
392
+ CausalConv1d(out_dim, z_dim, 3, padding=1),
393
+ )
394
+
395
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
396
+ if feat_cache is not None:
397
+ idx = feat_idx[0]
398
+ cache_x = x[:, :, -CACHE_T:].clone()
399
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
400
+ cache_x = torch.cat(
401
+ [
402
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
403
+ cache_x,
404
+ ],
405
+ dim=2,
406
+ )
407
+ x = self.conv1(x, feat_cache[idx])
408
+ feat_cache[idx] = cache_x
409
+ feat_idx[0] += 1
410
+ else:
411
+ x = self.conv1(x)
412
+
413
+ ## downsamples
414
+ for layer in self.downsamples:
415
+ if feat_cache is not None:
416
+ x = layer(x, feat_cache, feat_idx)
417
+ else:
418
+ x = layer(x)
419
+
420
+ ## middle
421
+ for layer in self.middle:
422
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
423
+ x = layer(x, feat_cache, feat_idx)
424
+ else:
425
+ x = layer(x)
426
+
427
+ ## head
428
+ for layer in self.head:
429
+ if isinstance(layer, CausalConv1d) and feat_cache is not None:
430
+ idx = feat_idx[0]
431
+ cache_x = x[:, :, -CACHE_T:].clone()
432
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
433
+ cache_x = torch.cat(
434
+ [
435
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
436
+ cache_x,
437
+ ],
438
+ dim=2,
439
+ )
440
+ x = layer(x, feat_cache[idx])
441
+ feat_cache[idx] = cache_x
442
+ feat_idx[0] += 1
443
+ else:
444
+ x = layer(x)
445
+
446
+ return x
447
+
448
+
449
+ class Decoder1d(nn.Module):
450
+ def __init__(
451
+ self,
452
+ output_dim,
453
+ dim=128,
454
+ z_dim=4,
455
+ dim_mult=[1, 2, 4, 4],
456
+ num_res_blocks=2,
457
+ temperal_upsample=[False, True, True],
458
+ dropout=0.0,
459
+ ):
460
+ super().__init__()
461
+ self.dim = dim
462
+ self.z_dim = z_dim
463
+ self.dim_mult = dim_mult
464
+ self.num_res_blocks = num_res_blocks
465
+ self.temperal_upsample = temperal_upsample
466
+
467
+ # dimensions
468
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
469
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
470
+ # init block
471
+ self.conv1 = CausalConv1d(z_dim, dims[0], 3, padding=1)
472
+
473
+ # middle blocks
474
+ self.middle = nn.Sequential(
475
+ ResidualBlock(dims[0], dims[0], dropout),
476
+ RMS_norm(dims[0]),
477
+ CausalConv1d(dims[0], dims[0], 1),
478
+ ResidualBlock(dims[0], dims[0], dropout),
479
+ )
480
+
481
+ # upsample blocks
482
+ upsamples = []
483
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
484
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
485
+ upsamples.append(
486
+ Up_ResidualBlock(
487
+ in_dim=in_dim,
488
+ out_dim=out_dim,
489
+ dropout=dropout,
490
+ mult=num_res_blocks + 1,
491
+ temperal_upsample=t_up_flag,
492
+ )
493
+ )
494
+ self.upsamples = nn.Sequential(*upsamples)
495
+
496
+ # output blocks
497
+ self.head = nn.Sequential(
498
+ RMS_norm(out_dim),
499
+ nn.SiLU(),
500
+ CausalConv1d(out_dim, output_dim, 3, padding=1),
501
+ )
502
+
503
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
504
+ if feat_cache is not None:
505
+ idx = feat_idx[0]
506
+ cache_x = x[:, :, -CACHE_T:].clone()
507
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
508
+ cache_x = torch.cat(
509
+ [
510
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
511
+ cache_x,
512
+ ],
513
+ dim=2,
514
+ )
515
+ x = self.conv1(x, feat_cache[idx])
516
+ feat_cache[idx] = cache_x
517
+ feat_idx[0] += 1
518
+ else:
519
+ x = self.conv1(x)
520
+
521
+ for layer in self.middle:
522
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
523
+ x = layer(x, feat_cache, feat_idx)
524
+ else:
525
+ x = layer(x)
526
+
527
+ ## upsamples
528
+ for layer in self.upsamples:
529
+ if feat_cache is not None:
530
+ x = layer(x, feat_cache, feat_idx, first_chunk)
531
+ else:
532
+ x = layer(x)
533
+
534
+ ## head
535
+ for layer in self.head:
536
+ if isinstance(layer, CausalConv1d) and feat_cache is not None:
537
+ idx = feat_idx[0]
538
+ cache_x = x[:, :, -CACHE_T:].clone()
539
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
540
+ cache_x = torch.cat(
541
+ [
542
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
543
+ cache_x,
544
+ ],
545
+ dim=2,
546
+ )
547
+ x = layer(x, feat_cache[idx])
548
+ feat_cache[idx] = cache_x
549
+ feat_idx[0] += 1
550
+ else:
551
+ x = layer(x)
552
+ return x
553
+
554
+
555
+ def count_conv1d(model):
556
+ count = 0
557
+ for m in model.modules():
558
+ if isinstance(m, CausalConv1d):
559
+ count += 1
560
+ return count
561
+
562
+
563
+ class WanVAE_(nn.Module):
564
+ def __init__(
565
+ self,
566
+ input_dim,
567
+ dim=160,
568
+ dec_dim=256,
569
+ z_dim=16,
570
+ dim_mult=[1, 2, 4, 4],
571
+ num_res_blocks=1,
572
+ temperal_downsample=[True, True, False],
573
+ dropout=0.0,
574
+ ):
575
+ super().__init__()
576
+ self.dim = dim
577
+ self.z_dim = z_dim
578
+ self.dim_mult = dim_mult
579
+ self.num_res_blocks = num_res_blocks
580
+ self.temperal_downsample = temperal_downsample
581
+ self.temperal_upsample = temperal_downsample[::-1]
582
+
583
+ # modules
584
+ self.encoder = Encoder1d(
585
+ input_dim,
586
+ dim,
587
+ z_dim * 2,
588
+ dim_mult,
589
+ num_res_blocks,
590
+ self.temperal_downsample,
591
+ dropout,
592
+ )
593
+ self.conv1 = CausalConv1d(z_dim * 2, z_dim * 2, 1)
594
+ self.conv2 = CausalConv1d(z_dim, z_dim, 1)
595
+ self.decoder = Decoder1d(
596
+ input_dim,
597
+ dec_dim,
598
+ z_dim,
599
+ dim_mult,
600
+ num_res_blocks,
601
+ self.temperal_upsample,
602
+ dropout,
603
+ )
604
+
605
+ def forward(self, x, scale=[0, 1]):
606
+ mu = self.encode(x, scale)
607
+ x_recon = self.decode(mu, scale)
608
+ return x_recon, mu
609
+
610
+ def encode(self, x, scale, return_dist=False):
611
+ self.clear_cache()
612
+ t = x.shape[2]
613
+ iter_ = 1 + (t - 1) // 4
614
+ for i in range(iter_):
615
+ self._enc_conv_idx = [0]
616
+ if i == 0:
617
+ out = self.encoder(
618
+ x[:, :, :1],
619
+ feat_cache=self._enc_feat_map,
620
+ feat_idx=self._enc_conv_idx,
621
+ )
622
+ else:
623
+ out_ = self.encoder(
624
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
625
+ feat_cache=self._enc_feat_map,
626
+ feat_idx=self._enc_conv_idx,
627
+ )
628
+ out = torch.cat([out, out_], 2)
629
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
630
+ if isinstance(scale[0], torch.Tensor):
631
+ mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
632
+ 1, self.z_dim, 1
633
+ )
634
+ else:
635
+ mu = (mu - scale[0]) * scale[1]
636
+ self.clear_cache()
637
+ if return_dist:
638
+ return mu, log_var
639
+ return mu
640
+
641
+ def decode(self, z, scale):
642
+ self.clear_cache()
643
+ if isinstance(scale[0], torch.Tensor):
644
+ z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
645
+ else:
646
+ z = z / scale[1] + scale[0]
647
+ iter_ = z.shape[2]
648
+ x = self.conv2(z)
649
+ for i in range(iter_):
650
+ self._conv_idx = [0]
651
+ if i == 0:
652
+ out = self.decoder(
653
+ x[:, :, i : i + 1],
654
+ feat_cache=self._feat_map,
655
+ feat_idx=self._conv_idx,
656
+ first_chunk=True,
657
+ )
658
+ else:
659
+ out_ = self.decoder(
660
+ x[:, :, i : i + 1],
661
+ feat_cache=self._feat_map,
662
+ feat_idx=self._conv_idx,
663
+ )
664
+ out = torch.cat([out, out_], 2)
665
+ self.clear_cache()
666
+ return out
667
+
668
+ @torch.no_grad()
669
+ def stream_encode(self, x, first_chunk, scale, return_dist=False):
670
+ t = x.shape[2]
671
+ if first_chunk:
672
+ iter_ = 1 + (t - 1) // 4
673
+ else:
674
+ iter_ = t // 4
675
+ for i in range(iter_):
676
+ self._enc_conv_idx = [0]
677
+ if i == 0:
678
+ if first_chunk:
679
+ out = self.encoder(
680
+ x[:, :, :1],
681
+ feat_cache=self._enc_feat_map,
682
+ feat_idx=self._enc_conv_idx,
683
+ )
684
+ else:
685
+ out = self.encoder(
686
+ x[:, :, :4],
687
+ feat_cache=self._enc_feat_map,
688
+ feat_idx=self._enc_conv_idx,
689
+ )
690
+ else:
691
+ if first_chunk:
692
+ out_ = self.encoder(
693
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
694
+ feat_cache=self._enc_feat_map,
695
+ feat_idx=self._enc_conv_idx,
696
+ )
697
+ else:
698
+ out_ = self.encoder(
699
+ x[:, :, 4 * i : 4 * (i + 1)],
700
+ feat_cache=self._enc_feat_map,
701
+ feat_idx=self._enc_conv_idx,
702
+ )
703
+ out = torch.cat([out, out_], 2)
704
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
705
+ if isinstance(scale[0], torch.Tensor):
706
+ mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
707
+ 1, self.z_dim, 1
708
+ )
709
+ else:
710
+ mu = (mu - scale[0]) * scale[1]
711
+ if return_dist:
712
+ return mu, log_var
713
+ else:
714
+ return mu
715
+
716
+ @torch.no_grad()
717
+ def stream_decode(self, z, first_chunk, scale):
718
+ if isinstance(scale[0], torch.Tensor):
719
+ z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
720
+ else:
721
+ z = z / scale[1] + scale[0]
722
+ iter_ = z.shape[2]
723
+ x = self.conv2(z)
724
+ for i in range(iter_):
725
+ self._conv_idx = [0]
726
+ if i == 0:
727
+ out = self.decoder(
728
+ x[:, :, i : i + 1],
729
+ feat_cache=self._feat_map,
730
+ feat_idx=self._conv_idx,
731
+ first_chunk=first_chunk, # Use the external first_chunk parameter
732
+ )
733
+ else:
734
+ out_ = self.decoder(
735
+ x[:, :, i : i + 1],
736
+ feat_cache=self._feat_map,
737
+ feat_idx=self._conv_idx,
738
+ first_chunk=False, # Explicitly set to False for subsequent time steps within the same chunk
739
+ )
740
+ out = torch.cat([out, out_], 2)
741
+ return out
742
+
743
+ def reparameterize(self, mu, log_var):
744
+ std = torch.exp(0.5 * log_var)
745
+ eps = torch.randn_like(std)
746
+ return eps * std + mu
747
+
748
+ def sample(self, imgs, deterministic=False):
749
+ mu, log_var = self.encode(imgs)
750
+ if deterministic:
751
+ return mu
752
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
753
+ return mu + std * torch.randn_like(std)
754
+
755
+ def clear_cache(self):
756
+ self._conv_num = count_conv1d(self.decoder)
757
+ self._conv_idx = [0]
758
+ self._feat_map = [None] * self._conv_num
759
+ # cache encode
760
+ self._enc_conv_num = count_conv1d(self.encoder)
761
+ self._enc_conv_idx = [0]
762
+ self._enc_feat_map = [None] * self._enc_conv_num
ldf_models/vae_wan_1d.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .tools.wan_vae_1d import WanVAE_
6
+
7
+
8
+ class VAEWanModel(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim,
12
+ mean_path=None,
13
+ std_path=None,
14
+ z_dim=256,
15
+ dim=160,
16
+ dec_dim=512,
17
+ num_res_blocks=1,
18
+ dropout=0.0,
19
+ dim_mult=[1, 1, 1],
20
+ temperal_downsample=[True, True],
21
+ vel_window=[0, 0],
22
+ **kwargs,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.mean_path = mean_path
27
+ self.std_path = std_path
28
+ self.input_dim = input_dim
29
+ self.z_dim = z_dim
30
+ self.dim = dim
31
+ self.dec_dim = dec_dim
32
+ self.num_res_blocks = num_res_blocks
33
+ self.dropout = dropout
34
+ self.dim_mult = dim_mult
35
+ self.temperal_downsample = temperal_downsample
36
+ self.vel_window = vel_window
37
+ self.RECONS_LOSS = nn.SmoothL1Loss()
38
+ self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
39
+ self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5)
40
+ self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)
41
+
42
+ if self.mean_path is not None:
43
+ self.register_buffer(
44
+ "mean", torch.from_numpy(np.load(self.mean_path)).float()
45
+ )
46
+ else:
47
+ self.register_buffer("mean", torch.zeros(input_dim))
48
+
49
+ if self.std_path is not None:
50
+ self.register_buffer(
51
+ "std", torch.from_numpy(np.load(self.std_path)).float()
52
+ )
53
+ else:
54
+ self.register_buffer("std", torch.ones(input_dim))
55
+
56
+ self.model = WanVAE_(
57
+ input_dim=self.input_dim,
58
+ dim=self.dim,
59
+ dec_dim=self.dec_dim,
60
+ z_dim=self.z_dim,
61
+ dim_mult=self.dim_mult,
62
+ num_res_blocks=self.num_res_blocks,
63
+ temperal_downsample=self.temperal_downsample,
64
+ dropout=self.dropout,
65
+ )
66
+
67
+ downsample_factor = 1
68
+ for flag in self.temperal_downsample:
69
+ if flag:
70
+ downsample_factor *= 2
71
+ self.downsample_factor = downsample_factor
72
+
73
+ def preprocess(self, x):
74
+ # (bs, T, C) -> (bs, C, T)
75
+ x = x.permute(0, 2, 1)
76
+ return x
77
+
78
+ def postprocess(self, x):
79
+ # (bs, C, T) -> (bs, T, C)
80
+ x = x.permute(0, 2, 1)
81
+ return x
82
+
83
+ def forward(self, x):
84
+ features = x["feature"]
85
+ feature_length = x["feature_length"]
86
+ features = (features - self.mean) / self.std
87
+ # create mask based on feature_length
88
+ batch_size, seq_len = features.shape[:2]
89
+ mask = torch.zeros(
90
+ batch_size, seq_len, dtype=torch.bool, device=features.device
91
+ )
92
+ for i in range(batch_size):
93
+ mask[i, : feature_length[i]] = True
94
+
95
+ x_in = self.preprocess(features) # (bs, input_dim, T)
96
+ mu, log_var = self.model.encode(
97
+ x_in, scale=[0, 1], return_dist=True
98
+ ) # (bs, z_dim, T)
99
+ z = self.model.reparameterize(mu, log_var)
100
+ x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T)
101
+ x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
102
+
103
+ if x_out.size(1) != features.size(1):
104
+ min_len = min(x_out.size(1), features.size(1))
105
+ x_out = x_out[:, :min_len, :]
106
+ features = features[:, :min_len, :]
107
+ mask = mask[:, :min_len]
108
+
109
+ mask_expanded = mask.unsqueeze(-1)
110
+ x_out_masked = x_out * mask_expanded
111
+ features_masked = features * mask_expanded
112
+ loss_recons = self.RECONS_LOSS(x_out_masked, features_masked)
113
+ vel_start = self.vel_window[0]
114
+ vel_end = self.vel_window[1]
115
+ loss_vel = self.RECONS_LOSS(
116
+ x_out_masked[..., vel_start:vel_end],
117
+ features_masked[..., vel_start:vel_end],
118
+ )
119
+
120
+ # Compute KL divergence loss
121
+ # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
122
+ # log_var = log(sigma^2), so we can use it directly
123
+
124
+ # Build mask for latent space
125
+ T_latent = mu.size(2)
126
+ mask_downsampled = torch.zeros(
127
+ batch_size, T_latent, dtype=torch.bool, device=features.device
128
+ )
129
+ for i in range(batch_size):
130
+ latent_length = (
131
+ feature_length[i] + self.downsample_factor - 1
132
+ ) // self.downsample_factor
133
+ mask_downsampled[i, :latent_length] = True
134
+ mask_latent = mask_downsampled.unsqueeze(1) # (B, 1, T_latent)
135
+
136
+ # Compute KL loss per element
137
+ kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
138
+ # Apply mask: only compute KL loss for valid timesteps
139
+ kl_masked = kl_per_element * mask_latent
140
+ # Sum over all dimensions and normalize by the number of valid elements
141
+ kl_loss = torch.sum(kl_masked) / (
142
+ torch.sum(mask_downsampled) * mu.size(1)
143
+ ) # normalize by valid timesteps * latent_dim
144
+
145
+ # Total loss
146
+ total_loss = (
147
+ self.LAMBDA_FEATURE * loss_recons
148
+ + self.LAMBDA_VELOCITY * loss_vel
149
+ + self.LAMBDA_KL * kl_loss
150
+ )
151
+
152
+ loss_dict = {}
153
+ loss_dict["total"] = total_loss
154
+ loss_dict["recons"] = loss_recons
155
+ loss_dict["velocity"] = loss_vel
156
+ loss_dict["kl"] = kl_loss
157
+
158
+ return loss_dict
159
+
160
+ def encode(self, x):
161
+ x = (x - self.mean) / self.std
162
+ x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T)
163
+ mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T)
164
+ mu = self.postprocess(mu) # (bs, T, z_dim)
165
+ return mu
166
+
167
+ def decode(self, mu):
168
+ mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T)
169
+ x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T)
170
+ x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
171
+ x_out = x_out * self.std + self.mean
172
+ return x_out
173
+
174
+ @torch.no_grad()
175
+ def stream_encode(self, x, first_chunk=True):
176
+ x = (x - self.mean) / self.std
177
+ x_in = self.preprocess(x) # (bs, input_dim, T)
178
+ mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
179
+ mu = self.postprocess(mu) # (bs, T, z_dim)
180
+ return mu
181
+
182
+ @torch.no_grad()
183
+ def stream_decode(self, mu, first_chunk=True):
184
+ mu_in = self.preprocess(mu) # (bs, z_dim, T)
185
+ x_decoder = self.model.stream_decode(
186
+ mu_in, first_chunk=first_chunk, scale=[0, 1]
187
+ )
188
+ x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
189
+ x_out = x_out * self.std + self.mean
190
+ return x_out
191
+
192
+ def clear_cache(self):
193
+ self.model.clear_cache()
194
+
195
+ def generate(self, x):
196
+ features = x["feature"]
197
+ feature_length = x["feature_length"]
198
+ y_hat = self.decode(self.encode(features))
199
+
200
+ y_hat_out = []
201
+
202
+ for i in range(y_hat.shape[0]):
203
+ # cut off the padding and align lengths
204
+ valid_len = (
205
+ feature_length[i] - 1
206
+ ) // self.downsample_factor * self.downsample_factor + 1
207
+ # Make sure both have the same length (take minimum)
208
+ y_hat_out.append(y_hat[i, :valid_len, :])
209
+
210
+ out = {}
211
+ out["generated"] = y_hat_out
212
+ return out
ldf_utils/__init__.py ADDED
File without changes
ldf_utils/initialize.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import time
5
+ from datetime import datetime
6
+ from importlib import import_module
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ import torch
11
+ from lightning.pytorch.utilities import rank_zero_info
12
+ from omegaconf import OmegaConf
13
+
14
+
15
+ class Config:
16
+ def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None):
17
+ self.config = OmegaConf.create({})
18
+
19
+ # Load main config if provided
20
+ if config_path:
21
+ self.load_yaml(config_path)
22
+ if override_args:
23
+ self.override_config(override_args)
24
+
25
+ def load_yaml(self, config_path: str):
26
+ """Load YAML configuration file"""
27
+ loaded_config = OmegaConf.load(config_path)
28
+ self.config = OmegaConf.merge(self.config, loaded_config)
29
+
30
+ def override_config(self, override_args: Dict[str, Any]):
31
+ """Handle command line override arguments"""
32
+ dotlist = []
33
+ for key, value in override_args.items():
34
+ # Handle values that might be converted types but should be strings for paths
35
+ # The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong
36
+ # or splitting logic is wrong.
37
+ # Using OmegaConf's standard from_dotlist approach is safest.
38
+ # It expects "key=value" strings.
39
+ # We need to be careful about value conversion.
40
+ # Our _convert_value handles basic types.
41
+
42
+ val = self._convert_value(value)
43
+ # If val is a string, we keep it as is.
44
+ # OmegaConf.from_dotlist parses the string again if we pass "key=value".
45
+ # But we can construct a config from dict and merge.
46
+
47
+ # If we use OmegaConf.update(self.config, key, val) it should work for dotted keys.
48
+ # However, `update` takes a key and value.
49
+ OmegaConf.update(self.config, key, val)
50
+
51
+ def _convert_value(self, value: str) -> Any:
52
+ """Convert string value to appropriate type"""
53
+ if value.lower() == "true":
54
+ return True
55
+ elif value.lower() == "false":
56
+ return False
57
+ elif value.lower() == "null":
58
+ return None
59
+ try:
60
+ return int(value)
61
+ except ValueError:
62
+ try:
63
+ return float(value)
64
+ except ValueError:
65
+ return value
66
+
67
+ def get(self, key: str, default: Any = None) -> Any:
68
+ """Get configuration value"""
69
+ return OmegaConf.select(self.config, key, default=default)
70
+
71
+ def __getattr__(self, name: str) -> Any:
72
+ """Support dot notation access"""
73
+ return self.config[name]
74
+
75
+ def __getitem__(self, key: str) -> Any:
76
+ """Support dictionary-like access"""
77
+ return self.config[key]
78
+
79
+ def export_config(self, path: str):
80
+ """Export current configuration to file"""
81
+ OmegaConf.save(self.config, path)
82
+
83
+
84
+ def parse_args():
85
+ """Parse command line arguments"""
86
+ parser = argparse.ArgumentParser()
87
+ parser.add_argument(
88
+ "--config", type=str, required=True, help="Path to config file"
89
+ )
90
+ parser.add_argument(
91
+ "--override", type=str, nargs="+", help="Override config values (key=value)"
92
+ )
93
+ return parser.parse_args()
94
+
95
+
96
+ def load_config(
97
+ config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None
98
+ ) -> Config:
99
+ """Load configuration"""
100
+ if config_path is None:
101
+ args = parse_args()
102
+ config_path = args.config
103
+ if args.override:
104
+ override_args = {}
105
+ for override in args.override:
106
+ key, value = override.split("=", 1)
107
+ override_args[key.strip()] = value.strip()
108
+
109
+ return Config(config_path, override_args)
110
+
111
+
112
+ def instantiate(target, cfg=None, hfstyle=False, **init_args):
113
+ module_name, class_name = target.rsplit(".", 1)
114
+ module = import_module(module_name)
115
+ class_ = getattr(module, class_name)
116
+ if cfg is None:
117
+ return class_(**init_args)
118
+ else:
119
+ if hfstyle:
120
+ config_class = class_.config_class
121
+ cfg = config_class(config_obj=cfg)
122
+ return class_(cfg, **init_args)
123
+
124
+
125
+ def get_function(target):
126
+ module_name, function_name = target.rsplit(".", 1)
127
+ module = import_module(module_name)
128
+ function_ = getattr(module, function_name)
129
+ return function_
130
+
131
+
132
+ def save_config_and_codes(config, save_dir):
133
+ os.makedirs(save_dir, exist_ok=True)
134
+ sanity_check_dir = os.path.join(save_dir, "sanity_check")
135
+ os.makedirs(sanity_check_dir, exist_ok=True)
136
+ with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f:
137
+ OmegaConf.save(config.config, f)
138
+ current_dir = Path.cwd()
139
+ exclude_dir = current_dir / "outputs"
140
+ for py_file in current_dir.rglob("*.py"):
141
+ if exclude_dir in py_file.parents:
142
+ continue
143
+ dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
144
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
145
+ shutil.copy(py_file, dest_path)
146
+
147
+
148
+ def print_model_size(model):
149
+ total_params = sum(p.numel() for p in model.parameters())
150
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
151
+ rank_zero_info(f"Total parameters: {total_params:,}")
152
+ rank_zero_info(f"Trainable parameters: {trainable_params:,}")
153
+ rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")
154
+
155
+
156
+ def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers):
157
+ """Compare differences between state_dict and parameters"""
158
+ # Get all keys in state_dict
159
+ state_dict_keys = set(state_dict.keys())
160
+
161
+ # Get all keys in named_parameters
162
+ named_params_keys = set(name for name, _ in named_parameters)
163
+
164
+ # Find keys that only exist in state_dict
165
+ only_in_state_dict = state_dict_keys - named_params_keys
166
+
167
+ # Find keys that only exist in named_parameters
168
+ only_in_named_params = named_params_keys - state_dict_keys
169
+
170
+ # Print results
171
+ if only_in_state_dict:
172
+ print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}")
173
+
174
+ if only_in_named_params:
175
+ print(
176
+ f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}"
177
+ )
178
+
179
+ if not only_in_state_dict and not only_in_named_params:
180
+ print("All parameters match between state_dict and named_parameters")
181
+
182
+ # Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean)
183
+ named_buffers_keys = set(name for name, _ in named_buffers)
184
+ buffers_only = state_dict_keys - named_params_keys - named_buffers_keys
185
+
186
+ if buffers_only:
187
+ print(
188
+ f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}"
189
+ )
190
+
191
+ print(f"Total state_dict items: {len(state_dict_keys)}")
192
+ print(f"Total named_parameters: {len(named_params_keys)}")
193
+ print(f"Total named_buffers: {len(named_buffers_keys)}")
194
+
195
+
196
+ def _resolve_global_rank() -> int:
197
+ """Resolve the global rank from environment variables."""
198
+ for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"):
199
+ if key in os.environ:
200
+ try:
201
+ return int(os.environ[key])
202
+ except ValueError:
203
+ continue
204
+ return 0
205
+
206
+
207
+ def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str:
208
+ """
209
+ Get a synchronized run time across all processes.
210
+
211
+ This function ensures all processes (both in distributed training and multi-process
212
+ scenarios) use the same timestamp for output directories and experiment tracking.
213
+
214
+ Args:
215
+ base_dir: Base directory for output files
216
+ env_key: Environment variable key to cache the run time
217
+
218
+ Returns:
219
+ Synchronized timestamp string in format YYYYMMDD_HHMMSS
220
+ """
221
+ cached = os.environ.get(env_key)
222
+ if cached:
223
+ return cached
224
+
225
+ timestamp_format = "%Y%m%d_%H%M%S"
226
+
227
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
228
+ if torch.distributed.get_rank() == 0:
229
+ run_time = datetime.now().strftime(timestamp_format)
230
+ else:
231
+ run_time = None
232
+ container = [run_time]
233
+ torch.distributed.broadcast_object_list(container, src=0)
234
+ run_time = container[0]
235
+ if run_time is None:
236
+ raise RuntimeError("Failed to synchronize run time across ranks.")
237
+ os.environ[env_key] = run_time
238
+ return run_time
239
+
240
+ os.makedirs(base_dir, exist_ok=True)
241
+ sync_token = (
242
+ os.environ.get("SLURM_JOB_ID")
243
+ or os.environ.get("TORCHELASTIC_RUN_ID")
244
+ or os.environ.get("JOB_ID")
245
+ or "default"
246
+ )
247
+ sync_dir = os.path.join(base_dir, ".run_time_sync")
248
+ os.makedirs(sync_dir, exist_ok=True)
249
+ sync_file = os.path.join(sync_dir, f"{sync_token}.txt")
250
+
251
+ global_rank = _resolve_global_rank()
252
+ if global_rank == 0:
253
+ # Remove the sync file if it exists to avoid stale reads by other ranks
254
+ if os.path.exists(sync_file):
255
+ try:
256
+ os.remove(sync_file)
257
+ except OSError:
258
+ pass
259
+
260
+ run_time = datetime.now().strftime(timestamp_format)
261
+ with open(sync_file, "w", encoding="utf-8") as f:
262
+ f.write(run_time)
263
+ else:
264
+ timeout = time.monotonic() + 1200.0
265
+ while True:
266
+ if os.path.exists(sync_file):
267
+ try:
268
+ with open(sync_file, "r", encoding="utf-8") as f:
269
+ run_time = f.read().strip()
270
+ # Check if the timestamp is fresh (within 60 seconds)
271
+ # This prevents reading a stale timestamp from a previous run
272
+ dt = datetime.strptime(run_time, timestamp_format)
273
+ if abs((datetime.now() - dt).total_seconds()) < 60:
274
+ break
275
+ except (ValueError, OSError):
276
+ # File might be empty or partially written, or format mismatch
277
+ pass
278
+
279
+ if time.monotonic() > timeout:
280
+ raise TimeoutError(
281
+ "Timed out waiting for rank 0 to write synchronized timestamp."
282
+ )
283
+ time.sleep(0.1)
284
+
285
+ os.environ[env_key] = run_time
286
+ return run_time
ldf_utils/math/__init__.py ADDED
File without changes
ldf_utils/math/quaternion.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float64).eps
14
+
15
+ # PyTorch-backed implementations
16
+
17
+
18
+ def qinv(q):
19
+ assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
20
+ mask = torch.ones_like(q)
21
+ mask[..., 1:] = -mask[..., 1:]
22
+ return q * mask
23
+
24
+
25
+ def qinv_np(q):
26
+ assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
27
+ return qinv(torch.from_numpy(q).float()).numpy()
28
+
29
+
30
+ def qnormalize(q):
31
+ assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
32
+ return q / torch.norm(q, dim=-1, keepdim=True)
33
+
34
+
35
+ def qmul(q, r):
36
+ """
37
+ Multiply quaternion(s) q with quaternion(s) r.
38
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
39
+ Returns q*r as a tensor of shape (*, 4).
40
+ """
41
+ assert q.shape[-1] == 4
42
+ assert r.shape[-1] == 4
43
+
44
+ original_shape = q.shape
45
+
46
+ # Compute outer product
47
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
48
+
49
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
50
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
51
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
52
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
53
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
54
+
55
+
56
+ def qrot(q, v):
57
+ """
58
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
59
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
60
+ where * denotes any number of dimensions.
61
+ Returns a tensor of shape (*, 3).
62
+ """
63
+ assert q.shape[-1] == 4
64
+ assert v.shape[-1] == 3
65
+ assert q.shape[:-1] == v.shape[:-1]
66
+
67
+ original_shape = list(v.shape)
68
+ # print(q.shape)
69
+ q = q.contiguous().view(-1, 4)
70
+ v = v.contiguous().view(-1, 3)
71
+
72
+ qvec = q[:, 1:]
73
+ uv = torch.cross(qvec, v, dim=1)
74
+ uuv = torch.cross(qvec, uv, dim=1)
75
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
76
+
77
+
78
+ def qeuler(q, order, epsilon=0, deg=True):
79
+ """
80
+ Convert quaternion(s) q to Euler angles.
81
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
82
+ Returns a tensor of shape (*, 3).
83
+ """
84
+ assert q.shape[-1] == 4
85
+
86
+ original_shape = list(q.shape)
87
+ original_shape[-1] = 3
88
+ q = q.view(-1, 4)
89
+
90
+ q0 = q[:, 0]
91
+ q1 = q[:, 1]
92
+ q2 = q[:, 2]
93
+ q3 = q[:, 3]
94
+
95
+ if order == "xyz":
96
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
97
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
98
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
99
+ elif order == "yzx":
100
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
101
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
102
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
103
+ elif order == "zxy":
104
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
105
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
106
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ elif order == "xzy":
108
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
109
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
110
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
111
+ elif order == "yxz":
112
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
113
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
114
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
115
+ elif order == "zyx":
116
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
117
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
118
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
119
+ else:
120
+ raise
121
+
122
+ if deg:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
124
+ else:
125
+ return torch.stack((x, y, z), dim=1).view(original_shape)
126
+
127
+
128
+ # Numpy-backed implementations
129
+
130
+
131
+ def qmul_np(q, r):
132
+ q = torch.from_numpy(q).contiguous().float()
133
+ r = torch.from_numpy(r).contiguous().float()
134
+ return qmul(q, r).numpy()
135
+
136
+
137
+ def qrot_np(q, v):
138
+ q = torch.from_numpy(q).contiguous().float()
139
+ v = torch.from_numpy(v).contiguous().float()
140
+ return qrot(q, v).numpy()
141
+
142
+
143
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
144
+ if use_gpu:
145
+ q = torch.from_numpy(q).cuda().float()
146
+ return qeuler(q, order, epsilon).cpu().numpy()
147
+ else:
148
+ q = torch.from_numpy(q).contiguous().float()
149
+ return qeuler(q, order, epsilon).numpy()
150
+
151
+
152
+ def qfix(q):
153
+ """
154
+ Enforce quaternion continuity across the time dimension by selecting
155
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
156
+ between two consecutive frames.
157
+
158
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
159
+ Returns a tensor of the same shape.
160
+ """
161
+ assert len(q.shape) == 3
162
+ assert q.shape[-1] == 4
163
+
164
+ result = q.copy()
165
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
166
+ mask = dot_products < 0
167
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
168
+ result[1:][mask] *= -1
169
+ return result
170
+
171
+
172
+ def euler2quat(e, order, deg=True):
173
+ """
174
+ Convert Euler angles to quaternions.
175
+ """
176
+ assert e.shape[-1] == 3
177
+
178
+ original_shape = list(e.shape)
179
+ original_shape[-1] = 4
180
+
181
+ e = e.view(-1, 3)
182
+
183
+ # if euler angles in degrees
184
+ if deg:
185
+ e = e * np.pi / 180.0
186
+
187
+ x = e[:, 0]
188
+ y = e[:, 1]
189
+ z = e[:, 2]
190
+
191
+ rx = torch.stack(
192
+ (torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)),
193
+ dim=1,
194
+ )
195
+ ry = torch.stack(
196
+ (torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)),
197
+ dim=1,
198
+ )
199
+ rz = torch.stack(
200
+ (torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)),
201
+ dim=1,
202
+ )
203
+
204
+ result = None
205
+ for coord in order:
206
+ if coord == "x":
207
+ r = rx
208
+ elif coord == "y":
209
+ r = ry
210
+ elif coord == "z":
211
+ r = rz
212
+ else:
213
+ raise
214
+ if result is None:
215
+ result = r
216
+ else:
217
+ result = qmul(result, r)
218
+
219
+ # Reverse antipodal representation to have a non-negative "w"
220
+ if order in ["xyz", "yzx", "zxy"]:
221
+ result *= -1
222
+
223
+ return result.view(original_shape)
224
+
225
+
226
+ def expmap_to_quaternion(e):
227
+ """
228
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
229
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
230
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
231
+ Returns a tensor of shape (*, 4).
232
+ """
233
+ assert e.shape[-1] == 3
234
+
235
+ original_shape = list(e.shape)
236
+ original_shape[-1] = 4
237
+ e = e.reshape(-1, 3)
238
+
239
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
240
+ w = np.cos(0.5 * theta).reshape(-1, 1)
241
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
242
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
243
+
244
+
245
+ def euler_to_quaternion(e, order):
246
+ """
247
+ Convert Euler angles to quaternions.
248
+ """
249
+ assert e.shape[-1] == 3
250
+
251
+ original_shape = list(e.shape)
252
+ original_shape[-1] = 4
253
+
254
+ e = e.reshape(-1, 3)
255
+
256
+ x = e[:, 0]
257
+ y = e[:, 1]
258
+ z = e[:, 2]
259
+
260
+ rx = np.stack(
261
+ (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1
262
+ )
263
+ ry = np.stack(
264
+ (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1
265
+ )
266
+ rz = np.stack(
267
+ (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1
268
+ )
269
+
270
+ result = None
271
+ for coord in order:
272
+ if coord == "x":
273
+ r = rx
274
+ elif coord == "y":
275
+ r = ry
276
+ elif coord == "z":
277
+ r = rz
278
+ else:
279
+ raise
280
+ if result is None:
281
+ result = r
282
+ else:
283
+ result = qmul_np(result, r)
284
+
285
+ # Reverse antipodal representation to have a non-negative "w"
286
+ if order in ["xyz", "yzx", "zxy"]:
287
+ result *= -1
288
+
289
+ return result.reshape(original_shape)
290
+
291
+
292
+ def quaternion_to_matrix(quaternions):
293
+ """
294
+ Convert rotations given as quaternions to rotation matrices.
295
+ Args:
296
+ quaternions: quaternions with real part first,
297
+ as tensor of shape (..., 4).
298
+ Returns:
299
+ Rotation matrices as tensor of shape (..., 3, 3).
300
+ """
301
+ r, i, j, k = torch.unbind(quaternions, -1)
302
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
303
+
304
+ o = torch.stack(
305
+ (
306
+ 1 - two_s * (j * j + k * k),
307
+ two_s * (i * j - k * r),
308
+ two_s * (i * k + j * r),
309
+ two_s * (i * j + k * r),
310
+ 1 - two_s * (i * i + k * k),
311
+ two_s * (j * k - i * r),
312
+ two_s * (i * k - j * r),
313
+ two_s * (j * k + i * r),
314
+ 1 - two_s * (i * i + j * j),
315
+ ),
316
+ -1,
317
+ )
318
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
319
+
320
+
321
+ def quaternion_to_matrix_np(quaternions):
322
+ q = torch.from_numpy(quaternions).contiguous().float()
323
+ return quaternion_to_matrix(q).numpy()
324
+
325
+
326
+ def quaternion_to_cont6d_np(quaternions):
327
+ rotation_mat = quaternion_to_matrix_np(quaternions)
328
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
329
+ return cont_6d
330
+
331
+
332
+ def quaternion_to_cont6d(quaternions):
333
+ rotation_mat = quaternion_to_matrix(quaternions)
334
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
335
+ return cont_6d
336
+
337
+
338
+ def cont6d_to_matrix(cont6d):
339
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
340
+ x_raw = cont6d[..., 0:3]
341
+ y_raw = cont6d[..., 3:6]
342
+
343
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
344
+ z = torch.cross(x, y_raw, dim=-1)
345
+ z = z / torch.norm(z, dim=-1, keepdim=True)
346
+
347
+ y = torch.cross(z, x, dim=-1)
348
+
349
+ x = x[..., None]
350
+ y = y[..., None]
351
+ z = z[..., None]
352
+
353
+ mat = torch.cat([x, y, z], dim=-1)
354
+ return mat
355
+
356
+
357
+ def cont6d_to_matrix_np(cont6d):
358
+ q = torch.from_numpy(cont6d).contiguous().float()
359
+ return cont6d_to_matrix(q).numpy()
360
+
361
+
362
+ def qpow(q0, t, dtype=torch.float):
363
+ """q0 : tensor of quaternions
364
+ t: tensor of powers
365
+ """
366
+ q0 = qnormalize(q0)
367
+ theta0 = torch.acos(q0[..., 0])
368
+
369
+ # if theta0 is close to zero, add epsilon to avoid NaNs
370
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
371
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
372
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
373
+
374
+ if isinstance(t, torch.Tensor):
375
+ q = torch.zeros(t.shape + q0.shape)
376
+ theta = t.view(-1, 1) * theta0.view(1, -1)
377
+ else: # if t is a number
378
+ q = torch.zeros(q0.shape)
379
+ theta = t * theta0
380
+
381
+ q[..., 0] = torch.cos(theta)
382
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
383
+
384
+ return q.to(dtype)
385
+
386
+
387
+ def qslerp(q0, q1, t):
388
+ """
389
+ q0: starting quaternion
390
+ q1: ending quaternion
391
+ t: array of points along the way
392
+
393
+ Returns:
394
+ Tensor of Slerps: t.shape + q0.shape
395
+ """
396
+
397
+ q0 = qnormalize(q0)
398
+ q1 = qnormalize(q1)
399
+ q_ = qpow(qmul(q1, qinv(q0)), t)
400
+
401
+ return qmul(
402
+ q_,
403
+ q0.contiguous()
404
+ .view(torch.Size([1] * len(t.shape)) + q0.shape)
405
+ .expand(t.shape + q0.shape)
406
+ .contiguous(),
407
+ )
408
+
409
+
410
+ def qbetween(v0, v1):
411
+ """
412
+ find the quaternion used to rotate v0 to v1
413
+ """
414
+ assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
415
+ assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
416
+
417
+ v = torch.cross(v0, v1)
418
+ w = torch.sqrt(
419
+ (v0**2).sum(dim=-1, keepdim=True) * (v1**2).sum(dim=-1, keepdim=True)
420
+ ) + (v0 * v1).sum(dim=-1, keepdim=True)
421
+ return qnormalize(torch.cat([w, v], dim=-1))
422
+
423
+
424
+ def qbetween_np(v0, v1):
425
+ """
426
+ find the quaternion used to rotate v0 to v1
427
+ """
428
+ assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
429
+ assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
430
+
431
+ v0 = torch.from_numpy(v0).float()
432
+ v1 = torch.from_numpy(v1).float()
433
+ return qbetween(v0, v1).numpy()
434
+
435
+
436
+ def lerp(p0, p1, t):
437
+ if not isinstance(t, torch.Tensor):
438
+ t = torch.Tensor([t])
439
+
440
+ new_shape = t.shape + p0.shape
441
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
442
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
443
+ p0 = p0.view(new_view_p).expand(new_shape)
444
+ p1 = p1.view(new_view_p).expand(new_shape)
445
+ t = t.view(new_view_t).expand(new_shape)
446
+
447
+ return p0 + t * (p1 - p0)
ldf_utils/motion_process.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from ldf_utils.math.quaternion import *
6
+
7
+ """
8
+ Motion data structure:
9
+ (B: batch size)
10
+ root_rot_velocity (B, seq_len, 1)
11
+ root_linear_velocity (B, seq_len, 2)
12
+ root_y (B, seq_len, 1)
13
+ ric_data (B, seq_len, (joint_num - 1)*3)
14
+ rot_data (B, seq_len, (joint_num - 1)*6)
15
+ local_velocity (B, seq_len, joint_num*3)
16
+ foot contact (B, seq_len, 4)
17
+ """
18
+
19
+
20
+ def recover_root_rot_pos(data):
21
+ # recover root rotation and position
22
+ rot_vel = data[..., 0]
23
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
24
+ """Get Y-axis rotation from rotation velocity"""
25
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
26
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
27
+
28
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
29
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
30
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
31
+
32
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
33
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
34
+ """Add Y-axis rotation to root position"""
35
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
36
+
37
+ r_pos = torch.cumsum(r_pos, dim=-2)
38
+
39
+ r_pos[..., 1] = data[..., 3]
40
+ return r_rot_quat, r_pos
41
+
42
+
43
+ def recover_joint_positions_263(data: np.ndarray, joints_num) -> np.ndarray:
44
+ """
45
+ Recovers 3D joint positions from the rotation-invariant local positions (ric_data).
46
+ This is the most direct way to get the skeleton for animation.
47
+ """
48
+ feature_vec = torch.from_numpy(data).unsqueeze(0).float()
49
+ r_rot_quat, r_pos = recover_root_rot_pos(feature_vec)
50
+ positions = feature_vec[..., 4 : (joints_num - 1) * 3 + 4]
51
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
52
+ """Add Y-axis rotation to local joints"""
53
+ positions = qrot(
54
+ qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions
55
+ )
56
+ """Add root XZ to joints"""
57
+ positions[..., 0] += r_pos[..., 0:1]
58
+ positions[..., 2] += r_pos[..., 2:3]
59
+ """Concatenate root and joints"""
60
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
61
+ joints_np = positions.squeeze(0).detach().cpu().numpy()
62
+ return joints_np
63
+
64
+
65
+ class StreamJointRecovery263:
66
+ """
67
+ Stream version of recover_joint_positions_263 that processes one frame at a time.
68
+ Maintains cumulative state for rotation angles and positions.
69
+
70
+ Key insight: The batch version uses PREVIOUS frame's velocity for the current frame,
71
+ so we need to delay the velocity application by one frame.
72
+ """
73
+
74
+ def __init__(self, joints_num: int):
75
+ self.joints_num = joints_num
76
+ self.reset()
77
+
78
+ def reset(self):
79
+ """Reset the accumulated state"""
80
+ self.r_rot_ang_accum = 0.0
81
+ self.r_pos_accum = np.array([0.0, 0.0, 0.0])
82
+ # Store previous frame's velocities for delayed application
83
+ self.prev_rot_vel = 0.0
84
+ self.prev_linear_vel = np.array([0.0, 0.0])
85
+
86
+ def process_frame(self, frame_data: np.ndarray) -> np.ndarray:
87
+ """
88
+ Process a single frame and return joint positions for that frame.
89
+
90
+ Args:
91
+ frame_data: numpy array of shape (263,) for a single frame
92
+
93
+ Returns:
94
+ joints: numpy array of shape (joints_num, 3) representing joint positions
95
+ """
96
+ # Convert to torch tensor
97
+ feature_vec = torch.from_numpy(frame_data).float()
98
+
99
+ # Extract current frame's velocities (will be used in NEXT frame)
100
+ curr_rot_vel = feature_vec[0].item()
101
+ curr_linear_vel = feature_vec[1:3].numpy()
102
+
103
+ # Update accumulated rotation angle with PREVIOUS frame's velocity FIRST
104
+ # This matches the batch processing: r_rot_ang[i] uses rot_vel[i-1]
105
+ self.r_rot_ang_accum += self.prev_rot_vel
106
+
107
+ # Calculate current rotation quaternion using updated accumulated angle
108
+ r_rot_quat = torch.zeros(4)
109
+ r_rot_quat[0] = np.cos(self.r_rot_ang_accum)
110
+ r_rot_quat[2] = np.sin(self.r_rot_ang_accum)
111
+
112
+ # Create velocity vector with Y=0 using PREVIOUS frame's velocity
113
+ r_vel = np.array([self.prev_linear_vel[0], 0.0, self.prev_linear_vel[1]])
114
+
115
+ # Apply inverse rotation to velocity using CURRENT rotation
116
+ r_vel_torch = torch.from_numpy(r_vel).float()
117
+ r_vel_rotated = qrot(qinv(r_rot_quat).unsqueeze(0), r_vel_torch.unsqueeze(0))
118
+ r_vel_rotated = r_vel_rotated.squeeze(0).numpy()
119
+
120
+ # Update accumulated position with rotated velocity
121
+ self.r_pos_accum += r_vel_rotated
122
+
123
+ # Get Y position from data
124
+ r_pos = self.r_pos_accum.copy()
125
+ r_pos[1] = feature_vec[3].item()
126
+
127
+ # Extract local joint positions
128
+ positions = feature_vec[4 : (self.joints_num - 1) * 3 + 4]
129
+ positions = positions.view(-1, 3)
130
+
131
+ # Apply inverse rotation to local joints
132
+ r_rot_quat_expanded = (
133
+ qinv(r_rot_quat).unsqueeze(0).expand(positions.shape[0], 4)
134
+ )
135
+ positions = qrot(r_rot_quat_expanded, positions)
136
+
137
+ # Add root XZ to joints
138
+ positions[:, 0] += r_pos[0]
139
+ positions[:, 2] += r_pos[2]
140
+
141
+ # Concatenate root and joints
142
+ r_pos_torch = torch.from_numpy(r_pos).float()
143
+ positions = torch.cat([r_pos_torch.unsqueeze(0), positions], dim=0)
144
+
145
+ # Convert to numpy
146
+ joints_np = positions.detach().cpu().numpy()
147
+
148
+ # Store current velocities for next frame
149
+ self.prev_rot_vel = curr_rot_vel
150
+ self.prev_linear_vel = curr_linear_vel
151
+
152
+ return joints_np
153
+
154
+
155
+ def accumulate_rotations(relative_rotations):
156
+ R_total = [relative_rotations[0]]
157
+ for R_rel in relative_rotations[1:]:
158
+ R_total.append(np.matmul(R_rel, R_total[-1]))
159
+
160
+ return np.array(R_total)
161
+
162
+
163
+ def recover_from_local_position(final_x, njoint):
164
+ nfrm, _ = final_x.shape
165
+ positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(
166
+ nfrm, -1, 3
167
+ ) # frames, njoints * 3
168
+ velocities_root_xy_no_heading = final_x[:, :2] # frames, 2
169
+ global_heading_diff_rot = final_x[:, 2:8] # frames, 6
170
+
171
+ # recover global heading
172
+ global_heading_rot = accumulate_rotations(
173
+ rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
174
+ )
175
+ inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
176
+ # add global heading to position
177
+ positions_with_heading = np.matmul(
178
+ np.repeat(inv_global_heading_rot[:, None, :, :], njoint, axis=1),
179
+ positions_no_heading[..., None],
180
+ ).squeeze(-1)
181
+
182
+ # recover root translation
183
+ # add heading to velocities_root_xy_no_heading
184
+
185
+ velocities_root_xyz_no_heading = np.zeros(
186
+ (
187
+ velocities_root_xy_no_heading.shape[0],
188
+ 3,
189
+ )
190
+ )
191
+ velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
192
+ velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
193
+ velocities_root_xyz_no_heading[1:, :] = np.matmul(
194
+ inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
195
+ ).squeeze(-1)
196
+
197
+ root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
198
+
199
+ # add root translation
200
+ positions_with_heading[:, :, 0] += root_translation[:, 0:1]
201
+ positions_with_heading[:, :, 2] += root_translation[:, 2:]
202
+
203
+ return positions_with_heading
204
+
205
+
206
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
207
+ a1, a2 = d6[..., :3], d6[..., 3:]
208
+ b1 = F.normalize(a1, dim=-1)
209
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
210
+ b2 = F.normalize(b2, dim=-1)
211
+ b3 = torch.cross(b1, b2, dim=-1)
212
+ return torch.stack((b1, b2, b3), dim=-2)
213
+
214
+
215
+ def _copysign(a, b):
216
+ signs_differ = (a < 0) != (b < 0)
217
+ return torch.where(signs_differ, -a, a)
218
+
219
+
220
+ def _sqrt_positive_part(x):
221
+ ret = torch.zeros_like(x)
222
+ positive_mask = x > 0
223
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
224
+ return ret
225
+
226
+
227
+ def matrix_to_quaternion(matrix):
228
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
229
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
230
+ m00 = matrix[..., 0, 0]
231
+ m11 = matrix[..., 1, 1]
232
+ m22 = matrix[..., 2, 2]
233
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
234
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
235
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
236
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
237
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
238
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
239
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
240
+ return torch.stack((o0, o1, o2, o3), -1)
241
+
242
+
243
+ def quaternion_to_axis_angle(quaternions):
244
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
245
+ half_angles = torch.atan2(norms, quaternions[..., :1])
246
+ angles = 2 * half_angles
247
+ eps = 1e-6
248
+ small_angles = angles.abs() < eps
249
+ sin_half_angles_over_angles = torch.empty_like(angles)
250
+ sin_half_angles_over_angles[~small_angles] = (
251
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
252
+ )
253
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
254
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
255
+ sin_half_angles_over_angles[small_angles] = (
256
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
257
+ )
258
+ return quaternions[..., 1:] / sin_half_angles_over_angles
259
+
260
+
261
+ def matrix_to_axis_angle(matrix):
262
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
263
+
264
+
265
+ def rotations_matrix_to_smpl85(rotations_matrix, translation):
266
+ nfrm, njoint, _, _ = rotations_matrix.shape
267
+ axis_angle = (
268
+ matrix_to_axis_angle(torch.from_numpy(rotations_matrix))
269
+ .numpy()
270
+ .reshape(nfrm, -1)
271
+ )
272
+ smpl_85 = np.concatenate(
273
+ [axis_angle, np.zeros((nfrm, 6)), translation, np.zeros((nfrm, 10))], axis=-1
274
+ )
275
+ return smpl_85
276
+
277
+
278
+ def recover_from_local_rotation(final_x, njoint):
279
+ nfrm, _ = final_x.shape
280
+ rotations_matrix = rotation_6d_to_matrix(
281
+ torch.from_numpy(final_x[:, 8 + 6 * njoint : 8 + 12 * njoint]).reshape(
282
+ nfrm, -1, 6
283
+ )
284
+ ).numpy()
285
+ global_heading_diff_rot = final_x[:, 2:8]
286
+ velocities_root_xy_no_heading = final_x[:, :2]
287
+ positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(nfrm, -1, 3)
288
+ height = positions_no_heading[:, 0, 1]
289
+
290
+ global_heading_rot = accumulate_rotations(
291
+ rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
292
+ )
293
+ inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
294
+ # recover root rotation
295
+ rotations_matrix[:, 0, ...] = np.matmul(
296
+ inv_global_heading_rot, rotations_matrix[:, 0, ...]
297
+ )
298
+ velocities_root_xyz_no_heading = np.zeros(
299
+ (
300
+ velocities_root_xy_no_heading.shape[0],
301
+ 3,
302
+ )
303
+ )
304
+ velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
305
+ velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
306
+ velocities_root_xyz_no_heading[1:, :] = np.matmul(
307
+ inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
308
+ ).squeeze(-1)
309
+ root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
310
+ root_translation[:, 1] = height
311
+ smpl_85 = rotations_matrix_to_smpl85(rotations_matrix, root_translation)
312
+ return smpl_85
313
+
314
+
315
+ def recover_joint_positions_272(data: np.ndarray, joints_num) -> np.ndarray:
316
+ return recover_from_local_position(data, joints_num)
317
+
318
+
319
+ def convert_motion_to_joints(
320
+ motion_data: np.ndarray,
321
+ dim: int,
322
+ mean: np.ndarray = None,
323
+ std: np.ndarray = None,
324
+ joints_num=22,
325
+ ):
326
+ """
327
+ Convert Kx263 dim or Kx272 dim motion data to Kx22x3 joint positions.
328
+ Args:
329
+ motion_data: numpy array of shape (K, 263) or (K, 272) where K is number of frames
330
+ Returns:
331
+ joints: numpy array of shape (K, 22, 3) representing joint positions
332
+ """
333
+ if mean is not None and std is not None:
334
+ motion_data = motion_data * std + mean
335
+ if dim == 263:
336
+ recovered_positions = recover_joint_positions_263(motion_data, joints_num)
337
+ elif dim == 272:
338
+ recovered_positions = recover_joint_positions_272(motion_data, joints_num)
339
+ else:
340
+ raise ValueError(f"Unsupported motion data dimension: {dim}")
341
+ return recovered_positions
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6034bb8fad782741caa33eb785935e2d1543e881772dacdbe9292107ec45436e
3
+ size 454897088
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ huggingface_hub>=0.16.0
5
+ safetensors>=0.3.0
6
+ diffusers>=0.20.0
7
+
8
+ # Inference
9
+ lightning>=2.0.0
10
+ ftfy
11
+
12
+ # Configuration
13
+ omegaconf
14
+
15
+ # Utilities
16
+ numpy
17
+
18
+ # Note: flash-attn is required but needs special installation
19
+ # See README.md for installation instructions
vae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a40164154c476309ff952a4b7563750b7e76fbdd8d263ec261ad877cf452e7b
3
+ size 70027220