diff --git a/aitviewer/models/smpl.py b/aitviewer/models/smpl.py index 711423c..afa1cb6 100644 --- a/aitviewer/models/smpl.py +++ b/aitviewer/models/smpl.py @@ -234,10 +234,70 @@ def fk( expression=expression, ) + return output.vertices, output.joints + + def fk_mano( + self, + hand_pose, + betas, + global_orient=None, + trans=None, + normalize_root=False, + mano=True, + ): + """ + Convert mano pose data (joint angles and shape parameters) to positional data (joint and mesh vertex positions). + :param hand_pose: A tensor of shape (N, N_JOINTS*3), i.e. joint angles in angle-axis format or PCA format (N, N_PCA_COMPONENTS). This contains all + body joints which are not the root. + :param betas: A tensor of shape (N, N_BETAS) containing the betas/shape parameters. + :param global_orient: Orientation of the root or None. If specified expected shape is (N, 3). + :param trans: translation that is applied to vertices and joints or None, this is the 'transl' parameter + of the MANO Model. If specified expected shape is (N, 3). + :param normalize_root: If set, it will normalize the root such that its orientation is the identity in the + first frame and its position starts at the origin. + :return: The resulting vertices and joints. + """ + + batch_size = hand_pose.shape[0] + device = hand_pose.device + + if global_orient is None: + global_orient = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device) + if trans is None: + trans = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device) + + + # Batch shapes if they don't match batch dimension. + if len(betas.shape) == 1 or betas.shape[0] == 1: + betas = betas.repeat(hand_pose.shape[0], 1) + betas = betas[:, : self.num_betas] + + if normalize_root: + # Make everything relative to the first root orientation. + root_ori = aa2rot(global_orient) + first_root_ori = torch.inverse(root_ori[0:1]) + root_ori = torch.matmul(first_root_ori, root_ori) + global_orient = rot2aa(root_ori) + trans = torch.matmul(first_root_ori.unsqueeze(0), trans.unsqueeze(-1)).squeeze() + trans = trans - trans[0:1] + + + output = self.bm( + hand_pose=hand_pose, + betas=betas, + global_orient=global_orient, + transl=trans, + ) + + return output.vertices, output.joints def forward(self, *args, **kwargs): """ Forward pass using forward kinematics """ - return self.fk(*args, **kwargs) + + if 'mano' in kwargs.keys(): + return self.fk_mano(*args, **kwargs) + else: + return self.fk(*args, **kwargs) \ No newline at end of file diff --git a/aitviewer/renderables/smpl.py b/aitviewer/renderables/smpl.py index 4020ff0..2abedbe 100644 --- a/aitviewer/renderables/smpl.py +++ b/aitviewer/renderables/smpl.py @@ -145,6 +145,7 @@ def __init__( # Nodes self.vertices, self.joints, self.faces, self.skeleton = self.fk() + if self._is_rigged: self.skeleton_seq = Skeletons( self.joints, @@ -157,11 +158,18 @@ def __init__( # First convert the relative joint angles to global joint angles in rotation matrix form. if self.smpl_layer.model_type != "flame": - global_oris = local_to_global( - torch.cat([self.poses_root, self.poses_body], dim=-1), - self.skeleton[:, 0], - output_format="rotmat", - ) + if self.smpl_layer.model_type != "mano": + global_oris = local_to_global( + torch.cat([self.poses_root, self.poses_body, self.poses_left_hand, self.poses_right_hand], dim=-1), + self.skeleton[:, 0], + output_format="rotmat", + ) + else: + global_oris = local_to_global( + torch.cat([self.poses_root, self.poses_body], dim=-1), + self.skeleton[:, 0], + output_format="rotmat", + ) global_oris = c2c(global_oris.reshape((self.n_frames, -1, 3, 3))) else: global_oris = np.tile(np.eye(3), self.joints.shape[:-1])[np.newaxis] @@ -169,8 +177,10 @@ def __init__( if self._z_up and not C.z_up: self.rotation = np.matmul(np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), self.rotation) - self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles") - self._add_node(self.rbs, enabled=self._show_joint_angles) + + if self.smpl_layer.model_type != "mano": + self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles") + self._add_node(self.rbs, enabled=self._show_joint_angles) self.mesh_seq = Meshes( self.vertices, @@ -396,21 +406,30 @@ def fk(self, current_frame_only=False): poses_right_hand = self.poses_right_hand trans = self.trans betas = self.betas - - verts, joints = self.smpl_layer( - poses_root=poses_root, - poses_body=poses_body, - poses_left_hand=poses_left_hand, - poses_right_hand=poses_right_hand, + + if self.smpl_layer.model_type == "mano": + verts, joints = self.smpl_layer( + hand_pose=poses_body, betas=betas, + global_orient=poses_root, trans=trans, + mano=True, ) + else: + verts, joints = self.smpl_layer( + poses_root=poses_root, + poses_body=poses_body, + poses_left_hand=poses_left_hand, + poses_right_hand=poses_right_hand, + betas=betas, + trans=trans, + ) # Apply post_fk_func if specified. if self.post_fk_func: verts, joints = self.post_fk_func(self, verts, joints, current_frame_only) - skeleton = self.smpl_layer.skeletons()["body"].T + skeleton = self.smpl_layer.skeletons()["body"].T if not self.smpl_layer.model_type == "mano" else self.smpl_layer.skeletons()["all"].T faces = self.smpl_layer.bm.faces.astype(np.int64) joints = joints[:, : skeleton.shape[0]]