diff --git a/configs/train/stage1.yaml b/configs/train/stage1.yaml new file mode 100644 index 00000000..28760ed2 --- /dev/null +++ b/configs/train/stage1.yaml @@ -0,0 +1,63 @@ +data: + train_bs: 8 + train_width: 512 + train_height: 512 + meta_paths: + - "./data/HDTF_meta.json" + # Margin of frame indexes between ref and tgt images + sample_margin: 30 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: "no" + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: False + max_train_steps: 30000 + max_grad_norm: 1.0 + # lr + learning_rate: 1.0e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: "constant" + + # optimizer + use_8bit_adam: False + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 500 + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + steps_offset: 1 + clip_sample: false + +base_model_path: "./pretrained_models/stable-diffusion-v1-5/" +vae_model_path: "./pretrained_models/sd-vae-ft-mse" +face_analysis_model_path: "./pretrained_models/face_analysis" + +weight_dtype: "fp16" # [fp16, fp32] +uncond_ratio: 0.1 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True +face_locator_pretrained: False + +seed: 42 +resume_from_checkpoint: "latest" +checkpointing_steps: 500 +exp_name: "stage1" +output_dir: "./exp_output" + +ref_image_paths: + - "examples/reference_images/1.jpg" + +mask_image_paths: + - "examples/masks/1.png" + diff --git a/configs/train/stage2.yaml b/configs/train/stage2.yaml new file mode 100644 index 00000000..bd32c0dd --- /dev/null +++ b/configs/train/stage2.yaml @@ -0,0 +1,119 @@ +data: + train_bs: 4 + val_bs: 1 + train_width: 512 + train_height: 512 + fps: 25 + sample_rate: 16000 + n_motion_frames: 2 + n_sample_frames: 14 + audio_margin: 2 + train_meta_paths: + - "./data/hdtf_split_stage2.json" + +wav2vec_config: + audio_type: "vocals" # audio vocals + model_scale: "base" # base large + features: "all" # last avg all + model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h +audio_separator: + model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx +face_expand_ratio: 1.2 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: "no" + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: True + max_train_steps: 30000 + max_grad_norm: 1.0 + # lr + learning_rate: 1e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: "constant" + + # optimizer + use_8bit_adam: True + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 1000 + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + +unet_additional_kwargs: + use_inflated_groupnorm: true + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + use_audio_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: true + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + audio_attention_dim: 768 + stack_enable_blocks_name: + - "up" + - "down" + - "mid" + stack_enable_blocks_depth: [0,1,2,3] + +trainable_para: + - audio_modules + - motion_modules + +base_model_path: "./pretrained_models/stable-diffusion-v1-5/" +vae_model_path: "./pretrained_models/sd-vae-ft-mse" +face_analysis_model_path: "./pretrained_models/face_analysis" +mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt" + +weight_dtype: "fp16" # [fp16, fp32] +uncond_img_ratio: 0.05 +uncond_audio_ratio: 0.05 +uncond_ia_ratio: 0.05 +start_ratio: 0.05 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True +stage1_ckpt_dir: "./pretrained_models/hallo/stage1" + +single_inference_times: 10 +inference_steps: 40 +cfg_scale: 3.5 + +seed: 42 +resume_from_checkpoint: "latest" +checkpointing_steps: 500 +exp_name: "stage2_test" +output_dir: "./exp_output" + +ref_img_path: + - "examples/reference_images/1.jpg" + +audio_path: + - "examples/driving_audios/1.wav" + + diff --git a/examples/masks/1.png b/examples/masks/1.png new file mode 100644 index 00000000..c63e0757 Binary files /dev/null and b/examples/masks/1.png differ diff --git a/hallo/datasets/talk_video.py b/hallo/datasets/talk_video.py index 4f9114ba..25c3ab81 100644 --- a/hallo/datasets/talk_video.py +++ b/hallo/datasets/talk_video.py @@ -145,25 +145,29 @@ def __init__( ) self.attn_transform_64 = transforms.Compose( [ - transforms.Resize((64,64)), + transforms.Resize( + (self.img_size[0] // 8, self.img_size[0] // 8)), transforms.ToTensor(), ] ) self.attn_transform_32 = transforms.Compose( [ - transforms.Resize((32, 32)), + transforms.Resize( + (self.img_size[0] // 16, self.img_size[0] // 16)), transforms.ToTensor(), ] ) self.attn_transform_16 = transforms.Compose( [ - transforms.Resize((16, 16)), + transforms.Resize( + (self.img_size[0] // 32, self.img_size[0] // 32)), transforms.ToTensor(), ] ) self.attn_transform_8 = transforms.Compose( [ - transforms.Resize((8, 8)), + transforms.Resize( + (self.img_size[0] // 64, self.img_size[0] // 64)), transforms.ToTensor(), ] ) diff --git a/hallo/models/motion_module.py b/hallo/models/motion_module.py index 07f98454..f62877d4 100644 --- a/hallo/models/motion_module.py +++ b/hallo/models/motion_module.py @@ -507,6 +507,7 @@ def extra_repr(self): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, + attention_op = None, ): """ Sets the use of memory-efficient attention xformers for the VersatileAttention class. diff --git a/hallo/utils/util.py b/hallo/utils/util.py index f4b6563a..9dc61fb8 100644 --- a/hallo/utils/util.py +++ b/hallo/utils/util.py @@ -67,6 +67,7 @@ import subprocess import sys from pathlib import Path +from typing import List import av import cv2 @@ -614,3 +615,150 @@ def get_face_region(image_path: str, detector): except Exception as e: print(f"Error processing image {image_path}: {e}") return None, None + + +def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None: + """ + Save the model's state_dict to a checkpoint file. + + If `total_limit` is provided, this function will remove the oldest checkpoints + until the total number of checkpoints is less than the specified limit. + + Args: + model (nn.Module): The model whose state_dict is to be saved. + save_dir (str): The directory where the checkpoint will be saved. + prefix (str): The prefix for the checkpoint file name. + ckpt_num (int): The checkpoint number to be saved. + total_limit (int, optional): The maximum number of checkpoints to keep. + Defaults to None, in which case no checkpoints will be removed. + + Raises: + FileNotFoundError: If the save directory does not exist. + ValueError: If the checkpoint number is negative. + OSError: If there is an error saving the checkpoint. + """ + + if not osp.exists(save_dir): + raise FileNotFoundError( + f"The save directory {save_dir} does not exist.") + + if ckpt_num < 0: + raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.") + + save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") + + if total_limit > 0: + checkpoints = os.listdir(save_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + + if len(checkpoints) >= total_limit: + num_to_remove = len(checkpoints) - total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + print( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + print( + f"Removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint_path = osp.join( + save_dir, removing_checkpoint) + try: + os.remove(removing_checkpoint_path) + except OSError as e: + print( + f"Error removing checkpoint {removing_checkpoint_path}: {e}") + + state_dict = model.state_dict() + try: + torch.save(state_dict, save_path) + print(f"Checkpoint saved at {save_path}") + except OSError as e: + raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e + + +def init_output_dir(dir_list: List[str]): + """ + Initialize the output directories. + + This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing. + + Args: + dir_list (List[str]): List of directory paths to create. + """ + for path in dir_list: + os.makedirs(path, exist_ok=True) + + +def load_checkpoint(cfg, save_dir, accelerator): + """ + Load the most recent checkpoint from the specified directory. + + This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest". + If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found, + it starts training from scratch. + + Args: + cfg: The configuration object containing training parameters. + save_dir (str): The directory where checkpoints are saved. + accelerator: The accelerator object for distributed training. + + Returns: + int: The global step at which to resume training. + """ + if cfg.resume_from_checkpoint != "latest": + resume_dir = cfg.resume_from_checkpoint + else: + resume_dir = save_dir + # Get the most recent checkpoint + dirs = os.listdir(resume_dir) + + dirs = [d for d in dirs if d.startswith("checkpoint")] + if len(dirs) > 0: + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.load_state(os.path.join(resume_dir, path)) + accelerator.print(f"Resuming from checkpoint {path}") + global_step = int(path.split("-")[1]) + else: + accelerator.print( + f"Could not find checkpoint under {resume_dir}, start training from scratch") + global_step = 0 + + return global_step + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ + 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ + # 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py new file mode 100644 index 00000000..305b9395 --- /dev/null +++ b/scripts/train_stage1.py @@ -0,0 +1,784 @@ +# pylint: disable=E1101,C0415,W0718,R0801 +# scripts/train_stage1.py +""" +This is the main training script for stage 1 of the project. +It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration. + +The script includes the following classes and functions: + +1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings, + and face masks as input and returns the denoised latents. +3. log_validation: A function that logs the validation information using the given VAE, image encoder, + network, scheduler, accelerator, width, height, and configuration. +4. train_stage1_process: A function that processes the training stage 1 using the given configuration. + +The script also includes the necessary imports and a brief description of the purpose of the file. +""" + +import argparse +import logging +import math +import os +import random +import warnings +from datetime import datetime + +import cv2 +import diffusers +import mlflow +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from insightface.app import FaceAnalysis +from omegaconf import OmegaConf +from PIL import Image +from torch import nn +from tqdm.auto import tqdm + +from hallo.animate.face_animate_static import StaticPipeline +from hallo.datasets.mask_image import FaceMaskDataset +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.mutual_self_attention import ReferenceAttentionControl +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.util import (compute_snr, delete_additional_ckpt, + import_filename, init_output_dir, + load_checkpoint, save_checkpoint, + seed_everything) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class Net(nn.Module): + """ + The Net class defines a neural network model that combines a reference UNet2DConditionModel, + a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. + + Args: + reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. + denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. + face_locator (FaceLocator): The face locator model used for face animation. + reference_control_writer: The reference control writer component. + reference_control_reader: The reference control reader component. + imageproj: The image projection model. + + Forward method: + noisy_latents (torch.Tensor): The noisy latents tensor. + timesteps (torch.Tensor): The timesteps tensor. + ref_image_latents (torch.Tensor): The reference image latents tensor. + face_emb (torch.Tensor): The face embeddings tensor. + face_mask (torch.Tensor): The face mask tensor. + uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass. + + Returns: + torch.Tensor: The output tensor of the neural network model. + """ + + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + reference_control_writer: ReferenceAttentionControl, + reference_control_reader: ReferenceAttentionControl, + imageproj: ImageProjModel, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + self.imageproj = imageproj + + def forward( + self, + noisy_latents, + timesteps, + ref_image_latents, + face_emb, + face_mask, + uncond_fwd: bool = False, + ): + """ + Forward pass of the model. + Args: + self (Net): The model instance. + noisy_latents (torch.Tensor): Noisy latents. + timesteps (torch.Tensor): Timesteps. + ref_image_latents (torch.Tensor): Reference image latents. + face_emb (torch.Tensor): Face embedding. + face_mask (torch.Tensor): Face mask. + uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False. + + Returns: + torch.Tensor: Model prediction. + """ + + face_emb = self.imageproj(face_emb) + face_mask = face_mask.to(device="cuda") + face_mask_feature = self.face_locator(face_mask) + + if not uncond_fwd: + ref_timesteps = torch.zeros_like(timesteps) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=face_emb, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + mask_cond_fea=face_mask_feature, + encoder_hidden_states=face_emb, + ).sample + + return model_pred + + +def get_noise_scheduler(cfg: argparse.Namespace): + """ + Create noise scheduler for training + + Args: + cfg (omegaconf.dictconfig.DictConfig): Configuration object. + + Returns: + train noise scheduler and val noise scheduler + """ + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + return train_noise_scheduler, val_noise_scheduler + + +def log_validation( + vae, + net, + scheduler, + accelerator, + width, + height, + imageproj, + cfg, + save_dir, + global_step, + face_analysis_model_path, +): + """ + Log validation generation image. + + Args: + vae (nn.Module): Variational Autoencoder model. + net (Net): Main model. + scheduler (diffusers.SchedulerMixin): Noise scheduler. + accelerator (accelerate.Accelerator): Accelerator for training. + width (int): Width of the input images. + height (int): Height of the input images. + imageproj (nn.Module): Image projection model. + cfg (omegaconf.dictconfig.DictConfig): Configuration object. + save_dir (str): directory path to save log result. + global_step (int): Global step number. + + Returns: + None + """ + logger.info("Running validation... ") + + ori_net = accelerator.unwrap_model(net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + face_locator = ori_net.face_locator + + generator = torch.manual_seed(42) + image_enc = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + image_enc.prepare(ctx_id=0, det_size=(640, 640)) + + pipe = StaticPipeline( + vae=vae, + reference_unet=reference_unet, + denoising_unet=denoising_unet, + face_locator=face_locator, + scheduler=scheduler, + imageproj=imageproj, + ) + + pil_images = [] + for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths): + # for mask_image_path in mask_image_paths: + mask_name = os.path.splitext( + os.path.basename(mask_image_path))[0] + ref_name = os.path.splitext( + os.path.basename(ref_image_path))[0] + ref_image_pil = Image.open(ref_image_path).convert("RGB") + mask_image_pil = Image.open(mask_image_path).convert("RGB") + + # Prepare face embeds + face_info = image_enc.get( + cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * ( + x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face + face_emb = torch.tensor(face_info['embedding']) + face_emb = face_emb.to( + imageproj.device, imageproj.dtype) + + image = pipe( + ref_image_pil, + mask_image_pil, + width, + height, + 20, + 3.5, + face_emb, + generator=generator, + ).images + image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512) + res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) + # Save ref_image, src_image and the generated_image + w, h = res_image_pil.size + canvas = Image.new("RGB", (w * 3, h), "white") + ref_image_pil = ref_image_pil.resize((w, h)) + mask_image_pil = mask_image_pil.resize((w, h)) + canvas.paste(ref_image_pil, (0, 0)) + canvas.paste(mask_image_pil, (w, 0)) + canvas.paste(res_image_pil, (w * 2, 0)) + + out_file = os.path.join( + save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg" + ) + canvas.save(out_file) + + del pipe + torch.cuda.empty_cache() + + return pil_images + + +def train_stage1_process(cfg: argparse.Namespace) -> None: + """ + Trains the model using the given configuration (cfg). + + Args: + cfg (dict): The configuration dictionary containing the parameters for training. + + Notes: + - This function trains the model using the given configuration. + - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. + - The training progress is logged and tracked using the accelerator. + - The trained model is saved after the training is completed. + """ + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="mlflow", + project_dir="./mlruns", + kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + # create output dir for training + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + checkpoint_dir = os.path.join(save_dir, "checkpoints") + module_dir = os.path.join(save_dir, "modules") + validation_dir = os.path.join(save_dir, "validation") + + if accelerator.is_main_process: + init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir]) + + accelerator.wait_for_everyone() + + # create model + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + # create model + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + "", + subfolder="unet", + unet_additional_kwargs={ + "use_motion_module": False, + "unet_use_temporal_attention": False, + }, + use_landmark=False + ).to(device="cuda", dtype=weight_dtype) + imageproj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ).to(device="cuda", dtype=weight_dtype) + + if cfg.face_locator_pretrained: + face_locator = FaceLocator( + conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) + ).to(device="cuda", dtype=weight_dtype) + miss, _ = face_locator.load_state_dict( + cfg.face_state_dict_path, strict=False) + logger.info(f"Missing key for face locator: {len(miss)}") + else: + face_locator = FaceLocator( + conditioning_embedding_channels=320, + ).to(device="cuda", dtype=weight_dtype) + # Freeze + vae.requires_grad_(False) + denoising_unet.requires_grad_(True) + reference_unet.requires_grad_(True) + imageproj.requires_grad_(True) + face_locator.requires_grad_(True) + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + face_locator, + reference_control_writer, + reference_control_reader, + imageproj, + ).to(dtype=weight_dtype) + + # get noise scheduler + train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg) + + # init optimizer + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError as exc: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) from exc + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list( + filter(lambda p: p.requires_grad, net.parameters())) + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # init scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + # get data loader + train_dataset = FaceMaskDataset( + img_size=(cfg.data.train_width, cfg.data.train_height), + data_meta_paths=cfg.data.meta_paths, + sample_margin=cfg.data.sample_margin, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + cfg.exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + # dump config file + mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml") + + logger.info(f"save config to {save_dir}") + OmegaConf.save( + cfg, os.path.join(save_dir, "config.yaml") + ) + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # load checkpoint + # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + logger.info(f"Loading checkpoint from {checkpoint_dir}") + global_step = load_checkpoint(cfg, checkpoint_dir, accelerator) + first_epoch = global_step // num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_main_process, + ) + progress_bar.set_description("Steps") + net.train() + for _ in range(first_epoch, num_train_epochs): + train_loss = 0.0 + for _, batch in enumerate(train_dataloader): + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values = batch["img"].to(weight_dtype) + with torch.no_grad(): + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents.unsqueeze(2) # (b, c, 1, h, w) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0.0: + noise += cfg.noise_offset * torch.randn( + (noise.shape[0], noise.shape[1], 1, 1, 1), + device=noise.device, + ) + + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + face_mask_img = batch["tgt_mask"] + face_mask_img = face_mask_img.unsqueeze( + 2) + face_mask_img = face_mask_img.to(weight_dtype) + + uncond_fwd = random.random() < cfg.uncond_ratio + face_emb_list = [] + ref_image_list = [] + for _, (ref_img, face_emb) in enumerate( + zip(batch["ref_img"], batch["face_emb"]) + ): + if uncond_fwd: + face_emb_list.append(torch.zeros_like(face_emb)) + else: + face_emb_list.append(face_emb) + ref_image_list.append(ref_img) + + with torch.no_grad(): + ref_img = torch.stack(ref_image_list, dim=0).to( + dtype=vae.dtype, device=vae.device + ) + ref_image_latents = vae.encode( + ref_img + ).latent_dist.sample() + ref_image_latents = ref_image_latents * 0.18215 + + face_emb = torch.stack(face_emb_list, dim=0).to( + dtype=imageproj.dtype, device=imageproj.device + ) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + model_pred = net( + noisy_latents, + timesteps, + ref_image_latents, + face_emb, + face_mask_img, + uncond_fwd, + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps: + accelerator.wait_for_everyone() + save_path = os.path.join( + checkpoint_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + delete_additional_ckpt(checkpoint_dir, 3) + accelerator.save_state(save_path) + accelerator.wait_for_everyone() + unwrap_net = accelerator.unwrap_model(net) + if accelerator.is_main_process: + save_checkpoint( + unwrap_net.reference_unet, + module_dir, + "reference_unet", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.imageproj, + module_dir, + "imageproj", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.denoising_unet, + module_dir, + "denoising_unet", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.face_locator, + module_dir, + "face_locator", + global_step, + total_limit=3, + ) + + if global_step % cfg.val.validation_steps == 0 or global_step == 1: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + log_validation( + vae=vae, + net=net, + scheduler=val_noise_scheduler, + accelerator=accelerator, + width=cfg.data.train_width, + height=cfg.data.train_height, + imageproj=imageproj, + cfg=cfg, + save_dir=validation_dir, + global_step=global_step, + face_analysis_model_path=cfg.face_analysis_model_path + ) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + if global_step >= cfg.solver.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +def load_config(config_path: str) -> dict: + """ + Loads the configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The configuration dictionary. + """ + + if config_path.endswith(".yaml"): + return OmegaConf.load(config_path) + if config_path.endswith(".py"): + return import_filename(config_path).cfg + raise ValueError("Unsupported format for config file") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, + default="./configs/train/stage1.yaml") + args = parser.parse_args() + + try: + config = load_config(args.config) + train_stage1_process(config) + except Exception as e: + logging.error("Failed to execute the training process: %s", e) diff --git a/scripts/train_stage2.py b/scripts/train_stage2.py new file mode 100644 index 00000000..8cff266d --- /dev/null +++ b/scripts/train_stage2.py @@ -0,0 +1,991 @@ +# pylint: disable=E1101,C0415,W0718,R0801 +# scripts/train_stage2.py +""" +This is the main training script for stage 2 of the project. +It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration. + +The script includes the following classes and functions: + +1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings, + and face masks as input and returns the denoised latents. +2. get_attention_mask: A function that rearranges the mask tensors to the required format. +3. get_noise_scheduler: A function that creates and returns the noise schedulers for training and validation. +4. process_audio_emb: A function that processes the audio embeddings to concatenate with other tensors. +5. log_validation: A function that logs the validation information using the given VAE, image encoder, + network, scheduler, accelerator, width, height, and configuration. +6. train_stage2_process: A function that processes the training stage 2 using the given configuration. +7. load_config: A function that loads the configuration file from the given path. + +The script also includes the necessary imports and a brief description of the purpose of the file. +""" + +import argparse +import copy +import logging +import math +import os +import random +import time +import warnings +from datetime import datetime +from typing import List, Tuple + +import diffusers +import mlflow +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat +from omegaconf import OmegaConf +from torch import nn +from tqdm.auto import tqdm + +from hallo.animate.face_animate import FaceAnimatePipeline +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessor +from hallo.datasets.talk_video import TalkingVideoDataset +from hallo.models.audio_proj import AudioProjModel +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.mutual_self_attention import ReferenceAttentionControl +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.util import (compute_snr, delete_additional_ckpt, + import_filename, init_output_dir, + load_checkpoint, save_checkpoint, + seed_everything, tensor_to_video) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class Net(nn.Module): + """ + The Net class defines a neural network model that combines a reference UNet2DConditionModel, + a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. + + Args: + reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. + denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. + face_locator (FaceLocator): The face locator model used for face animation. + reference_control_writer: The reference control writer component. + reference_control_reader: The reference control reader component. + imageproj: The image projection model. + audioproj: The audio projection model. + + Forward method: + noisy_latents (torch.Tensor): The noisy latents tensor. + timesteps (torch.Tensor): The timesteps tensor. + ref_image_latents (torch.Tensor): The reference image latents tensor. + face_emb (torch.Tensor): The face embeddings tensor. + audio_emb (torch.Tensor): The audio embeddings tensor. + mask (torch.Tensor): Hard face mask for face locator. + full_mask (torch.Tensor): Pose Mask. + face_mask (torch.Tensor): Face Mask + lip_mask (torch.Tensor): Lip Mask + uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass. + uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass. + + Returns: + torch.Tensor: The output tensor of the neural network model. + """ + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + reference_control_writer, + reference_control_reader, + imageproj, + audioproj, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + self.imageproj = imageproj + self.audioproj = audioproj + + def forward( + self, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + ref_image_latents: torch.Tensor, + face_emb: torch.Tensor, + audio_emb: torch.Tensor, + mask: torch.Tensor, + full_mask: torch.Tensor, + face_mask: torch.Tensor, + lip_mask: torch.Tensor, + uncond_img_fwd: bool = False, + uncond_audio_fwd: bool = False, + ): + """ + simple docstring to prevent pylint error + """ + face_emb = self.imageproj(face_emb) + mask = mask.to(device="cuda") + mask_feature = self.face_locator(mask) + audio_emb = audio_emb.to( + device=self.audioproj.device, dtype=self.audioproj.dtype) + audio_emb = self.audioproj(audio_emb) + + # condition forward + if not uncond_img_fwd: + ref_timesteps = torch.zeros_like(timesteps) + ref_timesteps = repeat( + ref_timesteps, + "b -> (repeat b)", + repeat=ref_image_latents.size(0) // ref_timesteps.size(0), + ) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=face_emb, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + + if uncond_audio_fwd: + audio_emb = torch.zeros_like(audio_emb).to( + device=audio_emb.device, dtype=audio_emb.dtype + ) + + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + mask_cond_fea=mask_feature, + encoder_hidden_states=face_emb, + audio_embedding=audio_emb, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask + ).sample + + return model_pred + + +def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: + """ + Rearrange the mask tensors to the required format. + + Args: + mask (torch.Tensor): The input mask tensor. + weight_dtype (torch.dtype): The data type for the mask tensor. + + Returns: + torch.Tensor: The rearranged mask tensor. + """ + if isinstance(mask, List): + _mask = [] + for m in mask: + _mask.append( + rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype)) + return _mask + mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype) + return mask + + +def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]: + """ + Create noise scheduler for training. + + Args: + cfg (argparse.Namespace): Configuration object. + + Returns: + Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler. + """ + + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + return train_noise_scheduler, val_noise_scheduler + + +def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor: + """ + Process the audio embedding to concatenate with other tensors. + + Parameters: + audio_emb (torch.Tensor): The audio embedding tensor to process. + + Returns: + concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. + """ + concatenated_tensors = [] + + for i in range(audio_emb.shape[0]): + vectors_to_concat = [ + audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)] + concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) + + audio_emb = torch.stack(concatenated_tensors, dim=0) + + return audio_emb + + +def log_validation( + accelerator: Accelerator, + vae: AutoencoderKL, + net: Net, + scheduler: DDIMScheduler, + width: int, + height: int, + clip_length: int = 24, + generator: torch.Generator = None, + cfg: dict = None, + save_dir: str = None, + global_step: int = 0, + times: int = None, + face_analysis_model_path: str = "", +) -> None: + """ + Log validation video during the training process. + + Args: + accelerator (Accelerator): The accelerator for distributed training. + vae (AutoencoderKL): The autoencoder model. + net (Net): The main neural network model. + scheduler (DDIMScheduler): The scheduler for noise. + width (int): The width of the input images. + height (int): The height of the input images. + clip_length (int): The length of the video clips. Defaults to 24. + generator (torch.Generator): The random number generator. Defaults to None. + cfg (dict): The configuration dictionary. Defaults to None. + save_dir (str): The directory to save validation results. Defaults to None. + global_step (int): The current global step in training. Defaults to 0. + times (int): The number of inference times. Defaults to None. + face_analysis_model_path (str): The path to the face analysis model. Defaults to "". + + Returns: + torch.Tensor: The tensor result of the validation. + """ + ori_net = accelerator.unwrap_model(net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + face_locator = ori_net.face_locator + imageproj = ori_net.imageproj + audioproj = ori_net.audioproj + + generator = torch.manual_seed(42) + tmp_denoising_unet = copy.deepcopy(denoising_unet) + + pipeline = FaceAnimatePipeline( + vae=vae, + reference_unet=reference_unet, + denoising_unet=tmp_denoising_unet, + face_locator=face_locator, + image_proj=imageproj, + scheduler=scheduler, + ) + pipeline = pipeline.to("cuda") + + image_processor = ImageProcessor((width, height), face_analysis_model_path) + audio_processor = AudioProcessor( + cfg.data.sample_rate, + cfg.data.fps, + cfg.wav2vec_config.model_path, + cfg.wav2vec_config.features == "last", + os.path.dirname(cfg.audio_separator.model_path), + os.path.basename(cfg.audio_separator.model_path), + os.path.join(save_dir, '.cache', "audio_preprocess") + ) + + for idx, ref_img_path in enumerate(cfg.ref_img_path): + audio_path = cfg.audio_path[idx] + source_image_pixels, \ + source_image_face_region, \ + source_image_face_emb, \ + source_image_full_mask, \ + source_image_face_mask, \ + source_image_lip_mask = image_processor.preprocess( + ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio) + audio_emb, audio_length = audio_processor.preprocess( + audio_path, clip_length) + + audio_emb = process_audio_emb(audio_emb) + + source_image_pixels = source_image_pixels.unsqueeze(0) + source_image_face_region = source_image_face_region.unsqueeze(0) + source_image_face_emb = source_image_face_emb.reshape(1, -1) + source_image_face_emb = torch.tensor(source_image_face_emb) + + source_image_full_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_full_mask + ] + source_image_face_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_face_mask + ] + source_image_lip_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_lip_mask + ] + + times = audio_emb.shape[0] // clip_length + tensor_result = [] + generator = torch.manual_seed(42) + for t in range(times): + print(f"[{t+1}/{times}]") + + if len(tensor_result) == 0: + # The first iteration + motion_zeros = source_image_pixels.repeat( + cfg.data.n_motion_frames, 1, 1, 1) + motion_zeros = motion_zeros.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames + else: + motion_frames = tensor_result[-1][0] + motion_frames = motion_frames.permute(1, 0, 2, 3) + motion_frames = motion_frames[0 - cfg.data.n_motion_frames:] + motion_frames = motion_frames * 2.0 - 1.0 + motion_frames = motion_frames.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames + + pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) + + audio_tensor = audio_emb[ + t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) + ] + audio_tensor = audio_tensor.unsqueeze(0) + audio_tensor = audio_tensor.to( + device=audioproj.device, dtype=audioproj.dtype) + audio_tensor = audioproj(audio_tensor) + + pipeline_output = pipeline( + ref_image=pixel_values_ref_img, + audio_tensor=audio_tensor, + face_emb=source_image_face_emb, + face_mask=source_image_face_region, + pixel_values_full_mask=source_image_full_mask, + pixel_values_face_mask=source_image_face_mask, + pixel_values_lip_mask=source_image_lip_mask, + width=cfg.data.train_width, + height=cfg.data.train_height, + video_length=clip_length, + num_inference_steps=cfg.inference_steps, + guidance_scale=cfg.cfg_scale, + generator=generator, + ) + + tensor_result.append(pipeline_output.videos) + + tensor_result = torch.cat(tensor_result, dim=2) + tensor_result = tensor_result.squeeze(0) + tensor_result = tensor_result[:, :audio_length] + audio_name = os.path.basename(audio_path).split('.')[0] + ref_name = os.path.basename(ref_img_path).split('.')[0] + output_file = os.path.join(save_dir,f"{global_step}_{ref_name}_{audio_name}.mp4") + # save the result after all iteration + tensor_to_video(tensor_result, output_file, audio_path) + + + # clean up + del tmp_denoising_unet + del pipeline + del image_processor + del audio_processor + torch.cuda.empty_cache() + + return tensor_result + + +def train_stage2_process(cfg: argparse.Namespace) -> None: + """ + Trains the model using the given configuration (cfg). + + Args: + cfg (dict): The configuration dictionary containing the parameters for training. + + Notes: + - This function trains the model using the given configuration. + - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. + - The training progress is logged and tracked using the accelerator. + - The trained model is saved after the training is completed. + """ + kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="mlflow", + project_dir="./mlruns", + kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + # create output dir for training + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + checkpoint_dir = os.path.join(save_dir, "checkpoints") + module_dir = os.path.join(save_dir, "modules") + validation_dir = os.path.join(save_dir, "validation") + if accelerator.is_main_process: + init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir]) + + accelerator.wait_for_everyone() + + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + # Create Models + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + cfg.mm_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + cfg.unet_additional_kwargs), + use_landmark=False + ).to(device="cuda", dtype=weight_dtype) + imageproj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ).to(device="cuda", dtype=weight_dtype) + face_locator = FaceLocator( + conditioning_embedding_channels=320, + ).to(device="cuda", dtype=weight_dtype) + audioproj = AudioProjModel( + seq_len=5, + blocks=12, + channels=768, + intermediate_dim=512, + output_dim=768, + context_tokens=32, + ).to(device="cuda", dtype=weight_dtype) + + # load module weight from stage 1 + stage1_ckpt_dir = cfg.stage1_ckpt_dir + denoising_unet.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "denoising_unet.pth"), + map_location="cpu", + ), + strict=False, + ) + reference_unet.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "reference_unet.pth"), + map_location="cpu", + ), + strict=False, + ) + face_locator.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "face_locator.pth"), + map_location="cpu", + ), + strict=False, + ) + imageproj.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "imageproj.pth"), + map_location="cpu", + ), + strict=False, + ) + + # Freeze + vae.requires_grad_(False) + imageproj.requires_grad_(False) + reference_unet.requires_grad_(False) + denoising_unet.requires_grad_(False) + face_locator.requires_grad_(False) + audioproj.requires_grad_(True) + + # Set motion module learnable + trainable_modules = cfg.trainable_para + for name, module in denoising_unet.named_modules(): + if any(trainable_mod in name for trainable_mod in trainable_modules): + for params in module.parameters(): + params.requires_grad_(True) + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + face_locator, + reference_control_writer, + reference_control_reader, + imageproj, + audioproj, + ).to(dtype=weight_dtype) + + # get noise scheduler + train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg) + + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError as exc: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) from exc + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list( + filter(lambda p: p.requires_grad, net.parameters())) + logger.info(f"Total trainable params {len(trainable_params)}") + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # Scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + # get data loader + train_dataset = TalkingVideoDataset( + img_size=(cfg.data.train_width, cfg.data.train_height), + sample_rate=cfg.data.sample_rate, + n_sample_frames=cfg.data.n_sample_frames, + n_motion_frames=cfg.data.n_motion_frames, + audio_margin=cfg.data.audio_margin, + data_meta_paths=cfg.data.train_meta_paths, + wav2vec_cfg=cfg.wav2vec_config, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + # dump config file + mlflow.log_dict( + OmegaConf.to_container( + cfg), "config.yaml" + ) + logger.info(f"save config to {save_dir}") + OmegaConf.save( + cfg, os.path.join(save_dir, "config.yaml") + ) + + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + logger.info(f"Loading checkpoint from {checkpoint_dir}") + global_step = load_checkpoint(cfg, checkpoint_dir, accelerator) + first_epoch = global_step // num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + for _ in range(first_epoch, num_train_epochs): + train_loss = 0.0 + t_data_start = time.time() + for _, batch in enumerate(train_dataloader): + t_data = time.time() - t_data_start + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype) + + pixel_values_face_mask = batch["pixel_values_face_mask"] + pixel_values_face_mask = get_attention_mask( + pixel_values_face_mask, weight_dtype + ) + pixel_values_lip_mask = batch["pixel_values_lip_mask"] + pixel_values_lip_mask = get_attention_mask( + pixel_values_lip_mask, weight_dtype + ) + pixel_values_full_mask = batch["pixel_values_full_mask"] + pixel_values_full_mask = get_attention_mask( + pixel_values_full_mask, weight_dtype + ) + + with torch.no_grad(): + video_length = pixel_values_vid.shape[1] + pixel_values_vid = rearrange( + pixel_values_vid, "b f c h w -> (b f) c h w" + ) + latents = vae.encode(pixel_values_vid).latent_dist.sample() + latents = rearrange( + latents, "(b f) c h w -> b c f h w", f=video_length + ) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0: + noise += cfg.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1, 1), + device=latents.device, + ) + + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + # mask for face locator + pixel_values_mask = ( + batch["pixel_values_mask"].unsqueeze( + 1).to(dtype=weight_dtype) + ) + pixel_values_mask = repeat( + pixel_values_mask, + "b f c h w -> b (repeat f) c h w", + repeat=video_length, + ) + pixel_values_mask = pixel_values_mask.transpose( + 1, 2) + + uncond_img_fwd = random.random() < cfg.uncond_img_ratio + uncond_audio_fwd = random.random() < cfg.uncond_audio_ratio + + start_frame = random.random() < cfg.start_ratio + pixel_values_ref_img = batch["pixel_values_ref_img"].to( + dtype=weight_dtype + ) + # initialize the motion frames as zero maps + if start_frame: + pixel_values_ref_img[:, 1:] = 0.0 + + ref_img_and_motion = rearrange( + pixel_values_ref_img, "b f c h w -> (b f) c h w" + ) + + with torch.no_grad(): + ref_image_latents = vae.encode( + ref_img_and_motion + ).latent_dist.sample() + ref_image_latents = ref_image_latents * 0.18215 + image_prompt_embeds = batch["face_emb"].to( + dtype=imageproj.dtype, device=imageproj.device + ) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + + # ---- Forward!!! ----- + model_pred = net( + noisy_latents=noisy_latents, + timesteps=timesteps, + ref_image_latents=ref_image_latents, + face_emb=image_prompt_embeds, + mask=pixel_values_mask, + full_mask=pixel_values_full_mask, + face_mask=pixel_values_face_mask, + lip_mask=pixel_values_lip_mask, + audio_emb=batch["audio_tensor"].to( + dtype=weight_dtype), + uncond_img_fwd=uncond_img_fwd, + uncond_audio_fwd=uncond_audio_fwd, + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean", + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean", + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ).mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % cfg.val.validation_steps == 0 or global_step==1: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + + log_validation( + accelerator=accelerator, + vae=vae, + net=net, + scheduler=val_noise_scheduler, + width=cfg.data.train_width, + height=cfg.data.train_height, + clip_length=cfg.data.n_sample_frames, + cfg=cfg, + save_dir=validation_dir, + global_step=global_step, + times=cfg.single_inference_times if cfg.single_inference_times is not None else None, + face_analysis_model_path=cfg.face_analysis_model_path + ) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "td": f"{t_data:.2f}s", + } + t_data_start = time.time() + progress_bar.set_postfix(**logs) + + if ( + global_step % cfg.checkpointing_steps == 0 + or global_step == cfg.solver.max_train_steps + ): + # save model + save_path = os.path.join( + checkpoint_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + delete_additional_ckpt(checkpoint_dir, 30) + accelerator.wait_for_everyone() + accelerator.save_state(save_path) + + # save model weight + unwrap_net = accelerator.unwrap_model(net) + if accelerator.is_main_process: + save_checkpoint( + unwrap_net, + module_dir, + "net", + global_step, + total_limit=30, + ) + if global_step >= cfg.solver.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + accelerator.end_training() + + +def load_config(config_path: str) -> dict: + """ + Loads the configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The configuration dictionary. + """ + + if config_path.endswith(".yaml"): + return OmegaConf.load(config_path) + if config_path.endswith(".py"): + return import_filename(config_path).cfg + raise ValueError("Unsupported format for config file") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", type=str, default="./configs/train/stage2.yaml" + ) + args = parser.parse_args() + + try: + config = load_config(args.config) + train_stage2_process(config) + except Exception as e: + logging.error("Failed to execute the training process: %s", e)