diff --git a/README.md b/README.md index a26462f..3fd0ca9 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,11 @@ This repo is the official implementation of "Generating Human Motion in 3D Scene ![pipeline](doc/pipeline.png) ## News +[2024/11/02] We release the training code. + [2024/10/21] We release the visualization code. -[2024/06/09] We first release the test & evaluation code. +[2024/06/09] We first release the test & evaluation code. ## Installation ```bash conda create -n most python=3.9 @@ -47,6 +49,15 @@ mkdir data ln -s /path/to/humanise data/HUMANISE ``` +### AMASS dataset +(Only needed if you want to train the models by yourself.) +1. Please follow [HUMOR](https://github.com/davrempe/humor/tree/main?tab=readme-ov-file#datasets) to download and preprocess AMASS dataset. +2. Link to data/ +```bash +ln -s /path/to/amass_processed data/amass_preprocess +``` + + ### SMPLX models 1. Download SMPLX models from [link](https://smpl-x.is.tue.mpg.de/). 2. Put the smplx folder under ```data/smpl_models``` folder: @@ -91,7 +102,7 @@ The generated results are shared in [link](https://drive.google.com/file/d/1zrpz We use [wis3d](https://pypi.org/project/wis3d/) lib to visualize the results. To prepare for the visualization: ```bash -python tools/visualizae_results.py -c configs/test/visualize.yaml +python tools/visualize_results.py -c configs/test/visualize.yaml ``` Then, in terminal: ```bash @@ -100,6 +111,26 @@ wis3d --vis_dir out/vis3d --host ${HOST} --port ${PORT} You can then visualize the results in ```${HOST}:${PORT}```. +# Train the models by yourself +## Pretrain on the AMASS dataset +Train the trajectory model: +```bash +python tools/train.net -c configs/train/trajgen/traj_amass.yaml task amass_traj +``` +Train the motion model: +```bash +python tools/train.net -c configs/train/motiongen/motion_amass.yaml task amass_motion +``` +The outputs and models will be saved in ```out/train/``` +## Finetune on the HUMANISE dataset +Train the trajectory model: +```bash +python tools/train.net -c configs/train/trajgen/traj_humanise.yaml task humanise_traj resume True resume_model_dir out/train/amass_traj/model +``` +Train the motion model: +```bash +python tools/train.net -c configs/train/motiongen/motion_humanise.yaml task humanise_motion resume True resume_model_dir out/train/amass_motion/model +``` # Citation ``` diff --git a/configs/dataset/amass.yaml b/configs/dataset/amass.yaml new file mode 100644 index 0000000..f4de331 --- /dev/null +++ b/configs/dataset/amass.yaml @@ -0,0 +1,22 @@ +train_dat: + name: AMASS + limit_size: -1 + split: train + +val_dat: + name: AMASS + limit_size: -1 + split: test + +test_dat: + name: AMASS + limit_size: -1 + split: test + +dat_cfg: + amass_root: data/amass_preprocess + max_motion_len: 120 + num_scene_points: 1024 + interval: 30 + sample_data_interval: 1 + preload: True \ No newline at end of file diff --git a/configs/train/default.yaml b/configs/train/default.yaml new file mode 100644 index 0000000..1a93ba3 --- /dev/null +++ b/configs/train/default.yaml @@ -0,0 +1,45 @@ +auto_config: [] + +trainer: + name: default +wrapper_cfg: + name: default + pre_methods: [] + post_methods: [] + sup_methods: [] + vis_methods: [] + +loss_weights: + place_holder: 0 + +metrics: [] + +# ====== Train/Val/Test dataset/dataloader settings ====== # +train: + epoch: 200 + batch_size: &batch_size 256 + shuffle: True + num_workers: 2 + optimizer: + optim: adam + lr: 0. + canonical_lr: 1.e-4 + canonical_bs: *batch_size + weight_decay: 0.0 + adamw_weight_decay: 0.01 + scheduler: + type: multi_step + milestones: [50, 100, 150] # if epoch == 200 + gamma: 0.5 + +val: + epoch: 1 + batch_size: 64 + shuffle: False + num_workers: 2 + +test: + epoch: 1 + batch_size: 64 + shuffle: False + num_workers: 2 diff --git a/configs/train/motiongen/motion_amass.yaml b/configs/train/motiongen/motion_amass.yaml new file mode 100644 index 0000000..a010861 --- /dev/null +++ b/configs/train/motiongen/motion_amass.yaml @@ -0,0 +1,57 @@ +task: auto + +dataset_cfg_path: "configs/dataset/amass.yaml" +coord: &coord az + +wrapper_cfg: + name: MotionDiffuserWrapper + smplx_model_type: amass + pre_methods: [] + normalizer: + name: NormalizerPoseMotion + file: out/release/normalize/amass_az_humanise_oc_transl_orient6d_pose6d.pkl + +eval_ep: 20 +metrics: ['recon_localpose', 'recon_trans', 'recon_orient'] + +loss_weights: + recon_trans: 1.0 + recon_orient_6d: 1.0 + recon_pose_6d: 10.0 + +train: + epoch: 200 + batch_size: &batch_size 256 + optimizer: + canonical_bs: *batch_size + +net_cfg: + coord: *coord + repr: motion # traj + name: DiffuserNetwork + k_sample: 10 + diffuser: + name: ObserConditionalDDPM + timesteps: 200 + pred_type: pred_x0 + obser: false + schedule_cfg: + beta: [0.0001, 0.01] + beta_schedule: cosine + s: 0.008 + model: + name: MotionFromSceneTextTrajVoxelV0 + d_l: 120 + d_x: 135 + d_betas: 10 + env_sensor: + name: EnvSensor + voxel_dim: 8 + radius: 2.0 + target_sensor: + name: TargetSensor + voxel_dim: 8 + traj_sensor: + name: TrajSensor + voxel_dim: 8 + radius: 1.0 diff --git a/configs/train/motiongen/motion_humanise.yaml b/configs/train/motiongen/motion_humanise.yaml new file mode 100644 index 0000000..3003e5c --- /dev/null +++ b/configs/train/motiongen/motion_humanise.yaml @@ -0,0 +1,68 @@ +task: auto + +# resume: auto +coord: &coord oc +resume_model_dir: auto + +dataset_cfg_path: configs/dataset/humanise_motion.yaml + +wrapper_cfg: + name: MotionDiffuserWrapper + smplx_model_type: humanise + pre_methods: ['clip_text'] + normalizer: + name: NormalizerPoseMotion + file: out/release/normalize/amass_az_humanise_oc_transl_orient6d_pose6d.pkl + +metrics: ['recon_localpose', 'recon_trans', 'recon_orient'] + +loss_weights: + recon_trans: 1.0 + recon_orient_6d: 1.0 + recon_pose_6d: 10.0 + +train: + epoch: 400 + batch_size: 128 + optimizer: + canonical_bs: 128 + scheduler: + type: multi_step + milestones: [50, 100, 150, 250, 300, 350] # if epoch == 200 + +net_cfg: + coord: *coord + repr: motion + name: DiffuserNetwork + k_sample: 10 + diffuser: + name: ObserConditionalDDPM + timesteps: 200 + pred_type: pred_x0 + obser: False + obser_type: start_motion + schedule_cfg: + beta: [0.0001, 0.01] + beta_schedule: cosine + s: 0.008 + model: + name: MotionFromSceneTextTrajVoxelV0 + d_l: 120 + d_x: 135 + d_betas: 10 + env_sensor: + name: EnvSensor + voxel_dim: 8 + radius: 2.0 + target_sensor: + name: TargetSensor + voxel_dim: 8 + traj_sensor: + name: TrajSensor + voxel_dim: 8 + radius: 1.0 + + optimizer: + name: default + planner: + name: default diff --git a/configs/train/trajgen/traj_amass.yaml b/configs/train/trajgen/traj_amass.yaml new file mode 100644 index 0000000..416ba06 --- /dev/null +++ b/configs/train/trajgen/traj_amass.yaml @@ -0,0 +1,52 @@ +task: auto + +dataset_cfg_path: "configs/dataset/amass.yaml" + +wrapper_cfg: + name: MotionDiffuserWrapper + smplx_model_type: amass + pre_methods: [] + normalizer: + name: NormalizerPoseMotion + file: out/release/normalize/amass_az_humanise_oc_transl_orient6d_pose6d.pkl + +eval_ep: 20 +metrics: ['recon_trans', 'recon_orient'] + +loss_weights: + recon_trans: 1.0 + recon_orient_6d: 1.0 + +train: + epoch: 200 + batch_size: 256 + optimizer: + canonical_bs: 256 + +net_cfg: + coord: az + repr: traj + name: DiffuserNetwork + k_sample: 10 + diffuser: + name: ObserConditionalDDPM + timesteps: 200 + pred_type: pred_x0 + obser: false + obser_type: start_traj + schedule_cfg: + beta: [0.0001, 0.01] + beta_schedule: cosine + s: 0.008 + model: + name: TrajFromSceneTextVoxelV0 + d_l: 120 + d_x: 9 + d_betas: 10 + env_sensor: + name: EnvSensor + voxel_dim: 8 + radius: 2.0 + target_sensor: + name: TargetSensor + voxel_dim: 8 diff --git a/configs/train/trajgen/traj_humanise.yaml b/configs/train/trajgen/traj_humanise.yaml new file mode 100644 index 0000000..a45dce8 --- /dev/null +++ b/configs/train/trajgen/traj_humanise.yaml @@ -0,0 +1,56 @@ +task: auto + +# resume: auto +resume_model_dir: auto +dataset_cfg_path: configs/dataset/humanise_motion.yaml + +wrapper_cfg: + name: MotionDiffuserWrapper + smplx_model_type: humanise + pre_methods: ['clip_text'] + normalizer: + name: NormalizerPoseMotion + file: out/release/normalize/amass_az_humanise_oc_transl_orient6d_pose6d.pkl + +metrics: ['recon_trans', 'recon_orient'] + +loss_weights: + recon_trans: 1.0 + recon_orient_6d: 1.0 + +train: + epoch: 400 + batch_size: 256 + optimizer: + canonical_bs: 256 + scheduler: + type: multi_step + milestones: [50, 100, 150, 250, 300, 350] + +net_cfg: + coord: oc + repr: traj + name: DiffuserNetwork + k_sample: 10 + diffuser: + name: ObserConditionalDDPM + timesteps: 200 + pred_type: pred_x0 + obser: false + obser_type: start_traj + schedule_cfg: + beta: [0.0001, 0.01] + beta_schedule: cosine + s: 0.008 + model: + name: TrajFromSceneTextVoxelV0 + d_l: 120 + d_x: 9 + d_betas: 10 + env_sensor: + name: EnvSensor + voxel_dim: 8 + radius: 2.0 + target_sensor: + name: TargetSensor + voxel_dim: 8 \ No newline at end of file diff --git a/lib/config/config.py b/lib/config/config.py index 1eb41dd..a81c1bd 100644 --- a/lib/config/config.py +++ b/lib/config/config.py @@ -26,9 +26,11 @@ def make_cfg(args): default_cfg_path = f"configs/{args.cfg_file.split('/')[1]}/default.yaml" if os.path.exists(default_cfg_path): cfg.merge_from_file(default_cfg_path) + cfg.merge_from_file(args.cfg_file) if 'dataset_cfg_path' in cfg.keys(): cfg.merge_from_file(cfg.dataset_cfg_path) - cfg.merge_from_file(args.cfg_file) + cfg.merge_from_file(args.cfg_file) + cfg.merge_from_list(getattr(args, 'opts', [])) # dirs if cfg.record_dir == 'auto': cfg.record_dir = f'out/train/{cfg.task}' @@ -41,7 +43,7 @@ def make_cfg(args): logger.warning('overwrite gpus and resume!') cfg.gpus = [0] cfg.resume = True - cfg.merge_from_list(getattr(args, 'opts', [])) + cfg.merge_from_list(getattr(args, 'opts', [])) cfg.is_train = not args.is_test # 1. Auto config devices diff --git a/lib/datasets/amass/amass_base.py b/lib/datasets/amass/amass_base.py new file mode 100644 index 0000000..e172ae3 --- /dev/null +++ b/lib/datasets/amass/amass_base.py @@ -0,0 +1,154 @@ +from pathlib import Path +import numpy as np +import torch +from torch.utils import data +from tqdm import tqdm + +from lib.utils import logger +from lib.datasets.make_dataset import DATASET + + +@DATASET.register() +class AMASS(data.Dataset): + def __init__(self, dat_cfg, split='train'): + super().__init__() + self.split = split + self.max_motion_len = dat_cfg.get('max_motion_len', 120) + self.num_scene_points = dat_cfg.get('num_scene_points', 1024) + self.interval = dat_cfg.get('interval', 60) + self.sample_data_interval = dat_cfg.get('sample_data_interval', 1) + + # file path + self.amass_root = Path(dat_cfg.amass_root) + + # load data path + self.datapaths = self._load_datapaths() + self._datapaths2meta() + + self.preload = dat_cfg.get('preload', True) + if self.preload: + self._preload_npz_datas() + + self._load_scene() + self.idx2meta = self.idx2meta[::self.sample_data_interval] + + + def _load_datapaths(self): + test_splits = ['TotalCapture'] + # test_splits = [] + data_paths = [] + for dir in self.amass_root.iterdir(): + if self.split == 'train' and dir.name not in test_splits: + data_paths += dir.glob('*/*.npz') + elif self.split == 'test' and dir.name in test_splits: + data_paths += dir.glob('*/*.npz') + return data_paths + + + def _datapaths2meta(self): + self.idx2meta = [] + for datapath in self.datapaths: + nframes = int(datapath.stem.split('_')[-4]) + for start in range(0, max(0, nframes - self.max_motion_len) + 1, self.interval): + end = min(start + self.max_motion_len, nframes) + if end - start < 20: + continue + meta = { + 'start': start, + 'end': end, + 'datapath': str(datapath), + } + self.idx2meta.append(meta) + + def _preload_npz_datas(self): + self.npz_datas = {} + for datapath in tqdm(self.datapaths): + npz_data = np.load(datapath, allow_pickle=True) + npz_dict = { + 'joints': npz_data['joints'], + 'trans': npz_data['trans'], + 'root_orient': npz_data['root_orient'], + 'betas': npz_data['betas'], + 'pose_body': npz_data['pose_body'], + 'floor_height': npz_data['floor_height'], + } + self.npz_datas[str(datapath)] = npz_dict + + + def _load_scene(self): + self.floor_normal = np.zeros((self.num_scene_points, 3)) + self.floor_normal[:, 2] = 1.0 + + + def __getitem__(self, index): + meta = self.idx2meta[index] + data = {} + meta.update({ + 'idx': index, + 'split': self.split, + 'dataname': 'amass', + }) + smplx_params, motion_mask, joints, floor_height = self.get_motion(meta) + smplx_params = {k: torch.FloatTensor(v) for k, v in smplx_params.items()} + xyz_az = self.get_scene(joints, floor_height) + + data.update({ + 'meta': meta, + 'smplx_params_az': smplx_params, + 'motion_mask': torch.BoolTensor(motion_mask), + 'xyz_az': torch.FloatTensor(xyz_az), + 'normal': torch.FloatTensor(self.floor_normal), + 'betas': smplx_params['betas'][0], + }) + return data + + + def __len__(self): + return len(self.idx2meta) + + + def get_motion(self, meta): + if self.preload: + npz_data = self.npz_datas[meta['datapath']] + else: + npz_data = np.load(meta['datapath'], allow_pickle=True) + st, ed = meta['start'], meta['end'] + joints = npz_data['joints'][st:ed] + + smplx_params = { + 'transl': npz_data['trans'][st:ed], + 'global_orient': npz_data['root_orient'][st:ed], + 'betas': npz_data['betas'], + 'body_pose': npz_data['pose_body'][st:ed], + } + # pad motion + S = len(joints) + meta.update({ + 'm_len': S, + }) + smplx_params['betas'] = np.tile(smplx_params['betas'], (S, 1)) + if S == self.max_motion_len: + pass + elif S > self.max_motion_len: + for k, d in smplx_params.items(): + smplx_params[k] = d[: self.max_motion_len] + joints = joints[: self.max_motion_len] + else: + for k, d in smplx_params.items(): + padding = np.tile(d[-1], (self.max_motion_len - S, 1)) + smplx_params[k] = np.concatenate([d, padding]) + padding = np.tile(joints[-1], (self.max_motion_len - S, 1, 1)) + joints = np.concatenate([joints, padding]) + motion_mask = np.zeros(self.max_motion_len, dtype=bool) + motion_mask[S:] = 1 + + return smplx_params, motion_mask, joints, npz_data['floor_height'] + + + def get_scene(self, joints, floor_height): + joints = joints.reshape(-1, 3) + joints_mean = joints.mean(axis=0) + radius = np.linalg.norm(joints - joints_mean, axis=1).max() + floor = (np.random.rand(self.num_scene_points, 3) * 2.0 - 1.0) * radius + joints_mean + floor[:, 2] = - floor_height + return floor \ No newline at end of file diff --git a/lib/datasets/make_dataset.py b/lib/datasets/make_dataset.py index 7c322d7..65e1c32 100644 --- a/lib/datasets/make_dataset.py +++ b/lib/datasets/make_dataset.py @@ -8,6 +8,7 @@ from lib.utils.registry import Registry DATASET = Registry('dataset') from .humanise.humanise_motion import HumaniseMotion +from .amass.amass_base import AMASS def make_dataset(cfg, split='train'): diff --git a/lib/evaluators/__init__.py b/lib/evaluators/__init__.py new file mode 100644 index 0000000..d4902ab --- /dev/null +++ b/lib/evaluators/__init__.py @@ -0,0 +1 @@ +from .make_evaluator import make_evaluator \ No newline at end of file diff --git a/lib/evaluators/make_evaluator.py b/lib/evaluators/make_evaluator.py new file mode 100644 index 0000000..57fa13a --- /dev/null +++ b/lib/evaluators/make_evaluator.py @@ -0,0 +1,54 @@ +import numpy as np +import torch +from lib.utils import logger +from lib.utils.comm import all_gather +from lib.utils.smplx_utils import load_smpl_faces +from .metrics import METRIC + + +class Evaluator: + def __init__(self, cfg): + self.cfg = cfg + self.local_rank = cfg.local_rank + self.metric_func_names = cfg.metrics + self.body_faces = torch.from_numpy(load_smpl_faces()).cuda() + self.coord = cfg.net_cfg.coord + + logger.info(f"Metrics Functions: {self.metric_func_names}") + self.init_metric_stats() + + def init_metric_stats(self): + """ Call at initialization and end """ + self.metric_stats = {} + + def update(self, k, v_list: list): + """ v_list need to be List of simple scalars """ + if k in self.metric_stats: + self.metric_stats[k].extend(v_list) + else: + self.metric_stats[k] = v_list + + def evaluate(self, batch): + for k in self.metric_func_names: + METRIC.get(f'{k}_metric')(self, batch) + + def summarize(self): + if len(self.metric_stats) == 0: + return {}, {} + + values = [np.array(all_gather(self.metric_stats[k])).flatten() for k in self.metric_stats] + metrics_raw = {k: vs for k, vs in zip(self.metric_stats, values)} + metrics = {k: np.mean(vs) for k, vs in zip(self.metric_stats, values)} + + message = f"Avg-over {len(values[0])}. Metrics: " + for k, v in metrics.items(): + message += f'{k}: {v:.4f} ; ' + if self.local_rank == 0: + logger.info(message) + + self.init_metric_stats() + return metrics, metrics_raw + + +def make_evaluator(cfg): + return Evaluator(cfg) diff --git a/lib/evaluators/metrics.py b/lib/evaluators/metrics.py new file mode 100644 index 0000000..3953063 --- /dev/null +++ b/lib/evaluators/metrics.py @@ -0,0 +1,33 @@ +from lib.utils.net_utils import L1_loss, to_list +from lib.utils.registry import Registry +METRIC = Registry('metric') + + +@METRIC.register() +def recon_trans_metric(evaluator, batch): + motion_mask = batch['motion_mask'] + # transl + coord = evaluator.coord + l1_loss = L1_loss(batch[f'smplx_params_{coord}']['transl'], batch['recon_transl']).mean(-1) * (~motion_mask) + mean = l1_loss.sum(-1) / (~motion_mask).sum(-1) + evaluator.update('R-Trans', to_list(mean)) + + +@METRIC.register() +def recon_orient_metric(evaluator, batch): + motion_mask = batch['motion_mask'] + # global orient + coord = evaluator.coord + l1_loss = L1_loss(batch[f'smplx_params_{coord}']['global_orient'], batch['recon_orient']).mean(-1) * (~motion_mask) + mean = l1_loss.sum(-1) / (~motion_mask).sum(-1) + evaluator.update('R-Orient', to_list(mean)) + + +@METRIC.register() +def recon_localpose_metric(evaluator, batch): + motion_mask = batch['motion_mask'] + coord = evaluator.coord + # body pose + l1_loss = L1_loss(batch[f'smplx_params_{coord}']['body_pose'], batch['recon_body_pose']).mean(-1) * (~motion_mask) + mean = l1_loss.sum(-1) / (~motion_mask).sum(-1) + evaluator.update('R-BodyPose', to_list(mean)) \ No newline at end of file diff --git a/lib/train/__init__.py b/lib/train/__init__.py new file mode 100644 index 0000000..a11d853 --- /dev/null +++ b/lib/train/__init__.py @@ -0,0 +1,4 @@ +from .make_trainer import make_trainer +from .optimizer import make_optimizer +from .scheduler import make_lr_scheduler, set_lr_scheduler +from .recorder import make_recorder \ No newline at end of file diff --git a/lib/train/make_trainer.py b/lib/train/make_trainer.py new file mode 100644 index 0000000..9facebf --- /dev/null +++ b/lib/train/make_trainer.py @@ -0,0 +1,14 @@ +from lib.utils import logger +from lib.wrapper.wrapper import WRAPPER +from .trainer import Trainer + +def _wrapper_factory(cfg, network): + wrapper_cfg = cfg.wrapper_cfg + logger.info(f'Making network wrapper @ {wrapper_cfg.name}') + network_wrapper = WRAPPER.get(wrapper_cfg.name)(network, cfg, wrapper_cfg) + return network_wrapper + +def make_trainer(cfg, network): + logger.info(f"Making trainer @ {cfg.trainer.name}") + network = _wrapper_factory(cfg, network) + return Trainer(network, cfg) diff --git a/lib/train/optimizer.py b/lib/train/optimizer.py new file mode 100644 index 0000000..9aaa9c6 --- /dev/null +++ b/lib/train/optimizer.py @@ -0,0 +1,38 @@ +import os +import torch +from .optimizers.radam import RAdam +from lib.utils import logger + +_optimizer_factory = { + 'adam': torch.optim.Adam, + 'adamw': torch.optim.AdamW, + 'radam': RAdam, + 'sgd': torch.optim.SGD +} + + +def make_optimizer(cfg, net, is_train, net_params=None): + cfg_optimizer = cfg.train.optimizer if is_train else cfg.test.optimizer + optim = cfg_optimizer.optim + + # set-up learning rate + lr = cfg_optimizer.lr + if lr == 0: + world_bs = cfg.train.batch_size + lr = cfg_optimizer.canonical_lr * (world_bs / cfg_optimizer.canonical_bs) + logger.info(f"lr {lr:.2e}, world batchsize {world_bs}") + + adam_decay = cfg_optimizer.weight_decay + adamw_decay = cfg_optimizer.adamw_weight_decay + + parameters = net_params if net_params else net.parameters() + if optim == 'adam': + optimizer = _optimizer_factory[optim]( + parameters, lr=lr, weight_decay=adam_decay) + elif optim == 'adamw': + optimizer = _optimizer_factory[optim]( + parameters, lr=lr, weight_decay=adamw_decay) + else: + raise ValueError + + return optimizer diff --git a/lib/train/optimizers/lr_scheduler.py b/lib/train/optimizers/lr_scheduler.py new file mode 100644 index 0000000..4e6f0cf --- /dev/null +++ b/lib/train/optimizers/lr_scheduler.py @@ -0,0 +1,75 @@ +from bisect import bisect_right +from collections import Counter + +import torch + + +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer, + milestones, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=5, + warmup_method="linear", + last_epoch=-1, + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + warmup_factor = 1 + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = float(self.last_epoch) / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + return [ + base_lr + * warmup_factor + * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + +class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + super(MultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups] + + +class ExponentialLR(torch.optim.lr_scheduler._LRScheduler): + + def __init__(self, optimizer, decay_epochs, gamma=0.1, last_epoch=-1): + self.decay_epochs = decay_epochs + self.gamma = gamma + super(ExponentialLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + return [base_lr * self.gamma ** (self.last_epoch / self.decay_epochs) + for base_lr in self.base_lrs] diff --git a/lib/train/optimizers/radam.py b/lib/train/optimizers/radam.py new file mode 100644 index 0000000..2934e8e --- /dev/null +++ b/lib/train/optimizers/radam.py @@ -0,0 +1,246 @@ +import math +import torch +from torch.optim.optimizer import Optimizer, required + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + p.data.copy_(p_data_fp32) + + return loss + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + + # more conservative since it's an approximated value + if N_sma >= 5: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + p.data.copy_(p_data_fp32) + elif self.degenerated_to_sgd: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + p.data.copy_(p_data_fp32) + + return loss + + +class AdamW(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, warmup = warmup) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + if group['warmup'] > state['step']: + scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] + else: + scheduled_lr = group['lr'] + + step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) + + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + + p.data.copy_(p_data_fp32) + + return loss + diff --git a/lib/train/recorder.py b/lib/train/recorder.py new file mode 100644 index 0000000..746347d --- /dev/null +++ b/lib/train/recorder.py @@ -0,0 +1,104 @@ +from collections import deque, defaultdict +import torch +from tensorboardX import SummaryWriter +import os + +from termcolor import colored +from lib.utils import logger +from pathlib import Path + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + + def update(self, value): + self.deque.append(value) + self.count += 1 + self.total += value + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque)) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + +class Recorder(object): + def __init__(self, cfg): + self.cfg = cfg + self.local_rank = cfg.local_rank + + if cfg.local_rank > 0: + return + + self.writer = SummaryWriter(log_dir=cfg.record_dir) + logger.info(f"Record at {cfg.record_dir}") + + # scalars + self.epoch = 0 + self.step = 0 + self.loss_stats = defaultdict(SmoothedValue) + self.batch_time = SmoothedValue() + self.data_time = SmoothedValue() + + def update_loss_stats(self, loss_dict): + if self.local_rank > 0: + return + for k, v in loss_dict.items(): + self.loss_stats[k].update(v.detach().cpu()) + + def record(self, prefix, step=-1, loss_stats=None, image_stats=None): + if self.local_rank > 0: + return + + pattern = prefix + '/{}' + step = step if step >= 0 else self.step + loss_stats = loss_stats if loss_stats else self.loss_stats + + for k, v in loss_stats.items(): + if isinstance(v, SmoothedValue): + self.writer.add_scalar(pattern.format(k), v.median, step) + else: + self.writer.add_scalar(pattern.format(k), v, step) + + def state_dict(self): + if self.local_rank > 0: + return + scalar_dict = {} + scalar_dict['step'] = self.step + return scalar_dict + + def load_state_dict(self, scalar_dict): + if self.local_rank > 0: + return + self.step = scalar_dict['step'] + + def __str__(self): + if self.local_rank > 0: + return + loss_state = [] + for k, v in self.loss_stats.items(): + loss_state.append('{}: {:.4f}'.format(k, v.avg)) + loss_state = ' '.join(loss_state) + + recording_state = ' '.join(['epoch: {}', 'step: {}', '{}', 'data: {:.4f}', 'batch: {:.4f}']) + return recording_state.format(self.epoch, self.step, loss_state, self.data_time.avg, self.batch_time.avg) + + +def make_recorder(cfg): + return Recorder(cfg) diff --git a/lib/train/scheduler.py b/lib/train/scheduler.py new file mode 100644 index 0000000..f97c143 --- /dev/null +++ b/lib/train/scheduler.py @@ -0,0 +1,24 @@ +from collections import Counter +from .optimizers.lr_scheduler import WarmupMultiStepLR, MultiStepLR, ExponentialLR + + +def make_lr_scheduler(cfg, optimizer): + cfg_scheduler = cfg.train.scheduler + if cfg_scheduler.type == 'multi_step': + scheduler = MultiStepLR(optimizer, + milestones=cfg_scheduler.milestones, + gamma=cfg_scheduler.gamma) + elif cfg_scheduler.type == 'exponential': + scheduler = ExponentialLR(optimizer, + decay_epochs=cfg_scheduler.decay_epochs, + gamma=cfg_scheduler.gamma) + return scheduler + + +def set_lr_scheduler(cfg, scheduler): + cfg_scheduler = cfg.train.scheduler + if cfg_scheduler.type == 'multi_step': + scheduler.milestones = Counter(cfg_scheduler.milestones) + elif cfg_scheduler.type == 'exponential': + scheduler.decay_epochs = cfg_scheduler.decay_epochs + scheduler.gamma = cfg_scheduler.gamma diff --git a/lib/train/trainer.py b/lib/train/trainer.py new file mode 100644 index 0000000..3cba2fc --- /dev/null +++ b/lib/train/trainer.py @@ -0,0 +1,109 @@ +import time +import datetime +import torch +from torch.nn.parallel import DistributedDataParallel +import tqdm + +from lib.utils import logger +from lib.utils.net_utils import to_cuda + + +class Trainer(object): + def __init__(self, network, cfg): + self.cfg = cfg.clone() + network.cuda(cfg.local_rank) + if cfg.distributed: + network = DistributedDataParallel( + network, device_ids=[cfg.local_rank], output_device=cfg.local_rank, find_unused_parameters=True) + self.network = network + self.local_rank = cfg.local_rank + + def reduce_loss_stats(self, loss_stats): + reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()} + return reduced_losses + + def train(self, epoch, data_loader, optimizer, recorder): + if self.local_rank == 0: + logger.info(f"Training: Epoch {epoch}, {self.cfg.task}") + self.network.train() + end = time.time() + + if self.cfg.distributed: + data_loader.sampler.set_epoch(epoch) + + max_iter = len(data_loader) + for iteration, batch in enumerate(data_loader): + data_time = time.time() - end + iteration = iteration + 1 + + batch = to_cuda(batch) + batch['cur_epoch'] = epoch + batch = self.network(batch) + + # training stage: loss; optimizer; scheduler + optimizer.zero_grad() + loss = batch["loss"].mean() + loss.backward() + torch.nn.utils.clip_grad_value_(self.network.parameters(), 0.5) + optimizer.step() + + if self.local_rank > 0: + continue + + # data recording stage: loss_stats, time, image_stats + recorder.step += 1 + loss_stats = self.reduce_loss_stats(batch["loss_stats"]) + recorder.update_loss_stats(loss_stats) + batch_time = time.time() - end + end = time.time() + recorder.batch_time.update(batch_time) + recorder.data_time.update(data_time) + + if iteration % self.cfg.log_interval == 0 or iteration == (max_iter - 1): + # print training state + eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + lr = optimizer.param_groups[0]['lr'] + memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + + training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}']) + training_state = training_state.format(eta_string, str(recorder), lr, memory) + print(training_state) + + if iteration % self.cfg.rec_interval == 0 or iteration == (max_iter - 1): + recorder.record('train') + + @torch.no_grad() + def val(self, epoch, data_loader, evaluator=None, recorder=None): + if self.cfg.local_rank == 0: + logger.info(f"Validation / Testing: Epoch {epoch}") + self.network.eval() + torch.cuda.empty_cache() + val_loss_stats = {} + data_size = len(data_loader) + + for batch in tqdm.tqdm(data_loader): + batch = to_cuda(batch) + batch = self.network(batch) + + loss_stats = self.reduce_loss_stats(batch["loss_stats"]) + if evaluator is not None: + evaluator.evaluate(batch) + + for k, v in loss_stats.items(): + val_loss_stats.setdefault(k, 0) + val_loss_stats[k] += v + + loss_state = [] + for k in val_loss_stats.keys(): + val_loss_stats[k] /= data_size + loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) + + if evaluator is not None: + result, result_raw = evaluator.summarize() + if recorder: + recorder.record('val_metric', epoch, result) + + if recorder: + recorder.record('val', epoch, val_loss_stats) + diff --git a/lib/utils/smplx_utils.py b/lib/utils/smplx_utils.py index 76b13b6..0b40439 100644 --- a/lib/utils/smplx_utils.py +++ b/lib/utils/smplx_utils.py @@ -3,6 +3,9 @@ import pickle import torch.nn as nn import smplx +from smplx import SMPL, SMPLH, SMPLX +from smplx.utils import Struct +from smplx.vertex_ids import vertex_ids SMPLH_PARENTS = torch.tensor([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, @@ -95,7 +98,149 @@ def forward(self, ) return bm_out - + + +class BodyModel(nn.Module): + """ + Wrapper around SMPLX body model class. + modified by Zehong Shen + """ + + def __init__(self, + model_path, + num_betas=16, + use_vtx_selector=False, + model_type='smplh'): + super().__init__() + ''' + Creates the body model object at the given path. + + :param bm_path: path to the body model pkl file + :param model_type: one of [smpl, smplh, smplx] + :param use_vtx_selector: if true, returns additional vertices as joints that correspond to OpenPose joints + ''' + self.use_vtx_selector = use_vtx_selector + cur_vertex_ids = None + if self.use_vtx_selector: + cur_vertex_ids = vertex_ids[model_type] + data_struct = None + if '.npz' in model_path: + # smplx does not support .npz by default, so have to load in manually + smpl_dict = np.load(model_path, encoding='latin1') + data_struct = Struct(**smpl_dict) + # print(smpl_dict.files) + if model_type == 'smplh': + data_struct.hands_componentsl = np.zeros((0)) + data_struct.hands_componentsr = np.zeros((0)) + data_struct.hands_meanl = np.zeros((15 * 3)) + data_struct.hands_meanr = np.zeros((15 * 3)) + V, D, B = data_struct.shapedirs.shape + data_struct.shapedirs = np.concatenate([data_struct.shapedirs, np.zeros( + (V, D, SMPL.SHAPE_SPACE_DIM-B))], axis=-1) # super hacky way to let smplh use 16-size beta + kwargs = { + 'model_type': model_type, + 'data_struct': data_struct, + 'num_betas': num_betas, + 'vertex_ids': cur_vertex_ids, + 'use_pca': False, + 'flat_hand_mean': True, + # - enable variable batchsize, since we don't need module variable - # + 'create_body_pose': False, + 'create_betas': False, + 'create_global_orient': False, + 'create_transl': False, + 'create_left_hand_pose': False, + 'create_right_hand_pose': False, + } + assert(model_type in ['smpl', 'smplh', 'smplx']) + if model_type == 'smpl': + self.bm = SMPL(model_path, **kwargs) + self.num_joints = SMPL.NUM_JOINTS + elif model_type == 'smplh': + self.bm = SMPLH(model_path, **kwargs) + self.num_joints = SMPLH.NUM_JOINTS + elif model_type == 'smplx': + self.bm = SMPLX(model_path, **kwargs) + self.num_joints = SMPLX.NUM_JOINTS + + self.model_type = model_type + + self.hand_pose_dim = self.bm.num_pca_comps if self.bm.use_pca else 3 * self.bm.NUM_HAND_JOINTS + + def forward(self, + betas=None, + global_orient=None, + transl=None, + body_pose=None, + left_hand_pose=None, + right_hand_pose=None, + expression=None, + jaw_pose=None, + leye_pose=None, + reye_pose=None, + **kwargs): + + device, dtype = self.bm.shapedirs.device, self.bm.shapedirs.dtype + + model_vars = [betas, global_orient, body_pose, transl, + expression, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose] + batch_size = 1 + for var in model_vars: + if var is None: + continue + batch_size = max(batch_size, len(var)) + + if global_orient is None: + global_orient = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if body_pose is None: + body_pose = torch.zeros(3 * self.bm.NUM_BODY_JOINTS, device=device, + dtype=dtype)[None].expand(batch_size, -1).contiguous() + if left_hand_pose is None: + left_hand_pose = torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[ + None].expand(batch_size, -1).contiguous() + if right_hand_pose is None: + right_hand_pose = torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[ + None].expand(batch_size, -1).contiguous() + if jaw_pose is None: + jaw_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if leye_pose is None: + leye_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if reye_pose is None: + reye_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device) + if expression is None: + expression = torch.zeros([batch_size, self.bm.num_expression_coeffs], dtype=dtype, device=device) + if betas is None: + betas = torch.zeros([batch_size, self.bm.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + bm_out = self.bm( + betas=betas, + global_orient=global_orient, + body_pose=body_pose, + left_hand_pose=left_hand_pose, + right_hand_pose=right_hand_pose, + transl=transl, + expression=expression, + jaw_pose=jaw_pose, + leye_pose=leye_pose, + reye_pose=reye_pose, + **kwargs + ) + + return bm_out + + def forward_motion(self, **kwargs): + B, W, _ = kwargs['pose_body'].shape + kwargs = {k: v.reshape(B*W, v.shape[-1]) for k, v in kwargs.items()} + + smpl_opt = self.forward(**kwargs) + smpl_opt.vertices = smpl_opt.vertices.reshape(B, W, -1, 3) + smpl_opt.joints = smpl_opt.joints.reshape(B, W, -1, 3) + + return smpl_opt + + def make_smplx(type='humanise', **kwargs,): if type == 'humanise': @@ -107,6 +252,11 @@ def make_smplx(type='humanise', **kwargs,): num_betas=10, use_pca=False, ) + elif type == 'amass': + gender = kwargs.get('gender', 'neutral') + num_betas = kwargs.get('num_betas', 16) + model_path = f'data/smpl_models/smplh/{gender}/model.npz' + model = BodyModel(model_path=model_path, num_betas=num_betas) else: raise NotImplementedError diff --git a/lib/wrapper/loss.py b/lib/wrapper/loss.py new file mode 100644 index 0000000..16748ff --- /dev/null +++ b/lib/wrapper/loss.py @@ -0,0 +1,39 @@ +import torch + +from lib.utils.net_utils import L1_loss, L2_loss, cross_entropy +from lib.utils.registry import Registry +LOSS = Registry('loss') + + + +@LOSS.register() +def cal_recon_trans_loss(wrapper, batch, loss_stats): + coord = wrapper.coord + l1_loss = cal_L1_with_mask(wrapper, batch[f'smplx_params_{coord}']['transl'], batch['recon_transl'], batch['motion_mask']) + loss_stats['recon_trans'] = l1_loss.detach().cpu() + return l1_loss + +@LOSS.register() +def cal_recon_orient_6d_loss(wrapper, batch, loss_stats): + l1_loss = cal_L1_with_mask(wrapper, batch['orient_6d'], batch['recon_orient_6d'], batch['motion_mask']) + loss_stats['recon_orient_6d'] = l1_loss.detach().cpu() + return l1_loss + +@LOSS.register() +def cal_recon_pose_6d_loss(wrapper, batch, loss_stats): + l1_loss = cal_L1_with_mask(wrapper, batch['body_pose_6d'], batch['recon_body_pose_6d'], batch['motion_mask']) + loss_stats['recon_body_pose_6d'] = l1_loss.detach().cpu() + return l1_loss + + +def cal_L1_with_mask(wrapper, x, y, mask): + ''' + Input: + x, y: (B, S, D) + mask: (B, S) + Output: + l1_loss: (B) + ''' + l1_loss = L1_loss(x, y).mean(-1) * (~mask) + l1_loss = l1_loss.sum(-1) / (~mask).sum(-1) + return l1_loss \ No newline at end of file diff --git a/lib/wrapper/supervision.py b/lib/wrapper/supervision.py new file mode 100644 index 0000000..204e967 --- /dev/null +++ b/lib/wrapper/supervision.py @@ -0,0 +1,67 @@ +from lib.utils.registry import Registry +import torch +from lib.wrapper.preprocess import axis_angle_to_rot_6d +SUP = Registry('supervision') + + +@SUP.register() +def get_gt_smplx_params(wrapper, batch): + smplx_params = batch['smplx_params'] = batch.get('smplx_params_oc', batch['smplx_params_hm']) + batch['orient_6d'] = axis_angle_to_rot_6d(smplx_params['global_orient']) + batch['body_pose_6d'] = axis_angle_to_rot_6d(smplx_params['body_pose']) + + +@SUP.register() +def get_object2pelvis(wrapper, batch): + target_pelvis_xy = [] + for b in range(len(batch['meta'])): + anchor = batch['meta'][b]['anchor'] + target_pelvis_xy.append(batch['pelvis'][b, anchor, :2]) + + target_pelvis_xy = torch.stack(target_pelvis_xy, dim=0) + batch['gt_anchor_pelvis_xy'] = target_pelvis_xy + obj2pelivs_dist = (target_pelvis_xy[:, None] - batch['object_verts'][:, :, :2]).pow(2).sum(-1) # (B, N) + obj2pelivs_dist = obj2pelivs_dist ** 2 + dmin = obj2pelivs_dist.min(-1)[0][:, None] + dmax = obj2pelivs_dist.max(-1)[0][:, None] + norm_dist = (obj2pelivs_dist - dmin) / (dmax - dmin) + obj_verts_prob = 1 - norm_dist + batch['gt_pointheat'] = obj_verts_prob + # mask_9 = batch['gt_pointheat'][0] > 0.9 + # mask_8 = batch['gt_pointheat'][0] > 0.8 + # from lib.utils.vis3d_utils import make_vis3d + # vis3d = make_vis3d(None, 'debug-heat') + # vis3d.add_point_cloud(batch['object_verts'][0], name='all') + # vis3d.add_point_cloud(batch['object_verts'][0][mask_9], name='9') + # vis3d.add_point_cloud(batch['object_verts'][0][mask_8], name='8') + pass + +@SUP.register() +def get_gt_zero_vertices(wrapper, batch): + B, S, _ = batch['smplx_params']['transl'].shape + smplx_params = { + 'betas': batch['smplx_params']['betas'], + 'body_pose': batch['smplx_params']['body_pose'], + } + smplx_params = {k: v.reshape(B * S, -1) for k, v in smplx_params.items()} + body_opt = wrapper.smplx_model(**smplx_params, return_verts=True) + body_vertices = body_opt.vertices + batch.update({ + 'gt_zero_verts': body_vertices.reshape(B, S, -1, 3), + }) + +@SUP.register() +def get_gt_vertices(wrapper, batch): + B, S, _ = batch['smplx_params']['transl'].shape + smplx_params = { + 'betas': batch['smplx_params']['betas'], + 'transl': batch['smplx_params']['transl'], + 'global_orient': batch['smplx_params']['global_orient'], + 'body_pose': batch['smplx_params']['body_pose'], + } + smplx_params = {k: v.reshape(B * S, -1) for k, v in smplx_params.items()} + body_opt = wrapper.smplx_model(**smplx_params, return_verts=True) + batch.update({ + 'gt_verts': body_opt.vertices.reshape(B, S, -1, 3), + 'gt_joints': body_opt.joints.reshape(B, S, -1, 3), + }) \ No newline at end of file diff --git a/lib/wrapper/wrapper.py b/lib/wrapper/wrapper.py index 83e8606..099e1d3 100644 --- a/lib/wrapper/wrapper.py +++ b/lib/wrapper/wrapper.py @@ -4,14 +4,116 @@ import pickle import clip +from lib.utils import logger from lib.utils.smplx_utils import make_smplx, load_smpl_faces from lib.utils.normalize import make_normalizer from lib.utils.registry import Registry WRAPPER = Registry('wrapper') +from .loss import LOSS +from .supervision import SUP from .preprocess import PRE from .postprocess import POST from .two_stage import * + +@WRAPPER.register() +class Wrapper(nn.Module): + def __init__(self, net, cfg, + wrapper_cfg): + super().__init__() + # config + self.cfg = cfg + self.vis3d = None + self.loss_weights = {k: v for k, v in cfg.loss_weights.items() if v > 0} + self.pre_methods = wrapper_cfg.pre_methods + self.post_methods = wrapper_cfg.post_methods + self.sup_methods = wrapper_cfg.sup_methods + self.vis_methods = wrapper_cfg.vis_methods + logger.info(f'Preprocess method: {self.pre_methods}') + logger.info(f'Postprocess method: {self.post_methods}') + logger.info(f'Supervision method: {self.sup_methods}') + logger.info(f'Visualization method: {self.vis_methods}') + + # main network + self.net = net + self.inference = False + + # smplx model + self.smplx_model = make_smplx(wrapper_cfg.smplx_model_type) + self.smplx_faces = torch.from_numpy(load_smpl_faces()) + + self._build_text_model(wrapper_cfg) + + def _build_text_model(self, cfg): + # clip + self.clip_model, clip_preprocess = clip.load('ViT-B/32', device='cpu', + jit=False) # Must set jit=False for training + self.clip_model = self.clip_model.float() + self.clip_model.eval() + for p in self.clip_model.parameters(): + p.requires_grad_(False) + + def forward(self, batch, inference=False, compute_supervision=True, compute_loss=True): + self.inference = inference + self.preprocess(batch) + if inference: + self.net.inference(batch) + else: + self.net(batch) + self.postprocess(batch) + + if compute_supervision: + self.compute_supervision(batch) + if compute_loss: + self.compute_loss(batch) + + return batch # important for DDP + + def preprocess(self, batch): + for name in self.pre_methods: + PRE.get(f'process_{name}')(self, batch) + + def postprocess(self, batch): + for name in self.post_methods: + POST.get(f'get_{name}')(self, batch) + + def compute_supervision(self, batch): + for name in self.sup_methods: + SUP.get(f'get_{name}')(self, batch) + + def compute_loss(self, batch): + B = len(batch['meta']) + loss = 0. + loss_stats = {} + for k, v in self.loss_weights.items(): + cur_loss = v * LOSS.get(f'cal_{k}_loss')(self, batch, loss_stats) + assert cur_loss.shape[0] == B + loss += cur_loss + loss_stats.update({'loss_weighted_sum': loss.detach().cpu()}) + batch.update({"loss": loss, "loss_stats": loss_stats}) + + +@WRAPPER.register() +class MotionDiffuserWrapper(Wrapper): + def __init__(self, net, cfg, wrapper_cfg) -> None: + super().__init__(net, cfg, wrapper_cfg) + self.coord = cfg.net_cfg.coord + self.normalizer = None + if wrapper_cfg.get('normalizer', None) is not None: + self.normalizer = self._build_normalizer(wrapper_cfg.normalizer) + + def _build_normalizer(self, cfg): + with open(cfg.file, 'rb') as f: + fdata = pickle.load(f) + xmin, xmax = fdata['xmin'], fdata['xmax'] + return make_normalizer(cfg.name, (xmin, xmax)) + + def forward(self, batch, inference=False, compute_supervision=True, compute_loss=True): + if self.normalizer is not None: + batch['normalizer'] = self.normalizer + return super().forward(batch, inference, compute_supervision, compute_loss) + + @WRAPPER.register() class TwoStageWrapper(nn.Module): def __init__(self, path_net, motion_net, cfg, diff --git a/tools/train_net.py b/tools/train_net.py new file mode 100644 index 0000000..3aa5d3d --- /dev/null +++ b/tools/train_net.py @@ -0,0 +1,59 @@ +from lib.utils import offscreen_flag +import argparse +from lib.config import make_cfg, save_cfg +from lib.datasets import make_data_loader +from lib.networks import make_network +from lib.evaluators import make_evaluator +from lib.train import make_trainer, make_recorder, make_lr_scheduler, set_lr_scheduler, make_optimizer +from lib.utils.comm import setup_distributed, clear_directory_for_training +from lib.utils.net_utils import save_network, load_network + +import torch +torch.autograd.set_detect_anomaly(True) + + +def train(cfg): + data_loader = make_data_loader(cfg, split='train') + val_loader = make_data_loader(cfg, split='test') + + resume = cfg.resume + if cfg.local_rank == 0: + clear_directory_for_training(cfg.record_dir, resume) + save_cfg(cfg, resume) + + network = make_network(cfg) + trainer = make_trainer(cfg, network) + recorder = make_recorder(cfg) + evaluator = make_evaluator(cfg) + # optimizer + optimizer = make_optimizer(cfg, network, is_train=True) + # scheduler + if cfg.train.get('scheduler', None) is not None: + scheduler = make_lr_scheduler(cfg, optimizer) + set_lr_scheduler(cfg, scheduler) + + epoch_start = load_network(network, resume, cfg, epoch=-1) + for epoch in range(epoch_start+1, cfg.train.epoch+2): + recorder.epoch = epoch + trainer.train(epoch, data_loader, optimizer, recorder) + if cfg.train.get('scheduler', None) is not None: + scheduler.step() + # save + if epoch % cfg.save_ep == 0 and cfg.local_rank == 0: + save_network(network, cfg.model_dir, epoch) + # eval + if epoch % cfg.eval_ep == 0 and cfg.local_rank == 0: + trainer.val(epoch, val_loader, evaluator, recorder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--type', type=str, default='main') + parser.add_argument('--cfg_file', '-c', type=str, required=True) + parser.add_argument('--is_test', action='store_true', default=False) + parser.add_argument('opts', default=None, nargs=argparse.REMAINDER) + args = parser.parse_args() + cfg = make_cfg(args) + if cfg.distributed: + setup_distributed() + train(cfg) \ No newline at end of file diff --git a/tools/visualizae_results.py b/tools/visualize_results.py similarity index 100% rename from tools/visualizae_results.py rename to tools/visualize_results.py