Initial commit: FloodDiffusion text-to-motion generation model
Browse files- .gitattributes +1 -0
- README.md +219 -0
- __init__.py +11 -0
- config.json +11 -0
- generate_ldf.py +139 -0
- hf_pipeline.py +278 -0
- ldf.yaml +32 -0
- ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/special_tokens_map.json +308 -0
- ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/spiece.model +3 -0
- ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json +3 -0
- ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer_config.json +2748 -0
- ldf_deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth +3 -0
- ldf_models/__init__.py +0 -0
- ldf_models/diffusion_forcing_wan.py +899 -0
- ldf_models/tools/attention.py +188 -0
- ldf_models/tools/t5.py +564 -0
- ldf_models/tools/tokenizers.py +84 -0
- ldf_models/tools/wan_model.py +592 -0
- ldf_models/tools/wan_vae_1d.py +762 -0
- ldf_models/vae_wan_1d.py +212 -0
- ldf_utils/__init__.py +0 -0
- ldf_utils/initialize.py +286 -0
- ldf_utils/math/__init__.py +0 -0
- ldf_utils/math/quaternion.py +447 -0
- ldf_utils/motion_process.py +341 -0
- model.safetensors +3 -0
- requirements.txt +19 -0
- vae.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- text-to-motion
|
| 5 |
+
- motion-generation
|
| 6 |
+
- diffusion-forcing
|
| 7 |
+
- humanml3d
|
| 8 |
+
- computer-animation
|
| 9 |
+
library_name: transformers
|
| 10 |
+
pipeline_tag: other
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
|
| 17 |
+
**A state-of-the-art text-to-motion generation model based on Latent Diffusion Forcing**
|
| 18 |
+
|
| 19 |
+
[Paper]() | [Project Page]() | [Demo]()
|
| 20 |
+
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
## Overview
|
| 24 |
+
|
| 25 |
+
We present FloodDiffusion, a new framework for text-driven, streaming human motion generation. Given time-varying text prompts, FloodDiffusion generates text-aligned, seamless motion sequences with real-time latency.
|
| 26 |
+
|
| 27 |
+
## Model Architecture
|
| 28 |
+
|
| 29 |
+
The model consists of three main components:
|
| 30 |
+
|
| 31 |
+
1. **Text Encoder**: UMT5-XXL encoder for text feature extraction
|
| 32 |
+
2. **Latent Diffusion Model**: Transformer-based diffusion model operating in latent space
|
| 33 |
+
3. **VAE Decoder**: 1D convolutional VAE for decoding latent features to motion sequences
|
| 34 |
+
|
| 35 |
+
**Technical Specifications:**
|
| 36 |
+
- Input: Natural language text
|
| 37 |
+
- Output: Motion sequences in two formats:
|
| 38 |
+
- 263-dimensional HumanML3D features (default)
|
| 39 |
+
- 22×3 joint coordinates (optional)
|
| 40 |
+
- Latent dimension: 4
|
| 41 |
+
- Upsampling factor: 4× (VAE decoder)
|
| 42 |
+
- Frame rate: 20 FPS
|
| 43 |
+
|
| 44 |
+
## Installation
|
| 45 |
+
|
| 46 |
+
### Prerequisites
|
| 47 |
+
|
| 48 |
+
- Python 3.8+
|
| 49 |
+
- CUDA-capable GPU with 16GB+ VRAM (recommended)
|
| 50 |
+
- 16GB+ system RAM
|
| 51 |
+
|
| 52 |
+
### Dependencies
|
| 53 |
+
|
| 54 |
+
**Step 1: Install basic dependencies**
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
pip install torch transformers huggingface_hub
|
| 58 |
+
pip install lightning diffusers omegaconf ftfy numpy
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
**Step 2: Install Flash Attention (Required)**
|
| 62 |
+
|
| 63 |
+
Flash attention requires CUDA and may need compilation. Choose the appropriate method:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
pip install flash-attn --no-build-isolation
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
**Note:** Flash attention is **required** for this model. If installation fails, please refer to the [official flash-attention installation guide](https://github.com/Dao-AILab/flash-attention#installation-and-features).
|
| 70 |
+
|
| 71 |
+
## Quick Start
|
| 72 |
+
|
| 73 |
+
### Basic Usage
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from transformers import AutoModel
|
| 77 |
+
|
| 78 |
+
# Load model
|
| 79 |
+
model = AutoModel.from_pretrained(
|
| 80 |
+
"ShandaAI/FloodDiffusion",
|
| 81 |
+
trust_remote_code=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Generate motion from text (263-dim HumanML3D features)
|
| 85 |
+
motion = model("a person walking forward", length=60)
|
| 86 |
+
print(f"Generated motion: {motion.shape}") # (~240, 263)
|
| 87 |
+
|
| 88 |
+
# Generate motion as joint coordinates (22 joints × 3 coords)
|
| 89 |
+
motion_joints = model("a person walking forward", length=60, output_joints=True)
|
| 90 |
+
print(f"Generated joints: {motion_joints.shape}") # (~240, 22, 3)
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
### Batch Generation
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
# Generate multiple motions efficiently
|
| 97 |
+
texts = [
|
| 98 |
+
"a person walking forward",
|
| 99 |
+
"a person running quickly",
|
| 100 |
+
"a person jumping up and down"
|
| 101 |
+
]
|
| 102 |
+
lengths = [60, 50, 40] # Different lengths for each motion
|
| 103 |
+
|
| 104 |
+
motions = model(texts, length=lengths)
|
| 105 |
+
|
| 106 |
+
for i, motion in enumerate(motions):
|
| 107 |
+
print(f"Motion {i}: {motion.shape}")
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### Multi-Text Motion Transitions
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
# Generate a motion sequence with smooth transitions between actions
|
| 114 |
+
motion = model(
|
| 115 |
+
text=[["walk forward", "turn around", "run back"]],
|
| 116 |
+
length=[120],
|
| 117 |
+
text_end=[[40, 80, 120]] # Transition points in latent tokens
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Output: ~480 frames showing all three actions smoothly connected
|
| 121 |
+
print(f"Transition motion: {motion[0].shape}")
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
## API Reference
|
| 125 |
+
|
| 126 |
+
### `model(text, length=60, text_end=None, num_denoise_steps=None, output_joints=False)`
|
| 127 |
+
|
| 128 |
+
Generate motion sequences from text descriptions.
|
| 129 |
+
|
| 130 |
+
**Parameters:**
|
| 131 |
+
|
| 132 |
+
- **text** (`str`, `List[str]`, or `List[List[str]]`): Text description(s)
|
| 133 |
+
- Single string: Generate one motion
|
| 134 |
+
- List of strings: Batch generation
|
| 135 |
+
- Nested list: Multiple text prompts per motion (for transitions)
|
| 136 |
+
|
| 137 |
+
- **length** (`int` or `List[int]`, default=60): Number of latent tokens to generate
|
| 138 |
+
- Output frames ≈ `length × 4` (due to VAE upsampling)
|
| 139 |
+
- Example: `length=60` → ~240 frames (~12 seconds at 20 FPS)
|
| 140 |
+
|
| 141 |
+
- **text_end** (`List[int]` or `List[List[int]]`, optional): Latent token positions for text transitions
|
| 142 |
+
- Only used when `text` is a nested list
|
| 143 |
+
- Specifies when to switch between different text descriptions
|
| 144 |
+
- **IMPORTANT**: Must have the same length as the corresponding text list
|
| 145 |
+
- Example: `text=[["walk", "turn", "sit"]]` requires `text_end=[[20, 40, 60]]` (3 endpoints for 3 texts)
|
| 146 |
+
- Must be in ascending order
|
| 147 |
+
|
| 148 |
+
- **num_denoise_steps** (`int`, optional): Number of denoising iterations
|
| 149 |
+
- Higher values produce better quality but slower generation
|
| 150 |
+
- Recommended range: 10-50
|
| 151 |
+
|
| 152 |
+
- **output_joints** (`bool`, default=False): Output format selector
|
| 153 |
+
- `False`: Returns 263-dimensional HumanML3D features
|
| 154 |
+
- `True`: Returns 22×3 joint coordinates for direct visualization
|
| 155 |
+
|
| 156 |
+
**Returns:**
|
| 157 |
+
- Single motion:
|
| 158 |
+
- `output_joints=False`: `numpy.ndarray` of shape `(frames, 263)`
|
| 159 |
+
- `output_joints=True`: `numpy.ndarray` of shape `(frames, 22, 3)`
|
| 160 |
+
- Batch: `List[numpy.ndarray]` with shapes as above
|
| 161 |
+
|
| 162 |
+
**Example:**
|
| 163 |
+
```python
|
| 164 |
+
# Single generation (263-dim features)
|
| 165 |
+
motion = model("walk forward", length=60) # Returns (240, 263)
|
| 166 |
+
|
| 167 |
+
# Single generation (joint coordinates)
|
| 168 |
+
joints = model("walk forward", length=60, output_joints=True) # Returns (240, 22, 3)
|
| 169 |
+
|
| 170 |
+
# Batch generation
|
| 171 |
+
motions = model(["walk", "run"], length=[60, 50]) # Returns list of 2 arrays
|
| 172 |
+
|
| 173 |
+
# Multi-text transitions
|
| 174 |
+
motion = model(
|
| 175 |
+
[["walk", "turn"]],
|
| 176 |
+
length=[60],
|
| 177 |
+
text_end=[[30, 60]]
|
| 178 |
+
) # Returns list with 1 array of shape (240, 263)
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## Citation
|
| 182 |
+
|
| 183 |
+
If you use this model in your research, please cite:
|
| 184 |
+
|
| 185 |
+
```bibtex
|
| 186 |
+
@article{flood2025,
|
| 187 |
+
title={FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation},
|
| 188 |
+
author={YIYI CAI, Yuhan Wu, Kunhang Li, YOU ZHOU, Bo Zheng, Haiyang Liu},
|
| 189 |
+
year={2025}
|
| 190 |
+
}
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
## Troubleshooting
|
| 194 |
+
|
| 195 |
+
### Common Issues
|
| 196 |
+
|
| 197 |
+
**ImportError with trust_remote_code:**
|
| 198 |
+
```python
|
| 199 |
+
# Solution: Add trust_remote_code=True
|
| 200 |
+
model = AutoModel.from_pretrained(
|
| 201 |
+
"ShandaAI/FloodDiffusion",
|
| 202 |
+
trust_remote_code=True # Required!
|
| 203 |
+
)
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
**Out of Memory:**
|
| 207 |
+
```python
|
| 208 |
+
# Solution: Generate shorter sequences
|
| 209 |
+
motion = model("walk", length=30) # Shorter = less memory
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
**Slow first load:**
|
| 213 |
+
The first load downloads ~14GB of model files and may take 5-30 minutes depending on internet speed. Subsequent loads use cached files and are instant.
|
| 214 |
+
|
| 215 |
+
**Module import errors:**
|
| 216 |
+
Ensure all dependencies are installed:
|
| 217 |
+
```bash
|
| 218 |
+
pip install lightning diffusers omegaconf ftfy numpy
|
| 219 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FloodDiffusion - Text-to-Motion Generation
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from transformers import AutoModel
|
| 6 |
+
|
| 7 |
+
model = AutoModel.from_pretrained("your-username/FloodDiffusion", trust_remote_code=True)
|
| 8 |
+
motion = model("a person walking forward", length=60)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__version__ = "1.0.0"
|
config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": ["LDFModel"],
|
| 3 |
+
"model_type": "ldf_motion",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "hf_pipeline.LDFModel",
|
| 6 |
+
"AutoConfig": "hf_pipeline.LDFConfig"
|
| 7 |
+
},
|
| 8 |
+
"torch_dtype": "float32",
|
| 9 |
+
"transformers_version": "4.30.0",
|
| 10 |
+
"license": "mit"
|
| 11 |
+
}
|
generate_ldf.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from lightning import seed_everything
|
| 6 |
+
from safetensors.torch import load_file as load_safetensors
|
| 7 |
+
|
| 8 |
+
from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config
|
| 9 |
+
|
| 10 |
+
# Set tokenizers parallelism to false to avoid warnings in multiprocessing
|
| 11 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model_from_config():
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
torch.set_float32_matmul_precision("high")
|
| 17 |
+
cfg = load_config()
|
| 18 |
+
seed_everything(cfg.seed)
|
| 19 |
+
|
| 20 |
+
# Get the directory containing the config file
|
| 21 |
+
# Try to find config directory from sys.argv or use current directory
|
| 22 |
+
if '--config' in sys.argv:
|
| 23 |
+
config_idx = sys.argv.index('--config') + 1
|
| 24 |
+
config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
|
| 25 |
+
else:
|
| 26 |
+
config_dir = os.getcwd()
|
| 27 |
+
|
| 28 |
+
vae = instantiate(
|
| 29 |
+
target=cfg.test_vae.target,
|
| 30 |
+
cfg=None,
|
| 31 |
+
hfstyle=False,
|
| 32 |
+
**cfg.test_vae.params,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Handle relative paths
|
| 36 |
+
vae_path = cfg.test_vae_ckpt
|
| 37 |
+
if not os.path.isabs(vae_path):
|
| 38 |
+
vae_path = os.path.join(config_dir, vae_path)
|
| 39 |
+
|
| 40 |
+
# Load from safetensors (already contains EMA weights)
|
| 41 |
+
vae_state_dict = load_safetensors(vae_path)
|
| 42 |
+
vae.load_state_dict(vae_state_dict, strict=True)
|
| 43 |
+
print(f"Loaded VAE model from {vae_path}")
|
| 44 |
+
|
| 45 |
+
compare_statedict_and_parameters(
|
| 46 |
+
state_dict=vae.state_dict(),
|
| 47 |
+
named_parameters=vae.named_parameters(),
|
| 48 |
+
named_buffers=vae.named_buffers(),
|
| 49 |
+
)
|
| 50 |
+
vae.to(device)
|
| 51 |
+
vae.eval()
|
| 52 |
+
|
| 53 |
+
# Model - fix relative paths in model params
|
| 54 |
+
model_params = dict(cfg.model.params)
|
| 55 |
+
# Convert relative paths to absolute paths
|
| 56 |
+
if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
|
| 57 |
+
if not os.path.isabs(model_params['checkpoint_path']):
|
| 58 |
+
model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
|
| 59 |
+
if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
|
| 60 |
+
if not os.path.isabs(model_params['tokenizer_path']):
|
| 61 |
+
model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
|
| 62 |
+
|
| 63 |
+
model = instantiate(
|
| 64 |
+
target=cfg.model.target, cfg=None, hfstyle=False, **model_params
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Handle relative paths
|
| 68 |
+
model_path = cfg.test_ckpt
|
| 69 |
+
if not os.path.isabs(model_path):
|
| 70 |
+
model_path = os.path.join(config_dir, model_path)
|
| 71 |
+
|
| 72 |
+
# Load from safetensors (already contains EMA weights)
|
| 73 |
+
model_state_dict = load_safetensors(model_path)
|
| 74 |
+
model.load_state_dict(model_state_dict, strict=True)
|
| 75 |
+
print(f"Loaded model from {model_path}")
|
| 76 |
+
|
| 77 |
+
compare_statedict_and_parameters(
|
| 78 |
+
state_dict=model.state_dict(),
|
| 79 |
+
named_parameters=model.named_parameters(),
|
| 80 |
+
named_buffers=model.named_buffers(),
|
| 81 |
+
)
|
| 82 |
+
model.to(device)
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
return vae, model
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.inference_mode()
|
| 89 |
+
def generate_feature_stream(
|
| 90 |
+
model, feature_length, text, feature_text_end=None, num_denoise_steps=None
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Streaming interface for feature generation
|
| 94 |
+
Args:
|
| 95 |
+
model: Loaded model
|
| 96 |
+
feature_length: List[int], generation length for each sample
|
| 97 |
+
text: List[str] or List[List[str]], text prompts
|
| 98 |
+
feature_text_end: List[List[int]], time points where text ends (if text is list of list)
|
| 99 |
+
num_denoise_steps: Number of denoising steps
|
| 100 |
+
Yields:
|
| 101 |
+
dict: Contains "generated" (current generated feature segment)
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
# Construct input dict x
|
| 105 |
+
# stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
|
| 106 |
+
x = {"feature_length": torch.tensor(feature_length), "text": text}
|
| 107 |
+
|
| 108 |
+
if feature_text_end is not None:
|
| 109 |
+
x["feature_text_end"] = feature_text_end
|
| 110 |
+
|
| 111 |
+
# Call model's stream_generate
|
| 112 |
+
# Note: stream_generate is a generator
|
| 113 |
+
generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)
|
| 114 |
+
|
| 115 |
+
for step_output in generator:
|
| 116 |
+
# step_output is already a dict with "generated" key
|
| 117 |
+
yield step_output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
import argparse
|
| 122 |
+
|
| 123 |
+
parser = argparse.ArgumentParser()
|
| 124 |
+
parser.add_argument("--config", type=str, required=True, help="Path to config")
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--text", type=str, default="a person walks forward", help="Text prompt"
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument("--length", type=int, default=120, help="Motion length")
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--output", type=str, default="output.mp4", help="Output video path"
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
|
| 134 |
+
)
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
|
| 137 |
+
print("Loading model...")
|
| 138 |
+
vae, model = load_model_from_config()
|
| 139 |
+
|
hf_pipeline.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LDF Model for Hugging Face Hub
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from transformers import AutoModel
|
| 6 |
+
|
| 7 |
+
model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
|
| 8 |
+
motion = model("a person walking forward", length=60)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 13 |
+
from typing import Union, List, Optional
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LDFConfig(PretrainedConfig):
|
| 19 |
+
"""Configuration for LDF Motion Generation Model"""
|
| 20 |
+
model_type = "ldf_motion"
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
input_dim=4,
|
| 25 |
+
output_dim=263,
|
| 26 |
+
**kwargs
|
| 27 |
+
):
|
| 28 |
+
super().__init__(**kwargs)
|
| 29 |
+
self.input_dim = input_dim
|
| 30 |
+
self.output_dim = output_dim
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LDFModel(PreTrainedModel):
|
| 34 |
+
"""
|
| 35 |
+
LDF Motion Generation Model
|
| 36 |
+
|
| 37 |
+
This model generates motion sequences from text descriptions using Latent Diffusion Forcing.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
>>> from transformers import AutoModel
|
| 41 |
+
>>> model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
|
| 42 |
+
>>> motion = model("a person walking forward", length=60)
|
| 43 |
+
>>> print(motion.shape) # (~240, 263)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
config_class = LDFConfig
|
| 47 |
+
|
| 48 |
+
def __init__(self, config):
|
| 49 |
+
super().__init__(config)
|
| 50 |
+
self.config = config
|
| 51 |
+
|
| 52 |
+
# Will be loaded in from_pretrained
|
| 53 |
+
self.ldf_model = None
|
| 54 |
+
self.vae = None
|
| 55 |
+
self.model_dir = None # Store model directory for later use
|
| 56 |
+
|
| 57 |
+
def _load_models(self):
|
| 58 |
+
"""Load the actual LDF and VAE models"""
|
| 59 |
+
if self.ldf_model is not None:
|
| 60 |
+
return # Already loaded
|
| 61 |
+
|
| 62 |
+
# Get the model directory - should be set by from_pretrained
|
| 63 |
+
if hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path):
|
| 64 |
+
model_dir = self.name_or_path
|
| 65 |
+
else:
|
| 66 |
+
raise RuntimeError(
|
| 67 |
+
"Model directory not found. Please use from_pretrained() to load the model."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Save model_dir for later use (e.g., in output_joints conversion)
|
| 71 |
+
self.model_dir = model_dir
|
| 72 |
+
|
| 73 |
+
# Add model_dir to sys.path for imports
|
| 74 |
+
if model_dir not in sys.path:
|
| 75 |
+
sys.path.insert(0, model_dir)
|
| 76 |
+
|
| 77 |
+
# Use dynamic import to avoid HF's static import checker
|
| 78 |
+
import importlib
|
| 79 |
+
generate_ldf = importlib.import_module('generate_ldf')
|
| 80 |
+
load_model_from_config = generate_ldf.load_model_from_config
|
| 81 |
+
|
| 82 |
+
config_path = os.path.join(model_dir, "ldf.yaml")
|
| 83 |
+
old_argv = sys.argv
|
| 84 |
+
sys.argv = ['model', '--config', config_path]
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
self.vae, self.ldf_model = load_model_from_config()
|
| 88 |
+
|
| 89 |
+
# Move to correct device
|
| 90 |
+
device = next(self.parameters()).device if list(self.parameters()) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
+
self.ldf_model = self.ldf_model.to(device)
|
| 92 |
+
self.vae = self.vae.to(device)
|
| 93 |
+
finally:
|
| 94 |
+
sys.argv = old_argv
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 98 |
+
"""
|
| 99 |
+
Load pretrained model
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
pretrained_model_name_or_path: Model name or path
|
| 103 |
+
trust_remote_code: Must be True to load this custom model
|
| 104 |
+
**kwargs: Additional arguments
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
LDFModel instance
|
| 108 |
+
"""
|
| 109 |
+
# Check trust_remote_code
|
| 110 |
+
if not kwargs.get('trust_remote_code', False):
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Loading this model requires trust_remote_code=True. "
|
| 113 |
+
"Usage: AutoModel.from_pretrained(..., trust_remote_code=True)"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Download if needed
|
| 117 |
+
if not os.path.exists(pretrained_model_name_or_path):
|
| 118 |
+
from huggingface_hub import snapshot_download
|
| 119 |
+
model_path = snapshot_download(repo_id=pretrained_model_name_or_path)
|
| 120 |
+
else:
|
| 121 |
+
model_path = pretrained_model_name_or_path
|
| 122 |
+
|
| 123 |
+
# Load config
|
| 124 |
+
config = LDFConfig.from_pretrained(model_path)
|
| 125 |
+
|
| 126 |
+
# Create model
|
| 127 |
+
model = cls(config)
|
| 128 |
+
model.name_or_path = model_path
|
| 129 |
+
|
| 130 |
+
# Load the actual models
|
| 131 |
+
model._load_models()
|
| 132 |
+
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
text: Union[str, List[str], List[List[str]]],
|
| 138 |
+
length: Union[int, List[int]] = 60,
|
| 139 |
+
text_end: Optional[Union[List[int], List[List[int]]]] = None,
|
| 140 |
+
num_denoise_steps: Optional[int] = None,
|
| 141 |
+
**kwargs
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Generate motion from text
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
text: Text description(s)
|
| 148 |
+
length: Number of latent tokens (output frames ≈ length × 4)
|
| 149 |
+
text_end: Transition points for multi-text
|
| 150 |
+
num_denoise_steps: Number of denoising steps
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Generated motion sequence(s)
|
| 154 |
+
"""
|
| 155 |
+
return self.__call__(text, length, text_end, num_denoise_steps)
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def __call__(
|
| 159 |
+
self,
|
| 160 |
+
text: Union[str, List[str], List[List[str]]],
|
| 161 |
+
length: Union[int, List[int]] = 60,
|
| 162 |
+
text_end: Optional[Union[List[int], List[List[int]]]] = None,
|
| 163 |
+
num_denoise_steps: Optional[int] = None,
|
| 164 |
+
output_joints: bool = False
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
Generate motion sequences
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
text: Text description
|
| 171 |
+
- Single string: "walk" -> single sample
|
| 172 |
+
- String list: ["walk", "run"] -> batch
|
| 173 |
+
- Nested list: [["walk", "turn"], ["run", "jump"]] -> multi-text per sample
|
| 174 |
+
length: Number of latent tokens (frames ≈ length × 4)
|
| 175 |
+
text_end: Token positions for text switching
|
| 176 |
+
num_denoise_steps: Number of denoising steps
|
| 177 |
+
output_joints: If True, output 22×3 joint coordinates; if False (default), output 263-dim HumanML3D features
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
numpy.ndarray or list of arrays
|
| 181 |
+
- If output_joints=False: shape (frames, 263)
|
| 182 |
+
- If output_joints=True: shape (frames, 22, 3)
|
| 183 |
+
"""
|
| 184 |
+
# Ensure models are loaded
|
| 185 |
+
self._load_models()
|
| 186 |
+
|
| 187 |
+
# Normalize inputs
|
| 188 |
+
is_single = not isinstance(length, list)
|
| 189 |
+
if is_single:
|
| 190 |
+
text_batch = [text]
|
| 191 |
+
length_batch = [length]
|
| 192 |
+
text_end_batch = [text_end] if text_end is not None else None
|
| 193 |
+
else:
|
| 194 |
+
text_batch = text
|
| 195 |
+
length_batch = length
|
| 196 |
+
text_end_batch = text_end
|
| 197 |
+
|
| 198 |
+
# Validate text_end alignment with text
|
| 199 |
+
if text_end_batch is not None:
|
| 200 |
+
for i, (txt, te) in enumerate(zip(text_batch, text_end_batch)):
|
| 201 |
+
if isinstance(txt, list) and te is not None:
|
| 202 |
+
if len(txt) != len(te):
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"Batch {i}: text has {len(txt)} segments but text_end has {len(te)} endpoints. "
|
| 205 |
+
f"They must match! text={txt}, text_end={te}"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
batch_size = len(text_batch)
|
| 209 |
+
|
| 210 |
+
# Construct input dict for model
|
| 211 |
+
x = {"feature_length": torch.tensor(length_batch), "text": text_batch}
|
| 212 |
+
if text_end_batch is not None:
|
| 213 |
+
x["feature_text_end"] = text_end_batch
|
| 214 |
+
|
| 215 |
+
# Non-streaming generate (following generate_ldf.py 125-139)
|
| 216 |
+
output = self.ldf_model.generate(x, num_denoise_steps=num_denoise_steps)
|
| 217 |
+
generated_batch = output["generated"]
|
| 218 |
+
|
| 219 |
+
# Decode with VAE and optionally convert to joints
|
| 220 |
+
decoded_results = []
|
| 221 |
+
joints_results = [] if output_joints else None
|
| 222 |
+
|
| 223 |
+
for i, generated in enumerate(generated_batch):
|
| 224 |
+
if generated is not None and torch.is_tensor(generated):
|
| 225 |
+
# Decode with VAE (following generate_ldf.py line 130)
|
| 226 |
+
decoded_g = self.vae.decode(generated[None, :])[0]
|
| 227 |
+
|
| 228 |
+
if output_joints:
|
| 229 |
+
# Use the model_dir that was saved during _load_models
|
| 230 |
+
model_dir = self.model_dir
|
| 231 |
+
|
| 232 |
+
# Import convert_motion_to_joints from ldf_utils
|
| 233 |
+
import importlib.util
|
| 234 |
+
import numpy as np
|
| 235 |
+
utils_spec = importlib.util.spec_from_file_location(
|
| 236 |
+
"motion_process",
|
| 237 |
+
os.path.join(model_dir, "ldf_utils", "motion_process.py")
|
| 238 |
+
)
|
| 239 |
+
motion_process_module = importlib.util.module_from_spec(utils_spec)
|
| 240 |
+
utils_spec.loader.exec_module(motion_process_module)
|
| 241 |
+
|
| 242 |
+
# Convert to joints using convert_motion_to_joints
|
| 243 |
+
decoded_np = decoded_g.cpu().numpy()
|
| 244 |
+
|
| 245 |
+
joints = motion_process_module.convert_motion_to_joints(
|
| 246 |
+
decoded_np, dim=263
|
| 247 |
+
)
|
| 248 |
+
joints_results.append(joints)
|
| 249 |
+
else:
|
| 250 |
+
decoded_results.append(decoded_g.cpu().numpy())
|
| 251 |
+
else:
|
| 252 |
+
if output_joints:
|
| 253 |
+
joints_results.append(None)
|
| 254 |
+
else:
|
| 255 |
+
decoded_results.append(None)
|
| 256 |
+
|
| 257 |
+
# Return results
|
| 258 |
+
if output_joints:
|
| 259 |
+
return joints_results[0] if is_single else joints_results
|
| 260 |
+
else:
|
| 261 |
+
return decoded_results[0] if is_single else decoded_results
|
| 262 |
+
|
| 263 |
+
def generate(self, *args, **kwargs):
|
| 264 |
+
"""Alias for __call__ to match transformers API"""
|
| 265 |
+
return self.__call__(*args, **kwargs)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# For backwards compatibility
|
| 269 |
+
LDFPipeline = LDFModel
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# Register with AutoModel
|
| 273 |
+
try:
|
| 274 |
+
from transformers import AutoModel, AutoConfig
|
| 275 |
+
AutoConfig.register("ldf_motion", LDFConfig)
|
| 276 |
+
AutoModel.register(LDFConfig, LDFModel)
|
| 277 |
+
except:
|
| 278 |
+
pass
|
ldf.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: ldf
|
| 2 |
+
seed: 1234
|
| 3 |
+
debug: false
|
| 4 |
+
train: false
|
| 5 |
+
|
| 6 |
+
save_dir: ./outputs
|
| 7 |
+
resume_ckpt: null
|
| 8 |
+
test_ckpt: "model.safetensors"
|
| 9 |
+
test_vae_ckpt: "vae.safetensors"
|
| 10 |
+
|
| 11 |
+
test_vae:
|
| 12 |
+
target: ldf_models.vae_wan_1d.VAEWanModel
|
| 13 |
+
ema_decay: 0.99
|
| 14 |
+
params:
|
| 15 |
+
input_dim: 263
|
| 16 |
+
z_dim: 4
|
| 17 |
+
|
| 18 |
+
test_setting:
|
| 19 |
+
render: false
|
| 20 |
+
simple: true
|
| 21 |
+
recover_dim: 263
|
| 22 |
+
|
| 23 |
+
val_repeat: 1
|
| 24 |
+
|
| 25 |
+
model:
|
| 26 |
+
target: ldf_models.diffusion_forcing_wan.DiffForcingWanModel
|
| 27 |
+
ema_decay: 0.99
|
| 28 |
+
params:
|
| 29 |
+
checkpoint_path: "ldf_deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth"
|
| 30 |
+
tokenizer_path: "ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl"
|
| 31 |
+
input_dim: 4
|
| 32 |
+
noise_steps: 10
|
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/special_tokens_map.json
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<extra_id_0>",
|
| 4 |
+
"<extra_id_1>",
|
| 5 |
+
"<extra_id_2>",
|
| 6 |
+
"<extra_id_3>",
|
| 7 |
+
"<extra_id_4>",
|
| 8 |
+
"<extra_id_5>",
|
| 9 |
+
"<extra_id_6>",
|
| 10 |
+
"<extra_id_7>",
|
| 11 |
+
"<extra_id_8>",
|
| 12 |
+
"<extra_id_9>",
|
| 13 |
+
"<extra_id_10>",
|
| 14 |
+
"<extra_id_11>",
|
| 15 |
+
"<extra_id_12>",
|
| 16 |
+
"<extra_id_13>",
|
| 17 |
+
"<extra_id_14>",
|
| 18 |
+
"<extra_id_15>",
|
| 19 |
+
"<extra_id_16>",
|
| 20 |
+
"<extra_id_17>",
|
| 21 |
+
"<extra_id_18>",
|
| 22 |
+
"<extra_id_19>",
|
| 23 |
+
"<extra_id_20>",
|
| 24 |
+
"<extra_id_21>",
|
| 25 |
+
"<extra_id_22>",
|
| 26 |
+
"<extra_id_23>",
|
| 27 |
+
"<extra_id_24>",
|
| 28 |
+
"<extra_id_25>",
|
| 29 |
+
"<extra_id_26>",
|
| 30 |
+
"<extra_id_27>",
|
| 31 |
+
"<extra_id_28>",
|
| 32 |
+
"<extra_id_29>",
|
| 33 |
+
"<extra_id_30>",
|
| 34 |
+
"<extra_id_31>",
|
| 35 |
+
"<extra_id_32>",
|
| 36 |
+
"<extra_id_33>",
|
| 37 |
+
"<extra_id_34>",
|
| 38 |
+
"<extra_id_35>",
|
| 39 |
+
"<extra_id_36>",
|
| 40 |
+
"<extra_id_37>",
|
| 41 |
+
"<extra_id_38>",
|
| 42 |
+
"<extra_id_39>",
|
| 43 |
+
"<extra_id_40>",
|
| 44 |
+
"<extra_id_41>",
|
| 45 |
+
"<extra_id_42>",
|
| 46 |
+
"<extra_id_43>",
|
| 47 |
+
"<extra_id_44>",
|
| 48 |
+
"<extra_id_45>",
|
| 49 |
+
"<extra_id_46>",
|
| 50 |
+
"<extra_id_47>",
|
| 51 |
+
"<extra_id_48>",
|
| 52 |
+
"<extra_id_49>",
|
| 53 |
+
"<extra_id_50>",
|
| 54 |
+
"<extra_id_51>",
|
| 55 |
+
"<extra_id_52>",
|
| 56 |
+
"<extra_id_53>",
|
| 57 |
+
"<extra_id_54>",
|
| 58 |
+
"<extra_id_55>",
|
| 59 |
+
"<extra_id_56>",
|
| 60 |
+
"<extra_id_57>",
|
| 61 |
+
"<extra_id_58>",
|
| 62 |
+
"<extra_id_59>",
|
| 63 |
+
"<extra_id_60>",
|
| 64 |
+
"<extra_id_61>",
|
| 65 |
+
"<extra_id_62>",
|
| 66 |
+
"<extra_id_63>",
|
| 67 |
+
"<extra_id_64>",
|
| 68 |
+
"<extra_id_65>",
|
| 69 |
+
"<extra_id_66>",
|
| 70 |
+
"<extra_id_67>",
|
| 71 |
+
"<extra_id_68>",
|
| 72 |
+
"<extra_id_69>",
|
| 73 |
+
"<extra_id_70>",
|
| 74 |
+
"<extra_id_71>",
|
| 75 |
+
"<extra_id_72>",
|
| 76 |
+
"<extra_id_73>",
|
| 77 |
+
"<extra_id_74>",
|
| 78 |
+
"<extra_id_75>",
|
| 79 |
+
"<extra_id_76>",
|
| 80 |
+
"<extra_id_77>",
|
| 81 |
+
"<extra_id_78>",
|
| 82 |
+
"<extra_id_79>",
|
| 83 |
+
"<extra_id_80>",
|
| 84 |
+
"<extra_id_81>",
|
| 85 |
+
"<extra_id_82>",
|
| 86 |
+
"<extra_id_83>",
|
| 87 |
+
"<extra_id_84>",
|
| 88 |
+
"<extra_id_85>",
|
| 89 |
+
"<extra_id_86>",
|
| 90 |
+
"<extra_id_87>",
|
| 91 |
+
"<extra_id_88>",
|
| 92 |
+
"<extra_id_89>",
|
| 93 |
+
"<extra_id_90>",
|
| 94 |
+
"<extra_id_91>",
|
| 95 |
+
"<extra_id_92>",
|
| 96 |
+
"<extra_id_93>",
|
| 97 |
+
"<extra_id_94>",
|
| 98 |
+
"<extra_id_95>",
|
| 99 |
+
"<extra_id_96>",
|
| 100 |
+
"<extra_id_97>",
|
| 101 |
+
"<extra_id_98>",
|
| 102 |
+
"<extra_id_99>",
|
| 103 |
+
"<extra_id_100>",
|
| 104 |
+
"<extra_id_101>",
|
| 105 |
+
"<extra_id_102>",
|
| 106 |
+
"<extra_id_103>",
|
| 107 |
+
"<extra_id_104>",
|
| 108 |
+
"<extra_id_105>",
|
| 109 |
+
"<extra_id_106>",
|
| 110 |
+
"<extra_id_107>",
|
| 111 |
+
"<extra_id_108>",
|
| 112 |
+
"<extra_id_109>",
|
| 113 |
+
"<extra_id_110>",
|
| 114 |
+
"<extra_id_111>",
|
| 115 |
+
"<extra_id_112>",
|
| 116 |
+
"<extra_id_113>",
|
| 117 |
+
"<extra_id_114>",
|
| 118 |
+
"<extra_id_115>",
|
| 119 |
+
"<extra_id_116>",
|
| 120 |
+
"<extra_id_117>",
|
| 121 |
+
"<extra_id_118>",
|
| 122 |
+
"<extra_id_119>",
|
| 123 |
+
"<extra_id_120>",
|
| 124 |
+
"<extra_id_121>",
|
| 125 |
+
"<extra_id_122>",
|
| 126 |
+
"<extra_id_123>",
|
| 127 |
+
"<extra_id_124>",
|
| 128 |
+
"<extra_id_125>",
|
| 129 |
+
"<extra_id_126>",
|
| 130 |
+
"<extra_id_127>",
|
| 131 |
+
"<extra_id_128>",
|
| 132 |
+
"<extra_id_129>",
|
| 133 |
+
"<extra_id_130>",
|
| 134 |
+
"<extra_id_131>",
|
| 135 |
+
"<extra_id_132>",
|
| 136 |
+
"<extra_id_133>",
|
| 137 |
+
"<extra_id_134>",
|
| 138 |
+
"<extra_id_135>",
|
| 139 |
+
"<extra_id_136>",
|
| 140 |
+
"<extra_id_137>",
|
| 141 |
+
"<extra_id_138>",
|
| 142 |
+
"<extra_id_139>",
|
| 143 |
+
"<extra_id_140>",
|
| 144 |
+
"<extra_id_141>",
|
| 145 |
+
"<extra_id_142>",
|
| 146 |
+
"<extra_id_143>",
|
| 147 |
+
"<extra_id_144>",
|
| 148 |
+
"<extra_id_145>",
|
| 149 |
+
"<extra_id_146>",
|
| 150 |
+
"<extra_id_147>",
|
| 151 |
+
"<extra_id_148>",
|
| 152 |
+
"<extra_id_149>",
|
| 153 |
+
"<extra_id_150>",
|
| 154 |
+
"<extra_id_151>",
|
| 155 |
+
"<extra_id_152>",
|
| 156 |
+
"<extra_id_153>",
|
| 157 |
+
"<extra_id_154>",
|
| 158 |
+
"<extra_id_155>",
|
| 159 |
+
"<extra_id_156>",
|
| 160 |
+
"<extra_id_157>",
|
| 161 |
+
"<extra_id_158>",
|
| 162 |
+
"<extra_id_159>",
|
| 163 |
+
"<extra_id_160>",
|
| 164 |
+
"<extra_id_161>",
|
| 165 |
+
"<extra_id_162>",
|
| 166 |
+
"<extra_id_163>",
|
| 167 |
+
"<extra_id_164>",
|
| 168 |
+
"<extra_id_165>",
|
| 169 |
+
"<extra_id_166>",
|
| 170 |
+
"<extra_id_167>",
|
| 171 |
+
"<extra_id_168>",
|
| 172 |
+
"<extra_id_169>",
|
| 173 |
+
"<extra_id_170>",
|
| 174 |
+
"<extra_id_171>",
|
| 175 |
+
"<extra_id_172>",
|
| 176 |
+
"<extra_id_173>",
|
| 177 |
+
"<extra_id_174>",
|
| 178 |
+
"<extra_id_175>",
|
| 179 |
+
"<extra_id_176>",
|
| 180 |
+
"<extra_id_177>",
|
| 181 |
+
"<extra_id_178>",
|
| 182 |
+
"<extra_id_179>",
|
| 183 |
+
"<extra_id_180>",
|
| 184 |
+
"<extra_id_181>",
|
| 185 |
+
"<extra_id_182>",
|
| 186 |
+
"<extra_id_183>",
|
| 187 |
+
"<extra_id_184>",
|
| 188 |
+
"<extra_id_185>",
|
| 189 |
+
"<extra_id_186>",
|
| 190 |
+
"<extra_id_187>",
|
| 191 |
+
"<extra_id_188>",
|
| 192 |
+
"<extra_id_189>",
|
| 193 |
+
"<extra_id_190>",
|
| 194 |
+
"<extra_id_191>",
|
| 195 |
+
"<extra_id_192>",
|
| 196 |
+
"<extra_id_193>",
|
| 197 |
+
"<extra_id_194>",
|
| 198 |
+
"<extra_id_195>",
|
| 199 |
+
"<extra_id_196>",
|
| 200 |
+
"<extra_id_197>",
|
| 201 |
+
"<extra_id_198>",
|
| 202 |
+
"<extra_id_199>",
|
| 203 |
+
"<extra_id_200>",
|
| 204 |
+
"<extra_id_201>",
|
| 205 |
+
"<extra_id_202>",
|
| 206 |
+
"<extra_id_203>",
|
| 207 |
+
"<extra_id_204>",
|
| 208 |
+
"<extra_id_205>",
|
| 209 |
+
"<extra_id_206>",
|
| 210 |
+
"<extra_id_207>",
|
| 211 |
+
"<extra_id_208>",
|
| 212 |
+
"<extra_id_209>",
|
| 213 |
+
"<extra_id_210>",
|
| 214 |
+
"<extra_id_211>",
|
| 215 |
+
"<extra_id_212>",
|
| 216 |
+
"<extra_id_213>",
|
| 217 |
+
"<extra_id_214>",
|
| 218 |
+
"<extra_id_215>",
|
| 219 |
+
"<extra_id_216>",
|
| 220 |
+
"<extra_id_217>",
|
| 221 |
+
"<extra_id_218>",
|
| 222 |
+
"<extra_id_219>",
|
| 223 |
+
"<extra_id_220>",
|
| 224 |
+
"<extra_id_221>",
|
| 225 |
+
"<extra_id_222>",
|
| 226 |
+
"<extra_id_223>",
|
| 227 |
+
"<extra_id_224>",
|
| 228 |
+
"<extra_id_225>",
|
| 229 |
+
"<extra_id_226>",
|
| 230 |
+
"<extra_id_227>",
|
| 231 |
+
"<extra_id_228>",
|
| 232 |
+
"<extra_id_229>",
|
| 233 |
+
"<extra_id_230>",
|
| 234 |
+
"<extra_id_231>",
|
| 235 |
+
"<extra_id_232>",
|
| 236 |
+
"<extra_id_233>",
|
| 237 |
+
"<extra_id_234>",
|
| 238 |
+
"<extra_id_235>",
|
| 239 |
+
"<extra_id_236>",
|
| 240 |
+
"<extra_id_237>",
|
| 241 |
+
"<extra_id_238>",
|
| 242 |
+
"<extra_id_239>",
|
| 243 |
+
"<extra_id_240>",
|
| 244 |
+
"<extra_id_241>",
|
| 245 |
+
"<extra_id_242>",
|
| 246 |
+
"<extra_id_243>",
|
| 247 |
+
"<extra_id_244>",
|
| 248 |
+
"<extra_id_245>",
|
| 249 |
+
"<extra_id_246>",
|
| 250 |
+
"<extra_id_247>",
|
| 251 |
+
"<extra_id_248>",
|
| 252 |
+
"<extra_id_249>",
|
| 253 |
+
"<extra_id_250>",
|
| 254 |
+
"<extra_id_251>",
|
| 255 |
+
"<extra_id_252>",
|
| 256 |
+
"<extra_id_253>",
|
| 257 |
+
"<extra_id_254>",
|
| 258 |
+
"<extra_id_255>",
|
| 259 |
+
"<extra_id_256>",
|
| 260 |
+
"<extra_id_257>",
|
| 261 |
+
"<extra_id_258>",
|
| 262 |
+
"<extra_id_259>",
|
| 263 |
+
"<extra_id_260>",
|
| 264 |
+
"<extra_id_261>",
|
| 265 |
+
"<extra_id_262>",
|
| 266 |
+
"<extra_id_263>",
|
| 267 |
+
"<extra_id_264>",
|
| 268 |
+
"<extra_id_265>",
|
| 269 |
+
"<extra_id_266>",
|
| 270 |
+
"<extra_id_267>",
|
| 271 |
+
"<extra_id_268>",
|
| 272 |
+
"<extra_id_269>",
|
| 273 |
+
"<extra_id_270>",
|
| 274 |
+
"<extra_id_271>",
|
| 275 |
+
"<extra_id_272>",
|
| 276 |
+
"<extra_id_273>",
|
| 277 |
+
"<extra_id_274>",
|
| 278 |
+
"<extra_id_275>",
|
| 279 |
+
"<extra_id_276>",
|
| 280 |
+
"<extra_id_277>",
|
| 281 |
+
"<extra_id_278>",
|
| 282 |
+
"<extra_id_279>",
|
| 283 |
+
"<extra_id_280>",
|
| 284 |
+
"<extra_id_281>",
|
| 285 |
+
"<extra_id_282>",
|
| 286 |
+
"<extra_id_283>",
|
| 287 |
+
"<extra_id_284>",
|
| 288 |
+
"<extra_id_285>",
|
| 289 |
+
"<extra_id_286>",
|
| 290 |
+
"<extra_id_287>",
|
| 291 |
+
"<extra_id_288>",
|
| 292 |
+
"<extra_id_289>",
|
| 293 |
+
"<extra_id_290>",
|
| 294 |
+
"<extra_id_291>",
|
| 295 |
+
"<extra_id_292>",
|
| 296 |
+
"<extra_id_293>",
|
| 297 |
+
"<extra_id_294>",
|
| 298 |
+
"<extra_id_295>",
|
| 299 |
+
"<extra_id_296>",
|
| 300 |
+
"<extra_id_297>",
|
| 301 |
+
"<extra_id_298>",
|
| 302 |
+
"<extra_id_299>"
|
| 303 |
+
],
|
| 304 |
+
"bos_token": "<s>",
|
| 305 |
+
"eos_token": "</s>",
|
| 306 |
+
"pad_token": "<pad>",
|
| 307 |
+
"unk_token": "<unk>"
|
| 308 |
+
}
|
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/spiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e3909a67b780650b35cf529ac782ad2b6b26e6d1f849d3fbb6a872905f452458
|
| 3 |
+
size 4548313
|
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e197b4d3dbd71da14b4eb255f4fa91c9c1f2068b20a2de2472967ca3d22602b
|
| 3 |
+
size 16837417
|
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer_config.json
ADDED
|
@@ -0,0 +1,2748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<pad>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "</s>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "<s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"256000": {
|
| 36 |
+
"content": "<extra_id_299>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
},
|
| 43 |
+
"256001": {
|
| 44 |
+
"content": "<extra_id_298>",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": true
|
| 50 |
+
},
|
| 51 |
+
"256002": {
|
| 52 |
+
"content": "<extra_id_297>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": false,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": true
|
| 58 |
+
},
|
| 59 |
+
"256003": {
|
| 60 |
+
"content": "<extra_id_296>",
|
| 61 |
+
"lstrip": false,
|
| 62 |
+
"normalized": false,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false,
|
| 65 |
+
"special": true
|
| 66 |
+
},
|
| 67 |
+
"256004": {
|
| 68 |
+
"content": "<extra_id_295>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": false,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"special": true
|
| 74 |
+
},
|
| 75 |
+
"256005": {
|
| 76 |
+
"content": "<extra_id_294>",
|
| 77 |
+
"lstrip": false,
|
| 78 |
+
"normalized": false,
|
| 79 |
+
"rstrip": false,
|
| 80 |
+
"single_word": false,
|
| 81 |
+
"special": true
|
| 82 |
+
},
|
| 83 |
+
"256006": {
|
| 84 |
+
"content": "<extra_id_293>",
|
| 85 |
+
"lstrip": false,
|
| 86 |
+
"normalized": false,
|
| 87 |
+
"rstrip": false,
|
| 88 |
+
"single_word": false,
|
| 89 |
+
"special": true
|
| 90 |
+
},
|
| 91 |
+
"256007": {
|
| 92 |
+
"content": "<extra_id_292>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": false,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false,
|
| 97 |
+
"special": true
|
| 98 |
+
},
|
| 99 |
+
"256008": {
|
| 100 |
+
"content": "<extra_id_291>",
|
| 101 |
+
"lstrip": false,
|
| 102 |
+
"normalized": false,
|
| 103 |
+
"rstrip": false,
|
| 104 |
+
"single_word": false,
|
| 105 |
+
"special": true
|
| 106 |
+
},
|
| 107 |
+
"256009": {
|
| 108 |
+
"content": "<extra_id_290>",
|
| 109 |
+
"lstrip": false,
|
| 110 |
+
"normalized": false,
|
| 111 |
+
"rstrip": false,
|
| 112 |
+
"single_word": false,
|
| 113 |
+
"special": true
|
| 114 |
+
},
|
| 115 |
+
"256010": {
|
| 116 |
+
"content": "<extra_id_289>",
|
| 117 |
+
"lstrip": false,
|
| 118 |
+
"normalized": false,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false,
|
| 121 |
+
"special": true
|
| 122 |
+
},
|
| 123 |
+
"256011": {
|
| 124 |
+
"content": "<extra_id_288>",
|
| 125 |
+
"lstrip": false,
|
| 126 |
+
"normalized": false,
|
| 127 |
+
"rstrip": false,
|
| 128 |
+
"single_word": false,
|
| 129 |
+
"special": true
|
| 130 |
+
},
|
| 131 |
+
"256012": {
|
| 132 |
+
"content": "<extra_id_287>",
|
| 133 |
+
"lstrip": false,
|
| 134 |
+
"normalized": false,
|
| 135 |
+
"rstrip": false,
|
| 136 |
+
"single_word": false,
|
| 137 |
+
"special": true
|
| 138 |
+
},
|
| 139 |
+
"256013": {
|
| 140 |
+
"content": "<extra_id_286>",
|
| 141 |
+
"lstrip": false,
|
| 142 |
+
"normalized": false,
|
| 143 |
+
"rstrip": false,
|
| 144 |
+
"single_word": false,
|
| 145 |
+
"special": true
|
| 146 |
+
},
|
| 147 |
+
"256014": {
|
| 148 |
+
"content": "<extra_id_285>",
|
| 149 |
+
"lstrip": false,
|
| 150 |
+
"normalized": false,
|
| 151 |
+
"rstrip": false,
|
| 152 |
+
"single_word": false,
|
| 153 |
+
"special": true
|
| 154 |
+
},
|
| 155 |
+
"256015": {
|
| 156 |
+
"content": "<extra_id_284>",
|
| 157 |
+
"lstrip": false,
|
| 158 |
+
"normalized": false,
|
| 159 |
+
"rstrip": false,
|
| 160 |
+
"single_word": false,
|
| 161 |
+
"special": true
|
| 162 |
+
},
|
| 163 |
+
"256016": {
|
| 164 |
+
"content": "<extra_id_283>",
|
| 165 |
+
"lstrip": false,
|
| 166 |
+
"normalized": false,
|
| 167 |
+
"rstrip": false,
|
| 168 |
+
"single_word": false,
|
| 169 |
+
"special": true
|
| 170 |
+
},
|
| 171 |
+
"256017": {
|
| 172 |
+
"content": "<extra_id_282>",
|
| 173 |
+
"lstrip": false,
|
| 174 |
+
"normalized": false,
|
| 175 |
+
"rstrip": false,
|
| 176 |
+
"single_word": false,
|
| 177 |
+
"special": true
|
| 178 |
+
},
|
| 179 |
+
"256018": {
|
| 180 |
+
"content": "<extra_id_281>",
|
| 181 |
+
"lstrip": false,
|
| 182 |
+
"normalized": false,
|
| 183 |
+
"rstrip": false,
|
| 184 |
+
"single_word": false,
|
| 185 |
+
"special": true
|
| 186 |
+
},
|
| 187 |
+
"256019": {
|
| 188 |
+
"content": "<extra_id_280>",
|
| 189 |
+
"lstrip": false,
|
| 190 |
+
"normalized": false,
|
| 191 |
+
"rstrip": false,
|
| 192 |
+
"single_word": false,
|
| 193 |
+
"special": true
|
| 194 |
+
},
|
| 195 |
+
"256020": {
|
| 196 |
+
"content": "<extra_id_279>",
|
| 197 |
+
"lstrip": false,
|
| 198 |
+
"normalized": false,
|
| 199 |
+
"rstrip": false,
|
| 200 |
+
"single_word": false,
|
| 201 |
+
"special": true
|
| 202 |
+
},
|
| 203 |
+
"256021": {
|
| 204 |
+
"content": "<extra_id_278>",
|
| 205 |
+
"lstrip": false,
|
| 206 |
+
"normalized": false,
|
| 207 |
+
"rstrip": false,
|
| 208 |
+
"single_word": false,
|
| 209 |
+
"special": true
|
| 210 |
+
},
|
| 211 |
+
"256022": {
|
| 212 |
+
"content": "<extra_id_277>",
|
| 213 |
+
"lstrip": false,
|
| 214 |
+
"normalized": false,
|
| 215 |
+
"rstrip": false,
|
| 216 |
+
"single_word": false,
|
| 217 |
+
"special": true
|
| 218 |
+
},
|
| 219 |
+
"256023": {
|
| 220 |
+
"content": "<extra_id_276>",
|
| 221 |
+
"lstrip": false,
|
| 222 |
+
"normalized": false,
|
| 223 |
+
"rstrip": false,
|
| 224 |
+
"single_word": false,
|
| 225 |
+
"special": true
|
| 226 |
+
},
|
| 227 |
+
"256024": {
|
| 228 |
+
"content": "<extra_id_275>",
|
| 229 |
+
"lstrip": false,
|
| 230 |
+
"normalized": false,
|
| 231 |
+
"rstrip": false,
|
| 232 |
+
"single_word": false,
|
| 233 |
+
"special": true
|
| 234 |
+
},
|
| 235 |
+
"256025": {
|
| 236 |
+
"content": "<extra_id_274>",
|
| 237 |
+
"lstrip": false,
|
| 238 |
+
"normalized": false,
|
| 239 |
+
"rstrip": false,
|
| 240 |
+
"single_word": false,
|
| 241 |
+
"special": true
|
| 242 |
+
},
|
| 243 |
+
"256026": {
|
| 244 |
+
"content": "<extra_id_273>",
|
| 245 |
+
"lstrip": false,
|
| 246 |
+
"normalized": false,
|
| 247 |
+
"rstrip": false,
|
| 248 |
+
"single_word": false,
|
| 249 |
+
"special": true
|
| 250 |
+
},
|
| 251 |
+
"256027": {
|
| 252 |
+
"content": "<extra_id_272>",
|
| 253 |
+
"lstrip": false,
|
| 254 |
+
"normalized": false,
|
| 255 |
+
"rstrip": false,
|
| 256 |
+
"single_word": false,
|
| 257 |
+
"special": true
|
| 258 |
+
},
|
| 259 |
+
"256028": {
|
| 260 |
+
"content": "<extra_id_271>",
|
| 261 |
+
"lstrip": false,
|
| 262 |
+
"normalized": false,
|
| 263 |
+
"rstrip": false,
|
| 264 |
+
"single_word": false,
|
| 265 |
+
"special": true
|
| 266 |
+
},
|
| 267 |
+
"256029": {
|
| 268 |
+
"content": "<extra_id_270>",
|
| 269 |
+
"lstrip": false,
|
| 270 |
+
"normalized": false,
|
| 271 |
+
"rstrip": false,
|
| 272 |
+
"single_word": false,
|
| 273 |
+
"special": true
|
| 274 |
+
},
|
| 275 |
+
"256030": {
|
| 276 |
+
"content": "<extra_id_269>",
|
| 277 |
+
"lstrip": false,
|
| 278 |
+
"normalized": false,
|
| 279 |
+
"rstrip": false,
|
| 280 |
+
"single_word": false,
|
| 281 |
+
"special": true
|
| 282 |
+
},
|
| 283 |
+
"256031": {
|
| 284 |
+
"content": "<extra_id_268>",
|
| 285 |
+
"lstrip": false,
|
| 286 |
+
"normalized": false,
|
| 287 |
+
"rstrip": false,
|
| 288 |
+
"single_word": false,
|
| 289 |
+
"special": true
|
| 290 |
+
},
|
| 291 |
+
"256032": {
|
| 292 |
+
"content": "<extra_id_267>",
|
| 293 |
+
"lstrip": false,
|
| 294 |
+
"normalized": false,
|
| 295 |
+
"rstrip": false,
|
| 296 |
+
"single_word": false,
|
| 297 |
+
"special": true
|
| 298 |
+
},
|
| 299 |
+
"256033": {
|
| 300 |
+
"content": "<extra_id_266>",
|
| 301 |
+
"lstrip": false,
|
| 302 |
+
"normalized": false,
|
| 303 |
+
"rstrip": false,
|
| 304 |
+
"single_word": false,
|
| 305 |
+
"special": true
|
| 306 |
+
},
|
| 307 |
+
"256034": {
|
| 308 |
+
"content": "<extra_id_265>",
|
| 309 |
+
"lstrip": false,
|
| 310 |
+
"normalized": false,
|
| 311 |
+
"rstrip": false,
|
| 312 |
+
"single_word": false,
|
| 313 |
+
"special": true
|
| 314 |
+
},
|
| 315 |
+
"256035": {
|
| 316 |
+
"content": "<extra_id_264>",
|
| 317 |
+
"lstrip": false,
|
| 318 |
+
"normalized": false,
|
| 319 |
+
"rstrip": false,
|
| 320 |
+
"single_word": false,
|
| 321 |
+
"special": true
|
| 322 |
+
},
|
| 323 |
+
"256036": {
|
| 324 |
+
"content": "<extra_id_263>",
|
| 325 |
+
"lstrip": false,
|
| 326 |
+
"normalized": false,
|
| 327 |
+
"rstrip": false,
|
| 328 |
+
"single_word": false,
|
| 329 |
+
"special": true
|
| 330 |
+
},
|
| 331 |
+
"256037": {
|
| 332 |
+
"content": "<extra_id_262>",
|
| 333 |
+
"lstrip": false,
|
| 334 |
+
"normalized": false,
|
| 335 |
+
"rstrip": false,
|
| 336 |
+
"single_word": false,
|
| 337 |
+
"special": true
|
| 338 |
+
},
|
| 339 |
+
"256038": {
|
| 340 |
+
"content": "<extra_id_261>",
|
| 341 |
+
"lstrip": false,
|
| 342 |
+
"normalized": false,
|
| 343 |
+
"rstrip": false,
|
| 344 |
+
"single_word": false,
|
| 345 |
+
"special": true
|
| 346 |
+
},
|
| 347 |
+
"256039": {
|
| 348 |
+
"content": "<extra_id_260>",
|
| 349 |
+
"lstrip": false,
|
| 350 |
+
"normalized": false,
|
| 351 |
+
"rstrip": false,
|
| 352 |
+
"single_word": false,
|
| 353 |
+
"special": true
|
| 354 |
+
},
|
| 355 |
+
"256040": {
|
| 356 |
+
"content": "<extra_id_259>",
|
| 357 |
+
"lstrip": false,
|
| 358 |
+
"normalized": false,
|
| 359 |
+
"rstrip": false,
|
| 360 |
+
"single_word": false,
|
| 361 |
+
"special": true
|
| 362 |
+
},
|
| 363 |
+
"256041": {
|
| 364 |
+
"content": "<extra_id_258>",
|
| 365 |
+
"lstrip": false,
|
| 366 |
+
"normalized": false,
|
| 367 |
+
"rstrip": false,
|
| 368 |
+
"single_word": false,
|
| 369 |
+
"special": true
|
| 370 |
+
},
|
| 371 |
+
"256042": {
|
| 372 |
+
"content": "<extra_id_257>",
|
| 373 |
+
"lstrip": false,
|
| 374 |
+
"normalized": false,
|
| 375 |
+
"rstrip": false,
|
| 376 |
+
"single_word": false,
|
| 377 |
+
"special": true
|
| 378 |
+
},
|
| 379 |
+
"256043": {
|
| 380 |
+
"content": "<extra_id_256>",
|
| 381 |
+
"lstrip": false,
|
| 382 |
+
"normalized": false,
|
| 383 |
+
"rstrip": false,
|
| 384 |
+
"single_word": false,
|
| 385 |
+
"special": true
|
| 386 |
+
},
|
| 387 |
+
"256044": {
|
| 388 |
+
"content": "<extra_id_255>",
|
| 389 |
+
"lstrip": false,
|
| 390 |
+
"normalized": false,
|
| 391 |
+
"rstrip": false,
|
| 392 |
+
"single_word": false,
|
| 393 |
+
"special": true
|
| 394 |
+
},
|
| 395 |
+
"256045": {
|
| 396 |
+
"content": "<extra_id_254>",
|
| 397 |
+
"lstrip": false,
|
| 398 |
+
"normalized": false,
|
| 399 |
+
"rstrip": false,
|
| 400 |
+
"single_word": false,
|
| 401 |
+
"special": true
|
| 402 |
+
},
|
| 403 |
+
"256046": {
|
| 404 |
+
"content": "<extra_id_253>",
|
| 405 |
+
"lstrip": false,
|
| 406 |
+
"normalized": false,
|
| 407 |
+
"rstrip": false,
|
| 408 |
+
"single_word": false,
|
| 409 |
+
"special": true
|
| 410 |
+
},
|
| 411 |
+
"256047": {
|
| 412 |
+
"content": "<extra_id_252>",
|
| 413 |
+
"lstrip": false,
|
| 414 |
+
"normalized": false,
|
| 415 |
+
"rstrip": false,
|
| 416 |
+
"single_word": false,
|
| 417 |
+
"special": true
|
| 418 |
+
},
|
| 419 |
+
"256048": {
|
| 420 |
+
"content": "<extra_id_251>",
|
| 421 |
+
"lstrip": false,
|
| 422 |
+
"normalized": false,
|
| 423 |
+
"rstrip": false,
|
| 424 |
+
"single_word": false,
|
| 425 |
+
"special": true
|
| 426 |
+
},
|
| 427 |
+
"256049": {
|
| 428 |
+
"content": "<extra_id_250>",
|
| 429 |
+
"lstrip": false,
|
| 430 |
+
"normalized": false,
|
| 431 |
+
"rstrip": false,
|
| 432 |
+
"single_word": false,
|
| 433 |
+
"special": true
|
| 434 |
+
},
|
| 435 |
+
"256050": {
|
| 436 |
+
"content": "<extra_id_249>",
|
| 437 |
+
"lstrip": false,
|
| 438 |
+
"normalized": false,
|
| 439 |
+
"rstrip": false,
|
| 440 |
+
"single_word": false,
|
| 441 |
+
"special": true
|
| 442 |
+
},
|
| 443 |
+
"256051": {
|
| 444 |
+
"content": "<extra_id_248>",
|
| 445 |
+
"lstrip": false,
|
| 446 |
+
"normalized": false,
|
| 447 |
+
"rstrip": false,
|
| 448 |
+
"single_word": false,
|
| 449 |
+
"special": true
|
| 450 |
+
},
|
| 451 |
+
"256052": {
|
| 452 |
+
"content": "<extra_id_247>",
|
| 453 |
+
"lstrip": false,
|
| 454 |
+
"normalized": false,
|
| 455 |
+
"rstrip": false,
|
| 456 |
+
"single_word": false,
|
| 457 |
+
"special": true
|
| 458 |
+
},
|
| 459 |
+
"256053": {
|
| 460 |
+
"content": "<extra_id_246>",
|
| 461 |
+
"lstrip": false,
|
| 462 |
+
"normalized": false,
|
| 463 |
+
"rstrip": false,
|
| 464 |
+
"single_word": false,
|
| 465 |
+
"special": true
|
| 466 |
+
},
|
| 467 |
+
"256054": {
|
| 468 |
+
"content": "<extra_id_245>",
|
| 469 |
+
"lstrip": false,
|
| 470 |
+
"normalized": false,
|
| 471 |
+
"rstrip": false,
|
| 472 |
+
"single_word": false,
|
| 473 |
+
"special": true
|
| 474 |
+
},
|
| 475 |
+
"256055": {
|
| 476 |
+
"content": "<extra_id_244>",
|
| 477 |
+
"lstrip": false,
|
| 478 |
+
"normalized": false,
|
| 479 |
+
"rstrip": false,
|
| 480 |
+
"single_word": false,
|
| 481 |
+
"special": true
|
| 482 |
+
},
|
| 483 |
+
"256056": {
|
| 484 |
+
"content": "<extra_id_243>",
|
| 485 |
+
"lstrip": false,
|
| 486 |
+
"normalized": false,
|
| 487 |
+
"rstrip": false,
|
| 488 |
+
"single_word": false,
|
| 489 |
+
"special": true
|
| 490 |
+
},
|
| 491 |
+
"256057": {
|
| 492 |
+
"content": "<extra_id_242>",
|
| 493 |
+
"lstrip": false,
|
| 494 |
+
"normalized": false,
|
| 495 |
+
"rstrip": false,
|
| 496 |
+
"single_word": false,
|
| 497 |
+
"special": true
|
| 498 |
+
},
|
| 499 |
+
"256058": {
|
| 500 |
+
"content": "<extra_id_241>",
|
| 501 |
+
"lstrip": false,
|
| 502 |
+
"normalized": false,
|
| 503 |
+
"rstrip": false,
|
| 504 |
+
"single_word": false,
|
| 505 |
+
"special": true
|
| 506 |
+
},
|
| 507 |
+
"256059": {
|
| 508 |
+
"content": "<extra_id_240>",
|
| 509 |
+
"lstrip": false,
|
| 510 |
+
"normalized": false,
|
| 511 |
+
"rstrip": false,
|
| 512 |
+
"single_word": false,
|
| 513 |
+
"special": true
|
| 514 |
+
},
|
| 515 |
+
"256060": {
|
| 516 |
+
"content": "<extra_id_239>",
|
| 517 |
+
"lstrip": false,
|
| 518 |
+
"normalized": false,
|
| 519 |
+
"rstrip": false,
|
| 520 |
+
"single_word": false,
|
| 521 |
+
"special": true
|
| 522 |
+
},
|
| 523 |
+
"256061": {
|
| 524 |
+
"content": "<extra_id_238>",
|
| 525 |
+
"lstrip": false,
|
| 526 |
+
"normalized": false,
|
| 527 |
+
"rstrip": false,
|
| 528 |
+
"single_word": false,
|
| 529 |
+
"special": true
|
| 530 |
+
},
|
| 531 |
+
"256062": {
|
| 532 |
+
"content": "<extra_id_237>",
|
| 533 |
+
"lstrip": false,
|
| 534 |
+
"normalized": false,
|
| 535 |
+
"rstrip": false,
|
| 536 |
+
"single_word": false,
|
| 537 |
+
"special": true
|
| 538 |
+
},
|
| 539 |
+
"256063": {
|
| 540 |
+
"content": "<extra_id_236>",
|
| 541 |
+
"lstrip": false,
|
| 542 |
+
"normalized": false,
|
| 543 |
+
"rstrip": false,
|
| 544 |
+
"single_word": false,
|
| 545 |
+
"special": true
|
| 546 |
+
},
|
| 547 |
+
"256064": {
|
| 548 |
+
"content": "<extra_id_235>",
|
| 549 |
+
"lstrip": false,
|
| 550 |
+
"normalized": false,
|
| 551 |
+
"rstrip": false,
|
| 552 |
+
"single_word": false,
|
| 553 |
+
"special": true
|
| 554 |
+
},
|
| 555 |
+
"256065": {
|
| 556 |
+
"content": "<extra_id_234>",
|
| 557 |
+
"lstrip": false,
|
| 558 |
+
"normalized": false,
|
| 559 |
+
"rstrip": false,
|
| 560 |
+
"single_word": false,
|
| 561 |
+
"special": true
|
| 562 |
+
},
|
| 563 |
+
"256066": {
|
| 564 |
+
"content": "<extra_id_233>",
|
| 565 |
+
"lstrip": false,
|
| 566 |
+
"normalized": false,
|
| 567 |
+
"rstrip": false,
|
| 568 |
+
"single_word": false,
|
| 569 |
+
"special": true
|
| 570 |
+
},
|
| 571 |
+
"256067": {
|
| 572 |
+
"content": "<extra_id_232>",
|
| 573 |
+
"lstrip": false,
|
| 574 |
+
"normalized": false,
|
| 575 |
+
"rstrip": false,
|
| 576 |
+
"single_word": false,
|
| 577 |
+
"special": true
|
| 578 |
+
},
|
| 579 |
+
"256068": {
|
| 580 |
+
"content": "<extra_id_231>",
|
| 581 |
+
"lstrip": false,
|
| 582 |
+
"normalized": false,
|
| 583 |
+
"rstrip": false,
|
| 584 |
+
"single_word": false,
|
| 585 |
+
"special": true
|
| 586 |
+
},
|
| 587 |
+
"256069": {
|
| 588 |
+
"content": "<extra_id_230>",
|
| 589 |
+
"lstrip": false,
|
| 590 |
+
"normalized": false,
|
| 591 |
+
"rstrip": false,
|
| 592 |
+
"single_word": false,
|
| 593 |
+
"special": true
|
| 594 |
+
},
|
| 595 |
+
"256070": {
|
| 596 |
+
"content": "<extra_id_229>",
|
| 597 |
+
"lstrip": false,
|
| 598 |
+
"normalized": false,
|
| 599 |
+
"rstrip": false,
|
| 600 |
+
"single_word": false,
|
| 601 |
+
"special": true
|
| 602 |
+
},
|
| 603 |
+
"256071": {
|
| 604 |
+
"content": "<extra_id_228>",
|
| 605 |
+
"lstrip": false,
|
| 606 |
+
"normalized": false,
|
| 607 |
+
"rstrip": false,
|
| 608 |
+
"single_word": false,
|
| 609 |
+
"special": true
|
| 610 |
+
},
|
| 611 |
+
"256072": {
|
| 612 |
+
"content": "<extra_id_227>",
|
| 613 |
+
"lstrip": false,
|
| 614 |
+
"normalized": false,
|
| 615 |
+
"rstrip": false,
|
| 616 |
+
"single_word": false,
|
| 617 |
+
"special": true
|
| 618 |
+
},
|
| 619 |
+
"256073": {
|
| 620 |
+
"content": "<extra_id_226>",
|
| 621 |
+
"lstrip": false,
|
| 622 |
+
"normalized": false,
|
| 623 |
+
"rstrip": false,
|
| 624 |
+
"single_word": false,
|
| 625 |
+
"special": true
|
| 626 |
+
},
|
| 627 |
+
"256074": {
|
| 628 |
+
"content": "<extra_id_225>",
|
| 629 |
+
"lstrip": false,
|
| 630 |
+
"normalized": false,
|
| 631 |
+
"rstrip": false,
|
| 632 |
+
"single_word": false,
|
| 633 |
+
"special": true
|
| 634 |
+
},
|
| 635 |
+
"256075": {
|
| 636 |
+
"content": "<extra_id_224>",
|
| 637 |
+
"lstrip": false,
|
| 638 |
+
"normalized": false,
|
| 639 |
+
"rstrip": false,
|
| 640 |
+
"single_word": false,
|
| 641 |
+
"special": true
|
| 642 |
+
},
|
| 643 |
+
"256076": {
|
| 644 |
+
"content": "<extra_id_223>",
|
| 645 |
+
"lstrip": false,
|
| 646 |
+
"normalized": false,
|
| 647 |
+
"rstrip": false,
|
| 648 |
+
"single_word": false,
|
| 649 |
+
"special": true
|
| 650 |
+
},
|
| 651 |
+
"256077": {
|
| 652 |
+
"content": "<extra_id_222>",
|
| 653 |
+
"lstrip": false,
|
| 654 |
+
"normalized": false,
|
| 655 |
+
"rstrip": false,
|
| 656 |
+
"single_word": false,
|
| 657 |
+
"special": true
|
| 658 |
+
},
|
| 659 |
+
"256078": {
|
| 660 |
+
"content": "<extra_id_221>",
|
| 661 |
+
"lstrip": false,
|
| 662 |
+
"normalized": false,
|
| 663 |
+
"rstrip": false,
|
| 664 |
+
"single_word": false,
|
| 665 |
+
"special": true
|
| 666 |
+
},
|
| 667 |
+
"256079": {
|
| 668 |
+
"content": "<extra_id_220>",
|
| 669 |
+
"lstrip": false,
|
| 670 |
+
"normalized": false,
|
| 671 |
+
"rstrip": false,
|
| 672 |
+
"single_word": false,
|
| 673 |
+
"special": true
|
| 674 |
+
},
|
| 675 |
+
"256080": {
|
| 676 |
+
"content": "<extra_id_219>",
|
| 677 |
+
"lstrip": false,
|
| 678 |
+
"normalized": false,
|
| 679 |
+
"rstrip": false,
|
| 680 |
+
"single_word": false,
|
| 681 |
+
"special": true
|
| 682 |
+
},
|
| 683 |
+
"256081": {
|
| 684 |
+
"content": "<extra_id_218>",
|
| 685 |
+
"lstrip": false,
|
| 686 |
+
"normalized": false,
|
| 687 |
+
"rstrip": false,
|
| 688 |
+
"single_word": false,
|
| 689 |
+
"special": true
|
| 690 |
+
},
|
| 691 |
+
"256082": {
|
| 692 |
+
"content": "<extra_id_217>",
|
| 693 |
+
"lstrip": false,
|
| 694 |
+
"normalized": false,
|
| 695 |
+
"rstrip": false,
|
| 696 |
+
"single_word": false,
|
| 697 |
+
"special": true
|
| 698 |
+
},
|
| 699 |
+
"256083": {
|
| 700 |
+
"content": "<extra_id_216>",
|
| 701 |
+
"lstrip": false,
|
| 702 |
+
"normalized": false,
|
| 703 |
+
"rstrip": false,
|
| 704 |
+
"single_word": false,
|
| 705 |
+
"special": true
|
| 706 |
+
},
|
| 707 |
+
"256084": {
|
| 708 |
+
"content": "<extra_id_215>",
|
| 709 |
+
"lstrip": false,
|
| 710 |
+
"normalized": false,
|
| 711 |
+
"rstrip": false,
|
| 712 |
+
"single_word": false,
|
| 713 |
+
"special": true
|
| 714 |
+
},
|
| 715 |
+
"256085": {
|
| 716 |
+
"content": "<extra_id_214>",
|
| 717 |
+
"lstrip": false,
|
| 718 |
+
"normalized": false,
|
| 719 |
+
"rstrip": false,
|
| 720 |
+
"single_word": false,
|
| 721 |
+
"special": true
|
| 722 |
+
},
|
| 723 |
+
"256086": {
|
| 724 |
+
"content": "<extra_id_213>",
|
| 725 |
+
"lstrip": false,
|
| 726 |
+
"normalized": false,
|
| 727 |
+
"rstrip": false,
|
| 728 |
+
"single_word": false,
|
| 729 |
+
"special": true
|
| 730 |
+
},
|
| 731 |
+
"256087": {
|
| 732 |
+
"content": "<extra_id_212>",
|
| 733 |
+
"lstrip": false,
|
| 734 |
+
"normalized": false,
|
| 735 |
+
"rstrip": false,
|
| 736 |
+
"single_word": false,
|
| 737 |
+
"special": true
|
| 738 |
+
},
|
| 739 |
+
"256088": {
|
| 740 |
+
"content": "<extra_id_211>",
|
| 741 |
+
"lstrip": false,
|
| 742 |
+
"normalized": false,
|
| 743 |
+
"rstrip": false,
|
| 744 |
+
"single_word": false,
|
| 745 |
+
"special": true
|
| 746 |
+
},
|
| 747 |
+
"256089": {
|
| 748 |
+
"content": "<extra_id_210>",
|
| 749 |
+
"lstrip": false,
|
| 750 |
+
"normalized": false,
|
| 751 |
+
"rstrip": false,
|
| 752 |
+
"single_word": false,
|
| 753 |
+
"special": true
|
| 754 |
+
},
|
| 755 |
+
"256090": {
|
| 756 |
+
"content": "<extra_id_209>",
|
| 757 |
+
"lstrip": false,
|
| 758 |
+
"normalized": false,
|
| 759 |
+
"rstrip": false,
|
| 760 |
+
"single_word": false,
|
| 761 |
+
"special": true
|
| 762 |
+
},
|
| 763 |
+
"256091": {
|
| 764 |
+
"content": "<extra_id_208>",
|
| 765 |
+
"lstrip": false,
|
| 766 |
+
"normalized": false,
|
| 767 |
+
"rstrip": false,
|
| 768 |
+
"single_word": false,
|
| 769 |
+
"special": true
|
| 770 |
+
},
|
| 771 |
+
"256092": {
|
| 772 |
+
"content": "<extra_id_207>",
|
| 773 |
+
"lstrip": false,
|
| 774 |
+
"normalized": false,
|
| 775 |
+
"rstrip": false,
|
| 776 |
+
"single_word": false,
|
| 777 |
+
"special": true
|
| 778 |
+
},
|
| 779 |
+
"256093": {
|
| 780 |
+
"content": "<extra_id_206>",
|
| 781 |
+
"lstrip": false,
|
| 782 |
+
"normalized": false,
|
| 783 |
+
"rstrip": false,
|
| 784 |
+
"single_word": false,
|
| 785 |
+
"special": true
|
| 786 |
+
},
|
| 787 |
+
"256094": {
|
| 788 |
+
"content": "<extra_id_205>",
|
| 789 |
+
"lstrip": false,
|
| 790 |
+
"normalized": false,
|
| 791 |
+
"rstrip": false,
|
| 792 |
+
"single_word": false,
|
| 793 |
+
"special": true
|
| 794 |
+
},
|
| 795 |
+
"256095": {
|
| 796 |
+
"content": "<extra_id_204>",
|
| 797 |
+
"lstrip": false,
|
| 798 |
+
"normalized": false,
|
| 799 |
+
"rstrip": false,
|
| 800 |
+
"single_word": false,
|
| 801 |
+
"special": true
|
| 802 |
+
},
|
| 803 |
+
"256096": {
|
| 804 |
+
"content": "<extra_id_203>",
|
| 805 |
+
"lstrip": false,
|
| 806 |
+
"normalized": false,
|
| 807 |
+
"rstrip": false,
|
| 808 |
+
"single_word": false,
|
| 809 |
+
"special": true
|
| 810 |
+
},
|
| 811 |
+
"256097": {
|
| 812 |
+
"content": "<extra_id_202>",
|
| 813 |
+
"lstrip": false,
|
| 814 |
+
"normalized": false,
|
| 815 |
+
"rstrip": false,
|
| 816 |
+
"single_word": false,
|
| 817 |
+
"special": true
|
| 818 |
+
},
|
| 819 |
+
"256098": {
|
| 820 |
+
"content": "<extra_id_201>",
|
| 821 |
+
"lstrip": false,
|
| 822 |
+
"normalized": false,
|
| 823 |
+
"rstrip": false,
|
| 824 |
+
"single_word": false,
|
| 825 |
+
"special": true
|
| 826 |
+
},
|
| 827 |
+
"256099": {
|
| 828 |
+
"content": "<extra_id_200>",
|
| 829 |
+
"lstrip": false,
|
| 830 |
+
"normalized": false,
|
| 831 |
+
"rstrip": false,
|
| 832 |
+
"single_word": false,
|
| 833 |
+
"special": true
|
| 834 |
+
},
|
| 835 |
+
"256100": {
|
| 836 |
+
"content": "<extra_id_199>",
|
| 837 |
+
"lstrip": false,
|
| 838 |
+
"normalized": false,
|
| 839 |
+
"rstrip": false,
|
| 840 |
+
"single_word": false,
|
| 841 |
+
"special": true
|
| 842 |
+
},
|
| 843 |
+
"256101": {
|
| 844 |
+
"content": "<extra_id_198>",
|
| 845 |
+
"lstrip": false,
|
| 846 |
+
"normalized": false,
|
| 847 |
+
"rstrip": false,
|
| 848 |
+
"single_word": false,
|
| 849 |
+
"special": true
|
| 850 |
+
},
|
| 851 |
+
"256102": {
|
| 852 |
+
"content": "<extra_id_197>",
|
| 853 |
+
"lstrip": false,
|
| 854 |
+
"normalized": false,
|
| 855 |
+
"rstrip": false,
|
| 856 |
+
"single_word": false,
|
| 857 |
+
"special": true
|
| 858 |
+
},
|
| 859 |
+
"256103": {
|
| 860 |
+
"content": "<extra_id_196>",
|
| 861 |
+
"lstrip": false,
|
| 862 |
+
"normalized": false,
|
| 863 |
+
"rstrip": false,
|
| 864 |
+
"single_word": false,
|
| 865 |
+
"special": true
|
| 866 |
+
},
|
| 867 |
+
"256104": {
|
| 868 |
+
"content": "<extra_id_195>",
|
| 869 |
+
"lstrip": false,
|
| 870 |
+
"normalized": false,
|
| 871 |
+
"rstrip": false,
|
| 872 |
+
"single_word": false,
|
| 873 |
+
"special": true
|
| 874 |
+
},
|
| 875 |
+
"256105": {
|
| 876 |
+
"content": "<extra_id_194>",
|
| 877 |
+
"lstrip": false,
|
| 878 |
+
"normalized": false,
|
| 879 |
+
"rstrip": false,
|
| 880 |
+
"single_word": false,
|
| 881 |
+
"special": true
|
| 882 |
+
},
|
| 883 |
+
"256106": {
|
| 884 |
+
"content": "<extra_id_193>",
|
| 885 |
+
"lstrip": false,
|
| 886 |
+
"normalized": false,
|
| 887 |
+
"rstrip": false,
|
| 888 |
+
"single_word": false,
|
| 889 |
+
"special": true
|
| 890 |
+
},
|
| 891 |
+
"256107": {
|
| 892 |
+
"content": "<extra_id_192>",
|
| 893 |
+
"lstrip": false,
|
| 894 |
+
"normalized": false,
|
| 895 |
+
"rstrip": false,
|
| 896 |
+
"single_word": false,
|
| 897 |
+
"special": true
|
| 898 |
+
},
|
| 899 |
+
"256108": {
|
| 900 |
+
"content": "<extra_id_191>",
|
| 901 |
+
"lstrip": false,
|
| 902 |
+
"normalized": false,
|
| 903 |
+
"rstrip": false,
|
| 904 |
+
"single_word": false,
|
| 905 |
+
"special": true
|
| 906 |
+
},
|
| 907 |
+
"256109": {
|
| 908 |
+
"content": "<extra_id_190>",
|
| 909 |
+
"lstrip": false,
|
| 910 |
+
"normalized": false,
|
| 911 |
+
"rstrip": false,
|
| 912 |
+
"single_word": false,
|
| 913 |
+
"special": true
|
| 914 |
+
},
|
| 915 |
+
"256110": {
|
| 916 |
+
"content": "<extra_id_189>",
|
| 917 |
+
"lstrip": false,
|
| 918 |
+
"normalized": false,
|
| 919 |
+
"rstrip": false,
|
| 920 |
+
"single_word": false,
|
| 921 |
+
"special": true
|
| 922 |
+
},
|
| 923 |
+
"256111": {
|
| 924 |
+
"content": "<extra_id_188>",
|
| 925 |
+
"lstrip": false,
|
| 926 |
+
"normalized": false,
|
| 927 |
+
"rstrip": false,
|
| 928 |
+
"single_word": false,
|
| 929 |
+
"special": true
|
| 930 |
+
},
|
| 931 |
+
"256112": {
|
| 932 |
+
"content": "<extra_id_187>",
|
| 933 |
+
"lstrip": false,
|
| 934 |
+
"normalized": false,
|
| 935 |
+
"rstrip": false,
|
| 936 |
+
"single_word": false,
|
| 937 |
+
"special": true
|
| 938 |
+
},
|
| 939 |
+
"256113": {
|
| 940 |
+
"content": "<extra_id_186>",
|
| 941 |
+
"lstrip": false,
|
| 942 |
+
"normalized": false,
|
| 943 |
+
"rstrip": false,
|
| 944 |
+
"single_word": false,
|
| 945 |
+
"special": true
|
| 946 |
+
},
|
| 947 |
+
"256114": {
|
| 948 |
+
"content": "<extra_id_185>",
|
| 949 |
+
"lstrip": false,
|
| 950 |
+
"normalized": false,
|
| 951 |
+
"rstrip": false,
|
| 952 |
+
"single_word": false,
|
| 953 |
+
"special": true
|
| 954 |
+
},
|
| 955 |
+
"256115": {
|
| 956 |
+
"content": "<extra_id_184>",
|
| 957 |
+
"lstrip": false,
|
| 958 |
+
"normalized": false,
|
| 959 |
+
"rstrip": false,
|
| 960 |
+
"single_word": false,
|
| 961 |
+
"special": true
|
| 962 |
+
},
|
| 963 |
+
"256116": {
|
| 964 |
+
"content": "<extra_id_183>",
|
| 965 |
+
"lstrip": false,
|
| 966 |
+
"normalized": false,
|
| 967 |
+
"rstrip": false,
|
| 968 |
+
"single_word": false,
|
| 969 |
+
"special": true
|
| 970 |
+
},
|
| 971 |
+
"256117": {
|
| 972 |
+
"content": "<extra_id_182>",
|
| 973 |
+
"lstrip": false,
|
| 974 |
+
"normalized": false,
|
| 975 |
+
"rstrip": false,
|
| 976 |
+
"single_word": false,
|
| 977 |
+
"special": true
|
| 978 |
+
},
|
| 979 |
+
"256118": {
|
| 980 |
+
"content": "<extra_id_181>",
|
| 981 |
+
"lstrip": false,
|
| 982 |
+
"normalized": false,
|
| 983 |
+
"rstrip": false,
|
| 984 |
+
"single_word": false,
|
| 985 |
+
"special": true
|
| 986 |
+
},
|
| 987 |
+
"256119": {
|
| 988 |
+
"content": "<extra_id_180>",
|
| 989 |
+
"lstrip": false,
|
| 990 |
+
"normalized": false,
|
| 991 |
+
"rstrip": false,
|
| 992 |
+
"single_word": false,
|
| 993 |
+
"special": true
|
| 994 |
+
},
|
| 995 |
+
"256120": {
|
| 996 |
+
"content": "<extra_id_179>",
|
| 997 |
+
"lstrip": false,
|
| 998 |
+
"normalized": false,
|
| 999 |
+
"rstrip": false,
|
| 1000 |
+
"single_word": false,
|
| 1001 |
+
"special": true
|
| 1002 |
+
},
|
| 1003 |
+
"256121": {
|
| 1004 |
+
"content": "<extra_id_178>",
|
| 1005 |
+
"lstrip": false,
|
| 1006 |
+
"normalized": false,
|
| 1007 |
+
"rstrip": false,
|
| 1008 |
+
"single_word": false,
|
| 1009 |
+
"special": true
|
| 1010 |
+
},
|
| 1011 |
+
"256122": {
|
| 1012 |
+
"content": "<extra_id_177>",
|
| 1013 |
+
"lstrip": false,
|
| 1014 |
+
"normalized": false,
|
| 1015 |
+
"rstrip": false,
|
| 1016 |
+
"single_word": false,
|
| 1017 |
+
"special": true
|
| 1018 |
+
},
|
| 1019 |
+
"256123": {
|
| 1020 |
+
"content": "<extra_id_176>",
|
| 1021 |
+
"lstrip": false,
|
| 1022 |
+
"normalized": false,
|
| 1023 |
+
"rstrip": false,
|
| 1024 |
+
"single_word": false,
|
| 1025 |
+
"special": true
|
| 1026 |
+
},
|
| 1027 |
+
"256124": {
|
| 1028 |
+
"content": "<extra_id_175>",
|
| 1029 |
+
"lstrip": false,
|
| 1030 |
+
"normalized": false,
|
| 1031 |
+
"rstrip": false,
|
| 1032 |
+
"single_word": false,
|
| 1033 |
+
"special": true
|
| 1034 |
+
},
|
| 1035 |
+
"256125": {
|
| 1036 |
+
"content": "<extra_id_174>",
|
| 1037 |
+
"lstrip": false,
|
| 1038 |
+
"normalized": false,
|
| 1039 |
+
"rstrip": false,
|
| 1040 |
+
"single_word": false,
|
| 1041 |
+
"special": true
|
| 1042 |
+
},
|
| 1043 |
+
"256126": {
|
| 1044 |
+
"content": "<extra_id_173>",
|
| 1045 |
+
"lstrip": false,
|
| 1046 |
+
"normalized": false,
|
| 1047 |
+
"rstrip": false,
|
| 1048 |
+
"single_word": false,
|
| 1049 |
+
"special": true
|
| 1050 |
+
},
|
| 1051 |
+
"256127": {
|
| 1052 |
+
"content": "<extra_id_172>",
|
| 1053 |
+
"lstrip": false,
|
| 1054 |
+
"normalized": false,
|
| 1055 |
+
"rstrip": false,
|
| 1056 |
+
"single_word": false,
|
| 1057 |
+
"special": true
|
| 1058 |
+
},
|
| 1059 |
+
"256128": {
|
| 1060 |
+
"content": "<extra_id_171>",
|
| 1061 |
+
"lstrip": false,
|
| 1062 |
+
"normalized": false,
|
| 1063 |
+
"rstrip": false,
|
| 1064 |
+
"single_word": false,
|
| 1065 |
+
"special": true
|
| 1066 |
+
},
|
| 1067 |
+
"256129": {
|
| 1068 |
+
"content": "<extra_id_170>",
|
| 1069 |
+
"lstrip": false,
|
| 1070 |
+
"normalized": false,
|
| 1071 |
+
"rstrip": false,
|
| 1072 |
+
"single_word": false,
|
| 1073 |
+
"special": true
|
| 1074 |
+
},
|
| 1075 |
+
"256130": {
|
| 1076 |
+
"content": "<extra_id_169>",
|
| 1077 |
+
"lstrip": false,
|
| 1078 |
+
"normalized": false,
|
| 1079 |
+
"rstrip": false,
|
| 1080 |
+
"single_word": false,
|
| 1081 |
+
"special": true
|
| 1082 |
+
},
|
| 1083 |
+
"256131": {
|
| 1084 |
+
"content": "<extra_id_168>",
|
| 1085 |
+
"lstrip": false,
|
| 1086 |
+
"normalized": false,
|
| 1087 |
+
"rstrip": false,
|
| 1088 |
+
"single_word": false,
|
| 1089 |
+
"special": true
|
| 1090 |
+
},
|
| 1091 |
+
"256132": {
|
| 1092 |
+
"content": "<extra_id_167>",
|
| 1093 |
+
"lstrip": false,
|
| 1094 |
+
"normalized": false,
|
| 1095 |
+
"rstrip": false,
|
| 1096 |
+
"single_word": false,
|
| 1097 |
+
"special": true
|
| 1098 |
+
},
|
| 1099 |
+
"256133": {
|
| 1100 |
+
"content": "<extra_id_166>",
|
| 1101 |
+
"lstrip": false,
|
| 1102 |
+
"normalized": false,
|
| 1103 |
+
"rstrip": false,
|
| 1104 |
+
"single_word": false,
|
| 1105 |
+
"special": true
|
| 1106 |
+
},
|
| 1107 |
+
"256134": {
|
| 1108 |
+
"content": "<extra_id_165>",
|
| 1109 |
+
"lstrip": false,
|
| 1110 |
+
"normalized": false,
|
| 1111 |
+
"rstrip": false,
|
| 1112 |
+
"single_word": false,
|
| 1113 |
+
"special": true
|
| 1114 |
+
},
|
| 1115 |
+
"256135": {
|
| 1116 |
+
"content": "<extra_id_164>",
|
| 1117 |
+
"lstrip": false,
|
| 1118 |
+
"normalized": false,
|
| 1119 |
+
"rstrip": false,
|
| 1120 |
+
"single_word": false,
|
| 1121 |
+
"special": true
|
| 1122 |
+
},
|
| 1123 |
+
"256136": {
|
| 1124 |
+
"content": "<extra_id_163>",
|
| 1125 |
+
"lstrip": false,
|
| 1126 |
+
"normalized": false,
|
| 1127 |
+
"rstrip": false,
|
| 1128 |
+
"single_word": false,
|
| 1129 |
+
"special": true
|
| 1130 |
+
},
|
| 1131 |
+
"256137": {
|
| 1132 |
+
"content": "<extra_id_162>",
|
| 1133 |
+
"lstrip": false,
|
| 1134 |
+
"normalized": false,
|
| 1135 |
+
"rstrip": false,
|
| 1136 |
+
"single_word": false,
|
| 1137 |
+
"special": true
|
| 1138 |
+
},
|
| 1139 |
+
"256138": {
|
| 1140 |
+
"content": "<extra_id_161>",
|
| 1141 |
+
"lstrip": false,
|
| 1142 |
+
"normalized": false,
|
| 1143 |
+
"rstrip": false,
|
| 1144 |
+
"single_word": false,
|
| 1145 |
+
"special": true
|
| 1146 |
+
},
|
| 1147 |
+
"256139": {
|
| 1148 |
+
"content": "<extra_id_160>",
|
| 1149 |
+
"lstrip": false,
|
| 1150 |
+
"normalized": false,
|
| 1151 |
+
"rstrip": false,
|
| 1152 |
+
"single_word": false,
|
| 1153 |
+
"special": true
|
| 1154 |
+
},
|
| 1155 |
+
"256140": {
|
| 1156 |
+
"content": "<extra_id_159>",
|
| 1157 |
+
"lstrip": false,
|
| 1158 |
+
"normalized": false,
|
| 1159 |
+
"rstrip": false,
|
| 1160 |
+
"single_word": false,
|
| 1161 |
+
"special": true
|
| 1162 |
+
},
|
| 1163 |
+
"256141": {
|
| 1164 |
+
"content": "<extra_id_158>",
|
| 1165 |
+
"lstrip": false,
|
| 1166 |
+
"normalized": false,
|
| 1167 |
+
"rstrip": false,
|
| 1168 |
+
"single_word": false,
|
| 1169 |
+
"special": true
|
| 1170 |
+
},
|
| 1171 |
+
"256142": {
|
| 1172 |
+
"content": "<extra_id_157>",
|
| 1173 |
+
"lstrip": false,
|
| 1174 |
+
"normalized": false,
|
| 1175 |
+
"rstrip": false,
|
| 1176 |
+
"single_word": false,
|
| 1177 |
+
"special": true
|
| 1178 |
+
},
|
| 1179 |
+
"256143": {
|
| 1180 |
+
"content": "<extra_id_156>",
|
| 1181 |
+
"lstrip": false,
|
| 1182 |
+
"normalized": false,
|
| 1183 |
+
"rstrip": false,
|
| 1184 |
+
"single_word": false,
|
| 1185 |
+
"special": true
|
| 1186 |
+
},
|
| 1187 |
+
"256144": {
|
| 1188 |
+
"content": "<extra_id_155>",
|
| 1189 |
+
"lstrip": false,
|
| 1190 |
+
"normalized": false,
|
| 1191 |
+
"rstrip": false,
|
| 1192 |
+
"single_word": false,
|
| 1193 |
+
"special": true
|
| 1194 |
+
},
|
| 1195 |
+
"256145": {
|
| 1196 |
+
"content": "<extra_id_154>",
|
| 1197 |
+
"lstrip": false,
|
| 1198 |
+
"normalized": false,
|
| 1199 |
+
"rstrip": false,
|
| 1200 |
+
"single_word": false,
|
| 1201 |
+
"special": true
|
| 1202 |
+
},
|
| 1203 |
+
"256146": {
|
| 1204 |
+
"content": "<extra_id_153>",
|
| 1205 |
+
"lstrip": false,
|
| 1206 |
+
"normalized": false,
|
| 1207 |
+
"rstrip": false,
|
| 1208 |
+
"single_word": false,
|
| 1209 |
+
"special": true
|
| 1210 |
+
},
|
| 1211 |
+
"256147": {
|
| 1212 |
+
"content": "<extra_id_152>",
|
| 1213 |
+
"lstrip": false,
|
| 1214 |
+
"normalized": false,
|
| 1215 |
+
"rstrip": false,
|
| 1216 |
+
"single_word": false,
|
| 1217 |
+
"special": true
|
| 1218 |
+
},
|
| 1219 |
+
"256148": {
|
| 1220 |
+
"content": "<extra_id_151>",
|
| 1221 |
+
"lstrip": false,
|
| 1222 |
+
"normalized": false,
|
| 1223 |
+
"rstrip": false,
|
| 1224 |
+
"single_word": false,
|
| 1225 |
+
"special": true
|
| 1226 |
+
},
|
| 1227 |
+
"256149": {
|
| 1228 |
+
"content": "<extra_id_150>",
|
| 1229 |
+
"lstrip": false,
|
| 1230 |
+
"normalized": false,
|
| 1231 |
+
"rstrip": false,
|
| 1232 |
+
"single_word": false,
|
| 1233 |
+
"special": true
|
| 1234 |
+
},
|
| 1235 |
+
"256150": {
|
| 1236 |
+
"content": "<extra_id_149>",
|
| 1237 |
+
"lstrip": false,
|
| 1238 |
+
"normalized": false,
|
| 1239 |
+
"rstrip": false,
|
| 1240 |
+
"single_word": false,
|
| 1241 |
+
"special": true
|
| 1242 |
+
},
|
| 1243 |
+
"256151": {
|
| 1244 |
+
"content": "<extra_id_148>",
|
| 1245 |
+
"lstrip": false,
|
| 1246 |
+
"normalized": false,
|
| 1247 |
+
"rstrip": false,
|
| 1248 |
+
"single_word": false,
|
| 1249 |
+
"special": true
|
| 1250 |
+
},
|
| 1251 |
+
"256152": {
|
| 1252 |
+
"content": "<extra_id_147>",
|
| 1253 |
+
"lstrip": false,
|
| 1254 |
+
"normalized": false,
|
| 1255 |
+
"rstrip": false,
|
| 1256 |
+
"single_word": false,
|
| 1257 |
+
"special": true
|
| 1258 |
+
},
|
| 1259 |
+
"256153": {
|
| 1260 |
+
"content": "<extra_id_146>",
|
| 1261 |
+
"lstrip": false,
|
| 1262 |
+
"normalized": false,
|
| 1263 |
+
"rstrip": false,
|
| 1264 |
+
"single_word": false,
|
| 1265 |
+
"special": true
|
| 1266 |
+
},
|
| 1267 |
+
"256154": {
|
| 1268 |
+
"content": "<extra_id_145>",
|
| 1269 |
+
"lstrip": false,
|
| 1270 |
+
"normalized": false,
|
| 1271 |
+
"rstrip": false,
|
| 1272 |
+
"single_word": false,
|
| 1273 |
+
"special": true
|
| 1274 |
+
},
|
| 1275 |
+
"256155": {
|
| 1276 |
+
"content": "<extra_id_144>",
|
| 1277 |
+
"lstrip": false,
|
| 1278 |
+
"normalized": false,
|
| 1279 |
+
"rstrip": false,
|
| 1280 |
+
"single_word": false,
|
| 1281 |
+
"special": true
|
| 1282 |
+
},
|
| 1283 |
+
"256156": {
|
| 1284 |
+
"content": "<extra_id_143>",
|
| 1285 |
+
"lstrip": false,
|
| 1286 |
+
"normalized": false,
|
| 1287 |
+
"rstrip": false,
|
| 1288 |
+
"single_word": false,
|
| 1289 |
+
"special": true
|
| 1290 |
+
},
|
| 1291 |
+
"256157": {
|
| 1292 |
+
"content": "<extra_id_142>",
|
| 1293 |
+
"lstrip": false,
|
| 1294 |
+
"normalized": false,
|
| 1295 |
+
"rstrip": false,
|
| 1296 |
+
"single_word": false,
|
| 1297 |
+
"special": true
|
| 1298 |
+
},
|
| 1299 |
+
"256158": {
|
| 1300 |
+
"content": "<extra_id_141>",
|
| 1301 |
+
"lstrip": false,
|
| 1302 |
+
"normalized": false,
|
| 1303 |
+
"rstrip": false,
|
| 1304 |
+
"single_word": false,
|
| 1305 |
+
"special": true
|
| 1306 |
+
},
|
| 1307 |
+
"256159": {
|
| 1308 |
+
"content": "<extra_id_140>",
|
| 1309 |
+
"lstrip": false,
|
| 1310 |
+
"normalized": false,
|
| 1311 |
+
"rstrip": false,
|
| 1312 |
+
"single_word": false,
|
| 1313 |
+
"special": true
|
| 1314 |
+
},
|
| 1315 |
+
"256160": {
|
| 1316 |
+
"content": "<extra_id_139>",
|
| 1317 |
+
"lstrip": false,
|
| 1318 |
+
"normalized": false,
|
| 1319 |
+
"rstrip": false,
|
| 1320 |
+
"single_word": false,
|
| 1321 |
+
"special": true
|
| 1322 |
+
},
|
| 1323 |
+
"256161": {
|
| 1324 |
+
"content": "<extra_id_138>",
|
| 1325 |
+
"lstrip": false,
|
| 1326 |
+
"normalized": false,
|
| 1327 |
+
"rstrip": false,
|
| 1328 |
+
"single_word": false,
|
| 1329 |
+
"special": true
|
| 1330 |
+
},
|
| 1331 |
+
"256162": {
|
| 1332 |
+
"content": "<extra_id_137>",
|
| 1333 |
+
"lstrip": false,
|
| 1334 |
+
"normalized": false,
|
| 1335 |
+
"rstrip": false,
|
| 1336 |
+
"single_word": false,
|
| 1337 |
+
"special": true
|
| 1338 |
+
},
|
| 1339 |
+
"256163": {
|
| 1340 |
+
"content": "<extra_id_136>",
|
| 1341 |
+
"lstrip": false,
|
| 1342 |
+
"normalized": false,
|
| 1343 |
+
"rstrip": false,
|
| 1344 |
+
"single_word": false,
|
| 1345 |
+
"special": true
|
| 1346 |
+
},
|
| 1347 |
+
"256164": {
|
| 1348 |
+
"content": "<extra_id_135>",
|
| 1349 |
+
"lstrip": false,
|
| 1350 |
+
"normalized": false,
|
| 1351 |
+
"rstrip": false,
|
| 1352 |
+
"single_word": false,
|
| 1353 |
+
"special": true
|
| 1354 |
+
},
|
| 1355 |
+
"256165": {
|
| 1356 |
+
"content": "<extra_id_134>",
|
| 1357 |
+
"lstrip": false,
|
| 1358 |
+
"normalized": false,
|
| 1359 |
+
"rstrip": false,
|
| 1360 |
+
"single_word": false,
|
| 1361 |
+
"special": true
|
| 1362 |
+
},
|
| 1363 |
+
"256166": {
|
| 1364 |
+
"content": "<extra_id_133>",
|
| 1365 |
+
"lstrip": false,
|
| 1366 |
+
"normalized": false,
|
| 1367 |
+
"rstrip": false,
|
| 1368 |
+
"single_word": false,
|
| 1369 |
+
"special": true
|
| 1370 |
+
},
|
| 1371 |
+
"256167": {
|
| 1372 |
+
"content": "<extra_id_132>",
|
| 1373 |
+
"lstrip": false,
|
| 1374 |
+
"normalized": false,
|
| 1375 |
+
"rstrip": false,
|
| 1376 |
+
"single_word": false,
|
| 1377 |
+
"special": true
|
| 1378 |
+
},
|
| 1379 |
+
"256168": {
|
| 1380 |
+
"content": "<extra_id_131>",
|
| 1381 |
+
"lstrip": false,
|
| 1382 |
+
"normalized": false,
|
| 1383 |
+
"rstrip": false,
|
| 1384 |
+
"single_word": false,
|
| 1385 |
+
"special": true
|
| 1386 |
+
},
|
| 1387 |
+
"256169": {
|
| 1388 |
+
"content": "<extra_id_130>",
|
| 1389 |
+
"lstrip": false,
|
| 1390 |
+
"normalized": false,
|
| 1391 |
+
"rstrip": false,
|
| 1392 |
+
"single_word": false,
|
| 1393 |
+
"special": true
|
| 1394 |
+
},
|
| 1395 |
+
"256170": {
|
| 1396 |
+
"content": "<extra_id_129>",
|
| 1397 |
+
"lstrip": false,
|
| 1398 |
+
"normalized": false,
|
| 1399 |
+
"rstrip": false,
|
| 1400 |
+
"single_word": false,
|
| 1401 |
+
"special": true
|
| 1402 |
+
},
|
| 1403 |
+
"256171": {
|
| 1404 |
+
"content": "<extra_id_128>",
|
| 1405 |
+
"lstrip": false,
|
| 1406 |
+
"normalized": false,
|
| 1407 |
+
"rstrip": false,
|
| 1408 |
+
"single_word": false,
|
| 1409 |
+
"special": true
|
| 1410 |
+
},
|
| 1411 |
+
"256172": {
|
| 1412 |
+
"content": "<extra_id_127>",
|
| 1413 |
+
"lstrip": false,
|
| 1414 |
+
"normalized": false,
|
| 1415 |
+
"rstrip": false,
|
| 1416 |
+
"single_word": false,
|
| 1417 |
+
"special": true
|
| 1418 |
+
},
|
| 1419 |
+
"256173": {
|
| 1420 |
+
"content": "<extra_id_126>",
|
| 1421 |
+
"lstrip": false,
|
| 1422 |
+
"normalized": false,
|
| 1423 |
+
"rstrip": false,
|
| 1424 |
+
"single_word": false,
|
| 1425 |
+
"special": true
|
| 1426 |
+
},
|
| 1427 |
+
"256174": {
|
| 1428 |
+
"content": "<extra_id_125>",
|
| 1429 |
+
"lstrip": false,
|
| 1430 |
+
"normalized": false,
|
| 1431 |
+
"rstrip": false,
|
| 1432 |
+
"single_word": false,
|
| 1433 |
+
"special": true
|
| 1434 |
+
},
|
| 1435 |
+
"256175": {
|
| 1436 |
+
"content": "<extra_id_124>",
|
| 1437 |
+
"lstrip": false,
|
| 1438 |
+
"normalized": false,
|
| 1439 |
+
"rstrip": false,
|
| 1440 |
+
"single_word": false,
|
| 1441 |
+
"special": true
|
| 1442 |
+
},
|
| 1443 |
+
"256176": {
|
| 1444 |
+
"content": "<extra_id_123>",
|
| 1445 |
+
"lstrip": false,
|
| 1446 |
+
"normalized": false,
|
| 1447 |
+
"rstrip": false,
|
| 1448 |
+
"single_word": false,
|
| 1449 |
+
"special": true
|
| 1450 |
+
},
|
| 1451 |
+
"256177": {
|
| 1452 |
+
"content": "<extra_id_122>",
|
| 1453 |
+
"lstrip": false,
|
| 1454 |
+
"normalized": false,
|
| 1455 |
+
"rstrip": false,
|
| 1456 |
+
"single_word": false,
|
| 1457 |
+
"special": true
|
| 1458 |
+
},
|
| 1459 |
+
"256178": {
|
| 1460 |
+
"content": "<extra_id_121>",
|
| 1461 |
+
"lstrip": false,
|
| 1462 |
+
"normalized": false,
|
| 1463 |
+
"rstrip": false,
|
| 1464 |
+
"single_word": false,
|
| 1465 |
+
"special": true
|
| 1466 |
+
},
|
| 1467 |
+
"256179": {
|
| 1468 |
+
"content": "<extra_id_120>",
|
| 1469 |
+
"lstrip": false,
|
| 1470 |
+
"normalized": false,
|
| 1471 |
+
"rstrip": false,
|
| 1472 |
+
"single_word": false,
|
| 1473 |
+
"special": true
|
| 1474 |
+
},
|
| 1475 |
+
"256180": {
|
| 1476 |
+
"content": "<extra_id_119>",
|
| 1477 |
+
"lstrip": false,
|
| 1478 |
+
"normalized": false,
|
| 1479 |
+
"rstrip": false,
|
| 1480 |
+
"single_word": false,
|
| 1481 |
+
"special": true
|
| 1482 |
+
},
|
| 1483 |
+
"256181": {
|
| 1484 |
+
"content": "<extra_id_118>",
|
| 1485 |
+
"lstrip": false,
|
| 1486 |
+
"normalized": false,
|
| 1487 |
+
"rstrip": false,
|
| 1488 |
+
"single_word": false,
|
| 1489 |
+
"special": true
|
| 1490 |
+
},
|
| 1491 |
+
"256182": {
|
| 1492 |
+
"content": "<extra_id_117>",
|
| 1493 |
+
"lstrip": false,
|
| 1494 |
+
"normalized": false,
|
| 1495 |
+
"rstrip": false,
|
| 1496 |
+
"single_word": false,
|
| 1497 |
+
"special": true
|
| 1498 |
+
},
|
| 1499 |
+
"256183": {
|
| 1500 |
+
"content": "<extra_id_116>",
|
| 1501 |
+
"lstrip": false,
|
| 1502 |
+
"normalized": false,
|
| 1503 |
+
"rstrip": false,
|
| 1504 |
+
"single_word": false,
|
| 1505 |
+
"special": true
|
| 1506 |
+
},
|
| 1507 |
+
"256184": {
|
| 1508 |
+
"content": "<extra_id_115>",
|
| 1509 |
+
"lstrip": false,
|
| 1510 |
+
"normalized": false,
|
| 1511 |
+
"rstrip": false,
|
| 1512 |
+
"single_word": false,
|
| 1513 |
+
"special": true
|
| 1514 |
+
},
|
| 1515 |
+
"256185": {
|
| 1516 |
+
"content": "<extra_id_114>",
|
| 1517 |
+
"lstrip": false,
|
| 1518 |
+
"normalized": false,
|
| 1519 |
+
"rstrip": false,
|
| 1520 |
+
"single_word": false,
|
| 1521 |
+
"special": true
|
| 1522 |
+
},
|
| 1523 |
+
"256186": {
|
| 1524 |
+
"content": "<extra_id_113>",
|
| 1525 |
+
"lstrip": false,
|
| 1526 |
+
"normalized": false,
|
| 1527 |
+
"rstrip": false,
|
| 1528 |
+
"single_word": false,
|
| 1529 |
+
"special": true
|
| 1530 |
+
},
|
| 1531 |
+
"256187": {
|
| 1532 |
+
"content": "<extra_id_112>",
|
| 1533 |
+
"lstrip": false,
|
| 1534 |
+
"normalized": false,
|
| 1535 |
+
"rstrip": false,
|
| 1536 |
+
"single_word": false,
|
| 1537 |
+
"special": true
|
| 1538 |
+
},
|
| 1539 |
+
"256188": {
|
| 1540 |
+
"content": "<extra_id_111>",
|
| 1541 |
+
"lstrip": false,
|
| 1542 |
+
"normalized": false,
|
| 1543 |
+
"rstrip": false,
|
| 1544 |
+
"single_word": false,
|
| 1545 |
+
"special": true
|
| 1546 |
+
},
|
| 1547 |
+
"256189": {
|
| 1548 |
+
"content": "<extra_id_110>",
|
| 1549 |
+
"lstrip": false,
|
| 1550 |
+
"normalized": false,
|
| 1551 |
+
"rstrip": false,
|
| 1552 |
+
"single_word": false,
|
| 1553 |
+
"special": true
|
| 1554 |
+
},
|
| 1555 |
+
"256190": {
|
| 1556 |
+
"content": "<extra_id_109>",
|
| 1557 |
+
"lstrip": false,
|
| 1558 |
+
"normalized": false,
|
| 1559 |
+
"rstrip": false,
|
| 1560 |
+
"single_word": false,
|
| 1561 |
+
"special": true
|
| 1562 |
+
},
|
| 1563 |
+
"256191": {
|
| 1564 |
+
"content": "<extra_id_108>",
|
| 1565 |
+
"lstrip": false,
|
| 1566 |
+
"normalized": false,
|
| 1567 |
+
"rstrip": false,
|
| 1568 |
+
"single_word": false,
|
| 1569 |
+
"special": true
|
| 1570 |
+
},
|
| 1571 |
+
"256192": {
|
| 1572 |
+
"content": "<extra_id_107>",
|
| 1573 |
+
"lstrip": false,
|
| 1574 |
+
"normalized": false,
|
| 1575 |
+
"rstrip": false,
|
| 1576 |
+
"single_word": false,
|
| 1577 |
+
"special": true
|
| 1578 |
+
},
|
| 1579 |
+
"256193": {
|
| 1580 |
+
"content": "<extra_id_106>",
|
| 1581 |
+
"lstrip": false,
|
| 1582 |
+
"normalized": false,
|
| 1583 |
+
"rstrip": false,
|
| 1584 |
+
"single_word": false,
|
| 1585 |
+
"special": true
|
| 1586 |
+
},
|
| 1587 |
+
"256194": {
|
| 1588 |
+
"content": "<extra_id_105>",
|
| 1589 |
+
"lstrip": false,
|
| 1590 |
+
"normalized": false,
|
| 1591 |
+
"rstrip": false,
|
| 1592 |
+
"single_word": false,
|
| 1593 |
+
"special": true
|
| 1594 |
+
},
|
| 1595 |
+
"256195": {
|
| 1596 |
+
"content": "<extra_id_104>",
|
| 1597 |
+
"lstrip": false,
|
| 1598 |
+
"normalized": false,
|
| 1599 |
+
"rstrip": false,
|
| 1600 |
+
"single_word": false,
|
| 1601 |
+
"special": true
|
| 1602 |
+
},
|
| 1603 |
+
"256196": {
|
| 1604 |
+
"content": "<extra_id_103>",
|
| 1605 |
+
"lstrip": false,
|
| 1606 |
+
"normalized": false,
|
| 1607 |
+
"rstrip": false,
|
| 1608 |
+
"single_word": false,
|
| 1609 |
+
"special": true
|
| 1610 |
+
},
|
| 1611 |
+
"256197": {
|
| 1612 |
+
"content": "<extra_id_102>",
|
| 1613 |
+
"lstrip": false,
|
| 1614 |
+
"normalized": false,
|
| 1615 |
+
"rstrip": false,
|
| 1616 |
+
"single_word": false,
|
| 1617 |
+
"special": true
|
| 1618 |
+
},
|
| 1619 |
+
"256198": {
|
| 1620 |
+
"content": "<extra_id_101>",
|
| 1621 |
+
"lstrip": false,
|
| 1622 |
+
"normalized": false,
|
| 1623 |
+
"rstrip": false,
|
| 1624 |
+
"single_word": false,
|
| 1625 |
+
"special": true
|
| 1626 |
+
},
|
| 1627 |
+
"256199": {
|
| 1628 |
+
"content": "<extra_id_100>",
|
| 1629 |
+
"lstrip": false,
|
| 1630 |
+
"normalized": false,
|
| 1631 |
+
"rstrip": false,
|
| 1632 |
+
"single_word": false,
|
| 1633 |
+
"special": true
|
| 1634 |
+
},
|
| 1635 |
+
"256200": {
|
| 1636 |
+
"content": "<extra_id_99>",
|
| 1637 |
+
"lstrip": false,
|
| 1638 |
+
"normalized": false,
|
| 1639 |
+
"rstrip": false,
|
| 1640 |
+
"single_word": false,
|
| 1641 |
+
"special": true
|
| 1642 |
+
},
|
| 1643 |
+
"256201": {
|
| 1644 |
+
"content": "<extra_id_98>",
|
| 1645 |
+
"lstrip": false,
|
| 1646 |
+
"normalized": false,
|
| 1647 |
+
"rstrip": false,
|
| 1648 |
+
"single_word": false,
|
| 1649 |
+
"special": true
|
| 1650 |
+
},
|
| 1651 |
+
"256202": {
|
| 1652 |
+
"content": "<extra_id_97>",
|
| 1653 |
+
"lstrip": false,
|
| 1654 |
+
"normalized": false,
|
| 1655 |
+
"rstrip": false,
|
| 1656 |
+
"single_word": false,
|
| 1657 |
+
"special": true
|
| 1658 |
+
},
|
| 1659 |
+
"256203": {
|
| 1660 |
+
"content": "<extra_id_96>",
|
| 1661 |
+
"lstrip": false,
|
| 1662 |
+
"normalized": false,
|
| 1663 |
+
"rstrip": false,
|
| 1664 |
+
"single_word": false,
|
| 1665 |
+
"special": true
|
| 1666 |
+
},
|
| 1667 |
+
"256204": {
|
| 1668 |
+
"content": "<extra_id_95>",
|
| 1669 |
+
"lstrip": false,
|
| 1670 |
+
"normalized": false,
|
| 1671 |
+
"rstrip": false,
|
| 1672 |
+
"single_word": false,
|
| 1673 |
+
"special": true
|
| 1674 |
+
},
|
| 1675 |
+
"256205": {
|
| 1676 |
+
"content": "<extra_id_94>",
|
| 1677 |
+
"lstrip": false,
|
| 1678 |
+
"normalized": false,
|
| 1679 |
+
"rstrip": false,
|
| 1680 |
+
"single_word": false,
|
| 1681 |
+
"special": true
|
| 1682 |
+
},
|
| 1683 |
+
"256206": {
|
| 1684 |
+
"content": "<extra_id_93>",
|
| 1685 |
+
"lstrip": false,
|
| 1686 |
+
"normalized": false,
|
| 1687 |
+
"rstrip": false,
|
| 1688 |
+
"single_word": false,
|
| 1689 |
+
"special": true
|
| 1690 |
+
},
|
| 1691 |
+
"256207": {
|
| 1692 |
+
"content": "<extra_id_92>",
|
| 1693 |
+
"lstrip": false,
|
| 1694 |
+
"normalized": false,
|
| 1695 |
+
"rstrip": false,
|
| 1696 |
+
"single_word": false,
|
| 1697 |
+
"special": true
|
| 1698 |
+
},
|
| 1699 |
+
"256208": {
|
| 1700 |
+
"content": "<extra_id_91>",
|
| 1701 |
+
"lstrip": false,
|
| 1702 |
+
"normalized": false,
|
| 1703 |
+
"rstrip": false,
|
| 1704 |
+
"single_word": false,
|
| 1705 |
+
"special": true
|
| 1706 |
+
},
|
| 1707 |
+
"256209": {
|
| 1708 |
+
"content": "<extra_id_90>",
|
| 1709 |
+
"lstrip": false,
|
| 1710 |
+
"normalized": false,
|
| 1711 |
+
"rstrip": false,
|
| 1712 |
+
"single_word": false,
|
| 1713 |
+
"special": true
|
| 1714 |
+
},
|
| 1715 |
+
"256210": {
|
| 1716 |
+
"content": "<extra_id_89>",
|
| 1717 |
+
"lstrip": false,
|
| 1718 |
+
"normalized": false,
|
| 1719 |
+
"rstrip": false,
|
| 1720 |
+
"single_word": false,
|
| 1721 |
+
"special": true
|
| 1722 |
+
},
|
| 1723 |
+
"256211": {
|
| 1724 |
+
"content": "<extra_id_88>",
|
| 1725 |
+
"lstrip": false,
|
| 1726 |
+
"normalized": false,
|
| 1727 |
+
"rstrip": false,
|
| 1728 |
+
"single_word": false,
|
| 1729 |
+
"special": true
|
| 1730 |
+
},
|
| 1731 |
+
"256212": {
|
| 1732 |
+
"content": "<extra_id_87>",
|
| 1733 |
+
"lstrip": false,
|
| 1734 |
+
"normalized": false,
|
| 1735 |
+
"rstrip": false,
|
| 1736 |
+
"single_word": false,
|
| 1737 |
+
"special": true
|
| 1738 |
+
},
|
| 1739 |
+
"256213": {
|
| 1740 |
+
"content": "<extra_id_86>",
|
| 1741 |
+
"lstrip": false,
|
| 1742 |
+
"normalized": false,
|
| 1743 |
+
"rstrip": false,
|
| 1744 |
+
"single_word": false,
|
| 1745 |
+
"special": true
|
| 1746 |
+
},
|
| 1747 |
+
"256214": {
|
| 1748 |
+
"content": "<extra_id_85>",
|
| 1749 |
+
"lstrip": false,
|
| 1750 |
+
"normalized": false,
|
| 1751 |
+
"rstrip": false,
|
| 1752 |
+
"single_word": false,
|
| 1753 |
+
"special": true
|
| 1754 |
+
},
|
| 1755 |
+
"256215": {
|
| 1756 |
+
"content": "<extra_id_84>",
|
| 1757 |
+
"lstrip": false,
|
| 1758 |
+
"normalized": false,
|
| 1759 |
+
"rstrip": false,
|
| 1760 |
+
"single_word": false,
|
| 1761 |
+
"special": true
|
| 1762 |
+
},
|
| 1763 |
+
"256216": {
|
| 1764 |
+
"content": "<extra_id_83>",
|
| 1765 |
+
"lstrip": false,
|
| 1766 |
+
"normalized": false,
|
| 1767 |
+
"rstrip": false,
|
| 1768 |
+
"single_word": false,
|
| 1769 |
+
"special": true
|
| 1770 |
+
},
|
| 1771 |
+
"256217": {
|
| 1772 |
+
"content": "<extra_id_82>",
|
| 1773 |
+
"lstrip": false,
|
| 1774 |
+
"normalized": false,
|
| 1775 |
+
"rstrip": false,
|
| 1776 |
+
"single_word": false,
|
| 1777 |
+
"special": true
|
| 1778 |
+
},
|
| 1779 |
+
"256218": {
|
| 1780 |
+
"content": "<extra_id_81>",
|
| 1781 |
+
"lstrip": false,
|
| 1782 |
+
"normalized": false,
|
| 1783 |
+
"rstrip": false,
|
| 1784 |
+
"single_word": false,
|
| 1785 |
+
"special": true
|
| 1786 |
+
},
|
| 1787 |
+
"256219": {
|
| 1788 |
+
"content": "<extra_id_80>",
|
| 1789 |
+
"lstrip": false,
|
| 1790 |
+
"normalized": false,
|
| 1791 |
+
"rstrip": false,
|
| 1792 |
+
"single_word": false,
|
| 1793 |
+
"special": true
|
| 1794 |
+
},
|
| 1795 |
+
"256220": {
|
| 1796 |
+
"content": "<extra_id_79>",
|
| 1797 |
+
"lstrip": false,
|
| 1798 |
+
"normalized": false,
|
| 1799 |
+
"rstrip": false,
|
| 1800 |
+
"single_word": false,
|
| 1801 |
+
"special": true
|
| 1802 |
+
},
|
| 1803 |
+
"256221": {
|
| 1804 |
+
"content": "<extra_id_78>",
|
| 1805 |
+
"lstrip": false,
|
| 1806 |
+
"normalized": false,
|
| 1807 |
+
"rstrip": false,
|
| 1808 |
+
"single_word": false,
|
| 1809 |
+
"special": true
|
| 1810 |
+
},
|
| 1811 |
+
"256222": {
|
| 1812 |
+
"content": "<extra_id_77>",
|
| 1813 |
+
"lstrip": false,
|
| 1814 |
+
"normalized": false,
|
| 1815 |
+
"rstrip": false,
|
| 1816 |
+
"single_word": false,
|
| 1817 |
+
"special": true
|
| 1818 |
+
},
|
| 1819 |
+
"256223": {
|
| 1820 |
+
"content": "<extra_id_76>",
|
| 1821 |
+
"lstrip": false,
|
| 1822 |
+
"normalized": false,
|
| 1823 |
+
"rstrip": false,
|
| 1824 |
+
"single_word": false,
|
| 1825 |
+
"special": true
|
| 1826 |
+
},
|
| 1827 |
+
"256224": {
|
| 1828 |
+
"content": "<extra_id_75>",
|
| 1829 |
+
"lstrip": false,
|
| 1830 |
+
"normalized": false,
|
| 1831 |
+
"rstrip": false,
|
| 1832 |
+
"single_word": false,
|
| 1833 |
+
"special": true
|
| 1834 |
+
},
|
| 1835 |
+
"256225": {
|
| 1836 |
+
"content": "<extra_id_74>",
|
| 1837 |
+
"lstrip": false,
|
| 1838 |
+
"normalized": false,
|
| 1839 |
+
"rstrip": false,
|
| 1840 |
+
"single_word": false,
|
| 1841 |
+
"special": true
|
| 1842 |
+
},
|
| 1843 |
+
"256226": {
|
| 1844 |
+
"content": "<extra_id_73>",
|
| 1845 |
+
"lstrip": false,
|
| 1846 |
+
"normalized": false,
|
| 1847 |
+
"rstrip": false,
|
| 1848 |
+
"single_word": false,
|
| 1849 |
+
"special": true
|
| 1850 |
+
},
|
| 1851 |
+
"256227": {
|
| 1852 |
+
"content": "<extra_id_72>",
|
| 1853 |
+
"lstrip": false,
|
| 1854 |
+
"normalized": false,
|
| 1855 |
+
"rstrip": false,
|
| 1856 |
+
"single_word": false,
|
| 1857 |
+
"special": true
|
| 1858 |
+
},
|
| 1859 |
+
"256228": {
|
| 1860 |
+
"content": "<extra_id_71>",
|
| 1861 |
+
"lstrip": false,
|
| 1862 |
+
"normalized": false,
|
| 1863 |
+
"rstrip": false,
|
| 1864 |
+
"single_word": false,
|
| 1865 |
+
"special": true
|
| 1866 |
+
},
|
| 1867 |
+
"256229": {
|
| 1868 |
+
"content": "<extra_id_70>",
|
| 1869 |
+
"lstrip": false,
|
| 1870 |
+
"normalized": false,
|
| 1871 |
+
"rstrip": false,
|
| 1872 |
+
"single_word": false,
|
| 1873 |
+
"special": true
|
| 1874 |
+
},
|
| 1875 |
+
"256230": {
|
| 1876 |
+
"content": "<extra_id_69>",
|
| 1877 |
+
"lstrip": false,
|
| 1878 |
+
"normalized": false,
|
| 1879 |
+
"rstrip": false,
|
| 1880 |
+
"single_word": false,
|
| 1881 |
+
"special": true
|
| 1882 |
+
},
|
| 1883 |
+
"256231": {
|
| 1884 |
+
"content": "<extra_id_68>",
|
| 1885 |
+
"lstrip": false,
|
| 1886 |
+
"normalized": false,
|
| 1887 |
+
"rstrip": false,
|
| 1888 |
+
"single_word": false,
|
| 1889 |
+
"special": true
|
| 1890 |
+
},
|
| 1891 |
+
"256232": {
|
| 1892 |
+
"content": "<extra_id_67>",
|
| 1893 |
+
"lstrip": false,
|
| 1894 |
+
"normalized": false,
|
| 1895 |
+
"rstrip": false,
|
| 1896 |
+
"single_word": false,
|
| 1897 |
+
"special": true
|
| 1898 |
+
},
|
| 1899 |
+
"256233": {
|
| 1900 |
+
"content": "<extra_id_66>",
|
| 1901 |
+
"lstrip": false,
|
| 1902 |
+
"normalized": false,
|
| 1903 |
+
"rstrip": false,
|
| 1904 |
+
"single_word": false,
|
| 1905 |
+
"special": true
|
| 1906 |
+
},
|
| 1907 |
+
"256234": {
|
| 1908 |
+
"content": "<extra_id_65>",
|
| 1909 |
+
"lstrip": false,
|
| 1910 |
+
"normalized": false,
|
| 1911 |
+
"rstrip": false,
|
| 1912 |
+
"single_word": false,
|
| 1913 |
+
"special": true
|
| 1914 |
+
},
|
| 1915 |
+
"256235": {
|
| 1916 |
+
"content": "<extra_id_64>",
|
| 1917 |
+
"lstrip": false,
|
| 1918 |
+
"normalized": false,
|
| 1919 |
+
"rstrip": false,
|
| 1920 |
+
"single_word": false,
|
| 1921 |
+
"special": true
|
| 1922 |
+
},
|
| 1923 |
+
"256236": {
|
| 1924 |
+
"content": "<extra_id_63>",
|
| 1925 |
+
"lstrip": false,
|
| 1926 |
+
"normalized": false,
|
| 1927 |
+
"rstrip": false,
|
| 1928 |
+
"single_word": false,
|
| 1929 |
+
"special": true
|
| 1930 |
+
},
|
| 1931 |
+
"256237": {
|
| 1932 |
+
"content": "<extra_id_62>",
|
| 1933 |
+
"lstrip": false,
|
| 1934 |
+
"normalized": false,
|
| 1935 |
+
"rstrip": false,
|
| 1936 |
+
"single_word": false,
|
| 1937 |
+
"special": true
|
| 1938 |
+
},
|
| 1939 |
+
"256238": {
|
| 1940 |
+
"content": "<extra_id_61>",
|
| 1941 |
+
"lstrip": false,
|
| 1942 |
+
"normalized": false,
|
| 1943 |
+
"rstrip": false,
|
| 1944 |
+
"single_word": false,
|
| 1945 |
+
"special": true
|
| 1946 |
+
},
|
| 1947 |
+
"256239": {
|
| 1948 |
+
"content": "<extra_id_60>",
|
| 1949 |
+
"lstrip": false,
|
| 1950 |
+
"normalized": false,
|
| 1951 |
+
"rstrip": false,
|
| 1952 |
+
"single_word": false,
|
| 1953 |
+
"special": true
|
| 1954 |
+
},
|
| 1955 |
+
"256240": {
|
| 1956 |
+
"content": "<extra_id_59>",
|
| 1957 |
+
"lstrip": false,
|
| 1958 |
+
"normalized": false,
|
| 1959 |
+
"rstrip": false,
|
| 1960 |
+
"single_word": false,
|
| 1961 |
+
"special": true
|
| 1962 |
+
},
|
| 1963 |
+
"256241": {
|
| 1964 |
+
"content": "<extra_id_58>",
|
| 1965 |
+
"lstrip": false,
|
| 1966 |
+
"normalized": false,
|
| 1967 |
+
"rstrip": false,
|
| 1968 |
+
"single_word": false,
|
| 1969 |
+
"special": true
|
| 1970 |
+
},
|
| 1971 |
+
"256242": {
|
| 1972 |
+
"content": "<extra_id_57>",
|
| 1973 |
+
"lstrip": false,
|
| 1974 |
+
"normalized": false,
|
| 1975 |
+
"rstrip": false,
|
| 1976 |
+
"single_word": false,
|
| 1977 |
+
"special": true
|
| 1978 |
+
},
|
| 1979 |
+
"256243": {
|
| 1980 |
+
"content": "<extra_id_56>",
|
| 1981 |
+
"lstrip": false,
|
| 1982 |
+
"normalized": false,
|
| 1983 |
+
"rstrip": false,
|
| 1984 |
+
"single_word": false,
|
| 1985 |
+
"special": true
|
| 1986 |
+
},
|
| 1987 |
+
"256244": {
|
| 1988 |
+
"content": "<extra_id_55>",
|
| 1989 |
+
"lstrip": false,
|
| 1990 |
+
"normalized": false,
|
| 1991 |
+
"rstrip": false,
|
| 1992 |
+
"single_word": false,
|
| 1993 |
+
"special": true
|
| 1994 |
+
},
|
| 1995 |
+
"256245": {
|
| 1996 |
+
"content": "<extra_id_54>",
|
| 1997 |
+
"lstrip": false,
|
| 1998 |
+
"normalized": false,
|
| 1999 |
+
"rstrip": false,
|
| 2000 |
+
"single_word": false,
|
| 2001 |
+
"special": true
|
| 2002 |
+
},
|
| 2003 |
+
"256246": {
|
| 2004 |
+
"content": "<extra_id_53>",
|
| 2005 |
+
"lstrip": false,
|
| 2006 |
+
"normalized": false,
|
| 2007 |
+
"rstrip": false,
|
| 2008 |
+
"single_word": false,
|
| 2009 |
+
"special": true
|
| 2010 |
+
},
|
| 2011 |
+
"256247": {
|
| 2012 |
+
"content": "<extra_id_52>",
|
| 2013 |
+
"lstrip": false,
|
| 2014 |
+
"normalized": false,
|
| 2015 |
+
"rstrip": false,
|
| 2016 |
+
"single_word": false,
|
| 2017 |
+
"special": true
|
| 2018 |
+
},
|
| 2019 |
+
"256248": {
|
| 2020 |
+
"content": "<extra_id_51>",
|
| 2021 |
+
"lstrip": false,
|
| 2022 |
+
"normalized": false,
|
| 2023 |
+
"rstrip": false,
|
| 2024 |
+
"single_word": false,
|
| 2025 |
+
"special": true
|
| 2026 |
+
},
|
| 2027 |
+
"256249": {
|
| 2028 |
+
"content": "<extra_id_50>",
|
| 2029 |
+
"lstrip": false,
|
| 2030 |
+
"normalized": false,
|
| 2031 |
+
"rstrip": false,
|
| 2032 |
+
"single_word": false,
|
| 2033 |
+
"special": true
|
| 2034 |
+
},
|
| 2035 |
+
"256250": {
|
| 2036 |
+
"content": "<extra_id_49>",
|
| 2037 |
+
"lstrip": false,
|
| 2038 |
+
"normalized": false,
|
| 2039 |
+
"rstrip": false,
|
| 2040 |
+
"single_word": false,
|
| 2041 |
+
"special": true
|
| 2042 |
+
},
|
| 2043 |
+
"256251": {
|
| 2044 |
+
"content": "<extra_id_48>",
|
| 2045 |
+
"lstrip": false,
|
| 2046 |
+
"normalized": false,
|
| 2047 |
+
"rstrip": false,
|
| 2048 |
+
"single_word": false,
|
| 2049 |
+
"special": true
|
| 2050 |
+
},
|
| 2051 |
+
"256252": {
|
| 2052 |
+
"content": "<extra_id_47>",
|
| 2053 |
+
"lstrip": false,
|
| 2054 |
+
"normalized": false,
|
| 2055 |
+
"rstrip": false,
|
| 2056 |
+
"single_word": false,
|
| 2057 |
+
"special": true
|
| 2058 |
+
},
|
| 2059 |
+
"256253": {
|
| 2060 |
+
"content": "<extra_id_46>",
|
| 2061 |
+
"lstrip": false,
|
| 2062 |
+
"normalized": false,
|
| 2063 |
+
"rstrip": false,
|
| 2064 |
+
"single_word": false,
|
| 2065 |
+
"special": true
|
| 2066 |
+
},
|
| 2067 |
+
"256254": {
|
| 2068 |
+
"content": "<extra_id_45>",
|
| 2069 |
+
"lstrip": false,
|
| 2070 |
+
"normalized": false,
|
| 2071 |
+
"rstrip": false,
|
| 2072 |
+
"single_word": false,
|
| 2073 |
+
"special": true
|
| 2074 |
+
},
|
| 2075 |
+
"256255": {
|
| 2076 |
+
"content": "<extra_id_44>",
|
| 2077 |
+
"lstrip": false,
|
| 2078 |
+
"normalized": false,
|
| 2079 |
+
"rstrip": false,
|
| 2080 |
+
"single_word": false,
|
| 2081 |
+
"special": true
|
| 2082 |
+
},
|
| 2083 |
+
"256256": {
|
| 2084 |
+
"content": "<extra_id_43>",
|
| 2085 |
+
"lstrip": false,
|
| 2086 |
+
"normalized": false,
|
| 2087 |
+
"rstrip": false,
|
| 2088 |
+
"single_word": false,
|
| 2089 |
+
"special": true
|
| 2090 |
+
},
|
| 2091 |
+
"256257": {
|
| 2092 |
+
"content": "<extra_id_42>",
|
| 2093 |
+
"lstrip": false,
|
| 2094 |
+
"normalized": false,
|
| 2095 |
+
"rstrip": false,
|
| 2096 |
+
"single_word": false,
|
| 2097 |
+
"special": true
|
| 2098 |
+
},
|
| 2099 |
+
"256258": {
|
| 2100 |
+
"content": "<extra_id_41>",
|
| 2101 |
+
"lstrip": false,
|
| 2102 |
+
"normalized": false,
|
| 2103 |
+
"rstrip": false,
|
| 2104 |
+
"single_word": false,
|
| 2105 |
+
"special": true
|
| 2106 |
+
},
|
| 2107 |
+
"256259": {
|
| 2108 |
+
"content": "<extra_id_40>",
|
| 2109 |
+
"lstrip": false,
|
| 2110 |
+
"normalized": false,
|
| 2111 |
+
"rstrip": false,
|
| 2112 |
+
"single_word": false,
|
| 2113 |
+
"special": true
|
| 2114 |
+
},
|
| 2115 |
+
"256260": {
|
| 2116 |
+
"content": "<extra_id_39>",
|
| 2117 |
+
"lstrip": false,
|
| 2118 |
+
"normalized": false,
|
| 2119 |
+
"rstrip": false,
|
| 2120 |
+
"single_word": false,
|
| 2121 |
+
"special": true
|
| 2122 |
+
},
|
| 2123 |
+
"256261": {
|
| 2124 |
+
"content": "<extra_id_38>",
|
| 2125 |
+
"lstrip": false,
|
| 2126 |
+
"normalized": false,
|
| 2127 |
+
"rstrip": false,
|
| 2128 |
+
"single_word": false,
|
| 2129 |
+
"special": true
|
| 2130 |
+
},
|
| 2131 |
+
"256262": {
|
| 2132 |
+
"content": "<extra_id_37>",
|
| 2133 |
+
"lstrip": false,
|
| 2134 |
+
"normalized": false,
|
| 2135 |
+
"rstrip": false,
|
| 2136 |
+
"single_word": false,
|
| 2137 |
+
"special": true
|
| 2138 |
+
},
|
| 2139 |
+
"256263": {
|
| 2140 |
+
"content": "<extra_id_36>",
|
| 2141 |
+
"lstrip": false,
|
| 2142 |
+
"normalized": false,
|
| 2143 |
+
"rstrip": false,
|
| 2144 |
+
"single_word": false,
|
| 2145 |
+
"special": true
|
| 2146 |
+
},
|
| 2147 |
+
"256264": {
|
| 2148 |
+
"content": "<extra_id_35>",
|
| 2149 |
+
"lstrip": false,
|
| 2150 |
+
"normalized": false,
|
| 2151 |
+
"rstrip": false,
|
| 2152 |
+
"single_word": false,
|
| 2153 |
+
"special": true
|
| 2154 |
+
},
|
| 2155 |
+
"256265": {
|
| 2156 |
+
"content": "<extra_id_34>",
|
| 2157 |
+
"lstrip": false,
|
| 2158 |
+
"normalized": false,
|
| 2159 |
+
"rstrip": false,
|
| 2160 |
+
"single_word": false,
|
| 2161 |
+
"special": true
|
| 2162 |
+
},
|
| 2163 |
+
"256266": {
|
| 2164 |
+
"content": "<extra_id_33>",
|
| 2165 |
+
"lstrip": false,
|
| 2166 |
+
"normalized": false,
|
| 2167 |
+
"rstrip": false,
|
| 2168 |
+
"single_word": false,
|
| 2169 |
+
"special": true
|
| 2170 |
+
},
|
| 2171 |
+
"256267": {
|
| 2172 |
+
"content": "<extra_id_32>",
|
| 2173 |
+
"lstrip": false,
|
| 2174 |
+
"normalized": false,
|
| 2175 |
+
"rstrip": false,
|
| 2176 |
+
"single_word": false,
|
| 2177 |
+
"special": true
|
| 2178 |
+
},
|
| 2179 |
+
"256268": {
|
| 2180 |
+
"content": "<extra_id_31>",
|
| 2181 |
+
"lstrip": false,
|
| 2182 |
+
"normalized": false,
|
| 2183 |
+
"rstrip": false,
|
| 2184 |
+
"single_word": false,
|
| 2185 |
+
"special": true
|
| 2186 |
+
},
|
| 2187 |
+
"256269": {
|
| 2188 |
+
"content": "<extra_id_30>",
|
| 2189 |
+
"lstrip": false,
|
| 2190 |
+
"normalized": false,
|
| 2191 |
+
"rstrip": false,
|
| 2192 |
+
"single_word": false,
|
| 2193 |
+
"special": true
|
| 2194 |
+
},
|
| 2195 |
+
"256270": {
|
| 2196 |
+
"content": "<extra_id_29>",
|
| 2197 |
+
"lstrip": false,
|
| 2198 |
+
"normalized": false,
|
| 2199 |
+
"rstrip": false,
|
| 2200 |
+
"single_word": false,
|
| 2201 |
+
"special": true
|
| 2202 |
+
},
|
| 2203 |
+
"256271": {
|
| 2204 |
+
"content": "<extra_id_28>",
|
| 2205 |
+
"lstrip": false,
|
| 2206 |
+
"normalized": false,
|
| 2207 |
+
"rstrip": false,
|
| 2208 |
+
"single_word": false,
|
| 2209 |
+
"special": true
|
| 2210 |
+
},
|
| 2211 |
+
"256272": {
|
| 2212 |
+
"content": "<extra_id_27>",
|
| 2213 |
+
"lstrip": false,
|
| 2214 |
+
"normalized": false,
|
| 2215 |
+
"rstrip": false,
|
| 2216 |
+
"single_word": false,
|
| 2217 |
+
"special": true
|
| 2218 |
+
},
|
| 2219 |
+
"256273": {
|
| 2220 |
+
"content": "<extra_id_26>",
|
| 2221 |
+
"lstrip": false,
|
| 2222 |
+
"normalized": false,
|
| 2223 |
+
"rstrip": false,
|
| 2224 |
+
"single_word": false,
|
| 2225 |
+
"special": true
|
| 2226 |
+
},
|
| 2227 |
+
"256274": {
|
| 2228 |
+
"content": "<extra_id_25>",
|
| 2229 |
+
"lstrip": false,
|
| 2230 |
+
"normalized": false,
|
| 2231 |
+
"rstrip": false,
|
| 2232 |
+
"single_word": false,
|
| 2233 |
+
"special": true
|
| 2234 |
+
},
|
| 2235 |
+
"256275": {
|
| 2236 |
+
"content": "<extra_id_24>",
|
| 2237 |
+
"lstrip": false,
|
| 2238 |
+
"normalized": false,
|
| 2239 |
+
"rstrip": false,
|
| 2240 |
+
"single_word": false,
|
| 2241 |
+
"special": true
|
| 2242 |
+
},
|
| 2243 |
+
"256276": {
|
| 2244 |
+
"content": "<extra_id_23>",
|
| 2245 |
+
"lstrip": false,
|
| 2246 |
+
"normalized": false,
|
| 2247 |
+
"rstrip": false,
|
| 2248 |
+
"single_word": false,
|
| 2249 |
+
"special": true
|
| 2250 |
+
},
|
| 2251 |
+
"256277": {
|
| 2252 |
+
"content": "<extra_id_22>",
|
| 2253 |
+
"lstrip": false,
|
| 2254 |
+
"normalized": false,
|
| 2255 |
+
"rstrip": false,
|
| 2256 |
+
"single_word": false,
|
| 2257 |
+
"special": true
|
| 2258 |
+
},
|
| 2259 |
+
"256278": {
|
| 2260 |
+
"content": "<extra_id_21>",
|
| 2261 |
+
"lstrip": false,
|
| 2262 |
+
"normalized": false,
|
| 2263 |
+
"rstrip": false,
|
| 2264 |
+
"single_word": false,
|
| 2265 |
+
"special": true
|
| 2266 |
+
},
|
| 2267 |
+
"256279": {
|
| 2268 |
+
"content": "<extra_id_20>",
|
| 2269 |
+
"lstrip": false,
|
| 2270 |
+
"normalized": false,
|
| 2271 |
+
"rstrip": false,
|
| 2272 |
+
"single_word": false,
|
| 2273 |
+
"special": true
|
| 2274 |
+
},
|
| 2275 |
+
"256280": {
|
| 2276 |
+
"content": "<extra_id_19>",
|
| 2277 |
+
"lstrip": false,
|
| 2278 |
+
"normalized": false,
|
| 2279 |
+
"rstrip": false,
|
| 2280 |
+
"single_word": false,
|
| 2281 |
+
"special": true
|
| 2282 |
+
},
|
| 2283 |
+
"256281": {
|
| 2284 |
+
"content": "<extra_id_18>",
|
| 2285 |
+
"lstrip": false,
|
| 2286 |
+
"normalized": false,
|
| 2287 |
+
"rstrip": false,
|
| 2288 |
+
"single_word": false,
|
| 2289 |
+
"special": true
|
| 2290 |
+
},
|
| 2291 |
+
"256282": {
|
| 2292 |
+
"content": "<extra_id_17>",
|
| 2293 |
+
"lstrip": false,
|
| 2294 |
+
"normalized": false,
|
| 2295 |
+
"rstrip": false,
|
| 2296 |
+
"single_word": false,
|
| 2297 |
+
"special": true
|
| 2298 |
+
},
|
| 2299 |
+
"256283": {
|
| 2300 |
+
"content": "<extra_id_16>",
|
| 2301 |
+
"lstrip": false,
|
| 2302 |
+
"normalized": false,
|
| 2303 |
+
"rstrip": false,
|
| 2304 |
+
"single_word": false,
|
| 2305 |
+
"special": true
|
| 2306 |
+
},
|
| 2307 |
+
"256284": {
|
| 2308 |
+
"content": "<extra_id_15>",
|
| 2309 |
+
"lstrip": false,
|
| 2310 |
+
"normalized": false,
|
| 2311 |
+
"rstrip": false,
|
| 2312 |
+
"single_word": false,
|
| 2313 |
+
"special": true
|
| 2314 |
+
},
|
| 2315 |
+
"256285": {
|
| 2316 |
+
"content": "<extra_id_14>",
|
| 2317 |
+
"lstrip": false,
|
| 2318 |
+
"normalized": false,
|
| 2319 |
+
"rstrip": false,
|
| 2320 |
+
"single_word": false,
|
| 2321 |
+
"special": true
|
| 2322 |
+
},
|
| 2323 |
+
"256286": {
|
| 2324 |
+
"content": "<extra_id_13>",
|
| 2325 |
+
"lstrip": false,
|
| 2326 |
+
"normalized": false,
|
| 2327 |
+
"rstrip": false,
|
| 2328 |
+
"single_word": false,
|
| 2329 |
+
"special": true
|
| 2330 |
+
},
|
| 2331 |
+
"256287": {
|
| 2332 |
+
"content": "<extra_id_12>",
|
| 2333 |
+
"lstrip": false,
|
| 2334 |
+
"normalized": false,
|
| 2335 |
+
"rstrip": false,
|
| 2336 |
+
"single_word": false,
|
| 2337 |
+
"special": true
|
| 2338 |
+
},
|
| 2339 |
+
"256288": {
|
| 2340 |
+
"content": "<extra_id_11>",
|
| 2341 |
+
"lstrip": false,
|
| 2342 |
+
"normalized": false,
|
| 2343 |
+
"rstrip": false,
|
| 2344 |
+
"single_word": false,
|
| 2345 |
+
"special": true
|
| 2346 |
+
},
|
| 2347 |
+
"256289": {
|
| 2348 |
+
"content": "<extra_id_10>",
|
| 2349 |
+
"lstrip": false,
|
| 2350 |
+
"normalized": false,
|
| 2351 |
+
"rstrip": false,
|
| 2352 |
+
"single_word": false,
|
| 2353 |
+
"special": true
|
| 2354 |
+
},
|
| 2355 |
+
"256290": {
|
| 2356 |
+
"content": "<extra_id_9>",
|
| 2357 |
+
"lstrip": false,
|
| 2358 |
+
"normalized": false,
|
| 2359 |
+
"rstrip": false,
|
| 2360 |
+
"single_word": false,
|
| 2361 |
+
"special": true
|
| 2362 |
+
},
|
| 2363 |
+
"256291": {
|
| 2364 |
+
"content": "<extra_id_8>",
|
| 2365 |
+
"lstrip": false,
|
| 2366 |
+
"normalized": false,
|
| 2367 |
+
"rstrip": false,
|
| 2368 |
+
"single_word": false,
|
| 2369 |
+
"special": true
|
| 2370 |
+
},
|
| 2371 |
+
"256292": {
|
| 2372 |
+
"content": "<extra_id_7>",
|
| 2373 |
+
"lstrip": false,
|
| 2374 |
+
"normalized": false,
|
| 2375 |
+
"rstrip": false,
|
| 2376 |
+
"single_word": false,
|
| 2377 |
+
"special": true
|
| 2378 |
+
},
|
| 2379 |
+
"256293": {
|
| 2380 |
+
"content": "<extra_id_6>",
|
| 2381 |
+
"lstrip": false,
|
| 2382 |
+
"normalized": false,
|
| 2383 |
+
"rstrip": false,
|
| 2384 |
+
"single_word": false,
|
| 2385 |
+
"special": true
|
| 2386 |
+
},
|
| 2387 |
+
"256294": {
|
| 2388 |
+
"content": "<extra_id_5>",
|
| 2389 |
+
"lstrip": false,
|
| 2390 |
+
"normalized": false,
|
| 2391 |
+
"rstrip": false,
|
| 2392 |
+
"single_word": false,
|
| 2393 |
+
"special": true
|
| 2394 |
+
},
|
| 2395 |
+
"256295": {
|
| 2396 |
+
"content": "<extra_id_4>",
|
| 2397 |
+
"lstrip": false,
|
| 2398 |
+
"normalized": false,
|
| 2399 |
+
"rstrip": false,
|
| 2400 |
+
"single_word": false,
|
| 2401 |
+
"special": true
|
| 2402 |
+
},
|
| 2403 |
+
"256296": {
|
| 2404 |
+
"content": "<extra_id_3>",
|
| 2405 |
+
"lstrip": false,
|
| 2406 |
+
"normalized": false,
|
| 2407 |
+
"rstrip": false,
|
| 2408 |
+
"single_word": false,
|
| 2409 |
+
"special": true
|
| 2410 |
+
},
|
| 2411 |
+
"256297": {
|
| 2412 |
+
"content": "<extra_id_2>",
|
| 2413 |
+
"lstrip": false,
|
| 2414 |
+
"normalized": false,
|
| 2415 |
+
"rstrip": false,
|
| 2416 |
+
"single_word": false,
|
| 2417 |
+
"special": true
|
| 2418 |
+
},
|
| 2419 |
+
"256298": {
|
| 2420 |
+
"content": "<extra_id_1>",
|
| 2421 |
+
"lstrip": false,
|
| 2422 |
+
"normalized": false,
|
| 2423 |
+
"rstrip": false,
|
| 2424 |
+
"single_word": false,
|
| 2425 |
+
"special": true
|
| 2426 |
+
},
|
| 2427 |
+
"256299": {
|
| 2428 |
+
"content": "<extra_id_0>",
|
| 2429 |
+
"lstrip": false,
|
| 2430 |
+
"normalized": false,
|
| 2431 |
+
"rstrip": false,
|
| 2432 |
+
"single_word": false,
|
| 2433 |
+
"special": true
|
| 2434 |
+
}
|
| 2435 |
+
},
|
| 2436 |
+
"additional_special_tokens": [
|
| 2437 |
+
"<extra_id_0>",
|
| 2438 |
+
"<extra_id_1>",
|
| 2439 |
+
"<extra_id_2>",
|
| 2440 |
+
"<extra_id_3>",
|
| 2441 |
+
"<extra_id_4>",
|
| 2442 |
+
"<extra_id_5>",
|
| 2443 |
+
"<extra_id_6>",
|
| 2444 |
+
"<extra_id_7>",
|
| 2445 |
+
"<extra_id_8>",
|
| 2446 |
+
"<extra_id_9>",
|
| 2447 |
+
"<extra_id_10>",
|
| 2448 |
+
"<extra_id_11>",
|
| 2449 |
+
"<extra_id_12>",
|
| 2450 |
+
"<extra_id_13>",
|
| 2451 |
+
"<extra_id_14>",
|
| 2452 |
+
"<extra_id_15>",
|
| 2453 |
+
"<extra_id_16>",
|
| 2454 |
+
"<extra_id_17>",
|
| 2455 |
+
"<extra_id_18>",
|
| 2456 |
+
"<extra_id_19>",
|
| 2457 |
+
"<extra_id_20>",
|
| 2458 |
+
"<extra_id_21>",
|
| 2459 |
+
"<extra_id_22>",
|
| 2460 |
+
"<extra_id_23>",
|
| 2461 |
+
"<extra_id_24>",
|
| 2462 |
+
"<extra_id_25>",
|
| 2463 |
+
"<extra_id_26>",
|
| 2464 |
+
"<extra_id_27>",
|
| 2465 |
+
"<extra_id_28>",
|
| 2466 |
+
"<extra_id_29>",
|
| 2467 |
+
"<extra_id_30>",
|
| 2468 |
+
"<extra_id_31>",
|
| 2469 |
+
"<extra_id_32>",
|
| 2470 |
+
"<extra_id_33>",
|
| 2471 |
+
"<extra_id_34>",
|
| 2472 |
+
"<extra_id_35>",
|
| 2473 |
+
"<extra_id_36>",
|
| 2474 |
+
"<extra_id_37>",
|
| 2475 |
+
"<extra_id_38>",
|
| 2476 |
+
"<extra_id_39>",
|
| 2477 |
+
"<extra_id_40>",
|
| 2478 |
+
"<extra_id_41>",
|
| 2479 |
+
"<extra_id_42>",
|
| 2480 |
+
"<extra_id_43>",
|
| 2481 |
+
"<extra_id_44>",
|
| 2482 |
+
"<extra_id_45>",
|
| 2483 |
+
"<extra_id_46>",
|
| 2484 |
+
"<extra_id_47>",
|
| 2485 |
+
"<extra_id_48>",
|
| 2486 |
+
"<extra_id_49>",
|
| 2487 |
+
"<extra_id_50>",
|
| 2488 |
+
"<extra_id_51>",
|
| 2489 |
+
"<extra_id_52>",
|
| 2490 |
+
"<extra_id_53>",
|
| 2491 |
+
"<extra_id_54>",
|
| 2492 |
+
"<extra_id_55>",
|
| 2493 |
+
"<extra_id_56>",
|
| 2494 |
+
"<extra_id_57>",
|
| 2495 |
+
"<extra_id_58>",
|
| 2496 |
+
"<extra_id_59>",
|
| 2497 |
+
"<extra_id_60>",
|
| 2498 |
+
"<extra_id_61>",
|
| 2499 |
+
"<extra_id_62>",
|
| 2500 |
+
"<extra_id_63>",
|
| 2501 |
+
"<extra_id_64>",
|
| 2502 |
+
"<extra_id_65>",
|
| 2503 |
+
"<extra_id_66>",
|
| 2504 |
+
"<extra_id_67>",
|
| 2505 |
+
"<extra_id_68>",
|
| 2506 |
+
"<extra_id_69>",
|
| 2507 |
+
"<extra_id_70>",
|
| 2508 |
+
"<extra_id_71>",
|
| 2509 |
+
"<extra_id_72>",
|
| 2510 |
+
"<extra_id_73>",
|
| 2511 |
+
"<extra_id_74>",
|
| 2512 |
+
"<extra_id_75>",
|
| 2513 |
+
"<extra_id_76>",
|
| 2514 |
+
"<extra_id_77>",
|
| 2515 |
+
"<extra_id_78>",
|
| 2516 |
+
"<extra_id_79>",
|
| 2517 |
+
"<extra_id_80>",
|
| 2518 |
+
"<extra_id_81>",
|
| 2519 |
+
"<extra_id_82>",
|
| 2520 |
+
"<extra_id_83>",
|
| 2521 |
+
"<extra_id_84>",
|
| 2522 |
+
"<extra_id_85>",
|
| 2523 |
+
"<extra_id_86>",
|
| 2524 |
+
"<extra_id_87>",
|
| 2525 |
+
"<extra_id_88>",
|
| 2526 |
+
"<extra_id_89>",
|
| 2527 |
+
"<extra_id_90>",
|
| 2528 |
+
"<extra_id_91>",
|
| 2529 |
+
"<extra_id_92>",
|
| 2530 |
+
"<extra_id_93>",
|
| 2531 |
+
"<extra_id_94>",
|
| 2532 |
+
"<extra_id_95>",
|
| 2533 |
+
"<extra_id_96>",
|
| 2534 |
+
"<extra_id_97>",
|
| 2535 |
+
"<extra_id_98>",
|
| 2536 |
+
"<extra_id_99>",
|
| 2537 |
+
"<extra_id_100>",
|
| 2538 |
+
"<extra_id_101>",
|
| 2539 |
+
"<extra_id_102>",
|
| 2540 |
+
"<extra_id_103>",
|
| 2541 |
+
"<extra_id_104>",
|
| 2542 |
+
"<extra_id_105>",
|
| 2543 |
+
"<extra_id_106>",
|
| 2544 |
+
"<extra_id_107>",
|
| 2545 |
+
"<extra_id_108>",
|
| 2546 |
+
"<extra_id_109>",
|
| 2547 |
+
"<extra_id_110>",
|
| 2548 |
+
"<extra_id_111>",
|
| 2549 |
+
"<extra_id_112>",
|
| 2550 |
+
"<extra_id_113>",
|
| 2551 |
+
"<extra_id_114>",
|
| 2552 |
+
"<extra_id_115>",
|
| 2553 |
+
"<extra_id_116>",
|
| 2554 |
+
"<extra_id_117>",
|
| 2555 |
+
"<extra_id_118>",
|
| 2556 |
+
"<extra_id_119>",
|
| 2557 |
+
"<extra_id_120>",
|
| 2558 |
+
"<extra_id_121>",
|
| 2559 |
+
"<extra_id_122>",
|
| 2560 |
+
"<extra_id_123>",
|
| 2561 |
+
"<extra_id_124>",
|
| 2562 |
+
"<extra_id_125>",
|
| 2563 |
+
"<extra_id_126>",
|
| 2564 |
+
"<extra_id_127>",
|
| 2565 |
+
"<extra_id_128>",
|
| 2566 |
+
"<extra_id_129>",
|
| 2567 |
+
"<extra_id_130>",
|
| 2568 |
+
"<extra_id_131>",
|
| 2569 |
+
"<extra_id_132>",
|
| 2570 |
+
"<extra_id_133>",
|
| 2571 |
+
"<extra_id_134>",
|
| 2572 |
+
"<extra_id_135>",
|
| 2573 |
+
"<extra_id_136>",
|
| 2574 |
+
"<extra_id_137>",
|
| 2575 |
+
"<extra_id_138>",
|
| 2576 |
+
"<extra_id_139>",
|
| 2577 |
+
"<extra_id_140>",
|
| 2578 |
+
"<extra_id_141>",
|
| 2579 |
+
"<extra_id_142>",
|
| 2580 |
+
"<extra_id_143>",
|
| 2581 |
+
"<extra_id_144>",
|
| 2582 |
+
"<extra_id_145>",
|
| 2583 |
+
"<extra_id_146>",
|
| 2584 |
+
"<extra_id_147>",
|
| 2585 |
+
"<extra_id_148>",
|
| 2586 |
+
"<extra_id_149>",
|
| 2587 |
+
"<extra_id_150>",
|
| 2588 |
+
"<extra_id_151>",
|
| 2589 |
+
"<extra_id_152>",
|
| 2590 |
+
"<extra_id_153>",
|
| 2591 |
+
"<extra_id_154>",
|
| 2592 |
+
"<extra_id_155>",
|
| 2593 |
+
"<extra_id_156>",
|
| 2594 |
+
"<extra_id_157>",
|
| 2595 |
+
"<extra_id_158>",
|
| 2596 |
+
"<extra_id_159>",
|
| 2597 |
+
"<extra_id_160>",
|
| 2598 |
+
"<extra_id_161>",
|
| 2599 |
+
"<extra_id_162>",
|
| 2600 |
+
"<extra_id_163>",
|
| 2601 |
+
"<extra_id_164>",
|
| 2602 |
+
"<extra_id_165>",
|
| 2603 |
+
"<extra_id_166>",
|
| 2604 |
+
"<extra_id_167>",
|
| 2605 |
+
"<extra_id_168>",
|
| 2606 |
+
"<extra_id_169>",
|
| 2607 |
+
"<extra_id_170>",
|
| 2608 |
+
"<extra_id_171>",
|
| 2609 |
+
"<extra_id_172>",
|
| 2610 |
+
"<extra_id_173>",
|
| 2611 |
+
"<extra_id_174>",
|
| 2612 |
+
"<extra_id_175>",
|
| 2613 |
+
"<extra_id_176>",
|
| 2614 |
+
"<extra_id_177>",
|
| 2615 |
+
"<extra_id_178>",
|
| 2616 |
+
"<extra_id_179>",
|
| 2617 |
+
"<extra_id_180>",
|
| 2618 |
+
"<extra_id_181>",
|
| 2619 |
+
"<extra_id_182>",
|
| 2620 |
+
"<extra_id_183>",
|
| 2621 |
+
"<extra_id_184>",
|
| 2622 |
+
"<extra_id_185>",
|
| 2623 |
+
"<extra_id_186>",
|
| 2624 |
+
"<extra_id_187>",
|
| 2625 |
+
"<extra_id_188>",
|
| 2626 |
+
"<extra_id_189>",
|
| 2627 |
+
"<extra_id_190>",
|
| 2628 |
+
"<extra_id_191>",
|
| 2629 |
+
"<extra_id_192>",
|
| 2630 |
+
"<extra_id_193>",
|
| 2631 |
+
"<extra_id_194>",
|
| 2632 |
+
"<extra_id_195>",
|
| 2633 |
+
"<extra_id_196>",
|
| 2634 |
+
"<extra_id_197>",
|
| 2635 |
+
"<extra_id_198>",
|
| 2636 |
+
"<extra_id_199>",
|
| 2637 |
+
"<extra_id_200>",
|
| 2638 |
+
"<extra_id_201>",
|
| 2639 |
+
"<extra_id_202>",
|
| 2640 |
+
"<extra_id_203>",
|
| 2641 |
+
"<extra_id_204>",
|
| 2642 |
+
"<extra_id_205>",
|
| 2643 |
+
"<extra_id_206>",
|
| 2644 |
+
"<extra_id_207>",
|
| 2645 |
+
"<extra_id_208>",
|
| 2646 |
+
"<extra_id_209>",
|
| 2647 |
+
"<extra_id_210>",
|
| 2648 |
+
"<extra_id_211>",
|
| 2649 |
+
"<extra_id_212>",
|
| 2650 |
+
"<extra_id_213>",
|
| 2651 |
+
"<extra_id_214>",
|
| 2652 |
+
"<extra_id_215>",
|
| 2653 |
+
"<extra_id_216>",
|
| 2654 |
+
"<extra_id_217>",
|
| 2655 |
+
"<extra_id_218>",
|
| 2656 |
+
"<extra_id_219>",
|
| 2657 |
+
"<extra_id_220>",
|
| 2658 |
+
"<extra_id_221>",
|
| 2659 |
+
"<extra_id_222>",
|
| 2660 |
+
"<extra_id_223>",
|
| 2661 |
+
"<extra_id_224>",
|
| 2662 |
+
"<extra_id_225>",
|
| 2663 |
+
"<extra_id_226>",
|
| 2664 |
+
"<extra_id_227>",
|
| 2665 |
+
"<extra_id_228>",
|
| 2666 |
+
"<extra_id_229>",
|
| 2667 |
+
"<extra_id_230>",
|
| 2668 |
+
"<extra_id_231>",
|
| 2669 |
+
"<extra_id_232>",
|
| 2670 |
+
"<extra_id_233>",
|
| 2671 |
+
"<extra_id_234>",
|
| 2672 |
+
"<extra_id_235>",
|
| 2673 |
+
"<extra_id_236>",
|
| 2674 |
+
"<extra_id_237>",
|
| 2675 |
+
"<extra_id_238>",
|
| 2676 |
+
"<extra_id_239>",
|
| 2677 |
+
"<extra_id_240>",
|
| 2678 |
+
"<extra_id_241>",
|
| 2679 |
+
"<extra_id_242>",
|
| 2680 |
+
"<extra_id_243>",
|
| 2681 |
+
"<extra_id_244>",
|
| 2682 |
+
"<extra_id_245>",
|
| 2683 |
+
"<extra_id_246>",
|
| 2684 |
+
"<extra_id_247>",
|
| 2685 |
+
"<extra_id_248>",
|
| 2686 |
+
"<extra_id_249>",
|
| 2687 |
+
"<extra_id_250>",
|
| 2688 |
+
"<extra_id_251>",
|
| 2689 |
+
"<extra_id_252>",
|
| 2690 |
+
"<extra_id_253>",
|
| 2691 |
+
"<extra_id_254>",
|
| 2692 |
+
"<extra_id_255>",
|
| 2693 |
+
"<extra_id_256>",
|
| 2694 |
+
"<extra_id_257>",
|
| 2695 |
+
"<extra_id_258>",
|
| 2696 |
+
"<extra_id_259>",
|
| 2697 |
+
"<extra_id_260>",
|
| 2698 |
+
"<extra_id_261>",
|
| 2699 |
+
"<extra_id_262>",
|
| 2700 |
+
"<extra_id_263>",
|
| 2701 |
+
"<extra_id_264>",
|
| 2702 |
+
"<extra_id_265>",
|
| 2703 |
+
"<extra_id_266>",
|
| 2704 |
+
"<extra_id_267>",
|
| 2705 |
+
"<extra_id_268>",
|
| 2706 |
+
"<extra_id_269>",
|
| 2707 |
+
"<extra_id_270>",
|
| 2708 |
+
"<extra_id_271>",
|
| 2709 |
+
"<extra_id_272>",
|
| 2710 |
+
"<extra_id_273>",
|
| 2711 |
+
"<extra_id_274>",
|
| 2712 |
+
"<extra_id_275>",
|
| 2713 |
+
"<extra_id_276>",
|
| 2714 |
+
"<extra_id_277>",
|
| 2715 |
+
"<extra_id_278>",
|
| 2716 |
+
"<extra_id_279>",
|
| 2717 |
+
"<extra_id_280>",
|
| 2718 |
+
"<extra_id_281>",
|
| 2719 |
+
"<extra_id_282>",
|
| 2720 |
+
"<extra_id_283>",
|
| 2721 |
+
"<extra_id_284>",
|
| 2722 |
+
"<extra_id_285>",
|
| 2723 |
+
"<extra_id_286>",
|
| 2724 |
+
"<extra_id_287>",
|
| 2725 |
+
"<extra_id_288>",
|
| 2726 |
+
"<extra_id_289>",
|
| 2727 |
+
"<extra_id_290>",
|
| 2728 |
+
"<extra_id_291>",
|
| 2729 |
+
"<extra_id_292>",
|
| 2730 |
+
"<extra_id_293>",
|
| 2731 |
+
"<extra_id_294>",
|
| 2732 |
+
"<extra_id_295>",
|
| 2733 |
+
"<extra_id_296>",
|
| 2734 |
+
"<extra_id_297>",
|
| 2735 |
+
"<extra_id_298>",
|
| 2736 |
+
"<extra_id_299>"
|
| 2737 |
+
],
|
| 2738 |
+
"bos_token": "<s>",
|
| 2739 |
+
"clean_up_tokenization_spaces": true,
|
| 2740 |
+
"eos_token": "</s>",
|
| 2741 |
+
"extra_ids": 300,
|
| 2742 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 2743 |
+
"pad_token": "<pad>",
|
| 2744 |
+
"sp_model_kwargs": {},
|
| 2745 |
+
"spaces_between_special_tokens": false,
|
| 2746 |
+
"tokenizer_class": "T5Tokenizer",
|
| 2747 |
+
"unk_token": "<unk>"
|
| 2748 |
+
}
|
ldf_deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cace0da2b446bbbbc57d031ab6cf163a3d59b366da94e5afe36745b746fd81d
|
| 3 |
+
size 11361920418
|
ldf_models/__init__.py
ADDED
|
File without changes
|
ldf_models/diffusion_forcing_wan.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .tools.t5 import T5EncoderModel
|
| 6 |
+
from .tools.wan_model import WanModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DiffForcingWanModel(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
checkpoint_path="deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth",
|
| 13 |
+
tokenizer_path="deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl",
|
| 14 |
+
input_dim=256,
|
| 15 |
+
hidden_dim=1024,
|
| 16 |
+
ffn_dim=2048,
|
| 17 |
+
freq_dim=256,
|
| 18 |
+
num_heads=8,
|
| 19 |
+
num_layers=8,
|
| 20 |
+
time_embedding_scale=1.0,
|
| 21 |
+
chunk_size=5,
|
| 22 |
+
noise_steps=10,
|
| 23 |
+
use_text_cond=True,
|
| 24 |
+
text_len=128,
|
| 25 |
+
drop_out=0.1,
|
| 26 |
+
cfg_scale=5.0,
|
| 27 |
+
prediction_type="vel", # "vel", "x0", "noise"
|
| 28 |
+
causal=False,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.input_dim = input_dim
|
| 33 |
+
self.hidden_dim = hidden_dim
|
| 34 |
+
self.ffn_dim = ffn_dim
|
| 35 |
+
self.freq_dim = freq_dim
|
| 36 |
+
self.num_heads = num_heads
|
| 37 |
+
self.num_layers = num_layers
|
| 38 |
+
self.time_embedding_scale = time_embedding_scale
|
| 39 |
+
self.chunk_size = chunk_size
|
| 40 |
+
self.noise_steps = noise_steps
|
| 41 |
+
self.use_text_cond = use_text_cond
|
| 42 |
+
self.drop_out = drop_out
|
| 43 |
+
self.cfg_scale = cfg_scale
|
| 44 |
+
self.prediction_type = prediction_type
|
| 45 |
+
self.causal = causal
|
| 46 |
+
|
| 47 |
+
self.text_dim = 4096
|
| 48 |
+
self.text_len = text_len
|
| 49 |
+
self.text_encoder = T5EncoderModel(
|
| 50 |
+
text_len=self.text_len,
|
| 51 |
+
dtype=torch.bfloat16,
|
| 52 |
+
device=torch.device("cpu"),
|
| 53 |
+
checkpoint_path=checkpoint_path,
|
| 54 |
+
tokenizer_path=tokenizer_path,
|
| 55 |
+
shard_fn=None,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Text encoding cache
|
| 59 |
+
self.text_cache = {}
|
| 60 |
+
self.model = WanModel(
|
| 61 |
+
model_type="t2v",
|
| 62 |
+
patch_size=(1, 1, 1),
|
| 63 |
+
text_len=self.text_len,
|
| 64 |
+
in_dim=self.input_dim,
|
| 65 |
+
dim=self.hidden_dim,
|
| 66 |
+
ffn_dim=self.ffn_dim,
|
| 67 |
+
freq_dim=self.freq_dim,
|
| 68 |
+
text_dim=self.text_dim,
|
| 69 |
+
out_dim=self.input_dim,
|
| 70 |
+
num_heads=self.num_heads,
|
| 71 |
+
num_layers=self.num_layers,
|
| 72 |
+
window_size=(-1, -1),
|
| 73 |
+
qk_norm=True,
|
| 74 |
+
cross_attn_norm=True,
|
| 75 |
+
eps=1e-6,
|
| 76 |
+
causal=self.causal,
|
| 77 |
+
)
|
| 78 |
+
self.param_dtype = torch.float32
|
| 79 |
+
|
| 80 |
+
def encode_text_with_cache(self, text_list, device):
|
| 81 |
+
"""Encode text using cache
|
| 82 |
+
Args:
|
| 83 |
+
text_list: List[str], list of texts
|
| 84 |
+
device: torch.device
|
| 85 |
+
Returns:
|
| 86 |
+
List[Tensor]: List of encoded text features
|
| 87 |
+
"""
|
| 88 |
+
text_features = []
|
| 89 |
+
indices_to_encode = []
|
| 90 |
+
texts_to_encode = []
|
| 91 |
+
|
| 92 |
+
# Check cache
|
| 93 |
+
for i, text in enumerate(text_list):
|
| 94 |
+
if text in self.text_cache:
|
| 95 |
+
# Get from cache and move to correct device
|
| 96 |
+
cached_feature = self.text_cache[text].to(device)
|
| 97 |
+
text_features.append(cached_feature)
|
| 98 |
+
else:
|
| 99 |
+
# Need to encode
|
| 100 |
+
text_features.append(None)
|
| 101 |
+
indices_to_encode.append(i)
|
| 102 |
+
texts_to_encode.append(text)
|
| 103 |
+
|
| 104 |
+
# Batch encode uncached texts
|
| 105 |
+
if texts_to_encode:
|
| 106 |
+
self.text_encoder.model.to(device)
|
| 107 |
+
encoded = self.text_encoder(texts_to_encode, device)
|
| 108 |
+
|
| 109 |
+
# Store in cache and update results
|
| 110 |
+
for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded):
|
| 111 |
+
# Cache to CPU to save GPU memory
|
| 112 |
+
self.text_cache[text] = feature.cpu()
|
| 113 |
+
text_features[idx] = feature
|
| 114 |
+
|
| 115 |
+
return text_features
|
| 116 |
+
|
| 117 |
+
def preprocess(self, x):
|
| 118 |
+
# (bs, T, C) -> (bs, C, T, 1, 1)
|
| 119 |
+
x = x.permute(0, 2, 1)[:, :, :, None, None]
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
def postprocess(self, x):
|
| 123 |
+
# (bs, C, T, 1, 1) -> (bs, T, C)
|
| 124 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
def _get_noise_levels(self, device, seq_len, time_steps):
|
| 128 |
+
"""Get noise levels"""
|
| 129 |
+
# noise_level[i] = clip(1 + i / chunk_size - time_steps, 0, 1)
|
| 130 |
+
noise_level = torch.clamp(
|
| 131 |
+
1
|
| 132 |
+
+ torch.arange(seq_len, device=device) / self.chunk_size
|
| 133 |
+
- time_steps.unsqueeze(1),
|
| 134 |
+
min=0.0,
|
| 135 |
+
max=1.0,
|
| 136 |
+
)
|
| 137 |
+
return noise_level
|
| 138 |
+
|
| 139 |
+
def add_noise(self, x, noise_level):
|
| 140 |
+
"""Add noise
|
| 141 |
+
Args:
|
| 142 |
+
x: (B, T, D)
|
| 143 |
+
noise_level: (B, T)
|
| 144 |
+
"""
|
| 145 |
+
noise = torch.randn_like(x)
|
| 146 |
+
# noise_level: (B, T) -> (B, T, 1)
|
| 147 |
+
noise_level = noise_level.unsqueeze(-1)
|
| 148 |
+
noisy_x = x * (1 - noise_level) + noise_level * noise
|
| 149 |
+
return noisy_x, noise
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
feature = x["feature"] # (B, T, C)
|
| 153 |
+
feature_length = x["feature_length"] # (B,)
|
| 154 |
+
batch_size, seq_len, _ = feature.shape
|
| 155 |
+
device = feature.device
|
| 156 |
+
|
| 157 |
+
# Randomly use a time step
|
| 158 |
+
time_steps = []
|
| 159 |
+
for i in range(batch_size):
|
| 160 |
+
valid_len = feature_length[i].item()
|
| 161 |
+
# Random float from 0 to valid_len/chunk_size, not an integer
|
| 162 |
+
max_time = valid_len / self.chunk_size
|
| 163 |
+
# max_time = valid_len / self.chunk_size + 1
|
| 164 |
+
time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item())
|
| 165 |
+
time_steps = torch.tensor(time_steps, device=device) # (B,)
|
| 166 |
+
noise_level = self._get_noise_levels(device, seq_len, time_steps) # (B, T)
|
| 167 |
+
|
| 168 |
+
# # Debug: Print noise levels
|
| 169 |
+
# print("Time steps and corresponding noise levels:")
|
| 170 |
+
# for i in range(batch_size):
|
| 171 |
+
# t = time_steps[i].item()
|
| 172 |
+
# # Get noise level at each position
|
| 173 |
+
# start_idx = int(self.chunk_size * (t - 1))
|
| 174 |
+
# end_idx = int(self.chunk_size * t) + 2
|
| 175 |
+
# # Limit to valid range
|
| 176 |
+
# start_idx = max(0, start_idx)
|
| 177 |
+
# end_idx = min(seq_len, end_idx)
|
| 178 |
+
# print(time_steps[i])
|
| 179 |
+
# print(noise_level[i, start_idx:end_idx])
|
| 180 |
+
|
| 181 |
+
# Add noise to entire sequence
|
| 182 |
+
noisy_feature, noise = self.add_noise(feature, noise_level) # (B, T, D)
|
| 183 |
+
|
| 184 |
+
# Debug: Print noise addition information
|
| 185 |
+
# print("Added noise levels at chunk positions:")
|
| 186 |
+
# for i in range(batch_size):
|
| 187 |
+
# t = time_steps[i].item()
|
| 188 |
+
# start_idx = int(self.chunk_size * (t - 1))
|
| 189 |
+
# end_idx = int(self.chunk_size * t) + 2
|
| 190 |
+
# # Limit to valid range
|
| 191 |
+
# start_idx = max(0, start_idx)
|
| 192 |
+
# end_idx = min(seq_len, end_idx)
|
| 193 |
+
# test1 = (
|
| 194 |
+
# feature[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
|
| 195 |
+
# )
|
| 196 |
+
# test2 = (
|
| 197 |
+
# noise[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
|
| 198 |
+
# )
|
| 199 |
+
# # Compute length on last dimension
|
| 200 |
+
# print(test1.norm(dim=-1))
|
| 201 |
+
# print(test2.norm(dim=-1))
|
| 202 |
+
|
| 203 |
+
feature = self.preprocess(feature) # (B, C, T, 1, 1)
|
| 204 |
+
noisy_feature = self.preprocess(noisy_feature) # (B, C, T, 1, 1)
|
| 205 |
+
noise = self.preprocess(noise) # (B, C, T, 1, 1)
|
| 206 |
+
|
| 207 |
+
feature_ref = []
|
| 208 |
+
noise_ref = []
|
| 209 |
+
noisy_feature_input = []
|
| 210 |
+
for i in range(batch_size):
|
| 211 |
+
t = time_steps[i].item()
|
| 212 |
+
end_index = int(self.chunk_size * t) + 1
|
| 213 |
+
valid_len = feature_length[i].item()
|
| 214 |
+
end_index = min(valid_len, end_index)
|
| 215 |
+
feature_ref.append(feature[i, :, :end_index, ...])
|
| 216 |
+
noise_ref.append(noise[i, :, :end_index, ...])
|
| 217 |
+
noisy_feature_input.append(noisy_feature[i, :, :end_index, ...])
|
| 218 |
+
|
| 219 |
+
# Encode text condition (using cache)
|
| 220 |
+
if self.use_text_cond and "text" in x:
|
| 221 |
+
text_list = x["text"] # List[str] or List[List[str]]
|
| 222 |
+
if isinstance(text_list[0], list):
|
| 223 |
+
text_end_list = x["feature_text_end"]
|
| 224 |
+
all_text_context = []
|
| 225 |
+
for single_text_list, single_text_end_list in zip(
|
| 226 |
+
text_list, text_end_list
|
| 227 |
+
):
|
| 228 |
+
if np.random.rand() > self.drop_out:
|
| 229 |
+
single_text_list = [""]
|
| 230 |
+
single_text_end_list = [0, seq_len]
|
| 231 |
+
else:
|
| 232 |
+
single_text_end_list = [0] + [
|
| 233 |
+
min(t, seq_len) for t in single_text_end_list
|
| 234 |
+
]
|
| 235 |
+
single_text_length_list = [
|
| 236 |
+
t - b
|
| 237 |
+
for t, b in zip(
|
| 238 |
+
single_text_end_list[1:], single_text_end_list[:-1]
|
| 239 |
+
)
|
| 240 |
+
]
|
| 241 |
+
single_text_context = self.encode_text_with_cache(
|
| 242 |
+
single_text_list, device
|
| 243 |
+
)
|
| 244 |
+
single_text_context = [
|
| 245 |
+
u.to(self.param_dtype) for u in single_text_context
|
| 246 |
+
]
|
| 247 |
+
for u, duration in zip(
|
| 248 |
+
single_text_context, single_text_length_list
|
| 249 |
+
):
|
| 250 |
+
all_text_context.extend([u for _ in range(duration)])
|
| 251 |
+
all_text_context.extend(
|
| 252 |
+
[
|
| 253 |
+
single_text_context[-1]
|
| 254 |
+
for _ in range(seq_len - single_text_end_list[-1])
|
| 255 |
+
]
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
all_text_context = [
|
| 259 |
+
(u if np.random.rand() > self.drop_out else "") for u in text_list
|
| 260 |
+
]
|
| 261 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 262 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 263 |
+
else:
|
| 264 |
+
all_text_context = [""] * batch_size
|
| 265 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 266 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 267 |
+
|
| 268 |
+
# Through WanModel
|
| 269 |
+
predicted_result = self.model(
|
| 270 |
+
noisy_feature_input,
|
| 271 |
+
noise_level * self.time_embedding_scale,
|
| 272 |
+
all_text_context,
|
| 273 |
+
seq_len,
|
| 274 |
+
y=None,
|
| 275 |
+
) # (B, C, T, 1, 1)
|
| 276 |
+
|
| 277 |
+
loss = 0.0
|
| 278 |
+
for b in range(batch_size):
|
| 279 |
+
if self.prediction_type == "vel":
|
| 280 |
+
vel = feature_ref[b] - noise_ref[b] # (C, input_length, 1, 1)
|
| 281 |
+
squared_error = (
|
| 282 |
+
predicted_result[b][:, -self.chunk_size :, ...]
|
| 283 |
+
- vel[:, -self.chunk_size :, ...]
|
| 284 |
+
) ** 2
|
| 285 |
+
elif self.prediction_type == "x0":
|
| 286 |
+
squared_error = (
|
| 287 |
+
predicted_result[b][:, -self.chunk_size :, ...]
|
| 288 |
+
- feature_ref[b][:, -self.chunk_size :, ...]
|
| 289 |
+
) ** 2
|
| 290 |
+
elif self.prediction_type == "noise":
|
| 291 |
+
squared_error = (
|
| 292 |
+
predicted_result[b][:, -self.chunk_size :, ...]
|
| 293 |
+
- noise_ref[b][:, -self.chunk_size :, ...]
|
| 294 |
+
) ** 2
|
| 295 |
+
sample_loss = squared_error.sum().mean()
|
| 296 |
+
loss += sample_loss
|
| 297 |
+
loss = loss / batch_size
|
| 298 |
+
|
| 299 |
+
loss_dict = {"total": loss, "mse": loss}
|
| 300 |
+
return loss_dict
|
| 301 |
+
|
| 302 |
+
def generate(self, x, num_denoise_steps=None):
|
| 303 |
+
"""
|
| 304 |
+
Generation - Diffusion Forcing inference
|
| 305 |
+
Uses triangular noise schedule, progressively generating from left to right
|
| 306 |
+
|
| 307 |
+
Generation process:
|
| 308 |
+
1. Start from t=0, gradually increase t
|
| 309 |
+
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
|
| 310 |
+
3. After each denoising step, t increases slightly and continues
|
| 311 |
+
"""
|
| 312 |
+
feature_length = x["feature_length"]
|
| 313 |
+
batch_size = len(feature_length)
|
| 314 |
+
seq_len = max(feature_length).item()
|
| 315 |
+
|
| 316 |
+
# # debug
|
| 317 |
+
# x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
|
| 318 |
+
# x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
|
| 319 |
+
# text = x["text"]
|
| 320 |
+
# text_end = x["feature_text_end"]
|
| 321 |
+
# print(text)
|
| 322 |
+
# print(text_end)
|
| 323 |
+
# print(batch_size, seq_len, self.chunk_size)
|
| 324 |
+
|
| 325 |
+
if num_denoise_steps is None:
|
| 326 |
+
num_denoise_steps = self.noise_steps
|
| 327 |
+
assert num_denoise_steps % self.chunk_size == 0
|
| 328 |
+
|
| 329 |
+
device = next(self.parameters()).device
|
| 330 |
+
|
| 331 |
+
# Initialize entire sequence as pure noise
|
| 332 |
+
generated = torch.randn(
|
| 333 |
+
batch_size, seq_len + self.chunk_size, self.input_dim, device=device
|
| 334 |
+
)
|
| 335 |
+
generated = self.preprocess(generated) # (B, C, T, 1, 1)
|
| 336 |
+
|
| 337 |
+
# Calculate total number of time steps needed
|
| 338 |
+
max_t = 1 + (seq_len - 1) / self.chunk_size
|
| 339 |
+
|
| 340 |
+
# Step size for each advancement
|
| 341 |
+
dt = 1 / num_denoise_steps
|
| 342 |
+
total_steps = int(max_t / dt)
|
| 343 |
+
|
| 344 |
+
# Encode text condition (using cache)
|
| 345 |
+
if self.use_text_cond and "text" in x:
|
| 346 |
+
text_list = x["text"] # List[str] or List[List[str]]
|
| 347 |
+
if isinstance(text_list[0], list):
|
| 348 |
+
generated_length = []
|
| 349 |
+
text_end_list = x["feature_text_end"]
|
| 350 |
+
full_text = []
|
| 351 |
+
all_text_context = []
|
| 352 |
+
for single_text_list, single_text_end_list in zip(
|
| 353 |
+
text_list, text_end_list
|
| 354 |
+
):
|
| 355 |
+
single_text_end_list = [0] + [
|
| 356 |
+
min(t, seq_len) for t in single_text_end_list
|
| 357 |
+
]
|
| 358 |
+
generated_length.append(single_text_end_list[-1])
|
| 359 |
+
single_text_length_list = [
|
| 360 |
+
t - b
|
| 361 |
+
for t, b in zip(
|
| 362 |
+
single_text_end_list[1:], single_text_end_list[:-1]
|
| 363 |
+
)
|
| 364 |
+
]
|
| 365 |
+
full_text.append(
|
| 366 |
+
" ////////// ".join(
|
| 367 |
+
[
|
| 368 |
+
f"{u} //dur:{t}"
|
| 369 |
+
for u, t in zip(
|
| 370 |
+
single_text_list, single_text_length_list
|
| 371 |
+
)
|
| 372 |
+
]
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
single_text_context = self.encode_text_with_cache(
|
| 376 |
+
single_text_list, device
|
| 377 |
+
)
|
| 378 |
+
single_text_context = [
|
| 379 |
+
u.to(self.param_dtype) for u in single_text_context
|
| 380 |
+
]
|
| 381 |
+
for u, duration in zip(
|
| 382 |
+
single_text_context, single_text_length_list
|
| 383 |
+
):
|
| 384 |
+
all_text_context.extend([u for _ in range(duration)])
|
| 385 |
+
all_text_context.extend(
|
| 386 |
+
[
|
| 387 |
+
single_text_context[-1]
|
| 388 |
+
for _ in range(
|
| 389 |
+
seq_len + self.chunk_size - single_text_end_list[-1]
|
| 390 |
+
)
|
| 391 |
+
]
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
generated_length = feature_length
|
| 395 |
+
full_text = text_list
|
| 396 |
+
all_text_context = self.encode_text_with_cache(text_list, device)
|
| 397 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 398 |
+
else:
|
| 399 |
+
generated_length = feature_length
|
| 400 |
+
full_text = [""] * batch_size
|
| 401 |
+
all_text_context = [""] * batch_size
|
| 402 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 403 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 404 |
+
|
| 405 |
+
# Get empty text condition encoding (for CFG)
|
| 406 |
+
text_null_list = [""] * batch_size
|
| 407 |
+
text_null_context = self.encode_text_with_cache(text_null_list, device)
|
| 408 |
+
text_null_context = [u.to(self.param_dtype) for u in text_null_context]
|
| 409 |
+
|
| 410 |
+
# print(len(all_text_context), len(text_null_context))
|
| 411 |
+
|
| 412 |
+
# Progressively advance from t=0 to t=max_t
|
| 413 |
+
for step in range(total_steps):
|
| 414 |
+
# Current time step
|
| 415 |
+
t = step * dt
|
| 416 |
+
start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
|
| 417 |
+
end_index = int(self.chunk_size * t) + 1
|
| 418 |
+
time_steps = torch.full((batch_size,), t, device=device)
|
| 419 |
+
|
| 420 |
+
# Calculate current noise schedule
|
| 421 |
+
noise_level = self._get_noise_levels(
|
| 422 |
+
device, seq_len + self.chunk_size, time_steps
|
| 423 |
+
) # (B, T)
|
| 424 |
+
|
| 425 |
+
# Predict noise through WanModel
|
| 426 |
+
noisy_input = []
|
| 427 |
+
for i in range(batch_size):
|
| 428 |
+
noisy_input.append(generated[i, :, :end_index, ...])
|
| 429 |
+
|
| 430 |
+
predicted_result = self.model(
|
| 431 |
+
noisy_input,
|
| 432 |
+
noise_level * self.time_embedding_scale,
|
| 433 |
+
all_text_context,
|
| 434 |
+
seq_len + self.chunk_size,
|
| 435 |
+
y=None,
|
| 436 |
+
) # (B, C, T, 1, 1)
|
| 437 |
+
|
| 438 |
+
# Adjust using CFG
|
| 439 |
+
if self.cfg_scale != 1.0:
|
| 440 |
+
predicted_result_null = self.model(
|
| 441 |
+
noisy_input,
|
| 442 |
+
noise_level * self.time_embedding_scale,
|
| 443 |
+
text_null_context,
|
| 444 |
+
seq_len + self.chunk_size,
|
| 445 |
+
y=None,
|
| 446 |
+
) # (B, C, T, 1, 1)
|
| 447 |
+
predicted_result = [
|
| 448 |
+
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
|
| 449 |
+
for pv, pvn in zip(predicted_result, predicted_result_null)
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
for i in range(batch_size):
|
| 453 |
+
predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
|
| 454 |
+
if self.prediction_type == "vel":
|
| 455 |
+
predicted_vel = predicted_result_i[:, start_index:end_index, ...]
|
| 456 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 457 |
+
elif self.prediction_type == "x0":
|
| 458 |
+
predicted_vel = (
|
| 459 |
+
predicted_result_i[:, start_index:end_index, ...]
|
| 460 |
+
- generated[i, :, start_index:end_index, ...]
|
| 461 |
+
) / (
|
| 462 |
+
noise_level[i, start_index:end_index]
|
| 463 |
+
.unsqueeze(0)
|
| 464 |
+
.unsqueeze(-1)
|
| 465 |
+
.unsqueeze(-1)
|
| 466 |
+
)
|
| 467 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 468 |
+
elif self.prediction_type == "noise":
|
| 469 |
+
predicted_vel = (
|
| 470 |
+
generated[i, :, start_index:end_index, ...]
|
| 471 |
+
- predicted_result_i[:, start_index:end_index, ...]
|
| 472 |
+
) / (
|
| 473 |
+
1
|
| 474 |
+
+ dt
|
| 475 |
+
- noise_level[i, start_index:end_index]
|
| 476 |
+
.unsqueeze(0)
|
| 477 |
+
.unsqueeze(-1)
|
| 478 |
+
.unsqueeze(-1)
|
| 479 |
+
)
|
| 480 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 481 |
+
|
| 482 |
+
generated = self.postprocess(generated) # (B, T, C)
|
| 483 |
+
y_hat_out = []
|
| 484 |
+
for i in range(batch_size):
|
| 485 |
+
# cut off the padding
|
| 486 |
+
single_generated = generated[i, : generated_length[i], :]
|
| 487 |
+
y_hat_out.append(single_generated)
|
| 488 |
+
out = {}
|
| 489 |
+
out["generated"] = y_hat_out
|
| 490 |
+
out["text"] = full_text
|
| 491 |
+
|
| 492 |
+
return out
|
| 493 |
+
|
| 494 |
+
@torch.no_grad()
|
| 495 |
+
def stream_generate(self, x, num_denoise_steps=None):
|
| 496 |
+
"""
|
| 497 |
+
Streaming generation - Diffusion Forcing inference
|
| 498 |
+
Uses triangular noise schedule, progressively generating from left to right
|
| 499 |
+
|
| 500 |
+
Generation process:
|
| 501 |
+
1. Start from t=0, gradually increase t
|
| 502 |
+
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
|
| 503 |
+
3. After each denoising step, t increases slightly and continues
|
| 504 |
+
"""
|
| 505 |
+
feature_length = x["feature_length"]
|
| 506 |
+
batch_size = len(feature_length)
|
| 507 |
+
seq_len = max(feature_length).item()
|
| 508 |
+
|
| 509 |
+
# # debug
|
| 510 |
+
# x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
|
| 511 |
+
# x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
|
| 512 |
+
# text = x["text"]
|
| 513 |
+
# text_end = x["feature_text_end"]
|
| 514 |
+
# print(text)
|
| 515 |
+
# print(text_end)
|
| 516 |
+
# print(batch_size, seq_len, self.chunk_size)
|
| 517 |
+
|
| 518 |
+
if num_denoise_steps is None:
|
| 519 |
+
num_denoise_steps = self.noise_steps
|
| 520 |
+
assert num_denoise_steps % self.chunk_size == 0
|
| 521 |
+
|
| 522 |
+
device = next(self.parameters()).device
|
| 523 |
+
|
| 524 |
+
# Initialize entire sequence as pure noise
|
| 525 |
+
generated = torch.randn(
|
| 526 |
+
batch_size, seq_len + self.chunk_size, self.input_dim, device=device
|
| 527 |
+
)
|
| 528 |
+
generated = self.preprocess(generated) # (B, C, T, 1, 1)
|
| 529 |
+
|
| 530 |
+
# Calculate total number of time steps needed
|
| 531 |
+
max_t = 1 + (seq_len - 1) / self.chunk_size
|
| 532 |
+
|
| 533 |
+
# Step size for each advancement
|
| 534 |
+
dt = 1 / num_denoise_steps
|
| 535 |
+
total_steps = int(max_t / dt)
|
| 536 |
+
|
| 537 |
+
# Encode text condition (using cache)
|
| 538 |
+
if self.use_text_cond and "text" in x:
|
| 539 |
+
text_list = x["text"] # List[str] or List[List[str]]
|
| 540 |
+
if isinstance(text_list[0], list):
|
| 541 |
+
generated_length = []
|
| 542 |
+
text_end_list = x["feature_text_end"]
|
| 543 |
+
full_text = []
|
| 544 |
+
all_text_context = []
|
| 545 |
+
for single_text_list, single_text_end_list in zip(
|
| 546 |
+
text_list, text_end_list
|
| 547 |
+
):
|
| 548 |
+
single_text_end_list = [0] + [
|
| 549 |
+
min(t, seq_len) for t in single_text_end_list
|
| 550 |
+
]
|
| 551 |
+
generated_length.append(single_text_end_list[-1])
|
| 552 |
+
single_text_length_list = [
|
| 553 |
+
t - b
|
| 554 |
+
for t, b in zip(
|
| 555 |
+
single_text_end_list[1:], single_text_end_list[:-1]
|
| 556 |
+
)
|
| 557 |
+
]
|
| 558 |
+
full_text.append(
|
| 559 |
+
" ////////// ".join(
|
| 560 |
+
[
|
| 561 |
+
f"{u} //dur:{t}"
|
| 562 |
+
for u, t in zip(
|
| 563 |
+
single_text_list, single_text_length_list
|
| 564 |
+
)
|
| 565 |
+
]
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
single_text_context = self.encode_text_with_cache(
|
| 569 |
+
single_text_list, device
|
| 570 |
+
)
|
| 571 |
+
single_text_context = [
|
| 572 |
+
u.to(self.param_dtype) for u in single_text_context
|
| 573 |
+
]
|
| 574 |
+
for u, duration in zip(
|
| 575 |
+
single_text_context, single_text_length_list
|
| 576 |
+
):
|
| 577 |
+
all_text_context.extend([u for _ in range(duration)])
|
| 578 |
+
all_text_context.extend(
|
| 579 |
+
[
|
| 580 |
+
single_text_context[-1]
|
| 581 |
+
for _ in range(
|
| 582 |
+
seq_len + self.chunk_size - single_text_end_list[-1]
|
| 583 |
+
)
|
| 584 |
+
]
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
generated_length = feature_length
|
| 588 |
+
full_text = text_list
|
| 589 |
+
all_text_context = self.encode_text_with_cache(text_list, device)
|
| 590 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 591 |
+
else:
|
| 592 |
+
generated_length = feature_length
|
| 593 |
+
full_text = [""] * batch_size
|
| 594 |
+
all_text_context = [""] * batch_size
|
| 595 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 596 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 597 |
+
|
| 598 |
+
# Get empty text condition encoding (for CFG)
|
| 599 |
+
text_null_list = [""] * batch_size
|
| 600 |
+
text_null_context = self.encode_text_with_cache(text_null_list, device)
|
| 601 |
+
text_null_context = [u.to(self.param_dtype) for u in text_null_context]
|
| 602 |
+
|
| 603 |
+
# print(len(all_text_context), len(text_null_context))
|
| 604 |
+
|
| 605 |
+
commit_index = 0
|
| 606 |
+
# Progressively advance from t=0 to t=max_t
|
| 607 |
+
for step in range(total_steps):
|
| 608 |
+
# Current time step
|
| 609 |
+
t = step * dt
|
| 610 |
+
start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
|
| 611 |
+
end_index = int(self.chunk_size * t) + 1
|
| 612 |
+
time_steps = torch.full((batch_size,), t, device=device)
|
| 613 |
+
|
| 614 |
+
# Calculate current noise schedule
|
| 615 |
+
noise_level = self._get_noise_levels(
|
| 616 |
+
device, seq_len + self.chunk_size, time_steps
|
| 617 |
+
) # (B, T)
|
| 618 |
+
|
| 619 |
+
# Predict noise through WanModel
|
| 620 |
+
noisy_input = []
|
| 621 |
+
for i in range(batch_size):
|
| 622 |
+
noisy_input.append(generated[i, :, :end_index, ...])
|
| 623 |
+
|
| 624 |
+
predicted_result = self.model(
|
| 625 |
+
noisy_input,
|
| 626 |
+
noise_level * self.time_embedding_scale,
|
| 627 |
+
all_text_context,
|
| 628 |
+
seq_len + self.chunk_size,
|
| 629 |
+
y=None,
|
| 630 |
+
) # (B, C, T, 1, 1)
|
| 631 |
+
|
| 632 |
+
# Adjust using CFG
|
| 633 |
+
if self.cfg_scale != 1.0:
|
| 634 |
+
predicted_result_null = self.model(
|
| 635 |
+
noisy_input,
|
| 636 |
+
noise_level * self.time_embedding_scale,
|
| 637 |
+
text_null_context,
|
| 638 |
+
seq_len + self.chunk_size,
|
| 639 |
+
y=None,
|
| 640 |
+
) # (B, C, T, 1, 1)
|
| 641 |
+
predicted_result = [
|
| 642 |
+
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
|
| 643 |
+
for pv, pvn in zip(predicted_result, predicted_result_null)
|
| 644 |
+
]
|
| 645 |
+
|
| 646 |
+
for i in range(batch_size):
|
| 647 |
+
predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
|
| 648 |
+
if self.prediction_type == "vel":
|
| 649 |
+
predicted_vel = predicted_result_i[:, start_index:end_index, ...]
|
| 650 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 651 |
+
elif self.prediction_type == "x0":
|
| 652 |
+
predicted_vel = (
|
| 653 |
+
predicted_result_i[:, start_index:end_index, ...]
|
| 654 |
+
- generated[i, :, start_index:end_index, ...]
|
| 655 |
+
) / (
|
| 656 |
+
noise_level[i, start_index:end_index]
|
| 657 |
+
.unsqueeze(0)
|
| 658 |
+
.unsqueeze(-1)
|
| 659 |
+
.unsqueeze(-1)
|
| 660 |
+
)
|
| 661 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 662 |
+
elif self.prediction_type == "noise":
|
| 663 |
+
predicted_vel = (
|
| 664 |
+
generated[i, :, start_index:end_index, ...]
|
| 665 |
+
- predicted_result_i[:, start_index:end_index, ...]
|
| 666 |
+
) / (
|
| 667 |
+
1
|
| 668 |
+
+ dt
|
| 669 |
+
- noise_level[i, start_index:end_index]
|
| 670 |
+
.unsqueeze(0)
|
| 671 |
+
.unsqueeze(-1)
|
| 672 |
+
.unsqueeze(-1)
|
| 673 |
+
)
|
| 674 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 675 |
+
|
| 676 |
+
if commit_index < start_index:
|
| 677 |
+
output = generated[:, :, commit_index:start_index, ...]
|
| 678 |
+
output = self.postprocess(output) # (B, T, C)
|
| 679 |
+
y_hat_out = []
|
| 680 |
+
for i in range(batch_size):
|
| 681 |
+
if commit_index < generated_length[i]:
|
| 682 |
+
y_hat_out.append(
|
| 683 |
+
output[i, : generated_length[i] - commit_index, ...]
|
| 684 |
+
)
|
| 685 |
+
else:
|
| 686 |
+
y_hat_out.append(None)
|
| 687 |
+
|
| 688 |
+
out = {}
|
| 689 |
+
out["generated"] = y_hat_out
|
| 690 |
+
yield out
|
| 691 |
+
commit_index = start_index
|
| 692 |
+
|
| 693 |
+
output = generated[:, :, commit_index:, ...]
|
| 694 |
+
output = self.postprocess(output) # (B, T_remain, C)
|
| 695 |
+
y_hat_out = []
|
| 696 |
+
for i in range(batch_size):
|
| 697 |
+
if commit_index < generated_length[i]:
|
| 698 |
+
y_hat_out.append(output[i, : generated_length[i] - commit_index, ...])
|
| 699 |
+
else:
|
| 700 |
+
y_hat_out.append(None)
|
| 701 |
+
out = {}
|
| 702 |
+
out["generated"] = y_hat_out
|
| 703 |
+
yield out
|
| 704 |
+
|
| 705 |
+
def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None):
|
| 706 |
+
self.seq_len = seq_len
|
| 707 |
+
self.batch_size = batch_size
|
| 708 |
+
if num_denoise_steps is None:
|
| 709 |
+
self.num_denoise_steps = self.noise_steps
|
| 710 |
+
else:
|
| 711 |
+
self.num_denoise_steps = num_denoise_steps
|
| 712 |
+
assert self.num_denoise_steps % self.chunk_size == 0
|
| 713 |
+
self.dt = 1 / self.num_denoise_steps
|
| 714 |
+
self.current_step = 0
|
| 715 |
+
self.text_condition_list = [[] for _ in range(self.batch_size)]
|
| 716 |
+
self.generated = torch.randn(
|
| 717 |
+
self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim
|
| 718 |
+
)
|
| 719 |
+
self.generated = self.preprocess(self.generated) # (B, C, T, 1, 1)
|
| 720 |
+
self.commit_index = 0
|
| 721 |
+
|
| 722 |
+
@torch.no_grad()
|
| 723 |
+
def stream_generate_step(self, x, first_chunk=True):
|
| 724 |
+
"""
|
| 725 |
+
Streaming generation step - Diffusion Forcing inference
|
| 726 |
+
Uses triangular noise schedule, progressively generating from left to right
|
| 727 |
+
|
| 728 |
+
Generation process:
|
| 729 |
+
1. Start from t=0, gradually increase t
|
| 730 |
+
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
|
| 731 |
+
3. After each denoising step, t increases slightly and continues
|
| 732 |
+
"""
|
| 733 |
+
|
| 734 |
+
device = next(self.parameters()).device
|
| 735 |
+
if first_chunk:
|
| 736 |
+
self.generated = self.generated.to(device)
|
| 737 |
+
|
| 738 |
+
# Encode text condition (using cache)
|
| 739 |
+
if self.use_text_cond and "text" in x:
|
| 740 |
+
text_list = x["text"] # List[str]
|
| 741 |
+
new_text_context = self.encode_text_with_cache(text_list, device)
|
| 742 |
+
new_text_context = [u.to(self.param_dtype) for u in new_text_context]
|
| 743 |
+
else:
|
| 744 |
+
new_text_context = [""] * self.batch_size
|
| 745 |
+
new_text_context = self.encode_text_with_cache(new_text_context, device)
|
| 746 |
+
new_text_context = [u.to(self.param_dtype) for u in new_text_context]
|
| 747 |
+
|
| 748 |
+
# Get empty text condition encoding (for CFG)
|
| 749 |
+
text_null_list = [""] * self.batch_size
|
| 750 |
+
text_null_context = self.encode_text_with_cache(text_null_list, device)
|
| 751 |
+
text_null_context = [u.to(self.param_dtype) for u in text_null_context]
|
| 752 |
+
|
| 753 |
+
for i in range(self.batch_size):
|
| 754 |
+
if first_chunk:
|
| 755 |
+
self.text_condition_list[i].extend(
|
| 756 |
+
[new_text_context[i]] * self.chunk_size
|
| 757 |
+
)
|
| 758 |
+
else:
|
| 759 |
+
self.text_condition_list[i].extend([new_text_context[i]])
|
| 760 |
+
|
| 761 |
+
end_step = (
|
| 762 |
+
(self.commit_index + self.chunk_size)
|
| 763 |
+
* self.num_denoise_steps
|
| 764 |
+
/ self.chunk_size
|
| 765 |
+
)
|
| 766 |
+
while self.current_step < end_step:
|
| 767 |
+
current_time = self.current_step * self.dt
|
| 768 |
+
start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1)
|
| 769 |
+
end_index = int(self.chunk_size * current_time) + 1
|
| 770 |
+
time_steps = torch.full((self.batch_size,), current_time, device=device)
|
| 771 |
+
|
| 772 |
+
noise_level = self._get_noise_levels(device, end_index, time_steps)[
|
| 773 |
+
:, -self.seq_len :
|
| 774 |
+
] # (B, T)
|
| 775 |
+
|
| 776 |
+
# Predict noise through WanModel
|
| 777 |
+
noisy_input = []
|
| 778 |
+
for i in range(self.batch_size):
|
| 779 |
+
noisy_input.append(
|
| 780 |
+
self.generated[i, :, :end_index, ...][:, -self.seq_len :]
|
| 781 |
+
) # (C, T, 1, 1)
|
| 782 |
+
|
| 783 |
+
text_condition = []
|
| 784 |
+
for i in range(self.batch_size):
|
| 785 |
+
text_condition.extend(
|
| 786 |
+
self.text_condition_list[i][:end_index][-self.seq_len :]
|
| 787 |
+
) # (T, D, 4096)
|
| 788 |
+
|
| 789 |
+
# print("////////////////////")
|
| 790 |
+
# print("current step: ", self.current_step)
|
| 791 |
+
# print("chunk size: ", self.chunk_size)
|
| 792 |
+
# print("start_index: ", start_index)
|
| 793 |
+
# print("end_index: ", end_index)
|
| 794 |
+
# print("noisy_input shape: ", noisy_input[0].shape)
|
| 795 |
+
# print("noise_level: ", noise_level[0, start_index:end_index])
|
| 796 |
+
# print("text_condition shape: ", len(text_condition))
|
| 797 |
+
# print("commit_index: ", self.commit_index)
|
| 798 |
+
# print("////////////////////")
|
| 799 |
+
|
| 800 |
+
predicted_result = self.model(
|
| 801 |
+
noisy_input,
|
| 802 |
+
noise_level * self.time_embedding_scale,
|
| 803 |
+
text_condition,
|
| 804 |
+
min(end_index, self.seq_len),
|
| 805 |
+
y=None,
|
| 806 |
+
) # (B, C, T, 1, 1)
|
| 807 |
+
|
| 808 |
+
# Adjust using CFG
|
| 809 |
+
if self.cfg_scale != 1.0:
|
| 810 |
+
predicted_result_null = self.model(
|
| 811 |
+
noisy_input,
|
| 812 |
+
noise_level * self.time_embedding_scale,
|
| 813 |
+
text_null_context,
|
| 814 |
+
min(end_index, self.seq_len),
|
| 815 |
+
y=None,
|
| 816 |
+
) # (B, C, T, 1, 1)
|
| 817 |
+
predicted_result = [
|
| 818 |
+
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
|
| 819 |
+
for pv, pvn in zip(predicted_result, predicted_result_null)
|
| 820 |
+
]
|
| 821 |
+
|
| 822 |
+
for i in range(self.batch_size):
|
| 823 |
+
predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
|
| 824 |
+
if end_index > self.seq_len:
|
| 825 |
+
predicted_result_i = torch.cat(
|
| 826 |
+
[
|
| 827 |
+
torch.zeros(
|
| 828 |
+
predicted_result_i.shape[0],
|
| 829 |
+
end_index - self.seq_len,
|
| 830 |
+
predicted_result_i.shape[2],
|
| 831 |
+
predicted_result_i.shape[3],
|
| 832 |
+
device=device,
|
| 833 |
+
),
|
| 834 |
+
predicted_result_i,
|
| 835 |
+
],
|
| 836 |
+
dim=1,
|
| 837 |
+
)
|
| 838 |
+
if self.prediction_type == "vel":
|
| 839 |
+
predicted_vel = predicted_result_i[:, start_index:end_index, ...]
|
| 840 |
+
self.generated[i, :, start_index:end_index, ...] += (
|
| 841 |
+
predicted_vel * self.dt
|
| 842 |
+
)
|
| 843 |
+
elif self.prediction_type == "x0":
|
| 844 |
+
predicted_vel = (
|
| 845 |
+
predicted_result_i[:, start_index:end_index, ...]
|
| 846 |
+
- self.generated[i, :, start_index:end_index, ...]
|
| 847 |
+
) / (
|
| 848 |
+
noise_level[i, start_index:end_index]
|
| 849 |
+
.unsqueeze(0)
|
| 850 |
+
.unsqueeze(-1)
|
| 851 |
+
.unsqueeze(-1)
|
| 852 |
+
)
|
| 853 |
+
self.generated[i, :, start_index:end_index, ...] += (
|
| 854 |
+
predicted_vel * self.dt
|
| 855 |
+
)
|
| 856 |
+
elif self.prediction_type == "noise":
|
| 857 |
+
predicted_vel = (
|
| 858 |
+
self.generated[i, :, start_index:end_index, ...]
|
| 859 |
+
- predicted_result_i[:, start_index:end_index, ...]
|
| 860 |
+
) / (
|
| 861 |
+
1
|
| 862 |
+
+ self.dt
|
| 863 |
+
- noise_level[i, start_index:end_index]
|
| 864 |
+
.unsqueeze(0)
|
| 865 |
+
.unsqueeze(-1)
|
| 866 |
+
.unsqueeze(-1)
|
| 867 |
+
)
|
| 868 |
+
self.generated[i, :, start_index:end_index, ...] += (
|
| 869 |
+
predicted_vel * self.dt
|
| 870 |
+
)
|
| 871 |
+
self.current_step += 1
|
| 872 |
+
output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...]
|
| 873 |
+
output = self.postprocess(output) # (B, 1, C)
|
| 874 |
+
out = {}
|
| 875 |
+
out["generated"] = output
|
| 876 |
+
self.commit_index += 1
|
| 877 |
+
|
| 878 |
+
if self.commit_index == self.seq_len * 2:
|
| 879 |
+
self.generated = torch.cat(
|
| 880 |
+
[
|
| 881 |
+
self.generated[:, :, self.seq_len :, ...],
|
| 882 |
+
torch.randn(
|
| 883 |
+
self.batch_size,
|
| 884 |
+
self.input_dim,
|
| 885 |
+
self.seq_len,
|
| 886 |
+
1,
|
| 887 |
+
1,
|
| 888 |
+
device=device,
|
| 889 |
+
),
|
| 890 |
+
],
|
| 891 |
+
dim=2,
|
| 892 |
+
)
|
| 893 |
+
self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size
|
| 894 |
+
self.commit_index -= self.seq_len
|
| 895 |
+
for i in range(self.batch_size):
|
| 896 |
+
self.text_condition_list[i] = self.text_condition_list[i][
|
| 897 |
+
self.seq_len :
|
| 898 |
+
]
|
| 899 |
+
return out
|
ldf_models/tools/attention.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import flash_attn_interface
|
| 6 |
+
|
| 7 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 8 |
+
except ModuleNotFoundError:
|
| 9 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import flash_attn
|
| 13 |
+
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
import warnings
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"flash_attention",
|
| 22 |
+
"attention",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def flash_attention(
|
| 27 |
+
q,
|
| 28 |
+
k,
|
| 29 |
+
v,
|
| 30 |
+
q_lens=None,
|
| 31 |
+
k_lens=None,
|
| 32 |
+
dropout_p=0.0,
|
| 33 |
+
softmax_scale=None,
|
| 34 |
+
q_scale=None,
|
| 35 |
+
causal=False,
|
| 36 |
+
window_size=(-1, -1),
|
| 37 |
+
deterministic=False,
|
| 38 |
+
dtype=torch.bfloat16,
|
| 39 |
+
version=None,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
q: [B, Lq, Nq, C1].
|
| 43 |
+
k: [B, Lk, Nk, C1].
|
| 44 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 45 |
+
q_lens: [B].
|
| 46 |
+
k_lens: [B].
|
| 47 |
+
dropout_p: float. Dropout probability.
|
| 48 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 49 |
+
causal: bool. Whether to apply causal attention mask.
|
| 50 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 51 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 52 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 53 |
+
"""
|
| 54 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 55 |
+
assert dtype in half_dtypes
|
| 56 |
+
assert q.device.type == "cuda" and q.size(-1) <= 256
|
| 57 |
+
|
| 58 |
+
# params
|
| 59 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 60 |
+
|
| 61 |
+
def half(x):
|
| 62 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 63 |
+
|
| 64 |
+
# preprocess query
|
| 65 |
+
if q_lens is None:
|
| 66 |
+
q = half(q.flatten(0, 1))
|
| 67 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
|
| 68 |
+
device=q.device, non_blocking=True
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 72 |
+
|
| 73 |
+
# preprocess key, value
|
| 74 |
+
if k_lens is None:
|
| 75 |
+
k = half(k.flatten(0, 1))
|
| 76 |
+
v = half(v.flatten(0, 1))
|
| 77 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
|
| 78 |
+
device=k.device, non_blocking=True
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 82 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 83 |
+
|
| 84 |
+
q = q.to(v.dtype)
|
| 85 |
+
k = k.to(v.dtype)
|
| 86 |
+
|
| 87 |
+
if q_scale is not None:
|
| 88 |
+
q = q * q_scale
|
| 89 |
+
|
| 90 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 91 |
+
warnings.warn(
|
| 92 |
+
"Flash attention 3 is not available, use flash attention 2 instead."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# apply attention
|
| 96 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 97 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 98 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 99 |
+
q=q,
|
| 100 |
+
k=k,
|
| 101 |
+
v=v,
|
| 102 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
| 103 |
+
.cumsum(0, dtype=torch.int32)
|
| 104 |
+
.to(q.device, non_blocking=True),
|
| 105 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
| 106 |
+
.cumsum(0, dtype=torch.int32)
|
| 107 |
+
.to(q.device, non_blocking=True),
|
| 108 |
+
seqused_q=None,
|
| 109 |
+
seqused_k=None,
|
| 110 |
+
max_seqlen_q=lq,
|
| 111 |
+
max_seqlen_k=lk,
|
| 112 |
+
softmax_scale=softmax_scale,
|
| 113 |
+
causal=causal,
|
| 114 |
+
deterministic=deterministic,
|
| 115 |
+
)[0].unflatten(0, (b, lq))
|
| 116 |
+
else:
|
| 117 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 118 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 119 |
+
q=q,
|
| 120 |
+
k=k,
|
| 121 |
+
v=v,
|
| 122 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
| 123 |
+
.cumsum(0, dtype=torch.int32)
|
| 124 |
+
.to(q.device, non_blocking=True),
|
| 125 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
| 126 |
+
.cumsum(0, dtype=torch.int32)
|
| 127 |
+
.to(q.device, non_blocking=True),
|
| 128 |
+
max_seqlen_q=lq,
|
| 129 |
+
max_seqlen_k=lk,
|
| 130 |
+
dropout_p=dropout_p,
|
| 131 |
+
softmax_scale=softmax_scale,
|
| 132 |
+
causal=causal,
|
| 133 |
+
window_size=window_size,
|
| 134 |
+
deterministic=deterministic,
|
| 135 |
+
).unflatten(0, (b, lq))
|
| 136 |
+
|
| 137 |
+
# output
|
| 138 |
+
return x.type(out_dtype)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def attention(
|
| 142 |
+
q,
|
| 143 |
+
k,
|
| 144 |
+
v,
|
| 145 |
+
q_lens=None,
|
| 146 |
+
k_lens=None,
|
| 147 |
+
dropout_p=0.0,
|
| 148 |
+
softmax_scale=None,
|
| 149 |
+
q_scale=None,
|
| 150 |
+
causal=False,
|
| 151 |
+
window_size=(-1, -1),
|
| 152 |
+
deterministic=False,
|
| 153 |
+
dtype=torch.bfloat16,
|
| 154 |
+
fa_version=None,
|
| 155 |
+
):
|
| 156 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 157 |
+
return flash_attention(
|
| 158 |
+
q=q,
|
| 159 |
+
k=k,
|
| 160 |
+
v=v,
|
| 161 |
+
q_lens=q_lens,
|
| 162 |
+
k_lens=k_lens,
|
| 163 |
+
dropout_p=dropout_p,
|
| 164 |
+
softmax_scale=softmax_scale,
|
| 165 |
+
q_scale=q_scale,
|
| 166 |
+
causal=causal,
|
| 167 |
+
window_size=window_size,
|
| 168 |
+
deterministic=deterministic,
|
| 169 |
+
dtype=dtype,
|
| 170 |
+
version=fa_version,
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
if q_lens is not None or k_lens is not None:
|
| 174 |
+
warnings.warn(
|
| 175 |
+
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
|
| 176 |
+
)
|
| 177 |
+
attn_mask = None
|
| 178 |
+
|
| 179 |
+
q = q.transpose(1, 2).to(dtype)
|
| 180 |
+
k = k.transpose(1, 2).to(dtype)
|
| 181 |
+
v = v.transpose(1, 2).to(dtype)
|
| 182 |
+
|
| 183 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 184 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
out = out.transpose(1, 2).contiguous()
|
| 188 |
+
return out
|
ldf_models/tools/t5.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.t5.modeling_t5
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"T5Model",
|
| 14 |
+
"T5Encoder",
|
| 15 |
+
"T5Decoder",
|
| 16 |
+
"T5EncoderModel",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fp16_clamp(x):
|
| 21 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 22 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 23 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_weights(m):
|
| 28 |
+
if isinstance(m, T5LayerNorm):
|
| 29 |
+
nn.init.ones_(m.weight)
|
| 30 |
+
elif isinstance(m, T5Model):
|
| 31 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
| 32 |
+
elif isinstance(m, T5FeedForward):
|
| 33 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 34 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 35 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 36 |
+
elif isinstance(m, T5Attention):
|
| 37 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
|
| 38 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 39 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 40 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
|
| 41 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 42 |
+
nn.init.normal_(
|
| 43 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GELU(nn.Module):
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return (
|
| 50 |
+
0.5
|
| 51 |
+
* x
|
| 52 |
+
* (
|
| 53 |
+
1.0
|
| 54 |
+
+ torch.tanh(
|
| 55 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class T5LayerNorm(nn.Module):
|
| 62 |
+
def __init__(self, dim, eps=1e-6):
|
| 63 |
+
super(T5LayerNorm, self).__init__()
|
| 64 |
+
self.dim = dim
|
| 65 |
+
self.eps = eps
|
| 66 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 70 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 71 |
+
x = x.type_as(self.weight)
|
| 72 |
+
return self.weight * x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class T5Attention(nn.Module):
|
| 76 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 77 |
+
assert dim_attn % num_heads == 0
|
| 78 |
+
super(T5Attention, self).__init__()
|
| 79 |
+
self.dim = dim
|
| 80 |
+
self.dim_attn = dim_attn
|
| 81 |
+
self.num_heads = num_heads
|
| 82 |
+
self.head_dim = dim_attn // num_heads
|
| 83 |
+
|
| 84 |
+
# layers
|
| 85 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 86 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 87 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 88 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 89 |
+
self.dropout = nn.Dropout(dropout)
|
| 90 |
+
|
| 91 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 92 |
+
"""
|
| 93 |
+
x: [B, L1, C].
|
| 94 |
+
context: [B, L2, C] or None.
|
| 95 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 96 |
+
"""
|
| 97 |
+
# check inputs
|
| 98 |
+
context = x if context is None else context
|
| 99 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 100 |
+
|
| 101 |
+
# compute query, key, value
|
| 102 |
+
q = self.q(x).view(b, -1, n, c)
|
| 103 |
+
k = self.k(context).view(b, -1, n, c)
|
| 104 |
+
v = self.v(context).view(b, -1, n, c)
|
| 105 |
+
|
| 106 |
+
# attention bias
|
| 107 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 108 |
+
if pos_bias is not None:
|
| 109 |
+
attn_bias += pos_bias
|
| 110 |
+
if mask is not None:
|
| 111 |
+
assert mask.ndim in [2, 3]
|
| 112 |
+
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 113 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 114 |
+
|
| 115 |
+
# compute attention (T5 does not use scaling)
|
| 116 |
+
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
|
| 117 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 118 |
+
x = torch.einsum("bnij,bjnc->binc", attn, v)
|
| 119 |
+
|
| 120 |
+
# output
|
| 121 |
+
x = x.reshape(b, -1, n * c)
|
| 122 |
+
x = self.o(x)
|
| 123 |
+
x = self.dropout(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class T5FeedForward(nn.Module):
|
| 128 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 129 |
+
super(T5FeedForward, self).__init__()
|
| 130 |
+
self.dim = dim
|
| 131 |
+
self.dim_ffn = dim_ffn
|
| 132 |
+
|
| 133 |
+
# layers
|
| 134 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 135 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 136 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 137 |
+
self.dropout = nn.Dropout(dropout)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
x = self.fc1(x) * self.gate(x)
|
| 141 |
+
x = self.dropout(x)
|
| 142 |
+
x = self.fc2(x)
|
| 143 |
+
x = self.dropout(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class T5SelfAttention(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
dim,
|
| 151 |
+
dim_attn,
|
| 152 |
+
dim_ffn,
|
| 153 |
+
num_heads,
|
| 154 |
+
num_buckets,
|
| 155 |
+
shared_pos=True,
|
| 156 |
+
dropout=0.1,
|
| 157 |
+
):
|
| 158 |
+
super(T5SelfAttention, self).__init__()
|
| 159 |
+
self.dim = dim
|
| 160 |
+
self.dim_attn = dim_attn
|
| 161 |
+
self.dim_ffn = dim_ffn
|
| 162 |
+
self.num_heads = num_heads
|
| 163 |
+
self.num_buckets = num_buckets
|
| 164 |
+
self.shared_pos = shared_pos
|
| 165 |
+
|
| 166 |
+
# layers
|
| 167 |
+
self.norm1 = T5LayerNorm(dim)
|
| 168 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 169 |
+
self.norm2 = T5LayerNorm(dim)
|
| 170 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 171 |
+
self.pos_embedding = (
|
| 172 |
+
None
|
| 173 |
+
if shared_pos
|
| 174 |
+
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 178 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
| 179 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 180 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class T5CrossAttention(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
dim,
|
| 188 |
+
dim_attn,
|
| 189 |
+
dim_ffn,
|
| 190 |
+
num_heads,
|
| 191 |
+
num_buckets,
|
| 192 |
+
shared_pos=True,
|
| 193 |
+
dropout=0.1,
|
| 194 |
+
):
|
| 195 |
+
super(T5CrossAttention, self).__init__()
|
| 196 |
+
self.dim = dim
|
| 197 |
+
self.dim_attn = dim_attn
|
| 198 |
+
self.dim_ffn = dim_ffn
|
| 199 |
+
self.num_heads = num_heads
|
| 200 |
+
self.num_buckets = num_buckets
|
| 201 |
+
self.shared_pos = shared_pos
|
| 202 |
+
|
| 203 |
+
# layers
|
| 204 |
+
self.norm1 = T5LayerNorm(dim)
|
| 205 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 206 |
+
self.norm2 = T5LayerNorm(dim)
|
| 207 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 208 |
+
self.norm3 = T5LayerNorm(dim)
|
| 209 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 210 |
+
self.pos_embedding = (
|
| 211 |
+
None
|
| 212 |
+
if shared_pos
|
| 213 |
+
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward(
|
| 217 |
+
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
|
| 218 |
+
):
|
| 219 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
| 220 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 221 |
+
x = fp16_clamp(
|
| 222 |
+
x
|
| 223 |
+
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
|
| 224 |
+
)
|
| 225 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class T5RelativeEmbedding(nn.Module):
|
| 230 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 231 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 232 |
+
self.num_buckets = num_buckets
|
| 233 |
+
self.num_heads = num_heads
|
| 234 |
+
self.bidirectional = bidirectional
|
| 235 |
+
self.max_dist = max_dist
|
| 236 |
+
|
| 237 |
+
# layers
|
| 238 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 239 |
+
|
| 240 |
+
def forward(self, lq, lk):
|
| 241 |
+
device = self.embedding.weight.device
|
| 242 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 243 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 244 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
|
| 245 |
+
lq, device=device
|
| 246 |
+
).unsqueeze(1)
|
| 247 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 248 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 249 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
|
| 250 |
+
return rel_pos_embeds.contiguous()
|
| 251 |
+
|
| 252 |
+
def _relative_position_bucket(self, rel_pos):
|
| 253 |
+
# preprocess
|
| 254 |
+
if self.bidirectional:
|
| 255 |
+
num_buckets = self.num_buckets // 2
|
| 256 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 257 |
+
rel_pos = torch.abs(rel_pos)
|
| 258 |
+
else:
|
| 259 |
+
num_buckets = self.num_buckets
|
| 260 |
+
rel_buckets = 0
|
| 261 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 262 |
+
|
| 263 |
+
# embeddings for small and large positions
|
| 264 |
+
max_exact = num_buckets // 2
|
| 265 |
+
rel_pos_large = (
|
| 266 |
+
max_exact
|
| 267 |
+
+ (
|
| 268 |
+
torch.log(rel_pos.float() / max_exact)
|
| 269 |
+
/ math.log(self.max_dist / max_exact)
|
| 270 |
+
* (num_buckets - max_exact)
|
| 271 |
+
).long()
|
| 272 |
+
)
|
| 273 |
+
rel_pos_large = torch.min(
|
| 274 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
|
| 275 |
+
)
|
| 276 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 277 |
+
return rel_buckets
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class T5Encoder(nn.Module):
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
vocab,
|
| 284 |
+
dim,
|
| 285 |
+
dim_attn,
|
| 286 |
+
dim_ffn,
|
| 287 |
+
num_heads,
|
| 288 |
+
num_layers,
|
| 289 |
+
num_buckets,
|
| 290 |
+
shared_pos=True,
|
| 291 |
+
dropout=0.1,
|
| 292 |
+
):
|
| 293 |
+
super(T5Encoder, self).__init__()
|
| 294 |
+
self.dim = dim
|
| 295 |
+
self.dim_attn = dim_attn
|
| 296 |
+
self.dim_ffn = dim_ffn
|
| 297 |
+
self.num_heads = num_heads
|
| 298 |
+
self.num_layers = num_layers
|
| 299 |
+
self.num_buckets = num_buckets
|
| 300 |
+
self.shared_pos = shared_pos
|
| 301 |
+
|
| 302 |
+
# layers
|
| 303 |
+
self.token_embedding = (
|
| 304 |
+
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
| 305 |
+
)
|
| 306 |
+
self.pos_embedding = (
|
| 307 |
+
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
| 308 |
+
if shared_pos
|
| 309 |
+
else None
|
| 310 |
+
)
|
| 311 |
+
self.dropout = nn.Dropout(dropout)
|
| 312 |
+
self.blocks = nn.ModuleList(
|
| 313 |
+
[
|
| 314 |
+
T5SelfAttention(
|
| 315 |
+
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
|
| 316 |
+
)
|
| 317 |
+
for _ in range(num_layers)
|
| 318 |
+
]
|
| 319 |
+
)
|
| 320 |
+
self.norm = T5LayerNorm(dim)
|
| 321 |
+
|
| 322 |
+
# initialize weights
|
| 323 |
+
self.apply(init_weights)
|
| 324 |
+
|
| 325 |
+
def forward(self, ids, mask=None):
|
| 326 |
+
x = self.token_embedding(ids)
|
| 327 |
+
x = self.dropout(x)
|
| 328 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
| 329 |
+
for block in self.blocks:
|
| 330 |
+
x = block(x, mask, pos_bias=e)
|
| 331 |
+
x = self.norm(x)
|
| 332 |
+
x = self.dropout(x)
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class T5Decoder(nn.Module):
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
vocab,
|
| 340 |
+
dim,
|
| 341 |
+
dim_attn,
|
| 342 |
+
dim_ffn,
|
| 343 |
+
num_heads,
|
| 344 |
+
num_layers,
|
| 345 |
+
num_buckets,
|
| 346 |
+
shared_pos=True,
|
| 347 |
+
dropout=0.1,
|
| 348 |
+
):
|
| 349 |
+
super(T5Decoder, self).__init__()
|
| 350 |
+
self.dim = dim
|
| 351 |
+
self.dim_attn = dim_attn
|
| 352 |
+
self.dim_ffn = dim_ffn
|
| 353 |
+
self.num_heads = num_heads
|
| 354 |
+
self.num_layers = num_layers
|
| 355 |
+
self.num_buckets = num_buckets
|
| 356 |
+
self.shared_pos = shared_pos
|
| 357 |
+
|
| 358 |
+
# layers
|
| 359 |
+
self.token_embedding = (
|
| 360 |
+
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
| 361 |
+
)
|
| 362 |
+
self.pos_embedding = (
|
| 363 |
+
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
| 364 |
+
if shared_pos
|
| 365 |
+
else None
|
| 366 |
+
)
|
| 367 |
+
self.dropout = nn.Dropout(dropout)
|
| 368 |
+
self.blocks = nn.ModuleList(
|
| 369 |
+
[
|
| 370 |
+
T5CrossAttention(
|
| 371 |
+
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
|
| 372 |
+
)
|
| 373 |
+
for _ in range(num_layers)
|
| 374 |
+
]
|
| 375 |
+
)
|
| 376 |
+
self.norm = T5LayerNorm(dim)
|
| 377 |
+
|
| 378 |
+
# initialize weights
|
| 379 |
+
self.apply(init_weights)
|
| 380 |
+
|
| 381 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
| 382 |
+
b, s = ids.size()
|
| 383 |
+
|
| 384 |
+
# causal mask
|
| 385 |
+
if mask is None:
|
| 386 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
| 387 |
+
elif mask.ndim == 2:
|
| 388 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
| 389 |
+
|
| 390 |
+
# layers
|
| 391 |
+
x = self.token_embedding(ids)
|
| 392 |
+
x = self.dropout(x)
|
| 393 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
| 394 |
+
for block in self.blocks:
|
| 395 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
| 396 |
+
x = self.norm(x)
|
| 397 |
+
x = self.dropout(x)
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class T5Model(nn.Module):
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
vocab_size,
|
| 405 |
+
dim,
|
| 406 |
+
dim_attn,
|
| 407 |
+
dim_ffn,
|
| 408 |
+
num_heads,
|
| 409 |
+
encoder_layers,
|
| 410 |
+
decoder_layers,
|
| 411 |
+
num_buckets,
|
| 412 |
+
shared_pos=True,
|
| 413 |
+
dropout=0.1,
|
| 414 |
+
):
|
| 415 |
+
super(T5Model, self).__init__()
|
| 416 |
+
self.vocab_size = vocab_size
|
| 417 |
+
self.dim = dim
|
| 418 |
+
self.dim_attn = dim_attn
|
| 419 |
+
self.dim_ffn = dim_ffn
|
| 420 |
+
self.num_heads = num_heads
|
| 421 |
+
self.encoder_layers = encoder_layers
|
| 422 |
+
self.decoder_layers = decoder_layers
|
| 423 |
+
self.num_buckets = num_buckets
|
| 424 |
+
|
| 425 |
+
# layers
|
| 426 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
| 427 |
+
self.encoder = T5Encoder(
|
| 428 |
+
self.token_embedding,
|
| 429 |
+
dim,
|
| 430 |
+
dim_attn,
|
| 431 |
+
dim_ffn,
|
| 432 |
+
num_heads,
|
| 433 |
+
encoder_layers,
|
| 434 |
+
num_buckets,
|
| 435 |
+
shared_pos,
|
| 436 |
+
dropout,
|
| 437 |
+
)
|
| 438 |
+
self.decoder = T5Decoder(
|
| 439 |
+
self.token_embedding,
|
| 440 |
+
dim,
|
| 441 |
+
dim_attn,
|
| 442 |
+
dim_ffn,
|
| 443 |
+
num_heads,
|
| 444 |
+
decoder_layers,
|
| 445 |
+
num_buckets,
|
| 446 |
+
shared_pos,
|
| 447 |
+
dropout,
|
| 448 |
+
)
|
| 449 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 450 |
+
|
| 451 |
+
# initialize weights
|
| 452 |
+
self.apply(init_weights)
|
| 453 |
+
|
| 454 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
| 455 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
| 456 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
| 457 |
+
x = self.head(x)
|
| 458 |
+
return x
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _t5(
|
| 462 |
+
name,
|
| 463 |
+
encoder_only=False,
|
| 464 |
+
decoder_only=False,
|
| 465 |
+
return_tokenizer=False,
|
| 466 |
+
tokenizer_kwargs={},
|
| 467 |
+
dtype=torch.float32,
|
| 468 |
+
device="cpu",
|
| 469 |
+
**kwargs,
|
| 470 |
+
):
|
| 471 |
+
# sanity check
|
| 472 |
+
assert not (encoder_only and decoder_only)
|
| 473 |
+
|
| 474 |
+
# params
|
| 475 |
+
if encoder_only:
|
| 476 |
+
model_cls = T5Encoder
|
| 477 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
| 478 |
+
kwargs["num_layers"] = kwargs.pop("encoder_layers")
|
| 479 |
+
_ = kwargs.pop("decoder_layers")
|
| 480 |
+
elif decoder_only:
|
| 481 |
+
model_cls = T5Decoder
|
| 482 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
| 483 |
+
kwargs["num_layers"] = kwargs.pop("decoder_layers")
|
| 484 |
+
_ = kwargs.pop("encoder_layers")
|
| 485 |
+
else:
|
| 486 |
+
model_cls = T5Model
|
| 487 |
+
|
| 488 |
+
# init model
|
| 489 |
+
with torch.device(device):
|
| 490 |
+
model = model_cls(**kwargs)
|
| 491 |
+
|
| 492 |
+
# set device
|
| 493 |
+
model = model.to(dtype=dtype, device=device)
|
| 494 |
+
|
| 495 |
+
# init tokenizer
|
| 496 |
+
if return_tokenizer:
|
| 497 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 498 |
+
|
| 499 |
+
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
|
| 500 |
+
return model, tokenizer
|
| 501 |
+
else:
|
| 502 |
+
return model
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def umt5_xxl(**kwargs):
|
| 506 |
+
cfg = dict(
|
| 507 |
+
vocab_size=256384,
|
| 508 |
+
dim=4096,
|
| 509 |
+
dim_attn=4096,
|
| 510 |
+
dim_ffn=10240,
|
| 511 |
+
num_heads=64,
|
| 512 |
+
encoder_layers=24,
|
| 513 |
+
decoder_layers=24,
|
| 514 |
+
num_buckets=32,
|
| 515 |
+
shared_pos=False,
|
| 516 |
+
dropout=0.1,
|
| 517 |
+
)
|
| 518 |
+
cfg.update(**kwargs)
|
| 519 |
+
return _t5("umt5-xxl", **cfg)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class T5EncoderModel:
|
| 523 |
+
def __init__(
|
| 524 |
+
self,
|
| 525 |
+
text_len,
|
| 526 |
+
dtype=torch.bfloat16,
|
| 527 |
+
device=torch.cuda.current_device(),
|
| 528 |
+
checkpoint_path=None,
|
| 529 |
+
tokenizer_path=None,
|
| 530 |
+
shard_fn=None,
|
| 531 |
+
):
|
| 532 |
+
self.text_len = text_len
|
| 533 |
+
self.dtype = dtype
|
| 534 |
+
self.device = device
|
| 535 |
+
self.checkpoint_path = checkpoint_path
|
| 536 |
+
self.tokenizer_path = tokenizer_path
|
| 537 |
+
|
| 538 |
+
# init model
|
| 539 |
+
model = (
|
| 540 |
+
umt5_xxl(
|
| 541 |
+
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
|
| 542 |
+
)
|
| 543 |
+
.eval()
|
| 544 |
+
.requires_grad_(False)
|
| 545 |
+
)
|
| 546 |
+
logging.info(f"loading {checkpoint_path}")
|
| 547 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
| 548 |
+
self.model = model
|
| 549 |
+
if shard_fn is not None:
|
| 550 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
| 551 |
+
else:
|
| 552 |
+
self.model.to(self.device)
|
| 553 |
+
# init tokenizer
|
| 554 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 555 |
+
name=tokenizer_path, seq_len=text_len, clean="whitespace"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def __call__(self, texts, device):
|
| 559 |
+
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
|
| 560 |
+
ids = ids.to(device)
|
| 561 |
+
mask = mask.to(device)
|
| 562 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 563 |
+
context = self.model(ids, mask)
|
| 564 |
+
return [u[:v] for u, v in zip(context, seq_lens)]
|
ldf_models/tools/tokenizers.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import html
|
| 3 |
+
import string
|
| 4 |
+
|
| 5 |
+
import ftfy
|
| 6 |
+
import regex as re
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
__all__ = ["HuggingfaceTokenizer"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def basic_clean(text):
|
| 13 |
+
text = ftfy.fix_text(text)
|
| 14 |
+
text = html.unescape(html.unescape(text))
|
| 15 |
+
return text.strip()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def whitespace_clean(text):
|
| 19 |
+
text = re.sub(r"\s+", " ", text)
|
| 20 |
+
text = text.strip()
|
| 21 |
+
return text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
| 25 |
+
text = text.replace("_", " ")
|
| 26 |
+
if keep_punctuation_exact_string:
|
| 27 |
+
text = keep_punctuation_exact_string.join(
|
| 28 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
| 29 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 33 |
+
text = text.lower()
|
| 34 |
+
text = re.sub(r"\s+", " ", text)
|
| 35 |
+
return text.strip()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class HuggingfaceTokenizer:
|
| 39 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| 40 |
+
assert clean in (None, "whitespace", "lower", "canonicalize")
|
| 41 |
+
self.name = name
|
| 42 |
+
self.seq_len = seq_len
|
| 43 |
+
self.clean = clean
|
| 44 |
+
|
| 45 |
+
# init tokenizer
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| 47 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 48 |
+
|
| 49 |
+
def __call__(self, sequence, **kwargs):
|
| 50 |
+
return_mask = kwargs.pop("return_mask", False)
|
| 51 |
+
|
| 52 |
+
# arguments
|
| 53 |
+
_kwargs = {"return_tensors": "pt"}
|
| 54 |
+
if self.seq_len is not None:
|
| 55 |
+
_kwargs.update(
|
| 56 |
+
{
|
| 57 |
+
"padding": "max_length",
|
| 58 |
+
"truncation": True,
|
| 59 |
+
"max_length": self.seq_len,
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
_kwargs.update(**kwargs)
|
| 63 |
+
|
| 64 |
+
# tokenization
|
| 65 |
+
if isinstance(sequence, str):
|
| 66 |
+
sequence = [sequence]
|
| 67 |
+
if self.clean:
|
| 68 |
+
sequence = [self._clean(u) for u in sequence]
|
| 69 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
| 70 |
+
|
| 71 |
+
# output
|
| 72 |
+
if return_mask:
|
| 73 |
+
return ids.input_ids, ids.attention_mask
|
| 74 |
+
else:
|
| 75 |
+
return ids.input_ids
|
| 76 |
+
|
| 77 |
+
def _clean(self, text):
|
| 78 |
+
if self.clean == "whitespace":
|
| 79 |
+
text = whitespace_clean(basic_clean(text))
|
| 80 |
+
elif self.clean == "lower":
|
| 81 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 82 |
+
elif self.clean == "canonicalize":
|
| 83 |
+
text = canonicalize(basic_clean(text))
|
| 84 |
+
return text
|
ldf_models/tools/wan_model.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module uses modified code from Alibaba Wan Team
|
| 2 |
+
# Original source: https://github.com/Wan-Video/Wan2.2
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
# Modified to support stream mode for cross-attention.
|
| 5 |
+
# Added causal attention for self-attention (1d case)
|
| 6 |
+
# Added context length corrrection.
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
|
| 15 |
+
from .attention import flash_attention
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 19 |
+
# preprocess
|
| 20 |
+
assert dim % 2 == 0
|
| 21 |
+
half = dim // 2
|
| 22 |
+
position = position.type(torch.float64)
|
| 23 |
+
|
| 24 |
+
# calculation
|
| 25 |
+
sinusoid = torch.outer(
|
| 26 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half))
|
| 27 |
+
)
|
| 28 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.amp.autocast("cuda", enabled=False)
|
| 33 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 34 |
+
assert dim % 2 == 0
|
| 35 |
+
freqs = torch.outer(
|
| 36 |
+
torch.arange(max_seq_len),
|
| 37 |
+
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
| 38 |
+
)
|
| 39 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 40 |
+
return freqs
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@torch.amp.autocast("cuda", enabled=False)
|
| 44 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 45 |
+
n, c = x.size(2), x.size(3) // 2
|
| 46 |
+
|
| 47 |
+
# split freqs
|
| 48 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 49 |
+
|
| 50 |
+
# loop over samples
|
| 51 |
+
output = []
|
| 52 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 53 |
+
seq_len = f * h * w
|
| 54 |
+
|
| 55 |
+
# precompute multipliers
|
| 56 |
+
x_i = torch.view_as_complex(
|
| 57 |
+
x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
|
| 58 |
+
)
|
| 59 |
+
freqs_i = torch.cat(
|
| 60 |
+
[
|
| 61 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 62 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 63 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
| 64 |
+
],
|
| 65 |
+
dim=-1,
|
| 66 |
+
).reshape(seq_len, 1, -1)
|
| 67 |
+
|
| 68 |
+
# apply rotary embedding
|
| 69 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 70 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 71 |
+
|
| 72 |
+
# append to collection
|
| 73 |
+
output.append(x_i)
|
| 74 |
+
return torch.stack(output).float()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class WanRMSNorm(nn.Module):
|
| 78 |
+
def __init__(self, dim, eps=1e-5):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.dim = dim
|
| 81 |
+
self.eps = eps
|
| 82 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
r"""
|
| 86 |
+
Args:
|
| 87 |
+
x(Tensor): Shape [B, L, C]
|
| 88 |
+
"""
|
| 89 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 90 |
+
|
| 91 |
+
def _norm(self, x):
|
| 92 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 96 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 97 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
r"""
|
| 101 |
+
Args:
|
| 102 |
+
x(Tensor): Shape [B, L, C]
|
| 103 |
+
"""
|
| 104 |
+
return super().forward(x.float()).type_as(x)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class WanSelfAttention(nn.Module):
|
| 108 |
+
def __init__(
|
| 109 |
+
self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, causal=False
|
| 110 |
+
):
|
| 111 |
+
assert dim % num_heads == 0
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.num_heads = num_heads
|
| 115 |
+
self.head_dim = dim // num_heads
|
| 116 |
+
self.window_size = window_size
|
| 117 |
+
self.qk_norm = qk_norm
|
| 118 |
+
self.eps = eps
|
| 119 |
+
self.causal = causal
|
| 120 |
+
# layers
|
| 121 |
+
self.q = nn.Linear(dim, dim)
|
| 122 |
+
self.k = nn.Linear(dim, dim)
|
| 123 |
+
self.v = nn.Linear(dim, dim)
|
| 124 |
+
self.o = nn.Linear(dim, dim)
|
| 125 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 126 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 127 |
+
|
| 128 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 129 |
+
r"""
|
| 130 |
+
Args:
|
| 131 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 132 |
+
seq_lens(Tensor): Shape [B]
|
| 133 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 134 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 135 |
+
"""
|
| 136 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 137 |
+
|
| 138 |
+
# query, key, value function
|
| 139 |
+
def qkv_fn(x):
|
| 140 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 141 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 142 |
+
v = self.v(x).view(b, s, n, d)
|
| 143 |
+
return q, k, v
|
| 144 |
+
|
| 145 |
+
q, k, v = qkv_fn(x)
|
| 146 |
+
|
| 147 |
+
x = flash_attention(
|
| 148 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 149 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 150 |
+
v=v,
|
| 151 |
+
k_lens=seq_lens,
|
| 152 |
+
window_size=self.window_size,
|
| 153 |
+
causal=self.causal,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# output
|
| 157 |
+
x = x.flatten(2)
|
| 158 |
+
x = self.o(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WanCrossAttention(WanSelfAttention):
|
| 163 |
+
def forward(self, x, context, context_lens):
|
| 164 |
+
r"""
|
| 165 |
+
Args non-stream mode:
|
| 166 |
+
x(Tensor): Shape [B, L1, C]
|
| 167 |
+
context(Tensor): Shape [B, L2, C]
|
| 168 |
+
context_lens(Tensor): Shape [B]
|
| 169 |
+
Args stream mode:
|
| 170 |
+
x(Tensor): Shape [B, L1, C]
|
| 171 |
+
context(Tensor): Shape [BxL1, L2, C]
|
| 172 |
+
context_lens(Tensor): Shape [BxL1]
|
| 173 |
+
"""
|
| 174 |
+
out_sizes = x.size()
|
| 175 |
+
b, n, d = context.size(0), self.num_heads, self.head_dim
|
| 176 |
+
|
| 177 |
+
# compute query, key, value
|
| 178 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 179 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 180 |
+
v = self.v(context).view(b, -1, n, d)
|
| 181 |
+
|
| 182 |
+
# compute attention
|
| 183 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 184 |
+
|
| 185 |
+
# output
|
| 186 |
+
x = x.flatten(2).view(*out_sizes)
|
| 187 |
+
x = self.o(x)
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class WanAttentionBlock(nn.Module):
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
dim,
|
| 195 |
+
ffn_dim,
|
| 196 |
+
num_heads,
|
| 197 |
+
window_size=(-1, -1),
|
| 198 |
+
qk_norm=True,
|
| 199 |
+
cross_attn_norm=False,
|
| 200 |
+
eps=1e-6,
|
| 201 |
+
causal=False,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.dim = dim
|
| 205 |
+
self.ffn_dim = ffn_dim
|
| 206 |
+
self.num_heads = num_heads
|
| 207 |
+
self.window_size = window_size
|
| 208 |
+
self.qk_norm = qk_norm
|
| 209 |
+
self.cross_attn_norm = cross_attn_norm
|
| 210 |
+
self.eps = eps
|
| 211 |
+
self.causal = causal
|
| 212 |
+
# layers
|
| 213 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 214 |
+
self.self_attn = WanSelfAttention(
|
| 215 |
+
dim, num_heads, window_size, qk_norm, eps, causal
|
| 216 |
+
)
|
| 217 |
+
self.norm3 = (
|
| 218 |
+
WanLayerNorm(dim, eps, elementwise_affine=True)
|
| 219 |
+
if cross_attn_norm
|
| 220 |
+
else nn.Identity()
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
|
| 224 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 225 |
+
self.ffn = nn.Sequential(
|
| 226 |
+
nn.Linear(dim, ffn_dim),
|
| 227 |
+
nn.GELU(approximate="tanh"),
|
| 228 |
+
nn.Linear(ffn_dim, dim),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# modulation
|
| 232 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 233 |
+
|
| 234 |
+
def forward(
|
| 235 |
+
self,
|
| 236 |
+
x,
|
| 237 |
+
e,
|
| 238 |
+
seq_lens,
|
| 239 |
+
grid_sizes,
|
| 240 |
+
freqs,
|
| 241 |
+
context,
|
| 242 |
+
context_lens,
|
| 243 |
+
):
|
| 244 |
+
r"""
|
| 245 |
+
Args:
|
| 246 |
+
x(Tensor): Shape [B, L, C]
|
| 247 |
+
e(Tensor): Shape [B, L1, 6, C]
|
| 248 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 249 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 250 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 251 |
+
"""
|
| 252 |
+
assert e.dtype == torch.float32
|
| 253 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 254 |
+
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
| 255 |
+
assert e[0].dtype == torch.float32
|
| 256 |
+
|
| 257 |
+
# self-attention
|
| 258 |
+
y = self.self_attn(
|
| 259 |
+
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
|
| 260 |
+
seq_lens,
|
| 261 |
+
grid_sizes,
|
| 262 |
+
freqs,
|
| 263 |
+
)
|
| 264 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 265 |
+
x = x + y * e[2].squeeze(2)
|
| 266 |
+
|
| 267 |
+
# cross-attention & ffn function
|
| 268 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 269 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 270 |
+
y = self.ffn(
|
| 271 |
+
self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)
|
| 272 |
+
)
|
| 273 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 274 |
+
x = x + y * e[5].squeeze(2)
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class Head(nn.Module):
|
| 282 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.dim = dim
|
| 285 |
+
self.out_dim = out_dim
|
| 286 |
+
self.patch_size = patch_size
|
| 287 |
+
self.eps = eps
|
| 288 |
+
|
| 289 |
+
# layers
|
| 290 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 291 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 292 |
+
self.head = nn.Linear(dim, out_dim)
|
| 293 |
+
|
| 294 |
+
# modulation
|
| 295 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 296 |
+
|
| 297 |
+
def forward(self, x, e):
|
| 298 |
+
r"""
|
| 299 |
+
Args:
|
| 300 |
+
x(Tensor): Shape [B, L1, C]
|
| 301 |
+
e(Tensor): Shape [B, L1, C]
|
| 302 |
+
"""
|
| 303 |
+
assert e.dtype == torch.float32
|
| 304 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 305 |
+
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
| 306 |
+
x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class WanModel(ModelMixin, ConfigMixin):
|
| 311 |
+
r"""
|
| 312 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
ignore_for_config = [
|
| 316 |
+
"patch_size",
|
| 317 |
+
"cross_attn_norm",
|
| 318 |
+
"qk_norm",
|
| 319 |
+
"text_dim",
|
| 320 |
+
"window_size",
|
| 321 |
+
]
|
| 322 |
+
_no_split_modules = ["WanAttentionBlock"]
|
| 323 |
+
|
| 324 |
+
@register_to_config
|
| 325 |
+
def __init__(
|
| 326 |
+
self,
|
| 327 |
+
model_type="t2v",
|
| 328 |
+
patch_size=(1, 2, 2),
|
| 329 |
+
text_len=512,
|
| 330 |
+
in_dim=16,
|
| 331 |
+
dim=2048,
|
| 332 |
+
ffn_dim=8192,
|
| 333 |
+
freq_dim=256,
|
| 334 |
+
text_dim=4096,
|
| 335 |
+
out_dim=16,
|
| 336 |
+
num_heads=16,
|
| 337 |
+
num_layers=32,
|
| 338 |
+
window_size=(-1, -1),
|
| 339 |
+
qk_norm=True,
|
| 340 |
+
cross_attn_norm=True,
|
| 341 |
+
eps=1e-6,
|
| 342 |
+
causal=False,
|
| 343 |
+
):
|
| 344 |
+
r"""
|
| 345 |
+
Initialize the diffusion model backbone.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 349 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 350 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 351 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 352 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 353 |
+
Fixed length for text embeddings
|
| 354 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 355 |
+
Input video channels (C_in)
|
| 356 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 357 |
+
Hidden dimension of the transformer
|
| 358 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 359 |
+
Intermediate dimension in feed-forward network
|
| 360 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 361 |
+
Dimension for sinusoidal time embeddings
|
| 362 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 363 |
+
Input dimension for text embeddings
|
| 364 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 365 |
+
Output video channels (C_out)
|
| 366 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 367 |
+
Number of attention heads
|
| 368 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 369 |
+
Number of transformer blocks
|
| 370 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 371 |
+
Window size for local attention (-1 indicates global attention)
|
| 372 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 373 |
+
Enable query/key normalization
|
| 374 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 375 |
+
Enable cross-attention normalization
|
| 376 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 377 |
+
Epsilon value for normalization layers
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
super().__init__()
|
| 381 |
+
|
| 382 |
+
assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
|
| 383 |
+
self.model_type = model_type
|
| 384 |
+
|
| 385 |
+
self.patch_size = patch_size
|
| 386 |
+
self.text_len = text_len
|
| 387 |
+
self.in_dim = in_dim
|
| 388 |
+
self.dim = dim
|
| 389 |
+
self.ffn_dim = ffn_dim
|
| 390 |
+
self.freq_dim = freq_dim
|
| 391 |
+
self.text_dim = text_dim
|
| 392 |
+
self.out_dim = out_dim
|
| 393 |
+
self.num_heads = num_heads
|
| 394 |
+
self.num_layers = num_layers
|
| 395 |
+
self.window_size = window_size
|
| 396 |
+
self.qk_norm = qk_norm
|
| 397 |
+
self.cross_attn_norm = cross_attn_norm
|
| 398 |
+
self.eps = eps
|
| 399 |
+
self.causal = causal
|
| 400 |
+
# embeddings
|
| 401 |
+
self.patch_embedding = nn.Conv3d(
|
| 402 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size
|
| 403 |
+
)
|
| 404 |
+
self.text_embedding = nn.Sequential(
|
| 405 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.time_embedding = nn.Sequential(
|
| 409 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
| 410 |
+
)
|
| 411 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 412 |
+
|
| 413 |
+
# blocks
|
| 414 |
+
self.blocks = nn.ModuleList(
|
| 415 |
+
[
|
| 416 |
+
WanAttentionBlock(
|
| 417 |
+
dim,
|
| 418 |
+
ffn_dim,
|
| 419 |
+
num_heads,
|
| 420 |
+
window_size,
|
| 421 |
+
qk_norm,
|
| 422 |
+
cross_attn_norm,
|
| 423 |
+
eps,
|
| 424 |
+
causal,
|
| 425 |
+
)
|
| 426 |
+
for _ in range(num_layers)
|
| 427 |
+
]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# head
|
| 431 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 432 |
+
|
| 433 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 434 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 435 |
+
d = dim // num_heads
|
| 436 |
+
self.freqs = torch.cat(
|
| 437 |
+
[
|
| 438 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 439 |
+
rope_params(1024, 2 * (d // 6)),
|
| 440 |
+
rope_params(1024, 2 * (d // 6)),
|
| 441 |
+
],
|
| 442 |
+
dim=1,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# initialize weights
|
| 446 |
+
self.init_weights()
|
| 447 |
+
|
| 448 |
+
def forward(
|
| 449 |
+
self,
|
| 450 |
+
x,
|
| 451 |
+
t,
|
| 452 |
+
context,
|
| 453 |
+
seq_len,
|
| 454 |
+
y=None,
|
| 455 |
+
):
|
| 456 |
+
r"""
|
| 457 |
+
Forward pass through the diffusion model
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
x (List[Tensor]):
|
| 461 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 462 |
+
t (Tensor):
|
| 463 |
+
Diffusion timesteps tensor of shape [B]
|
| 464 |
+
context (List[Tensor]):
|
| 465 |
+
List of text embeddings each with shape [L, C]
|
| 466 |
+
seq_len (`int`):
|
| 467 |
+
Maximum sequence length for positional encoding
|
| 468 |
+
y (List[Tensor], *optional*):
|
| 469 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
List[Tensor]:
|
| 473 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 474 |
+
"""
|
| 475 |
+
if self.model_type == "i2v":
|
| 476 |
+
assert y is not None
|
| 477 |
+
# params
|
| 478 |
+
device = self.patch_embedding.weight.device
|
| 479 |
+
if self.freqs.device != device:
|
| 480 |
+
self.freqs = self.freqs.to(device)
|
| 481 |
+
|
| 482 |
+
if y is not None:
|
| 483 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 484 |
+
|
| 485 |
+
# embeddings
|
| 486 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 487 |
+
grid_sizes = torch.stack(
|
| 488 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
|
| 489 |
+
)
|
| 490 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 491 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 492 |
+
assert seq_lens.max() <= seq_len
|
| 493 |
+
x = torch.cat(
|
| 494 |
+
[
|
| 495 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 496 |
+
for u in x
|
| 497 |
+
]
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# time embeddings
|
| 501 |
+
if t.dim() == 1: # bs
|
| 502 |
+
t = t.expand(t.size(0), seq_len)
|
| 503 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 504 |
+
bt = t.size(0)
|
| 505 |
+
t = t.flatten()
|
| 506 |
+
e = self.time_embedding(
|
| 507 |
+
sinusoidal_embedding_1d(self.freq_dim, t)
|
| 508 |
+
.unflatten(0, (bt, seq_len))
|
| 509 |
+
.float()
|
| 510 |
+
)
|
| 511 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 512 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 513 |
+
|
| 514 |
+
# context
|
| 515 |
+
context_lens = torch.tensor([u.size(0) for u in context], dtype=torch.long)
|
| 516 |
+
context = self.text_embedding(
|
| 517 |
+
torch.stack(
|
| 518 |
+
[
|
| 519 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 520 |
+
for u in context
|
| 521 |
+
]
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# arguments
|
| 526 |
+
kwargs = dict(
|
| 527 |
+
e=e0,
|
| 528 |
+
seq_lens=seq_lens,
|
| 529 |
+
grid_sizes=grid_sizes,
|
| 530 |
+
freqs=self.freqs,
|
| 531 |
+
context=context,
|
| 532 |
+
context_lens=context_lens,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
for block in self.blocks:
|
| 536 |
+
x = block(x, **kwargs)
|
| 537 |
+
|
| 538 |
+
# head
|
| 539 |
+
x = self.head(x, e)
|
| 540 |
+
|
| 541 |
+
# unpatchify
|
| 542 |
+
x = self.unpatchify(x, grid_sizes)
|
| 543 |
+
return [u.float() for u in x]
|
| 544 |
+
|
| 545 |
+
def unpatchify(self, x, grid_sizes):
|
| 546 |
+
r"""
|
| 547 |
+
Reconstruct video tensors from patch embeddings.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
x (List[Tensor]):
|
| 551 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 552 |
+
grid_sizes (Tensor):
|
| 553 |
+
Original spatial-temporal grid dimensions before patching,
|
| 554 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
List[Tensor]:
|
| 558 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
c = self.out_dim
|
| 562 |
+
out = []
|
| 563 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 564 |
+
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
|
| 565 |
+
u = torch.einsum("fhwpqrc->cfphqwr", u)
|
| 566 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 567 |
+
out.append(u)
|
| 568 |
+
return out
|
| 569 |
+
|
| 570 |
+
def init_weights(self):
|
| 571 |
+
r"""
|
| 572 |
+
Initialize model parameters using Xavier initialization.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
# basic init
|
| 576 |
+
for m in self.modules():
|
| 577 |
+
if isinstance(m, nn.Linear):
|
| 578 |
+
nn.init.xavier_uniform_(m.weight)
|
| 579 |
+
if m.bias is not None:
|
| 580 |
+
nn.init.zeros_(m.bias)
|
| 581 |
+
|
| 582 |
+
# init embeddings
|
| 583 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 584 |
+
for m in self.text_embedding.modules():
|
| 585 |
+
if isinstance(m, nn.Linear):
|
| 586 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 587 |
+
for m in self.time_embedding.modules():
|
| 588 |
+
if isinstance(m, nn.Linear):
|
| 589 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 590 |
+
|
| 591 |
+
# init output layer
|
| 592 |
+
nn.init.zeros_(self.head.head.weight)
|
ldf_models/tools/wan_vae_1d.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module uses modified code from Alibaba Wan Team
|
| 2 |
+
# Original source: https://github.com/Wan-Video/Wan2.2
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
# Modified to support 1d features with (B, C, T)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
CACHE_T = 2
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CausalConv1d(nn.Conv1d):
|
| 14 |
+
"""
|
| 15 |
+
Causal 1d convolusion.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, *args, **kwargs):
|
| 19 |
+
super().__init__(*args, **kwargs)
|
| 20 |
+
self._padding = (
|
| 21 |
+
2 * self.padding[0],
|
| 22 |
+
0,
|
| 23 |
+
)
|
| 24 |
+
self.padding = (0,)
|
| 25 |
+
|
| 26 |
+
def forward(self, x, cache_x=None):
|
| 27 |
+
padding = list(self._padding)
|
| 28 |
+
if cache_x is not None and self._padding[0] > 0:
|
| 29 |
+
cache_x = cache_x.to(x.device)
|
| 30 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 31 |
+
padding[0] -= cache_x.shape[2]
|
| 32 |
+
x = F.pad(x, padding)
|
| 33 |
+
|
| 34 |
+
return super().forward(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RMS_norm(nn.Module):
|
| 38 |
+
def __init__(self, dim, channel_first=True, bias=False):
|
| 39 |
+
super().__init__()
|
| 40 |
+
broadcastable_dims = (1,)
|
| 41 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 42 |
+
|
| 43 |
+
self.channel_first = channel_first
|
| 44 |
+
self.scale = dim**0.5
|
| 45 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 46 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return (
|
| 50 |
+
F.normalize(x, dim=(1 if self.channel_first else -1))
|
| 51 |
+
* self.scale
|
| 52 |
+
* self.gamma
|
| 53 |
+
+ self.bias
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Upsample(nn.Upsample):
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
"""
|
| 60 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 61 |
+
"""
|
| 62 |
+
return super().forward(x.float()).type_as(x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Resample(nn.Module):
|
| 66 |
+
def __init__(self, dim, mode):
|
| 67 |
+
assert mode in (
|
| 68 |
+
"upsample1d",
|
| 69 |
+
"downsample1d",
|
| 70 |
+
)
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.dim = dim
|
| 73 |
+
self.mode = mode
|
| 74 |
+
|
| 75 |
+
# layers
|
| 76 |
+
if mode == "upsample1d":
|
| 77 |
+
self.time_conv = CausalConv1d(dim, dim * 2, (3,), padding=(1,))
|
| 78 |
+
elif mode == "downsample1d":
|
| 79 |
+
self.time_conv = CausalConv1d(dim, dim, (3,), stride=(2,), padding=(0,))
|
| 80 |
+
|
| 81 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 82 |
+
b, c, t = x.size()
|
| 83 |
+
if self.mode == "upsample1d":
|
| 84 |
+
if feat_cache is not None:
|
| 85 |
+
idx = feat_idx[0]
|
| 86 |
+
if feat_cache[idx] is None:
|
| 87 |
+
feat_cache[idx] = "Rep"
|
| 88 |
+
feat_idx[0] += 1
|
| 89 |
+
else:
|
| 90 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 91 |
+
if (
|
| 92 |
+
cache_x.shape[2] < 2
|
| 93 |
+
and feat_cache[idx] is not None
|
| 94 |
+
and feat_cache[idx] != "Rep"
|
| 95 |
+
):
|
| 96 |
+
# cache last frame of last two chunk
|
| 97 |
+
cache_x = torch.cat(
|
| 98 |
+
[
|
| 99 |
+
feat_cache[idx][:, :, -1]
|
| 100 |
+
.unsqueeze(2)
|
| 101 |
+
.to(cache_x.device),
|
| 102 |
+
cache_x,
|
| 103 |
+
],
|
| 104 |
+
dim=2,
|
| 105 |
+
)
|
| 106 |
+
if (
|
| 107 |
+
cache_x.shape[2] < 2
|
| 108 |
+
and feat_cache[idx] is not None
|
| 109 |
+
and feat_cache[idx] == "Rep"
|
| 110 |
+
):
|
| 111 |
+
cache_x = torch.cat(
|
| 112 |
+
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
|
| 113 |
+
dim=2,
|
| 114 |
+
)
|
| 115 |
+
if feat_cache[idx] == "Rep":
|
| 116 |
+
x = self.time_conv(x)
|
| 117 |
+
else:
|
| 118 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 119 |
+
feat_cache[idx] = cache_x
|
| 120 |
+
feat_idx[0] += 1
|
| 121 |
+
x = x.reshape(b, 2, c, t)
|
| 122 |
+
x = torch.stack((x[:, 0, :, :], x[:, 1, :, :]), 3)
|
| 123 |
+
x = x.reshape(b, c, t * 2)
|
| 124 |
+
|
| 125 |
+
if self.mode == "downsample1d":
|
| 126 |
+
if feat_cache is not None:
|
| 127 |
+
idx = feat_idx[0]
|
| 128 |
+
if feat_cache[idx] is None:
|
| 129 |
+
feat_cache[idx] = x.clone()
|
| 130 |
+
feat_idx[0] += 1
|
| 131 |
+
else:
|
| 132 |
+
cache_x = x[:, :, -1:].clone()
|
| 133 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:], x], 2))
|
| 134 |
+
feat_cache[idx] = cache_x
|
| 135 |
+
feat_idx[0] += 1
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ResidualBlock(nn.Module):
|
| 140 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.in_dim = in_dim
|
| 143 |
+
self.out_dim = out_dim
|
| 144 |
+
|
| 145 |
+
# layers
|
| 146 |
+
self.residual = nn.Sequential(
|
| 147 |
+
RMS_norm(in_dim),
|
| 148 |
+
nn.SiLU(),
|
| 149 |
+
CausalConv1d(in_dim, out_dim, 3, padding=1),
|
| 150 |
+
RMS_norm(out_dim),
|
| 151 |
+
nn.SiLU(),
|
| 152 |
+
nn.Dropout(dropout),
|
| 153 |
+
CausalConv1d(out_dim, out_dim, 3, padding=1),
|
| 154 |
+
)
|
| 155 |
+
self.shortcut = (
|
| 156 |
+
CausalConv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 160 |
+
h = self.shortcut(x)
|
| 161 |
+
for layer in self.residual:
|
| 162 |
+
if isinstance(layer, CausalConv1d) and feat_cache is not None:
|
| 163 |
+
idx = feat_idx[0]
|
| 164 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 165 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 166 |
+
# cache last frame of last two chunk
|
| 167 |
+
cache_x = torch.cat(
|
| 168 |
+
[
|
| 169 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 170 |
+
cache_x,
|
| 171 |
+
],
|
| 172 |
+
dim=2,
|
| 173 |
+
)
|
| 174 |
+
x = layer(x, feat_cache[idx])
|
| 175 |
+
feat_cache[idx] = cache_x
|
| 176 |
+
feat_idx[0] += 1
|
| 177 |
+
else:
|
| 178 |
+
x = layer(x)
|
| 179 |
+
return x + h
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class AvgDown1D(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
in_channels,
|
| 186 |
+
out_channels,
|
| 187 |
+
factor_t,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.in_channels = in_channels
|
| 191 |
+
self.out_channels = out_channels
|
| 192 |
+
self.factor_t = factor_t
|
| 193 |
+
self.factor = self.factor_t
|
| 194 |
+
|
| 195 |
+
assert in_channels * self.factor % out_channels == 0
|
| 196 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 200 |
+
pad = (pad_t, 0)
|
| 201 |
+
x = F.pad(x, pad)
|
| 202 |
+
B, C, T = x.shape
|
| 203 |
+
x = x.view(
|
| 204 |
+
B,
|
| 205 |
+
C,
|
| 206 |
+
T // self.factor_t,
|
| 207 |
+
self.factor_t,
|
| 208 |
+
)
|
| 209 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
| 210 |
+
x = x.view(
|
| 211 |
+
B,
|
| 212 |
+
C * self.factor,
|
| 213 |
+
T // self.factor_t,
|
| 214 |
+
)
|
| 215 |
+
x = x.view(
|
| 216 |
+
B,
|
| 217 |
+
self.out_channels,
|
| 218 |
+
self.group_size,
|
| 219 |
+
T // self.factor_t,
|
| 220 |
+
)
|
| 221 |
+
x = x.mean(dim=2)
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class DupUp1D(nn.Module):
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
in_channels: int,
|
| 229 |
+
out_channels: int,
|
| 230 |
+
factor_t,
|
| 231 |
+
):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.in_channels = in_channels
|
| 234 |
+
self.out_channels = out_channels
|
| 235 |
+
|
| 236 |
+
self.factor_t = factor_t
|
| 237 |
+
self.factor = self.factor_t
|
| 238 |
+
|
| 239 |
+
assert out_channels * self.factor % in_channels == 0
|
| 240 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 243 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 244 |
+
x = x.view(
|
| 245 |
+
x.size(0),
|
| 246 |
+
self.out_channels,
|
| 247 |
+
self.factor_t,
|
| 248 |
+
x.size(2),
|
| 249 |
+
)
|
| 250 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
| 251 |
+
x = x.view(
|
| 252 |
+
x.size(0),
|
| 253 |
+
self.out_channels,
|
| 254 |
+
x.size(2) * self.factor_t,
|
| 255 |
+
)
|
| 256 |
+
if first_chunk:
|
| 257 |
+
x = x[
|
| 258 |
+
:,
|
| 259 |
+
:,
|
| 260 |
+
self.factor_t - 1 :,
|
| 261 |
+
]
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class Down_ResidualBlock(nn.Module):
|
| 266 |
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False):
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
# Shortcut path with downsample
|
| 270 |
+
if temperal_downsample:
|
| 271 |
+
self.avg_shortcut = AvgDown1D(
|
| 272 |
+
in_dim,
|
| 273 |
+
out_dim,
|
| 274 |
+
factor_t=2,
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
self.avg_shortcut = None
|
| 278 |
+
|
| 279 |
+
# Main path with residual blocks and downsample
|
| 280 |
+
downsamples = []
|
| 281 |
+
for _ in range(mult):
|
| 282 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 283 |
+
in_dim = out_dim
|
| 284 |
+
|
| 285 |
+
# Add the final downsample block
|
| 286 |
+
if temperal_downsample:
|
| 287 |
+
downsamples.append(Resample(out_dim, mode="downsample1d"))
|
| 288 |
+
|
| 289 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 290 |
+
|
| 291 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 292 |
+
x_copy = x.clone()
|
| 293 |
+
for module in self.downsamples:
|
| 294 |
+
x = module(x, feat_cache, feat_idx)
|
| 295 |
+
if self.avg_shortcut is None:
|
| 296 |
+
return x
|
| 297 |
+
else:
|
| 298 |
+
return x + self.avg_shortcut(x_copy)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class Up_ResidualBlock(nn.Module):
|
| 302 |
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False):
|
| 303 |
+
super().__init__()
|
| 304 |
+
# Shortcut path with upsample
|
| 305 |
+
if temperal_upsample:
|
| 306 |
+
self.avg_shortcut = DupUp1D(
|
| 307 |
+
in_dim,
|
| 308 |
+
out_dim,
|
| 309 |
+
factor_t=2,
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
self.avg_shortcut = None
|
| 313 |
+
|
| 314 |
+
# Main path with residual blocks and upsample
|
| 315 |
+
upsamples = []
|
| 316 |
+
for _ in range(mult):
|
| 317 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 318 |
+
in_dim = out_dim
|
| 319 |
+
|
| 320 |
+
# Add the final upsample block
|
| 321 |
+
if temperal_upsample:
|
| 322 |
+
upsamples.append(Resample(out_dim, mode="upsample1d"))
|
| 323 |
+
|
| 324 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 325 |
+
|
| 326 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 327 |
+
x_main = x.clone()
|
| 328 |
+
for module in self.upsamples:
|
| 329 |
+
x_main = module(x_main, feat_cache, feat_idx)
|
| 330 |
+
if self.avg_shortcut is not None:
|
| 331 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 332 |
+
return x_main + x_shortcut
|
| 333 |
+
else:
|
| 334 |
+
return x_main
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class Encoder1d(nn.Module):
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
input_dim,
|
| 341 |
+
dim=128,
|
| 342 |
+
z_dim=4,
|
| 343 |
+
dim_mult=[1, 2, 4, 4],
|
| 344 |
+
num_res_blocks=2,
|
| 345 |
+
temperal_downsample=[True, True, False],
|
| 346 |
+
dropout=0.0,
|
| 347 |
+
):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.dim = dim
|
| 350 |
+
self.z_dim = z_dim
|
| 351 |
+
self.dim_mult = dim_mult
|
| 352 |
+
self.num_res_blocks = num_res_blocks
|
| 353 |
+
self.temperal_downsample = temperal_downsample
|
| 354 |
+
|
| 355 |
+
# dimensions
|
| 356 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 357 |
+
scale = 1.0
|
| 358 |
+
|
| 359 |
+
# init block
|
| 360 |
+
self.conv1 = CausalConv1d(input_dim, dims[0], 3, padding=1)
|
| 361 |
+
|
| 362 |
+
# downsample blocks
|
| 363 |
+
downsamples = []
|
| 364 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 365 |
+
t_down_flag = (
|
| 366 |
+
temperal_downsample[i] if i < len(temperal_downsample) else False
|
| 367 |
+
)
|
| 368 |
+
downsamples.append(
|
| 369 |
+
Down_ResidualBlock(
|
| 370 |
+
in_dim=in_dim,
|
| 371 |
+
out_dim=out_dim,
|
| 372 |
+
dropout=dropout,
|
| 373 |
+
mult=num_res_blocks,
|
| 374 |
+
temperal_downsample=t_down_flag,
|
| 375 |
+
)
|
| 376 |
+
)
|
| 377 |
+
scale /= 2.0
|
| 378 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 379 |
+
|
| 380 |
+
# middle blocks
|
| 381 |
+
self.middle = nn.Sequential(
|
| 382 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 383 |
+
RMS_norm(out_dim),
|
| 384 |
+
CausalConv1d(out_dim, out_dim, 1),
|
| 385 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# # output blocks
|
| 389 |
+
self.head = nn.Sequential(
|
| 390 |
+
RMS_norm(out_dim),
|
| 391 |
+
nn.SiLU(),
|
| 392 |
+
CausalConv1d(out_dim, z_dim, 3, padding=1),
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 396 |
+
if feat_cache is not None:
|
| 397 |
+
idx = feat_idx[0]
|
| 398 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 399 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 400 |
+
cache_x = torch.cat(
|
| 401 |
+
[
|
| 402 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 403 |
+
cache_x,
|
| 404 |
+
],
|
| 405 |
+
dim=2,
|
| 406 |
+
)
|
| 407 |
+
x = self.conv1(x, feat_cache[idx])
|
| 408 |
+
feat_cache[idx] = cache_x
|
| 409 |
+
feat_idx[0] += 1
|
| 410 |
+
else:
|
| 411 |
+
x = self.conv1(x)
|
| 412 |
+
|
| 413 |
+
## downsamples
|
| 414 |
+
for layer in self.downsamples:
|
| 415 |
+
if feat_cache is not None:
|
| 416 |
+
x = layer(x, feat_cache, feat_idx)
|
| 417 |
+
else:
|
| 418 |
+
x = layer(x)
|
| 419 |
+
|
| 420 |
+
## middle
|
| 421 |
+
for layer in self.middle:
|
| 422 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 423 |
+
x = layer(x, feat_cache, feat_idx)
|
| 424 |
+
else:
|
| 425 |
+
x = layer(x)
|
| 426 |
+
|
| 427 |
+
## head
|
| 428 |
+
for layer in self.head:
|
| 429 |
+
if isinstance(layer, CausalConv1d) and feat_cache is not None:
|
| 430 |
+
idx = feat_idx[0]
|
| 431 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 432 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 433 |
+
cache_x = torch.cat(
|
| 434 |
+
[
|
| 435 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 436 |
+
cache_x,
|
| 437 |
+
],
|
| 438 |
+
dim=2,
|
| 439 |
+
)
|
| 440 |
+
x = layer(x, feat_cache[idx])
|
| 441 |
+
feat_cache[idx] = cache_x
|
| 442 |
+
feat_idx[0] += 1
|
| 443 |
+
else:
|
| 444 |
+
x = layer(x)
|
| 445 |
+
|
| 446 |
+
return x
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class Decoder1d(nn.Module):
|
| 450 |
+
def __init__(
|
| 451 |
+
self,
|
| 452 |
+
output_dim,
|
| 453 |
+
dim=128,
|
| 454 |
+
z_dim=4,
|
| 455 |
+
dim_mult=[1, 2, 4, 4],
|
| 456 |
+
num_res_blocks=2,
|
| 457 |
+
temperal_upsample=[False, True, True],
|
| 458 |
+
dropout=0.0,
|
| 459 |
+
):
|
| 460 |
+
super().__init__()
|
| 461 |
+
self.dim = dim
|
| 462 |
+
self.z_dim = z_dim
|
| 463 |
+
self.dim_mult = dim_mult
|
| 464 |
+
self.num_res_blocks = num_res_blocks
|
| 465 |
+
self.temperal_upsample = temperal_upsample
|
| 466 |
+
|
| 467 |
+
# dimensions
|
| 468 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 469 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 470 |
+
# init block
|
| 471 |
+
self.conv1 = CausalConv1d(z_dim, dims[0], 3, padding=1)
|
| 472 |
+
|
| 473 |
+
# middle blocks
|
| 474 |
+
self.middle = nn.Sequential(
|
| 475 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 476 |
+
RMS_norm(dims[0]),
|
| 477 |
+
CausalConv1d(dims[0], dims[0], 1),
|
| 478 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# upsample blocks
|
| 482 |
+
upsamples = []
|
| 483 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 484 |
+
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
| 485 |
+
upsamples.append(
|
| 486 |
+
Up_ResidualBlock(
|
| 487 |
+
in_dim=in_dim,
|
| 488 |
+
out_dim=out_dim,
|
| 489 |
+
dropout=dropout,
|
| 490 |
+
mult=num_res_blocks + 1,
|
| 491 |
+
temperal_upsample=t_up_flag,
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 495 |
+
|
| 496 |
+
# output blocks
|
| 497 |
+
self.head = nn.Sequential(
|
| 498 |
+
RMS_norm(out_dim),
|
| 499 |
+
nn.SiLU(),
|
| 500 |
+
CausalConv1d(out_dim, output_dim, 3, padding=1),
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 504 |
+
if feat_cache is not None:
|
| 505 |
+
idx = feat_idx[0]
|
| 506 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 507 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 508 |
+
cache_x = torch.cat(
|
| 509 |
+
[
|
| 510 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 511 |
+
cache_x,
|
| 512 |
+
],
|
| 513 |
+
dim=2,
|
| 514 |
+
)
|
| 515 |
+
x = self.conv1(x, feat_cache[idx])
|
| 516 |
+
feat_cache[idx] = cache_x
|
| 517 |
+
feat_idx[0] += 1
|
| 518 |
+
else:
|
| 519 |
+
x = self.conv1(x)
|
| 520 |
+
|
| 521 |
+
for layer in self.middle:
|
| 522 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 523 |
+
x = layer(x, feat_cache, feat_idx)
|
| 524 |
+
else:
|
| 525 |
+
x = layer(x)
|
| 526 |
+
|
| 527 |
+
## upsamples
|
| 528 |
+
for layer in self.upsamples:
|
| 529 |
+
if feat_cache is not None:
|
| 530 |
+
x = layer(x, feat_cache, feat_idx, first_chunk)
|
| 531 |
+
else:
|
| 532 |
+
x = layer(x)
|
| 533 |
+
|
| 534 |
+
## head
|
| 535 |
+
for layer in self.head:
|
| 536 |
+
if isinstance(layer, CausalConv1d) and feat_cache is not None:
|
| 537 |
+
idx = feat_idx[0]
|
| 538 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 539 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 540 |
+
cache_x = torch.cat(
|
| 541 |
+
[
|
| 542 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 543 |
+
cache_x,
|
| 544 |
+
],
|
| 545 |
+
dim=2,
|
| 546 |
+
)
|
| 547 |
+
x = layer(x, feat_cache[idx])
|
| 548 |
+
feat_cache[idx] = cache_x
|
| 549 |
+
feat_idx[0] += 1
|
| 550 |
+
else:
|
| 551 |
+
x = layer(x)
|
| 552 |
+
return x
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def count_conv1d(model):
|
| 556 |
+
count = 0
|
| 557 |
+
for m in model.modules():
|
| 558 |
+
if isinstance(m, CausalConv1d):
|
| 559 |
+
count += 1
|
| 560 |
+
return count
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class WanVAE_(nn.Module):
|
| 564 |
+
def __init__(
|
| 565 |
+
self,
|
| 566 |
+
input_dim,
|
| 567 |
+
dim=160,
|
| 568 |
+
dec_dim=256,
|
| 569 |
+
z_dim=16,
|
| 570 |
+
dim_mult=[1, 2, 4, 4],
|
| 571 |
+
num_res_blocks=1,
|
| 572 |
+
temperal_downsample=[True, True, False],
|
| 573 |
+
dropout=0.0,
|
| 574 |
+
):
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.dim = dim
|
| 577 |
+
self.z_dim = z_dim
|
| 578 |
+
self.dim_mult = dim_mult
|
| 579 |
+
self.num_res_blocks = num_res_blocks
|
| 580 |
+
self.temperal_downsample = temperal_downsample
|
| 581 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 582 |
+
|
| 583 |
+
# modules
|
| 584 |
+
self.encoder = Encoder1d(
|
| 585 |
+
input_dim,
|
| 586 |
+
dim,
|
| 587 |
+
z_dim * 2,
|
| 588 |
+
dim_mult,
|
| 589 |
+
num_res_blocks,
|
| 590 |
+
self.temperal_downsample,
|
| 591 |
+
dropout,
|
| 592 |
+
)
|
| 593 |
+
self.conv1 = CausalConv1d(z_dim * 2, z_dim * 2, 1)
|
| 594 |
+
self.conv2 = CausalConv1d(z_dim, z_dim, 1)
|
| 595 |
+
self.decoder = Decoder1d(
|
| 596 |
+
input_dim,
|
| 597 |
+
dec_dim,
|
| 598 |
+
z_dim,
|
| 599 |
+
dim_mult,
|
| 600 |
+
num_res_blocks,
|
| 601 |
+
self.temperal_upsample,
|
| 602 |
+
dropout,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
def forward(self, x, scale=[0, 1]):
|
| 606 |
+
mu = self.encode(x, scale)
|
| 607 |
+
x_recon = self.decode(mu, scale)
|
| 608 |
+
return x_recon, mu
|
| 609 |
+
|
| 610 |
+
def encode(self, x, scale, return_dist=False):
|
| 611 |
+
self.clear_cache()
|
| 612 |
+
t = x.shape[2]
|
| 613 |
+
iter_ = 1 + (t - 1) // 4
|
| 614 |
+
for i in range(iter_):
|
| 615 |
+
self._enc_conv_idx = [0]
|
| 616 |
+
if i == 0:
|
| 617 |
+
out = self.encoder(
|
| 618 |
+
x[:, :, :1],
|
| 619 |
+
feat_cache=self._enc_feat_map,
|
| 620 |
+
feat_idx=self._enc_conv_idx,
|
| 621 |
+
)
|
| 622 |
+
else:
|
| 623 |
+
out_ = self.encoder(
|
| 624 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
|
| 625 |
+
feat_cache=self._enc_feat_map,
|
| 626 |
+
feat_idx=self._enc_conv_idx,
|
| 627 |
+
)
|
| 628 |
+
out = torch.cat([out, out_], 2)
|
| 629 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 630 |
+
if isinstance(scale[0], torch.Tensor):
|
| 631 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
|
| 632 |
+
1, self.z_dim, 1
|
| 633 |
+
)
|
| 634 |
+
else:
|
| 635 |
+
mu = (mu - scale[0]) * scale[1]
|
| 636 |
+
self.clear_cache()
|
| 637 |
+
if return_dist:
|
| 638 |
+
return mu, log_var
|
| 639 |
+
return mu
|
| 640 |
+
|
| 641 |
+
def decode(self, z, scale):
|
| 642 |
+
self.clear_cache()
|
| 643 |
+
if isinstance(scale[0], torch.Tensor):
|
| 644 |
+
z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
|
| 645 |
+
else:
|
| 646 |
+
z = z / scale[1] + scale[0]
|
| 647 |
+
iter_ = z.shape[2]
|
| 648 |
+
x = self.conv2(z)
|
| 649 |
+
for i in range(iter_):
|
| 650 |
+
self._conv_idx = [0]
|
| 651 |
+
if i == 0:
|
| 652 |
+
out = self.decoder(
|
| 653 |
+
x[:, :, i : i + 1],
|
| 654 |
+
feat_cache=self._feat_map,
|
| 655 |
+
feat_idx=self._conv_idx,
|
| 656 |
+
first_chunk=True,
|
| 657 |
+
)
|
| 658 |
+
else:
|
| 659 |
+
out_ = self.decoder(
|
| 660 |
+
x[:, :, i : i + 1],
|
| 661 |
+
feat_cache=self._feat_map,
|
| 662 |
+
feat_idx=self._conv_idx,
|
| 663 |
+
)
|
| 664 |
+
out = torch.cat([out, out_], 2)
|
| 665 |
+
self.clear_cache()
|
| 666 |
+
return out
|
| 667 |
+
|
| 668 |
+
@torch.no_grad()
|
| 669 |
+
def stream_encode(self, x, first_chunk, scale, return_dist=False):
|
| 670 |
+
t = x.shape[2]
|
| 671 |
+
if first_chunk:
|
| 672 |
+
iter_ = 1 + (t - 1) // 4
|
| 673 |
+
else:
|
| 674 |
+
iter_ = t // 4
|
| 675 |
+
for i in range(iter_):
|
| 676 |
+
self._enc_conv_idx = [0]
|
| 677 |
+
if i == 0:
|
| 678 |
+
if first_chunk:
|
| 679 |
+
out = self.encoder(
|
| 680 |
+
x[:, :, :1],
|
| 681 |
+
feat_cache=self._enc_feat_map,
|
| 682 |
+
feat_idx=self._enc_conv_idx,
|
| 683 |
+
)
|
| 684 |
+
else:
|
| 685 |
+
out = self.encoder(
|
| 686 |
+
x[:, :, :4],
|
| 687 |
+
feat_cache=self._enc_feat_map,
|
| 688 |
+
feat_idx=self._enc_conv_idx,
|
| 689 |
+
)
|
| 690 |
+
else:
|
| 691 |
+
if first_chunk:
|
| 692 |
+
out_ = self.encoder(
|
| 693 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
|
| 694 |
+
feat_cache=self._enc_feat_map,
|
| 695 |
+
feat_idx=self._enc_conv_idx,
|
| 696 |
+
)
|
| 697 |
+
else:
|
| 698 |
+
out_ = self.encoder(
|
| 699 |
+
x[:, :, 4 * i : 4 * (i + 1)],
|
| 700 |
+
feat_cache=self._enc_feat_map,
|
| 701 |
+
feat_idx=self._enc_conv_idx,
|
| 702 |
+
)
|
| 703 |
+
out = torch.cat([out, out_], 2)
|
| 704 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 705 |
+
if isinstance(scale[0], torch.Tensor):
|
| 706 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
|
| 707 |
+
1, self.z_dim, 1
|
| 708 |
+
)
|
| 709 |
+
else:
|
| 710 |
+
mu = (mu - scale[0]) * scale[1]
|
| 711 |
+
if return_dist:
|
| 712 |
+
return mu, log_var
|
| 713 |
+
else:
|
| 714 |
+
return mu
|
| 715 |
+
|
| 716 |
+
@torch.no_grad()
|
| 717 |
+
def stream_decode(self, z, first_chunk, scale):
|
| 718 |
+
if isinstance(scale[0], torch.Tensor):
|
| 719 |
+
z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
|
| 720 |
+
else:
|
| 721 |
+
z = z / scale[1] + scale[0]
|
| 722 |
+
iter_ = z.shape[2]
|
| 723 |
+
x = self.conv2(z)
|
| 724 |
+
for i in range(iter_):
|
| 725 |
+
self._conv_idx = [0]
|
| 726 |
+
if i == 0:
|
| 727 |
+
out = self.decoder(
|
| 728 |
+
x[:, :, i : i + 1],
|
| 729 |
+
feat_cache=self._feat_map,
|
| 730 |
+
feat_idx=self._conv_idx,
|
| 731 |
+
first_chunk=first_chunk, # Use the external first_chunk parameter
|
| 732 |
+
)
|
| 733 |
+
else:
|
| 734 |
+
out_ = self.decoder(
|
| 735 |
+
x[:, :, i : i + 1],
|
| 736 |
+
feat_cache=self._feat_map,
|
| 737 |
+
feat_idx=self._conv_idx,
|
| 738 |
+
first_chunk=False, # Explicitly set to False for subsequent time steps within the same chunk
|
| 739 |
+
)
|
| 740 |
+
out = torch.cat([out, out_], 2)
|
| 741 |
+
return out
|
| 742 |
+
|
| 743 |
+
def reparameterize(self, mu, log_var):
|
| 744 |
+
std = torch.exp(0.5 * log_var)
|
| 745 |
+
eps = torch.randn_like(std)
|
| 746 |
+
return eps * std + mu
|
| 747 |
+
|
| 748 |
+
def sample(self, imgs, deterministic=False):
|
| 749 |
+
mu, log_var = self.encode(imgs)
|
| 750 |
+
if deterministic:
|
| 751 |
+
return mu
|
| 752 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 753 |
+
return mu + std * torch.randn_like(std)
|
| 754 |
+
|
| 755 |
+
def clear_cache(self):
|
| 756 |
+
self._conv_num = count_conv1d(self.decoder)
|
| 757 |
+
self._conv_idx = [0]
|
| 758 |
+
self._feat_map = [None] * self._conv_num
|
| 759 |
+
# cache encode
|
| 760 |
+
self._enc_conv_num = count_conv1d(self.encoder)
|
| 761 |
+
self._enc_conv_idx = [0]
|
| 762 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
ldf_models/vae_wan_1d.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .tools.wan_vae_1d import WanVAE_
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VAEWanModel(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
input_dim,
|
| 12 |
+
mean_path=None,
|
| 13 |
+
std_path=None,
|
| 14 |
+
z_dim=256,
|
| 15 |
+
dim=160,
|
| 16 |
+
dec_dim=512,
|
| 17 |
+
num_res_blocks=1,
|
| 18 |
+
dropout=0.0,
|
| 19 |
+
dim_mult=[1, 1, 1],
|
| 20 |
+
temperal_downsample=[True, True],
|
| 21 |
+
vel_window=[0, 0],
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.mean_path = mean_path
|
| 27 |
+
self.std_path = std_path
|
| 28 |
+
self.input_dim = input_dim
|
| 29 |
+
self.z_dim = z_dim
|
| 30 |
+
self.dim = dim
|
| 31 |
+
self.dec_dim = dec_dim
|
| 32 |
+
self.num_res_blocks = num_res_blocks
|
| 33 |
+
self.dropout = dropout
|
| 34 |
+
self.dim_mult = dim_mult
|
| 35 |
+
self.temperal_downsample = temperal_downsample
|
| 36 |
+
self.vel_window = vel_window
|
| 37 |
+
self.RECONS_LOSS = nn.SmoothL1Loss()
|
| 38 |
+
self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
|
| 39 |
+
self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5)
|
| 40 |
+
self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)
|
| 41 |
+
|
| 42 |
+
if self.mean_path is not None:
|
| 43 |
+
self.register_buffer(
|
| 44 |
+
"mean", torch.from_numpy(np.load(self.mean_path)).float()
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
self.register_buffer("mean", torch.zeros(input_dim))
|
| 48 |
+
|
| 49 |
+
if self.std_path is not None:
|
| 50 |
+
self.register_buffer(
|
| 51 |
+
"std", torch.from_numpy(np.load(self.std_path)).float()
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
self.register_buffer("std", torch.ones(input_dim))
|
| 55 |
+
|
| 56 |
+
self.model = WanVAE_(
|
| 57 |
+
input_dim=self.input_dim,
|
| 58 |
+
dim=self.dim,
|
| 59 |
+
dec_dim=self.dec_dim,
|
| 60 |
+
z_dim=self.z_dim,
|
| 61 |
+
dim_mult=self.dim_mult,
|
| 62 |
+
num_res_blocks=self.num_res_blocks,
|
| 63 |
+
temperal_downsample=self.temperal_downsample,
|
| 64 |
+
dropout=self.dropout,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
downsample_factor = 1
|
| 68 |
+
for flag in self.temperal_downsample:
|
| 69 |
+
if flag:
|
| 70 |
+
downsample_factor *= 2
|
| 71 |
+
self.downsample_factor = downsample_factor
|
| 72 |
+
|
| 73 |
+
def preprocess(self, x):
|
| 74 |
+
# (bs, T, C) -> (bs, C, T)
|
| 75 |
+
x = x.permute(0, 2, 1)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
def postprocess(self, x):
|
| 79 |
+
# (bs, C, T) -> (bs, T, C)
|
| 80 |
+
x = x.permute(0, 2, 1)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
features = x["feature"]
|
| 85 |
+
feature_length = x["feature_length"]
|
| 86 |
+
features = (features - self.mean) / self.std
|
| 87 |
+
# create mask based on feature_length
|
| 88 |
+
batch_size, seq_len = features.shape[:2]
|
| 89 |
+
mask = torch.zeros(
|
| 90 |
+
batch_size, seq_len, dtype=torch.bool, device=features.device
|
| 91 |
+
)
|
| 92 |
+
for i in range(batch_size):
|
| 93 |
+
mask[i, : feature_length[i]] = True
|
| 94 |
+
|
| 95 |
+
x_in = self.preprocess(features) # (bs, input_dim, T)
|
| 96 |
+
mu, log_var = self.model.encode(
|
| 97 |
+
x_in, scale=[0, 1], return_dist=True
|
| 98 |
+
) # (bs, z_dim, T)
|
| 99 |
+
z = self.model.reparameterize(mu, log_var)
|
| 100 |
+
x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T)
|
| 101 |
+
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
|
| 102 |
+
|
| 103 |
+
if x_out.size(1) != features.size(1):
|
| 104 |
+
min_len = min(x_out.size(1), features.size(1))
|
| 105 |
+
x_out = x_out[:, :min_len, :]
|
| 106 |
+
features = features[:, :min_len, :]
|
| 107 |
+
mask = mask[:, :min_len]
|
| 108 |
+
|
| 109 |
+
mask_expanded = mask.unsqueeze(-1)
|
| 110 |
+
x_out_masked = x_out * mask_expanded
|
| 111 |
+
features_masked = features * mask_expanded
|
| 112 |
+
loss_recons = self.RECONS_LOSS(x_out_masked, features_masked)
|
| 113 |
+
vel_start = self.vel_window[0]
|
| 114 |
+
vel_end = self.vel_window[1]
|
| 115 |
+
loss_vel = self.RECONS_LOSS(
|
| 116 |
+
x_out_masked[..., vel_start:vel_end],
|
| 117 |
+
features_masked[..., vel_start:vel_end],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Compute KL divergence loss
|
| 121 |
+
# KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
| 122 |
+
# log_var = log(sigma^2), so we can use it directly
|
| 123 |
+
|
| 124 |
+
# Build mask for latent space
|
| 125 |
+
T_latent = mu.size(2)
|
| 126 |
+
mask_downsampled = torch.zeros(
|
| 127 |
+
batch_size, T_latent, dtype=torch.bool, device=features.device
|
| 128 |
+
)
|
| 129 |
+
for i in range(batch_size):
|
| 130 |
+
latent_length = (
|
| 131 |
+
feature_length[i] + self.downsample_factor - 1
|
| 132 |
+
) // self.downsample_factor
|
| 133 |
+
mask_downsampled[i, :latent_length] = True
|
| 134 |
+
mask_latent = mask_downsampled.unsqueeze(1) # (B, 1, T_latent)
|
| 135 |
+
|
| 136 |
+
# Compute KL loss per element
|
| 137 |
+
kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
|
| 138 |
+
# Apply mask: only compute KL loss for valid timesteps
|
| 139 |
+
kl_masked = kl_per_element * mask_latent
|
| 140 |
+
# Sum over all dimensions and normalize by the number of valid elements
|
| 141 |
+
kl_loss = torch.sum(kl_masked) / (
|
| 142 |
+
torch.sum(mask_downsampled) * mu.size(1)
|
| 143 |
+
) # normalize by valid timesteps * latent_dim
|
| 144 |
+
|
| 145 |
+
# Total loss
|
| 146 |
+
total_loss = (
|
| 147 |
+
self.LAMBDA_FEATURE * loss_recons
|
| 148 |
+
+ self.LAMBDA_VELOCITY * loss_vel
|
| 149 |
+
+ self.LAMBDA_KL * kl_loss
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
loss_dict = {}
|
| 153 |
+
loss_dict["total"] = total_loss
|
| 154 |
+
loss_dict["recons"] = loss_recons
|
| 155 |
+
loss_dict["velocity"] = loss_vel
|
| 156 |
+
loss_dict["kl"] = kl_loss
|
| 157 |
+
|
| 158 |
+
return loss_dict
|
| 159 |
+
|
| 160 |
+
def encode(self, x):
|
| 161 |
+
x = (x - self.mean) / self.std
|
| 162 |
+
x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T)
|
| 163 |
+
mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T)
|
| 164 |
+
mu = self.postprocess(mu) # (bs, T, z_dim)
|
| 165 |
+
return mu
|
| 166 |
+
|
| 167 |
+
def decode(self, mu):
|
| 168 |
+
mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T)
|
| 169 |
+
x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T)
|
| 170 |
+
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
|
| 171 |
+
x_out = x_out * self.std + self.mean
|
| 172 |
+
return x_out
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def stream_encode(self, x, first_chunk=True):
|
| 176 |
+
x = (x - self.mean) / self.std
|
| 177 |
+
x_in = self.preprocess(x) # (bs, input_dim, T)
|
| 178 |
+
mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
|
| 179 |
+
mu = self.postprocess(mu) # (bs, T, z_dim)
|
| 180 |
+
return mu
|
| 181 |
+
|
| 182 |
+
@torch.no_grad()
|
| 183 |
+
def stream_decode(self, mu, first_chunk=True):
|
| 184 |
+
mu_in = self.preprocess(mu) # (bs, z_dim, T)
|
| 185 |
+
x_decoder = self.model.stream_decode(
|
| 186 |
+
mu_in, first_chunk=first_chunk, scale=[0, 1]
|
| 187 |
+
)
|
| 188 |
+
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
|
| 189 |
+
x_out = x_out * self.std + self.mean
|
| 190 |
+
return x_out
|
| 191 |
+
|
| 192 |
+
def clear_cache(self):
|
| 193 |
+
self.model.clear_cache()
|
| 194 |
+
|
| 195 |
+
def generate(self, x):
|
| 196 |
+
features = x["feature"]
|
| 197 |
+
feature_length = x["feature_length"]
|
| 198 |
+
y_hat = self.decode(self.encode(features))
|
| 199 |
+
|
| 200 |
+
y_hat_out = []
|
| 201 |
+
|
| 202 |
+
for i in range(y_hat.shape[0]):
|
| 203 |
+
# cut off the padding and align lengths
|
| 204 |
+
valid_len = (
|
| 205 |
+
feature_length[i] - 1
|
| 206 |
+
) // self.downsample_factor * self.downsample_factor + 1
|
| 207 |
+
# Make sure both have the same length (take minimum)
|
| 208 |
+
y_hat_out.append(y_hat[i, :valid_len, :])
|
| 209 |
+
|
| 210 |
+
out = {}
|
| 211 |
+
out["generated"] = y_hat_out
|
| 212 |
+
return out
|
ldf_utils/__init__.py
ADDED
|
File without changes
|
ldf_utils/initialize.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from importlib import import_module
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from lightning.pytorch.utilities import rank_zero_info
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Config:
|
| 16 |
+
def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None):
|
| 17 |
+
self.config = OmegaConf.create({})
|
| 18 |
+
|
| 19 |
+
# Load main config if provided
|
| 20 |
+
if config_path:
|
| 21 |
+
self.load_yaml(config_path)
|
| 22 |
+
if override_args:
|
| 23 |
+
self.override_config(override_args)
|
| 24 |
+
|
| 25 |
+
def load_yaml(self, config_path: str):
|
| 26 |
+
"""Load YAML configuration file"""
|
| 27 |
+
loaded_config = OmegaConf.load(config_path)
|
| 28 |
+
self.config = OmegaConf.merge(self.config, loaded_config)
|
| 29 |
+
|
| 30 |
+
def override_config(self, override_args: Dict[str, Any]):
|
| 31 |
+
"""Handle command line override arguments"""
|
| 32 |
+
dotlist = []
|
| 33 |
+
for key, value in override_args.items():
|
| 34 |
+
# Handle values that might be converted types but should be strings for paths
|
| 35 |
+
# The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong
|
| 36 |
+
# or splitting logic is wrong.
|
| 37 |
+
# Using OmegaConf's standard from_dotlist approach is safest.
|
| 38 |
+
# It expects "key=value" strings.
|
| 39 |
+
# We need to be careful about value conversion.
|
| 40 |
+
# Our _convert_value handles basic types.
|
| 41 |
+
|
| 42 |
+
val = self._convert_value(value)
|
| 43 |
+
# If val is a string, we keep it as is.
|
| 44 |
+
# OmegaConf.from_dotlist parses the string again if we pass "key=value".
|
| 45 |
+
# But we can construct a config from dict and merge.
|
| 46 |
+
|
| 47 |
+
# If we use OmegaConf.update(self.config, key, val) it should work for dotted keys.
|
| 48 |
+
# However, `update` takes a key and value.
|
| 49 |
+
OmegaConf.update(self.config, key, val)
|
| 50 |
+
|
| 51 |
+
def _convert_value(self, value: str) -> Any:
|
| 52 |
+
"""Convert string value to appropriate type"""
|
| 53 |
+
if value.lower() == "true":
|
| 54 |
+
return True
|
| 55 |
+
elif value.lower() == "false":
|
| 56 |
+
return False
|
| 57 |
+
elif value.lower() == "null":
|
| 58 |
+
return None
|
| 59 |
+
try:
|
| 60 |
+
return int(value)
|
| 61 |
+
except ValueError:
|
| 62 |
+
try:
|
| 63 |
+
return float(value)
|
| 64 |
+
except ValueError:
|
| 65 |
+
return value
|
| 66 |
+
|
| 67 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 68 |
+
"""Get configuration value"""
|
| 69 |
+
return OmegaConf.select(self.config, key, default=default)
|
| 70 |
+
|
| 71 |
+
def __getattr__(self, name: str) -> Any:
|
| 72 |
+
"""Support dot notation access"""
|
| 73 |
+
return self.config[name]
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, key: str) -> Any:
|
| 76 |
+
"""Support dictionary-like access"""
|
| 77 |
+
return self.config[key]
|
| 78 |
+
|
| 79 |
+
def export_config(self, path: str):
|
| 80 |
+
"""Export current configuration to file"""
|
| 81 |
+
OmegaConf.save(self.config, path)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def parse_args():
|
| 85 |
+
"""Parse command line arguments"""
|
| 86 |
+
parser = argparse.ArgumentParser()
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--config", type=str, required=True, help="Path to config file"
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--override", type=str, nargs="+", help="Override config values (key=value)"
|
| 92 |
+
)
|
| 93 |
+
return parser.parse_args()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_config(
|
| 97 |
+
config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None
|
| 98 |
+
) -> Config:
|
| 99 |
+
"""Load configuration"""
|
| 100 |
+
if config_path is None:
|
| 101 |
+
args = parse_args()
|
| 102 |
+
config_path = args.config
|
| 103 |
+
if args.override:
|
| 104 |
+
override_args = {}
|
| 105 |
+
for override in args.override:
|
| 106 |
+
key, value = override.split("=", 1)
|
| 107 |
+
override_args[key.strip()] = value.strip()
|
| 108 |
+
|
| 109 |
+
return Config(config_path, override_args)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def instantiate(target, cfg=None, hfstyle=False, **init_args):
|
| 113 |
+
module_name, class_name = target.rsplit(".", 1)
|
| 114 |
+
module = import_module(module_name)
|
| 115 |
+
class_ = getattr(module, class_name)
|
| 116 |
+
if cfg is None:
|
| 117 |
+
return class_(**init_args)
|
| 118 |
+
else:
|
| 119 |
+
if hfstyle:
|
| 120 |
+
config_class = class_.config_class
|
| 121 |
+
cfg = config_class(config_obj=cfg)
|
| 122 |
+
return class_(cfg, **init_args)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_function(target):
|
| 126 |
+
module_name, function_name = target.rsplit(".", 1)
|
| 127 |
+
module = import_module(module_name)
|
| 128 |
+
function_ = getattr(module, function_name)
|
| 129 |
+
return function_
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def save_config_and_codes(config, save_dir):
|
| 133 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 134 |
+
sanity_check_dir = os.path.join(save_dir, "sanity_check")
|
| 135 |
+
os.makedirs(sanity_check_dir, exist_ok=True)
|
| 136 |
+
with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f:
|
| 137 |
+
OmegaConf.save(config.config, f)
|
| 138 |
+
current_dir = Path.cwd()
|
| 139 |
+
exclude_dir = current_dir / "outputs"
|
| 140 |
+
for py_file in current_dir.rglob("*.py"):
|
| 141 |
+
if exclude_dir in py_file.parents:
|
| 142 |
+
continue
|
| 143 |
+
dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
|
| 144 |
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
shutil.copy(py_file, dest_path)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def print_model_size(model):
|
| 149 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 150 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 151 |
+
rank_zero_info(f"Total parameters: {total_params:,}")
|
| 152 |
+
rank_zero_info(f"Trainable parameters: {trainable_params:,}")
|
| 153 |
+
rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers):
|
| 157 |
+
"""Compare differences between state_dict and parameters"""
|
| 158 |
+
# Get all keys in state_dict
|
| 159 |
+
state_dict_keys = set(state_dict.keys())
|
| 160 |
+
|
| 161 |
+
# Get all keys in named_parameters
|
| 162 |
+
named_params_keys = set(name for name, _ in named_parameters)
|
| 163 |
+
|
| 164 |
+
# Find keys that only exist in state_dict
|
| 165 |
+
only_in_state_dict = state_dict_keys - named_params_keys
|
| 166 |
+
|
| 167 |
+
# Find keys that only exist in named_parameters
|
| 168 |
+
only_in_named_params = named_params_keys - state_dict_keys
|
| 169 |
+
|
| 170 |
+
# Print results
|
| 171 |
+
if only_in_state_dict:
|
| 172 |
+
print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}")
|
| 173 |
+
|
| 174 |
+
if only_in_named_params:
|
| 175 |
+
print(
|
| 176 |
+
f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if not only_in_state_dict and not only_in_named_params:
|
| 180 |
+
print("All parameters match between state_dict and named_parameters")
|
| 181 |
+
|
| 182 |
+
# Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean)
|
| 183 |
+
named_buffers_keys = set(name for name, _ in named_buffers)
|
| 184 |
+
buffers_only = state_dict_keys - named_params_keys - named_buffers_keys
|
| 185 |
+
|
| 186 |
+
if buffers_only:
|
| 187 |
+
print(
|
| 188 |
+
f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print(f"Total state_dict items: {len(state_dict_keys)}")
|
| 192 |
+
print(f"Total named_parameters: {len(named_params_keys)}")
|
| 193 |
+
print(f"Total named_buffers: {len(named_buffers_keys)}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _resolve_global_rank() -> int:
|
| 197 |
+
"""Resolve the global rank from environment variables."""
|
| 198 |
+
for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"):
|
| 199 |
+
if key in os.environ:
|
| 200 |
+
try:
|
| 201 |
+
return int(os.environ[key])
|
| 202 |
+
except ValueError:
|
| 203 |
+
continue
|
| 204 |
+
return 0
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str:
|
| 208 |
+
"""
|
| 209 |
+
Get a synchronized run time across all processes.
|
| 210 |
+
|
| 211 |
+
This function ensures all processes (both in distributed training and multi-process
|
| 212 |
+
scenarios) use the same timestamp for output directories and experiment tracking.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
base_dir: Base directory for output files
|
| 216 |
+
env_key: Environment variable key to cache the run time
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Synchronized timestamp string in format YYYYMMDD_HHMMSS
|
| 220 |
+
"""
|
| 221 |
+
cached = os.environ.get(env_key)
|
| 222 |
+
if cached:
|
| 223 |
+
return cached
|
| 224 |
+
|
| 225 |
+
timestamp_format = "%Y%m%d_%H%M%S"
|
| 226 |
+
|
| 227 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 228 |
+
if torch.distributed.get_rank() == 0:
|
| 229 |
+
run_time = datetime.now().strftime(timestamp_format)
|
| 230 |
+
else:
|
| 231 |
+
run_time = None
|
| 232 |
+
container = [run_time]
|
| 233 |
+
torch.distributed.broadcast_object_list(container, src=0)
|
| 234 |
+
run_time = container[0]
|
| 235 |
+
if run_time is None:
|
| 236 |
+
raise RuntimeError("Failed to synchronize run time across ranks.")
|
| 237 |
+
os.environ[env_key] = run_time
|
| 238 |
+
return run_time
|
| 239 |
+
|
| 240 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 241 |
+
sync_token = (
|
| 242 |
+
os.environ.get("SLURM_JOB_ID")
|
| 243 |
+
or os.environ.get("TORCHELASTIC_RUN_ID")
|
| 244 |
+
or os.environ.get("JOB_ID")
|
| 245 |
+
or "default"
|
| 246 |
+
)
|
| 247 |
+
sync_dir = os.path.join(base_dir, ".run_time_sync")
|
| 248 |
+
os.makedirs(sync_dir, exist_ok=True)
|
| 249 |
+
sync_file = os.path.join(sync_dir, f"{sync_token}.txt")
|
| 250 |
+
|
| 251 |
+
global_rank = _resolve_global_rank()
|
| 252 |
+
if global_rank == 0:
|
| 253 |
+
# Remove the sync file if it exists to avoid stale reads by other ranks
|
| 254 |
+
if os.path.exists(sync_file):
|
| 255 |
+
try:
|
| 256 |
+
os.remove(sync_file)
|
| 257 |
+
except OSError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
run_time = datetime.now().strftime(timestamp_format)
|
| 261 |
+
with open(sync_file, "w", encoding="utf-8") as f:
|
| 262 |
+
f.write(run_time)
|
| 263 |
+
else:
|
| 264 |
+
timeout = time.monotonic() + 1200.0
|
| 265 |
+
while True:
|
| 266 |
+
if os.path.exists(sync_file):
|
| 267 |
+
try:
|
| 268 |
+
with open(sync_file, "r", encoding="utf-8") as f:
|
| 269 |
+
run_time = f.read().strip()
|
| 270 |
+
# Check if the timestamp is fresh (within 60 seconds)
|
| 271 |
+
# This prevents reading a stale timestamp from a previous run
|
| 272 |
+
dt = datetime.strptime(run_time, timestamp_format)
|
| 273 |
+
if abs((datetime.now() - dt).total_seconds()) < 60:
|
| 274 |
+
break
|
| 275 |
+
except (ValueError, OSError):
|
| 276 |
+
# File might be empty or partially written, or format mismatch
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
if time.monotonic() > timeout:
|
| 280 |
+
raise TimeoutError(
|
| 281 |
+
"Timed out waiting for rank 0 to write synchronized timestamp."
|
| 282 |
+
)
|
| 283 |
+
time.sleep(0.1)
|
| 284 |
+
|
| 285 |
+
os.environ[env_key] = run_time
|
| 286 |
+
return run_time
|
ldf_utils/math/__init__.py
ADDED
|
File without changes
|
ldf_utils/math/quaternion.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
| 12 |
+
|
| 13 |
+
_FLOAT_EPS = np.finfo(np.float64).eps
|
| 14 |
+
|
| 15 |
+
# PyTorch-backed implementations
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def qinv(q):
|
| 19 |
+
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
|
| 20 |
+
mask = torch.ones_like(q)
|
| 21 |
+
mask[..., 1:] = -mask[..., 1:]
|
| 22 |
+
return q * mask
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def qinv_np(q):
|
| 26 |
+
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
|
| 27 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def qnormalize(q):
|
| 31 |
+
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
|
| 32 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def qmul(q, r):
|
| 36 |
+
"""
|
| 37 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
| 38 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
| 39 |
+
Returns q*r as a tensor of shape (*, 4).
|
| 40 |
+
"""
|
| 41 |
+
assert q.shape[-1] == 4
|
| 42 |
+
assert r.shape[-1] == 4
|
| 43 |
+
|
| 44 |
+
original_shape = q.shape
|
| 45 |
+
|
| 46 |
+
# Compute outer product
|
| 47 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
| 48 |
+
|
| 49 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
| 50 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
| 51 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
| 52 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
| 53 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def qrot(q, v):
|
| 57 |
+
"""
|
| 58 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
| 59 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
| 60 |
+
where * denotes any number of dimensions.
|
| 61 |
+
Returns a tensor of shape (*, 3).
|
| 62 |
+
"""
|
| 63 |
+
assert q.shape[-1] == 4
|
| 64 |
+
assert v.shape[-1] == 3
|
| 65 |
+
assert q.shape[:-1] == v.shape[:-1]
|
| 66 |
+
|
| 67 |
+
original_shape = list(v.shape)
|
| 68 |
+
# print(q.shape)
|
| 69 |
+
q = q.contiguous().view(-1, 4)
|
| 70 |
+
v = v.contiguous().view(-1, 3)
|
| 71 |
+
|
| 72 |
+
qvec = q[:, 1:]
|
| 73 |
+
uv = torch.cross(qvec, v, dim=1)
|
| 74 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
| 75 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def qeuler(q, order, epsilon=0, deg=True):
|
| 79 |
+
"""
|
| 80 |
+
Convert quaternion(s) q to Euler angles.
|
| 81 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
| 82 |
+
Returns a tensor of shape (*, 3).
|
| 83 |
+
"""
|
| 84 |
+
assert q.shape[-1] == 4
|
| 85 |
+
|
| 86 |
+
original_shape = list(q.shape)
|
| 87 |
+
original_shape[-1] = 3
|
| 88 |
+
q = q.view(-1, 4)
|
| 89 |
+
|
| 90 |
+
q0 = q[:, 0]
|
| 91 |
+
q1 = q[:, 1]
|
| 92 |
+
q2 = q[:, 2]
|
| 93 |
+
q3 = q[:, 3]
|
| 94 |
+
|
| 95 |
+
if order == "xyz":
|
| 96 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 97 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
| 98 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 99 |
+
elif order == "yzx":
|
| 100 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 101 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 102 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
| 103 |
+
elif order == "zxy":
|
| 104 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
| 105 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 106 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 107 |
+
elif order == "xzy":
|
| 108 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 109 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 110 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
| 111 |
+
elif order == "yxz":
|
| 112 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
| 113 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 114 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 115 |
+
elif order == "zyx":
|
| 116 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 117 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
| 118 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 119 |
+
else:
|
| 120 |
+
raise
|
| 121 |
+
|
| 122 |
+
if deg:
|
| 123 |
+
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
|
| 124 |
+
else:
|
| 125 |
+
return torch.stack((x, y, z), dim=1).view(original_shape)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Numpy-backed implementations
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def qmul_np(q, r):
|
| 132 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 133 |
+
r = torch.from_numpy(r).contiguous().float()
|
| 134 |
+
return qmul(q, r).numpy()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def qrot_np(q, v):
|
| 138 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 139 |
+
v = torch.from_numpy(v).contiguous().float()
|
| 140 |
+
return qrot(q, v).numpy()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
| 144 |
+
if use_gpu:
|
| 145 |
+
q = torch.from_numpy(q).cuda().float()
|
| 146 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
| 147 |
+
else:
|
| 148 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 149 |
+
return qeuler(q, order, epsilon).numpy()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def qfix(q):
|
| 153 |
+
"""
|
| 154 |
+
Enforce quaternion continuity across the time dimension by selecting
|
| 155 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
| 156 |
+
between two consecutive frames.
|
| 157 |
+
|
| 158 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
| 159 |
+
Returns a tensor of the same shape.
|
| 160 |
+
"""
|
| 161 |
+
assert len(q.shape) == 3
|
| 162 |
+
assert q.shape[-1] == 4
|
| 163 |
+
|
| 164 |
+
result = q.copy()
|
| 165 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
| 166 |
+
mask = dot_products < 0
|
| 167 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
| 168 |
+
result[1:][mask] *= -1
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def euler2quat(e, order, deg=True):
|
| 173 |
+
"""
|
| 174 |
+
Convert Euler angles to quaternions.
|
| 175 |
+
"""
|
| 176 |
+
assert e.shape[-1] == 3
|
| 177 |
+
|
| 178 |
+
original_shape = list(e.shape)
|
| 179 |
+
original_shape[-1] = 4
|
| 180 |
+
|
| 181 |
+
e = e.view(-1, 3)
|
| 182 |
+
|
| 183 |
+
# if euler angles in degrees
|
| 184 |
+
if deg:
|
| 185 |
+
e = e * np.pi / 180.0
|
| 186 |
+
|
| 187 |
+
x = e[:, 0]
|
| 188 |
+
y = e[:, 1]
|
| 189 |
+
z = e[:, 2]
|
| 190 |
+
|
| 191 |
+
rx = torch.stack(
|
| 192 |
+
(torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)),
|
| 193 |
+
dim=1,
|
| 194 |
+
)
|
| 195 |
+
ry = torch.stack(
|
| 196 |
+
(torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)),
|
| 197 |
+
dim=1,
|
| 198 |
+
)
|
| 199 |
+
rz = torch.stack(
|
| 200 |
+
(torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)),
|
| 201 |
+
dim=1,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
result = None
|
| 205 |
+
for coord in order:
|
| 206 |
+
if coord == "x":
|
| 207 |
+
r = rx
|
| 208 |
+
elif coord == "y":
|
| 209 |
+
r = ry
|
| 210 |
+
elif coord == "z":
|
| 211 |
+
r = rz
|
| 212 |
+
else:
|
| 213 |
+
raise
|
| 214 |
+
if result is None:
|
| 215 |
+
result = r
|
| 216 |
+
else:
|
| 217 |
+
result = qmul(result, r)
|
| 218 |
+
|
| 219 |
+
# Reverse antipodal representation to have a non-negative "w"
|
| 220 |
+
if order in ["xyz", "yzx", "zxy"]:
|
| 221 |
+
result *= -1
|
| 222 |
+
|
| 223 |
+
return result.view(original_shape)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def expmap_to_quaternion(e):
|
| 227 |
+
"""
|
| 228 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
| 229 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
| 230 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
| 231 |
+
Returns a tensor of shape (*, 4).
|
| 232 |
+
"""
|
| 233 |
+
assert e.shape[-1] == 3
|
| 234 |
+
|
| 235 |
+
original_shape = list(e.shape)
|
| 236 |
+
original_shape[-1] = 4
|
| 237 |
+
e = e.reshape(-1, 3)
|
| 238 |
+
|
| 239 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
| 240 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
| 241 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
| 242 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def euler_to_quaternion(e, order):
|
| 246 |
+
"""
|
| 247 |
+
Convert Euler angles to quaternions.
|
| 248 |
+
"""
|
| 249 |
+
assert e.shape[-1] == 3
|
| 250 |
+
|
| 251 |
+
original_shape = list(e.shape)
|
| 252 |
+
original_shape[-1] = 4
|
| 253 |
+
|
| 254 |
+
e = e.reshape(-1, 3)
|
| 255 |
+
|
| 256 |
+
x = e[:, 0]
|
| 257 |
+
y = e[:, 1]
|
| 258 |
+
z = e[:, 2]
|
| 259 |
+
|
| 260 |
+
rx = np.stack(
|
| 261 |
+
(np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1
|
| 262 |
+
)
|
| 263 |
+
ry = np.stack(
|
| 264 |
+
(np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1
|
| 265 |
+
)
|
| 266 |
+
rz = np.stack(
|
| 267 |
+
(np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
result = None
|
| 271 |
+
for coord in order:
|
| 272 |
+
if coord == "x":
|
| 273 |
+
r = rx
|
| 274 |
+
elif coord == "y":
|
| 275 |
+
r = ry
|
| 276 |
+
elif coord == "z":
|
| 277 |
+
r = rz
|
| 278 |
+
else:
|
| 279 |
+
raise
|
| 280 |
+
if result is None:
|
| 281 |
+
result = r
|
| 282 |
+
else:
|
| 283 |
+
result = qmul_np(result, r)
|
| 284 |
+
|
| 285 |
+
# Reverse antipodal representation to have a non-negative "w"
|
| 286 |
+
if order in ["xyz", "yzx", "zxy"]:
|
| 287 |
+
result *= -1
|
| 288 |
+
|
| 289 |
+
return result.reshape(original_shape)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def quaternion_to_matrix(quaternions):
|
| 293 |
+
"""
|
| 294 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 295 |
+
Args:
|
| 296 |
+
quaternions: quaternions with real part first,
|
| 297 |
+
as tensor of shape (..., 4).
|
| 298 |
+
Returns:
|
| 299 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 300 |
+
"""
|
| 301 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 302 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 303 |
+
|
| 304 |
+
o = torch.stack(
|
| 305 |
+
(
|
| 306 |
+
1 - two_s * (j * j + k * k),
|
| 307 |
+
two_s * (i * j - k * r),
|
| 308 |
+
two_s * (i * k + j * r),
|
| 309 |
+
two_s * (i * j + k * r),
|
| 310 |
+
1 - two_s * (i * i + k * k),
|
| 311 |
+
two_s * (j * k - i * r),
|
| 312 |
+
two_s * (i * k - j * r),
|
| 313 |
+
two_s * (j * k + i * r),
|
| 314 |
+
1 - two_s * (i * i + j * j),
|
| 315 |
+
),
|
| 316 |
+
-1,
|
| 317 |
+
)
|
| 318 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def quaternion_to_matrix_np(quaternions):
|
| 322 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
| 323 |
+
return quaternion_to_matrix(q).numpy()
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def quaternion_to_cont6d_np(quaternions):
|
| 327 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
| 328 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
| 329 |
+
return cont_6d
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def quaternion_to_cont6d(quaternions):
|
| 333 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
| 334 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
| 335 |
+
return cont_6d
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def cont6d_to_matrix(cont6d):
|
| 339 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
| 340 |
+
x_raw = cont6d[..., 0:3]
|
| 341 |
+
y_raw = cont6d[..., 3:6]
|
| 342 |
+
|
| 343 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
| 344 |
+
z = torch.cross(x, y_raw, dim=-1)
|
| 345 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
| 346 |
+
|
| 347 |
+
y = torch.cross(z, x, dim=-1)
|
| 348 |
+
|
| 349 |
+
x = x[..., None]
|
| 350 |
+
y = y[..., None]
|
| 351 |
+
z = z[..., None]
|
| 352 |
+
|
| 353 |
+
mat = torch.cat([x, y, z], dim=-1)
|
| 354 |
+
return mat
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def cont6d_to_matrix_np(cont6d):
|
| 358 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
| 359 |
+
return cont6d_to_matrix(q).numpy()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def qpow(q0, t, dtype=torch.float):
|
| 363 |
+
"""q0 : tensor of quaternions
|
| 364 |
+
t: tensor of powers
|
| 365 |
+
"""
|
| 366 |
+
q0 = qnormalize(q0)
|
| 367 |
+
theta0 = torch.acos(q0[..., 0])
|
| 368 |
+
|
| 369 |
+
# if theta0 is close to zero, add epsilon to avoid NaNs
|
| 370 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
| 371 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
| 372 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
| 373 |
+
|
| 374 |
+
if isinstance(t, torch.Tensor):
|
| 375 |
+
q = torch.zeros(t.shape + q0.shape)
|
| 376 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
| 377 |
+
else: # if t is a number
|
| 378 |
+
q = torch.zeros(q0.shape)
|
| 379 |
+
theta = t * theta0
|
| 380 |
+
|
| 381 |
+
q[..., 0] = torch.cos(theta)
|
| 382 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
| 383 |
+
|
| 384 |
+
return q.to(dtype)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def qslerp(q0, q1, t):
|
| 388 |
+
"""
|
| 389 |
+
q0: starting quaternion
|
| 390 |
+
q1: ending quaternion
|
| 391 |
+
t: array of points along the way
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Tensor of Slerps: t.shape + q0.shape
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
q0 = qnormalize(q0)
|
| 398 |
+
q1 = qnormalize(q1)
|
| 399 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
| 400 |
+
|
| 401 |
+
return qmul(
|
| 402 |
+
q_,
|
| 403 |
+
q0.contiguous()
|
| 404 |
+
.view(torch.Size([1] * len(t.shape)) + q0.shape)
|
| 405 |
+
.expand(t.shape + q0.shape)
|
| 406 |
+
.contiguous(),
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def qbetween(v0, v1):
|
| 411 |
+
"""
|
| 412 |
+
find the quaternion used to rotate v0 to v1
|
| 413 |
+
"""
|
| 414 |
+
assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
|
| 415 |
+
assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
|
| 416 |
+
|
| 417 |
+
v = torch.cross(v0, v1)
|
| 418 |
+
w = torch.sqrt(
|
| 419 |
+
(v0**2).sum(dim=-1, keepdim=True) * (v1**2).sum(dim=-1, keepdim=True)
|
| 420 |
+
) + (v0 * v1).sum(dim=-1, keepdim=True)
|
| 421 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def qbetween_np(v0, v1):
|
| 425 |
+
"""
|
| 426 |
+
find the quaternion used to rotate v0 to v1
|
| 427 |
+
"""
|
| 428 |
+
assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
|
| 429 |
+
assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
|
| 430 |
+
|
| 431 |
+
v0 = torch.from_numpy(v0).float()
|
| 432 |
+
v1 = torch.from_numpy(v1).float()
|
| 433 |
+
return qbetween(v0, v1).numpy()
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def lerp(p0, p1, t):
|
| 437 |
+
if not isinstance(t, torch.Tensor):
|
| 438 |
+
t = torch.Tensor([t])
|
| 439 |
+
|
| 440 |
+
new_shape = t.shape + p0.shape
|
| 441 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
| 442 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
| 443 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
| 444 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
| 445 |
+
t = t.view(new_view_t).expand(new_shape)
|
| 446 |
+
|
| 447 |
+
return p0 + t * (p1 - p0)
|
ldf_utils/motion_process.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from ldf_utils.math.quaternion import *
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Motion data structure:
|
| 9 |
+
(B: batch size)
|
| 10 |
+
root_rot_velocity (B, seq_len, 1)
|
| 11 |
+
root_linear_velocity (B, seq_len, 2)
|
| 12 |
+
root_y (B, seq_len, 1)
|
| 13 |
+
ric_data (B, seq_len, (joint_num - 1)*3)
|
| 14 |
+
rot_data (B, seq_len, (joint_num - 1)*6)
|
| 15 |
+
local_velocity (B, seq_len, joint_num*3)
|
| 16 |
+
foot contact (B, seq_len, 4)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def recover_root_rot_pos(data):
|
| 21 |
+
# recover root rotation and position
|
| 22 |
+
rot_vel = data[..., 0]
|
| 23 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
| 24 |
+
"""Get Y-axis rotation from rotation velocity"""
|
| 25 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
| 26 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
| 27 |
+
|
| 28 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
| 29 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
| 30 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
| 31 |
+
|
| 32 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
| 33 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
| 34 |
+
"""Add Y-axis rotation to root position"""
|
| 35 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
| 36 |
+
|
| 37 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
| 38 |
+
|
| 39 |
+
r_pos[..., 1] = data[..., 3]
|
| 40 |
+
return r_rot_quat, r_pos
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def recover_joint_positions_263(data: np.ndarray, joints_num) -> np.ndarray:
|
| 44 |
+
"""
|
| 45 |
+
Recovers 3D joint positions from the rotation-invariant local positions (ric_data).
|
| 46 |
+
This is the most direct way to get the skeleton for animation.
|
| 47 |
+
"""
|
| 48 |
+
feature_vec = torch.from_numpy(data).unsqueeze(0).float()
|
| 49 |
+
r_rot_quat, r_pos = recover_root_rot_pos(feature_vec)
|
| 50 |
+
positions = feature_vec[..., 4 : (joints_num - 1) * 3 + 4]
|
| 51 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
| 52 |
+
"""Add Y-axis rotation to local joints"""
|
| 53 |
+
positions = qrot(
|
| 54 |
+
qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions
|
| 55 |
+
)
|
| 56 |
+
"""Add root XZ to joints"""
|
| 57 |
+
positions[..., 0] += r_pos[..., 0:1]
|
| 58 |
+
positions[..., 2] += r_pos[..., 2:3]
|
| 59 |
+
"""Concatenate root and joints"""
|
| 60 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
| 61 |
+
joints_np = positions.squeeze(0).detach().cpu().numpy()
|
| 62 |
+
return joints_np
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class StreamJointRecovery263:
|
| 66 |
+
"""
|
| 67 |
+
Stream version of recover_joint_positions_263 that processes one frame at a time.
|
| 68 |
+
Maintains cumulative state for rotation angles and positions.
|
| 69 |
+
|
| 70 |
+
Key insight: The batch version uses PREVIOUS frame's velocity for the current frame,
|
| 71 |
+
so we need to delay the velocity application by one frame.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, joints_num: int):
|
| 75 |
+
self.joints_num = joints_num
|
| 76 |
+
self.reset()
|
| 77 |
+
|
| 78 |
+
def reset(self):
|
| 79 |
+
"""Reset the accumulated state"""
|
| 80 |
+
self.r_rot_ang_accum = 0.0
|
| 81 |
+
self.r_pos_accum = np.array([0.0, 0.0, 0.0])
|
| 82 |
+
# Store previous frame's velocities for delayed application
|
| 83 |
+
self.prev_rot_vel = 0.0
|
| 84 |
+
self.prev_linear_vel = np.array([0.0, 0.0])
|
| 85 |
+
|
| 86 |
+
def process_frame(self, frame_data: np.ndarray) -> np.ndarray:
|
| 87 |
+
"""
|
| 88 |
+
Process a single frame and return joint positions for that frame.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
frame_data: numpy array of shape (263,) for a single frame
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
joints: numpy array of shape (joints_num, 3) representing joint positions
|
| 95 |
+
"""
|
| 96 |
+
# Convert to torch tensor
|
| 97 |
+
feature_vec = torch.from_numpy(frame_data).float()
|
| 98 |
+
|
| 99 |
+
# Extract current frame's velocities (will be used in NEXT frame)
|
| 100 |
+
curr_rot_vel = feature_vec[0].item()
|
| 101 |
+
curr_linear_vel = feature_vec[1:3].numpy()
|
| 102 |
+
|
| 103 |
+
# Update accumulated rotation angle with PREVIOUS frame's velocity FIRST
|
| 104 |
+
# This matches the batch processing: r_rot_ang[i] uses rot_vel[i-1]
|
| 105 |
+
self.r_rot_ang_accum += self.prev_rot_vel
|
| 106 |
+
|
| 107 |
+
# Calculate current rotation quaternion using updated accumulated angle
|
| 108 |
+
r_rot_quat = torch.zeros(4)
|
| 109 |
+
r_rot_quat[0] = np.cos(self.r_rot_ang_accum)
|
| 110 |
+
r_rot_quat[2] = np.sin(self.r_rot_ang_accum)
|
| 111 |
+
|
| 112 |
+
# Create velocity vector with Y=0 using PREVIOUS frame's velocity
|
| 113 |
+
r_vel = np.array([self.prev_linear_vel[0], 0.0, self.prev_linear_vel[1]])
|
| 114 |
+
|
| 115 |
+
# Apply inverse rotation to velocity using CURRENT rotation
|
| 116 |
+
r_vel_torch = torch.from_numpy(r_vel).float()
|
| 117 |
+
r_vel_rotated = qrot(qinv(r_rot_quat).unsqueeze(0), r_vel_torch.unsqueeze(0))
|
| 118 |
+
r_vel_rotated = r_vel_rotated.squeeze(0).numpy()
|
| 119 |
+
|
| 120 |
+
# Update accumulated position with rotated velocity
|
| 121 |
+
self.r_pos_accum += r_vel_rotated
|
| 122 |
+
|
| 123 |
+
# Get Y position from data
|
| 124 |
+
r_pos = self.r_pos_accum.copy()
|
| 125 |
+
r_pos[1] = feature_vec[3].item()
|
| 126 |
+
|
| 127 |
+
# Extract local joint positions
|
| 128 |
+
positions = feature_vec[4 : (self.joints_num - 1) * 3 + 4]
|
| 129 |
+
positions = positions.view(-1, 3)
|
| 130 |
+
|
| 131 |
+
# Apply inverse rotation to local joints
|
| 132 |
+
r_rot_quat_expanded = (
|
| 133 |
+
qinv(r_rot_quat).unsqueeze(0).expand(positions.shape[0], 4)
|
| 134 |
+
)
|
| 135 |
+
positions = qrot(r_rot_quat_expanded, positions)
|
| 136 |
+
|
| 137 |
+
# Add root XZ to joints
|
| 138 |
+
positions[:, 0] += r_pos[0]
|
| 139 |
+
positions[:, 2] += r_pos[2]
|
| 140 |
+
|
| 141 |
+
# Concatenate root and joints
|
| 142 |
+
r_pos_torch = torch.from_numpy(r_pos).float()
|
| 143 |
+
positions = torch.cat([r_pos_torch.unsqueeze(0), positions], dim=0)
|
| 144 |
+
|
| 145 |
+
# Convert to numpy
|
| 146 |
+
joints_np = positions.detach().cpu().numpy()
|
| 147 |
+
|
| 148 |
+
# Store current velocities for next frame
|
| 149 |
+
self.prev_rot_vel = curr_rot_vel
|
| 150 |
+
self.prev_linear_vel = curr_linear_vel
|
| 151 |
+
|
| 152 |
+
return joints_np
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def accumulate_rotations(relative_rotations):
|
| 156 |
+
R_total = [relative_rotations[0]]
|
| 157 |
+
for R_rel in relative_rotations[1:]:
|
| 158 |
+
R_total.append(np.matmul(R_rel, R_total[-1]))
|
| 159 |
+
|
| 160 |
+
return np.array(R_total)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def recover_from_local_position(final_x, njoint):
|
| 164 |
+
nfrm, _ = final_x.shape
|
| 165 |
+
positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(
|
| 166 |
+
nfrm, -1, 3
|
| 167 |
+
) # frames, njoints * 3
|
| 168 |
+
velocities_root_xy_no_heading = final_x[:, :2] # frames, 2
|
| 169 |
+
global_heading_diff_rot = final_x[:, 2:8] # frames, 6
|
| 170 |
+
|
| 171 |
+
# recover global heading
|
| 172 |
+
global_heading_rot = accumulate_rotations(
|
| 173 |
+
rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
|
| 174 |
+
)
|
| 175 |
+
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
|
| 176 |
+
# add global heading to position
|
| 177 |
+
positions_with_heading = np.matmul(
|
| 178 |
+
np.repeat(inv_global_heading_rot[:, None, :, :], njoint, axis=1),
|
| 179 |
+
positions_no_heading[..., None],
|
| 180 |
+
).squeeze(-1)
|
| 181 |
+
|
| 182 |
+
# recover root translation
|
| 183 |
+
# add heading to velocities_root_xy_no_heading
|
| 184 |
+
|
| 185 |
+
velocities_root_xyz_no_heading = np.zeros(
|
| 186 |
+
(
|
| 187 |
+
velocities_root_xy_no_heading.shape[0],
|
| 188 |
+
3,
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
|
| 192 |
+
velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
|
| 193 |
+
velocities_root_xyz_no_heading[1:, :] = np.matmul(
|
| 194 |
+
inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
|
| 195 |
+
).squeeze(-1)
|
| 196 |
+
|
| 197 |
+
root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
|
| 198 |
+
|
| 199 |
+
# add root translation
|
| 200 |
+
positions_with_heading[:, :, 0] += root_translation[:, 0:1]
|
| 201 |
+
positions_with_heading[:, :, 2] += root_translation[:, 2:]
|
| 202 |
+
|
| 203 |
+
return positions_with_heading
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
| 208 |
+
b1 = F.normalize(a1, dim=-1)
|
| 209 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
| 210 |
+
b2 = F.normalize(b2, dim=-1)
|
| 211 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 212 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _copysign(a, b):
|
| 216 |
+
signs_differ = (a < 0) != (b < 0)
|
| 217 |
+
return torch.where(signs_differ, -a, a)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _sqrt_positive_part(x):
|
| 221 |
+
ret = torch.zeros_like(x)
|
| 222 |
+
positive_mask = x > 0
|
| 223 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 224 |
+
return ret
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def matrix_to_quaternion(matrix):
|
| 228 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 229 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
| 230 |
+
m00 = matrix[..., 0, 0]
|
| 231 |
+
m11 = matrix[..., 1, 1]
|
| 232 |
+
m22 = matrix[..., 2, 2]
|
| 233 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
| 234 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
| 235 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
| 236 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
| 237 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
| 238 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
| 239 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
| 240 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def quaternion_to_axis_angle(quaternions):
|
| 244 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
| 245 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
| 246 |
+
angles = 2 * half_angles
|
| 247 |
+
eps = 1e-6
|
| 248 |
+
small_angles = angles.abs() < eps
|
| 249 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 250 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 251 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 252 |
+
)
|
| 253 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 254 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 255 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 256 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 257 |
+
)
|
| 258 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def matrix_to_axis_angle(matrix):
|
| 262 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def rotations_matrix_to_smpl85(rotations_matrix, translation):
|
| 266 |
+
nfrm, njoint, _, _ = rotations_matrix.shape
|
| 267 |
+
axis_angle = (
|
| 268 |
+
matrix_to_axis_angle(torch.from_numpy(rotations_matrix))
|
| 269 |
+
.numpy()
|
| 270 |
+
.reshape(nfrm, -1)
|
| 271 |
+
)
|
| 272 |
+
smpl_85 = np.concatenate(
|
| 273 |
+
[axis_angle, np.zeros((nfrm, 6)), translation, np.zeros((nfrm, 10))], axis=-1
|
| 274 |
+
)
|
| 275 |
+
return smpl_85
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def recover_from_local_rotation(final_x, njoint):
|
| 279 |
+
nfrm, _ = final_x.shape
|
| 280 |
+
rotations_matrix = rotation_6d_to_matrix(
|
| 281 |
+
torch.from_numpy(final_x[:, 8 + 6 * njoint : 8 + 12 * njoint]).reshape(
|
| 282 |
+
nfrm, -1, 6
|
| 283 |
+
)
|
| 284 |
+
).numpy()
|
| 285 |
+
global_heading_diff_rot = final_x[:, 2:8]
|
| 286 |
+
velocities_root_xy_no_heading = final_x[:, :2]
|
| 287 |
+
positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(nfrm, -1, 3)
|
| 288 |
+
height = positions_no_heading[:, 0, 1]
|
| 289 |
+
|
| 290 |
+
global_heading_rot = accumulate_rotations(
|
| 291 |
+
rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
|
| 292 |
+
)
|
| 293 |
+
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
|
| 294 |
+
# recover root rotation
|
| 295 |
+
rotations_matrix[:, 0, ...] = np.matmul(
|
| 296 |
+
inv_global_heading_rot, rotations_matrix[:, 0, ...]
|
| 297 |
+
)
|
| 298 |
+
velocities_root_xyz_no_heading = np.zeros(
|
| 299 |
+
(
|
| 300 |
+
velocities_root_xy_no_heading.shape[0],
|
| 301 |
+
3,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
|
| 305 |
+
velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
|
| 306 |
+
velocities_root_xyz_no_heading[1:, :] = np.matmul(
|
| 307 |
+
inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
|
| 308 |
+
).squeeze(-1)
|
| 309 |
+
root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
|
| 310 |
+
root_translation[:, 1] = height
|
| 311 |
+
smpl_85 = rotations_matrix_to_smpl85(rotations_matrix, root_translation)
|
| 312 |
+
return smpl_85
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def recover_joint_positions_272(data: np.ndarray, joints_num) -> np.ndarray:
|
| 316 |
+
return recover_from_local_position(data, joints_num)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def convert_motion_to_joints(
|
| 320 |
+
motion_data: np.ndarray,
|
| 321 |
+
dim: int,
|
| 322 |
+
mean: np.ndarray = None,
|
| 323 |
+
std: np.ndarray = None,
|
| 324 |
+
joints_num=22,
|
| 325 |
+
):
|
| 326 |
+
"""
|
| 327 |
+
Convert Kx263 dim or Kx272 dim motion data to Kx22x3 joint positions.
|
| 328 |
+
Args:
|
| 329 |
+
motion_data: numpy array of shape (K, 263) or (K, 272) where K is number of frames
|
| 330 |
+
Returns:
|
| 331 |
+
joints: numpy array of shape (K, 22, 3) representing joint positions
|
| 332 |
+
"""
|
| 333 |
+
if mean is not None and std is not None:
|
| 334 |
+
motion_data = motion_data * std + mean
|
| 335 |
+
if dim == 263:
|
| 336 |
+
recovered_positions = recover_joint_positions_263(motion_data, joints_num)
|
| 337 |
+
elif dim == 272:
|
| 338 |
+
recovered_positions = recover_joint_positions_272(motion_data, joints_num)
|
| 339 |
+
else:
|
| 340 |
+
raise ValueError(f"Unsupported motion data dimension: {dim}")
|
| 341 |
+
return recovered_positions
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6034bb8fad782741caa33eb785935e2d1543e881772dacdbe9292107ec45436e
|
| 3 |
+
size 454897088
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
huggingface_hub>=0.16.0
|
| 5 |
+
safetensors>=0.3.0
|
| 6 |
+
diffusers>=0.20.0
|
| 7 |
+
|
| 8 |
+
# Inference
|
| 9 |
+
lightning>=2.0.0
|
| 10 |
+
ftfy
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
omegaconf
|
| 14 |
+
|
| 15 |
+
# Utilities
|
| 16 |
+
numpy
|
| 17 |
+
|
| 18 |
+
# Note: flash-attn is required but needs special installation
|
| 19 |
+
# See README.md for installation instructions
|
vae.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a40164154c476309ff952a4b7563750b7e76fbdd8d263ec261ad877cf452e7b
|
| 3 |
+
size 70027220
|