import os import sys import torch from lightning import seed_everything from safetensors.torch import load_file as load_safetensors from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config # Set tokenizers parallelism to false to avoid warnings in multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" def load_model_from_config(): device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_float32_matmul_precision("high") cfg = load_config() seed_everything(cfg.seed) # Get the directory containing the config file # Try to find config directory from sys.argv or use current directory if '--config' in sys.argv: config_idx = sys.argv.index('--config') + 1 config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx])) else: config_dir = os.getcwd() vae = instantiate( target=cfg.test_vae.target, cfg=None, hfstyle=False, **cfg.test_vae.params, ) # Handle relative paths vae_path = cfg.test_vae_ckpt if not os.path.isabs(vae_path): vae_path = os.path.join(config_dir, vae_path) # Load from safetensors (already contains EMA weights) vae_state_dict = load_safetensors(vae_path) vae.load_state_dict(vae_state_dict, strict=True) print(f"Loaded VAE model from {vae_path}") compare_statedict_and_parameters( state_dict=vae.state_dict(), named_parameters=vae.named_parameters(), named_buffers=vae.named_buffers(), ) vae.to(device) vae.eval() # Model - fix relative paths in model params model_params = dict(cfg.model.params) # Convert relative paths to absolute paths if 'checkpoint_path' in model_params and model_params['checkpoint_path']: if not os.path.isabs(model_params['checkpoint_path']): model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path']) if 'tokenizer_path' in model_params and model_params['tokenizer_path']: if not os.path.isabs(model_params['tokenizer_path']): model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path']) model = instantiate( target=cfg.model.target, cfg=None, hfstyle=False, **model_params ) # Handle relative paths model_path = cfg.test_ckpt if not os.path.isabs(model_path): model_path = os.path.join(config_dir, model_path) # Load from safetensors (already contains EMA weights) model_state_dict = load_safetensors(model_path) model.load_state_dict(model_state_dict, strict=True) print(f"Loaded model from {model_path}") compare_statedict_and_parameters( state_dict=model.state_dict(), named_parameters=model.named_parameters(), named_buffers=model.named_buffers(), ) model.to(device) model.eval() return vae, model @torch.inference_mode() def generate_feature_stream( model, feature_length, text, feature_text_end=None, num_denoise_steps=None ): """ Streaming interface for feature generation Args: model: Loaded model feature_length: List[int], generation length for each sample text: List[str] or List[List[str]], text prompts feature_text_end: List[List[int]], time points where text ends (if text is list of list) num_denoise_steps: Number of denoising steps Yields: dict: Contains "generated" (current generated feature segment) """ # Construct input dict x # stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list) x = {"feature_length": torch.tensor(feature_length), "text": text} if feature_text_end is not None: x["feature_text_end"] = feature_text_end # Call model's stream_generate # Note: stream_generate is a generator generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps) for step_output in generator: # step_output is already a dict with "generated" key yield step_output if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to config") parser.add_argument( "--text", type=str, default="a person walks forward", help="Text prompt" ) parser.add_argument("--length", type=int, default=120, help="Motion length") parser.add_argument( "--output", type=str, default="output.mp4", help="Output video path" ) parser.add_argument( "--num_denoise_steps", type=int, default=None, help="Number of denoising steps" ) args = parser.parse_args() print("Loading model...") vae, model = load_model_from_config()