File size: 4,814 Bytes
ebc7f2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import sys
import torch
from lightning import seed_everything
from safetensors.torch import load_file as load_safetensors
from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config
# Set tokenizers parallelism to false to avoid warnings in multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def load_model_from_config():
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("high")
cfg = load_config()
seed_everything(cfg.seed)
# Get the directory containing the config file
# Try to find config directory from sys.argv or use current directory
if '--config' in sys.argv:
config_idx = sys.argv.index('--config') + 1
config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
else:
config_dir = os.getcwd()
vae = instantiate(
target=cfg.test_vae.target,
cfg=None,
hfstyle=False,
**cfg.test_vae.params,
)
# Handle relative paths
vae_path = cfg.test_vae_ckpt
if not os.path.isabs(vae_path):
vae_path = os.path.join(config_dir, vae_path)
# Load from safetensors (already contains EMA weights)
vae_state_dict = load_safetensors(vae_path)
vae.load_state_dict(vae_state_dict, strict=True)
print(f"Loaded VAE model from {vae_path}")
compare_statedict_and_parameters(
state_dict=vae.state_dict(),
named_parameters=vae.named_parameters(),
named_buffers=vae.named_buffers(),
)
vae.to(device)
vae.eval()
# Model - fix relative paths in model params
model_params = dict(cfg.model.params)
# Convert relative paths to absolute paths
if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
if not os.path.isabs(model_params['checkpoint_path']):
model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
if not os.path.isabs(model_params['tokenizer_path']):
model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
model = instantiate(
target=cfg.model.target, cfg=None, hfstyle=False, **model_params
)
# Handle relative paths
model_path = cfg.test_ckpt
if not os.path.isabs(model_path):
model_path = os.path.join(config_dir, model_path)
# Load from safetensors (already contains EMA weights)
model_state_dict = load_safetensors(model_path)
model.load_state_dict(model_state_dict, strict=True)
print(f"Loaded model from {model_path}")
compare_statedict_and_parameters(
state_dict=model.state_dict(),
named_parameters=model.named_parameters(),
named_buffers=model.named_buffers(),
)
model.to(device)
model.eval()
return vae, model
@torch.inference_mode()
def generate_feature_stream(
model, feature_length, text, feature_text_end=None, num_denoise_steps=None
):
"""
Streaming interface for feature generation
Args:
model: Loaded model
feature_length: List[int], generation length for each sample
text: List[str] or List[List[str]], text prompts
feature_text_end: List[List[int]], time points where text ends (if text is list of list)
num_denoise_steps: Number of denoising steps
Yields:
dict: Contains "generated" (current generated feature segment)
"""
# Construct input dict x
# stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
x = {"feature_length": torch.tensor(feature_length), "text": text}
if feature_text_end is not None:
x["feature_text_end"] = feature_text_end
# Call model's stream_generate
# Note: stream_generate is a generator
generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)
for step_output in generator:
# step_output is already a dict with "generated" key
yield step_output
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to config")
parser.add_argument(
"--text", type=str, default="a person walks forward", help="Text prompt"
)
parser.add_argument("--length", type=int, default=120, help="Motion length")
parser.add_argument(
"--output", type=str, default="output.mp4", help="Output video path"
)
parser.add_argument(
"--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
)
args = parser.parse_args()
print("Loading model...")
vae, model = load_model_from_config()
|