| | import os |
| | from tqdm import tqdm |
| | import torch |
| | import numpy as np |
| | import random |
| | import scipy.io as scio |
| | import src.utils.audio as audio |
| |
|
| | def crop_pad_audio(wav, audio_length): |
| | if len(wav) > audio_length: |
| | wav = wav[:audio_length] |
| | elif len(wav) < audio_length: |
| | wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) |
| | return wav |
| |
|
| | def parse_audio_length(audio_length, sr, fps): |
| | bit_per_frames = sr / fps |
| | num_frames = max(int(audio_length / bit_per_frames), 30) |
| | return int(num_frames * bit_per_frames), num_frames |
| |
|
| | def generate_blink_seq(num_frames): |
| | ratio = np.zeros((num_frames,1)) |
| | frame_id = 0 |
| | while frame_id in range(num_frames): |
| | start = 80 |
| | if frame_id+start+9 <= num_frames - 1: |
| | ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] |
| | frame_id = frame_id+start+9 |
| | else: |
| | break |
| | return ratio |
| |
|
| | def generate_blink_seq_randomly(num_frames): |
| | ratio = np.zeros((num_frames,1)) |
| | if num_frames <= 20: |
| | return ratio |
| | |
| | |
| | min_start = min(10, num_frames) |
| | max_start = min(int(num_frames/2), 70) |
| | |
| | |
| | if min_start >= max_start: |
| | max_start = min_start + 5 |
| | |
| | try: |
| | start = random.choice(range(min_start, max_start)) |
| | except IndexError: |
| | return ratio |
| | |
| | frame_id = 0 |
| | while frame_id in range(num_frames): |
| | if frame_id+start+5 <= num_frames - 1: |
| | ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] |
| | frame_id = frame_id+start+5 |
| | else: |
| | break |
| | return ratio |
| |
|
| | def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): |
| | syncnet_mel_step_size = 16 |
| | fps = 25 |
| |
|
| | pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] |
| | audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] |
| |
|
| | if idlemode: |
| | num_frames = int(length_of_audio * 25) |
| | indiv_mels = np.zeros((num_frames, 80, 16)) |
| | else: |
| | try: |
| | wav = audio.load_wav(audio_path, 16000) |
| | wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) |
| | |
| | |
| | if num_frames < 5: |
| | raise ValueError(f"Audio too short: only {num_frames} frames generated") |
| | |
| | wav = crop_pad_audio(wav, wav_length) |
| | orig_mel = audio.melspectrogram(wav).T |
| | spec = orig_mel.copy() |
| | indiv_mels = [] |
| |
|
| | for i in tqdm(range(num_frames), 'mel:'): |
| | start_frame_num = i-2 |
| | start_idx = int(80. * (start_frame_num / float(fps))) |
| | end_idx = start_idx + syncnet_mel_step_size |
| | seq = list(range(start_idx, end_idx)) |
| | seq = [min(max(item, 0), orig_mel.shape[0]-1) for item in seq] |
| | m = spec[seq, :] |
| | indiv_mels.append(m.T) |
| | indiv_mels = np.asarray(indiv_mels) |
| | except Exception as e: |
| | raise RuntimeError(f"Audio processing failed: {str(e)}") |
| |
|
| | |
| | try: |
| | ratio = generate_blink_seq_randomly(num_frames) |
| | except Exception as e: |
| | print(f"Warning: Blink sequence generation failed, using zeros: {str(e)}") |
| | ratio = np.zeros((num_frames,1)) |
| |
|
| | try: |
| | source_semantics_dict = scio.loadmat(first_coeff_path) |
| | ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] |
| | ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) |
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load source semantics: {str(e)}") |
| |
|
| | if ref_eyeblink_coeff_path is not None: |
| | try: |
| | ratio[:num_frames] = 0 |
| | refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) |
| | refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] |
| | |
| | refeyeblink_num_frames = refeyeblink_coeff.shape[0] |
| | if refeyeblink_num_frames < num_frames: |
| | div = num_frames//refeyeblink_num_frames |
| | re = num_frames%refeyeblink_num_frames |
| | refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)] |
| | refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64]) |
| | refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0) |
| | |
| | ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64] |
| | except Exception as e: |
| | print(f"Warning: Eyeblink reference processing failed: {str(e)}") |
| |
|
| | |
| | try: |
| | indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) |
| | ratio = torch.FloatTensor(ratio).unsqueeze(0) if use_blink else torch.FloatTensor(ratio).unsqueeze(0).fill_(0.) |
| | ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) |
| | |
| | indiv_mels = indiv_mels.to(device) |
| | ratio = ratio.to(device) |
| | ref_coeff = ref_coeff.to(device) |
| | except Exception as e: |
| | raise RuntimeError(f"Tensor conversion failed: {str(e)}") |
| |
|
| | return { |
| | 'indiv_mels': indiv_mels, |
| | 'ref': ref_coeff, |
| | 'num_frames': num_frames, |
| | 'ratio_gt': ratio, |
| | 'audio_name': audio_name, |
| | 'pic_name': pic_name |
| | } |
| |
|