FloodDiffusion / generate_ldf.py
herrscher0's picture
Initial commit: FloodDiffusion text-to-motion generation model
ebc7f2e verified
import os
import sys
import torch
from lightning import seed_everything
from safetensors.torch import load_file as load_safetensors
from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config
# Set tokenizers parallelism to false to avoid warnings in multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def load_model_from_config():
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("high")
cfg = load_config()
seed_everything(cfg.seed)
# Get the directory containing the config file
# Try to find config directory from sys.argv or use current directory
if '--config' in sys.argv:
config_idx = sys.argv.index('--config') + 1
config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
else:
config_dir = os.getcwd()
vae = instantiate(
target=cfg.test_vae.target,
cfg=None,
hfstyle=False,
**cfg.test_vae.params,
)
# Handle relative paths
vae_path = cfg.test_vae_ckpt
if not os.path.isabs(vae_path):
vae_path = os.path.join(config_dir, vae_path)
# Load from safetensors (already contains EMA weights)
vae_state_dict = load_safetensors(vae_path)
vae.load_state_dict(vae_state_dict, strict=True)
print(f"Loaded VAE model from {vae_path}")
compare_statedict_and_parameters(
state_dict=vae.state_dict(),
named_parameters=vae.named_parameters(),
named_buffers=vae.named_buffers(),
)
vae.to(device)
vae.eval()
# Model - fix relative paths in model params
model_params = dict(cfg.model.params)
# Convert relative paths to absolute paths
if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
if not os.path.isabs(model_params['checkpoint_path']):
model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
if not os.path.isabs(model_params['tokenizer_path']):
model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
model = instantiate(
target=cfg.model.target, cfg=None, hfstyle=False, **model_params
)
# Handle relative paths
model_path = cfg.test_ckpt
if not os.path.isabs(model_path):
model_path = os.path.join(config_dir, model_path)
# Load from safetensors (already contains EMA weights)
model_state_dict = load_safetensors(model_path)
model.load_state_dict(model_state_dict, strict=True)
print(f"Loaded model from {model_path}")
compare_statedict_and_parameters(
state_dict=model.state_dict(),
named_parameters=model.named_parameters(),
named_buffers=model.named_buffers(),
)
model.to(device)
model.eval()
return vae, model
@torch.inference_mode()
def generate_feature_stream(
model, feature_length, text, feature_text_end=None, num_denoise_steps=None
):
"""
Streaming interface for feature generation
Args:
model: Loaded model
feature_length: List[int], generation length for each sample
text: List[str] or List[List[str]], text prompts
feature_text_end: List[List[int]], time points where text ends (if text is list of list)
num_denoise_steps: Number of denoising steps
Yields:
dict: Contains "generated" (current generated feature segment)
"""
# Construct input dict x
# stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
x = {"feature_length": torch.tensor(feature_length), "text": text}
if feature_text_end is not None:
x["feature_text_end"] = feature_text_end
# Call model's stream_generate
# Note: stream_generate is a generator
generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)
for step_output in generator:
# step_output is already a dict with "generated" key
yield step_output
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to config")
parser.add_argument(
"--text", type=str, default="a person walks forward", help="Text prompt"
)
parser.add_argument("--length", type=int, default=120, help="Motion length")
parser.add_argument(
"--output", type=str, default="output.mp4", help="Output video path"
)
parser.add_argument(
"--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
)
args = parser.parse_args()
print("Loading model...")
vae, model = load_model_from_config()