Update src/facerender/animate.py
Browse files
src/facerender/animate.py
CHANGED
|
@@ -27,7 +27,7 @@ from src.utils.videoio import save_video_with_watermark
|
|
| 27 |
class AnimateFromCoeff():
|
| 28 |
|
| 29 |
def __init__(self, free_view_checkpoint, mapping_checkpoint,
|
| 30 |
-
config_path, device):
|
| 31 |
|
| 32 |
with open(config_path) as f:
|
| 33 |
config = yaml.safe_load(f)
|
|
@@ -88,7 +88,7 @@ class AnimateFromCoeff():
|
|
| 88 |
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
|
| 89 |
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
| 90 |
optimizer_discriminator=None, optimizer_kp_detector=None,
|
| 91 |
-
optimizer_he_estimator=None, device="
|
| 92 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
| 93 |
|
| 94 |
def adjust_state_dict(state_dict, model):
|
|
@@ -135,7 +135,7 @@ class AnimateFromCoeff():
|
|
| 135 |
return checkpoint['epoch']
|
| 136 |
|
| 137 |
def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
|
| 138 |
-
optimizer_mapping=None, optimizer_discriminator=None, device='
|
| 139 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
| 140 |
|
| 141 |
def adjust_state_dict(state_dict, model):
|
|
|
|
| 27 |
class AnimateFromCoeff():
|
| 28 |
|
| 29 |
def __init__(self, free_view_checkpoint, mapping_checkpoint,
|
| 30 |
+
config_path, device='cuda'):
|
| 31 |
|
| 32 |
with open(config_path) as f:
|
| 33 |
config = yaml.safe_load(f)
|
|
|
|
| 88 |
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
|
| 89 |
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
| 90 |
optimizer_discriminator=None, optimizer_kp_detector=None,
|
| 91 |
+
optimizer_he_estimator=None, device="cuda"):
|
| 92 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
| 93 |
|
| 94 |
def adjust_state_dict(state_dict, model):
|
|
|
|
| 135 |
return checkpoint['epoch']
|
| 136 |
|
| 137 |
def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
|
| 138 |
+
optimizer_mapping=None, optimizer_discriminator=None, device='cuda'):
|
| 139 |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
| 140 |
|
| 141 |
def adjust_state_dict(state_dict, model):
|