Skip to content

Commit

Permalink
add MANO to SMPL model
Browse files Browse the repository at this point in the history
  • Loading branch information
christsa committed Sep 23, 2024
1 parent f33a7d0 commit 8014e7d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 15 deletions.
62 changes: 61 additions & 1 deletion aitviewer/models/smpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,70 @@ def fk(
expression=expression,
)

return output.vertices, output.joints

def fk_mano(
self,
hand_pose,
betas,
global_orient=None,
trans=None,
normalize_root=False,
mano=True,
):
"""
Convert mano pose data (joint angles and shape parameters) to positional data (joint and mesh vertex positions).
:param hand_pose: A tensor of shape (N, N_JOINTS*3), i.e. joint angles in angle-axis format or PCA format (N, N_PCA_COMPONENTS). This contains all
body joints which are not the root.
:param betas: A tensor of shape (N, N_BETAS) containing the betas/shape parameters.
:param global_orient: Orientation of the root or None. If specified expected shape is (N, 3).
:param trans: translation that is applied to vertices and joints or None, this is the 'transl' parameter
of the MANO Model. If specified expected shape is (N, 3).
:param normalize_root: If set, it will normalize the root such that its orientation is the identity in the
first frame and its position starts at the origin.
:return: The resulting vertices and joints.
"""

batch_size = hand_pose.shape[0]
device = hand_pose.device

if global_orient is None:
global_orient = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device)
if trans is None:
trans = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device)


# Batch shapes if they don't match batch dimension.
if len(betas.shape) == 1 or betas.shape[0] == 1:
betas = betas.repeat(hand_pose.shape[0], 1)
betas = betas[:, : self.num_betas]

if normalize_root:
# Make everything relative to the first root orientation.
root_ori = aa2rot(global_orient)
first_root_ori = torch.inverse(root_ori[0:1])
root_ori = torch.matmul(first_root_ori, root_ori)
global_orient = rot2aa(root_ori)
trans = torch.matmul(first_root_ori.unsqueeze(0), trans.unsqueeze(-1)).squeeze()
trans = trans - trans[0:1]


output = self.bm(
hand_pose=hand_pose,
betas=betas,
global_orient=global_orient,
transl=trans,
)


return output.vertices, output.joints

def forward(self, *args, **kwargs):
"""
Forward pass using forward kinematics
"""
return self.fk(*args, **kwargs)

if 'mano' in kwargs.keys():
return self.fk_mano(*args, **kwargs)
else:
return self.fk(*args, **kwargs)
47 changes: 33 additions & 14 deletions aitviewer/renderables/smpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
# Nodes
self.vertices, self.joints, self.faces, self.skeleton = self.fk()


if self._is_rigged:
self.skeleton_seq = Skeletons(
self.joints,
Expand All @@ -157,20 +158,29 @@ def __init__(

# First convert the relative joint angles to global joint angles in rotation matrix form.
if self.smpl_layer.model_type != "flame":
global_oris = local_to_global(
torch.cat([self.poses_root, self.poses_body], dim=-1),
self.skeleton[:, 0],
output_format="rotmat",
)
if self.smpl_layer.model_type != "mano":
global_oris = local_to_global(
torch.cat([self.poses_root, self.poses_body, self.poses_left_hand, self.poses_right_hand], dim=-1),
self.skeleton[:, 0],
output_format="rotmat",
)
else:
global_oris = local_to_global(
torch.cat([self.poses_root, self.poses_body], dim=-1),
self.skeleton[:, 0],
output_format="rotmat",
)
global_oris = c2c(global_oris.reshape((self.n_frames, -1, 3, 3)))
else:
global_oris = np.tile(np.eye(3), self.joints.shape[:-1])[np.newaxis]

if self._z_up and not C.z_up:
self.rotation = np.matmul(np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), self.rotation)

self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles")
self._add_node(self.rbs, enabled=self._show_joint_angles)

if self.smpl_layer.model_type != "mano":
self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles")
self._add_node(self.rbs, enabled=self._show_joint_angles)

self.mesh_seq = Meshes(
self.vertices,
Expand Down Expand Up @@ -396,21 +406,30 @@ def fk(self, current_frame_only=False):
poses_right_hand = self.poses_right_hand
trans = self.trans
betas = self.betas

verts, joints = self.smpl_layer(
poses_root=poses_root,
poses_body=poses_body,
poses_left_hand=poses_left_hand,
poses_right_hand=poses_right_hand,

if self.smpl_layer.model_type == "mano":
verts, joints = self.smpl_layer(
hand_pose=poses_body,
betas=betas,
global_orient=poses_root,
trans=trans,
mano=True,
)
else:
verts, joints = self.smpl_layer(
poses_root=poses_root,
poses_body=poses_body,
poses_left_hand=poses_left_hand,
poses_right_hand=poses_right_hand,
betas=betas,
trans=trans,
)

# Apply post_fk_func if specified.
if self.post_fk_func:
verts, joints = self.post_fk_func(self, verts, joints, current_frame_only)

skeleton = self.smpl_layer.skeletons()["body"].T
skeleton = self.smpl_layer.skeletons()["body"].T if not self.smpl_layer.model_type == "mano" else self.smpl_layer.skeletons()["all"].T
faces = self.smpl_layer.bm.faces.astype(np.int64)
joints = joints[:, : skeleton.shape[0]]

Expand Down

0 comments on commit 8014e7d

Please sign in to comment.