diff --git a/doc/changelog.rst b/doc/changelog.rst index 4d16a788b8..e5f22a59f6 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -35,6 +35,7 @@ General MJX ^^^ - Added support for spatial tendons with internal sphere and cylinder wrapping. +- Fix a bug with box-box collisions :github:issue:`2356`. Version 3.2.7 (Jan 14, 2025) ---------------------------- diff --git a/mjx/mujoco/mjx/_src/collision_convex.py b/mjx/mujoco/mjx/_src/collision_convex.py index 465ec2bb32..31ce2057e6 100644 --- a/mjx/mujoco/mjx/_src/collision_convex.py +++ b/mjx/mujoco/mjx/_src/collision_convex.py @@ -678,7 +678,7 @@ def _create_contact_manifold( return dist, pos, normal -def _sat_bruteforce( +def _box_box_impl( faces_a: jax.Array, faces_b: jax.Array, vertices_a: jax.Array, @@ -688,17 +688,7 @@ def _sat_bruteforce( unique_edges_a: jax.Array, unique_edges_b: jax.Array, ) -> Tuple[jax.Array, jax.Array, jax.Array]: - """Runs the Separating Axis Test for a pair of hulls. - - Given two convex hulls, the Separating Axis Test finds a separating axis - between all edge pairs and face pairs. Edge pairs create a single contact - point and face pairs create a contact manifold (up to four contact points). - We return both the edge and face contacts. Valid contacts can be checked with - dist < 0. Resulting edge contacts should be preferred over face contacts. - - This method checks all separating axes via a brute force support function, and - is thus costly to run over large meshes, but is more performant for smaller - meshes (boxes, tetrahedra, etc.). + """Runs the Separating Axis Test for two boxes. Args: faces_a: Faces for hull A. @@ -713,18 +703,15 @@ def _sat_bruteforce( Returns: tuple of dist, pos, and normal """ - # get the separating axes - v_norm = jax.vmap(math.normalize) - edge_dir_a = v_norm(unique_edges_a[:, 0] - unique_edges_a[:, 1]) - edge_dir_b = v_norm(unique_edges_b[:, 0] - unique_edges_b[:, 1]) - edge_dir_a_r = jp.tile(edge_dir_a, reps=(unique_edges_b.shape[0], 1)) - edge_dir_b_r = jp.repeat(edge_dir_b, repeats=unique_edges_a.shape[0], axis=0) + edge_dir_a, edge_dir_b = unique_edges_a, unique_edges_b + edge_dir_a_r = jp.tile(edge_dir_a, reps=(edge_dir_b.shape[0], 1)) + edge_dir_b_r = jp.repeat(edge_dir_b, repeats=edge_dir_a.shape[0], axis=0) edge_axes = jax.vmap(jp.cross)(edge_dir_a_r, edge_dir_b_r) degenerate_edge_axes = (edge_axes**2).sum(axis=1) < 1e-6 edge_axes = jax.vmap(lambda x: math.normalize(x, axis=0))(edge_axes) - n_norm = normals_a.shape[0] + normals_b.shape[0] + n_face_axes = normals_a.shape[0] + normals_b.shape[0] degenerate_axes = jp.concatenate( - [jp.array([False] * n_norm), degenerate_edge_axes] + [jp.array([False] * n_face_axes), degenerate_edge_axes] ) axes = jp.concatenate([normals_a, normals_b, edge_axes]) @@ -745,11 +732,16 @@ def get_support(axis, is_degenerate): support, sign = get_support(axes, degenerate_axes) + # get the best face axis + best_face_idx = jp.argmin(support[:n_face_axes]) + best_face_axis = axes[best_face_idx] + # choose the best separating axis best_idx = jp.argmin(support) best_sign = sign[best_idx] best_axis = axes[best_idx] - is_edge_contact = best_idx >= (normals_a.shape[0] + normals_b.shape[0]) + is_edge_contact = best_idx >= n_face_axes + is_edge_contact &= jp.abs(best_face_axis.dot(best_axis)) < 0.99 # prefer face # get the (reference) face most aligned with the separating axis dist_a = normals_a @ best_axis @@ -788,6 +780,40 @@ def get_support(axis, is_degenerate): return dist, pos, normal +def _box_box(b1: ConvexInfo, b2: ConvexInfo) -> Collision: + """Calculates contacts between two boxes.""" + faces1 = b1.face + faces2 = b2.face + + to_local_pos = b2.mat.T @ (b1.pos - b2.pos) + to_local_mat = b2.mat.T @ b1.mat + + faces1 = to_local_pos + faces1 @ to_local_mat.T + normals1 = b1.face_normal @ to_local_mat.T + normals2 = b2.face_normal + + vertices1 = to_local_pos + b1.vert @ to_local_mat.T + vertices2 = b2.vert + + dist, pos, normal = _box_box_impl( + faces1, + faces2, + vertices1, + vertices2, + normals1, + normals2, + to_local_mat.T, + jp.eye(3, dtype=float), + ) + + # Go back to world frame. + pos = b2.pos + pos @ b2.mat.T + n = normal @ b2.mat.T + dist = jp.where(jp.isinf(dist), jp.finfo(float).max, dist) + + return dist, pos, n + + def _arcs_intersect( a: jax.Array, b: jax.Array, c: jax.Array, d: jax.Array ) -> jax.Array: @@ -941,44 +967,6 @@ def get_normals(a_dir, a_pt, b_dir): return dist, pos, normal -def _box_box(b1: ConvexInfo, b2: ConvexInfo) -> Collision: - """Calculates contacts between two boxes.""" - faces1 = b1.face - faces2 = b2.face - - to_local_pos = b2.mat.T @ (b1.pos - b2.pos) - to_local_mat = b2.mat.T @ b1.mat - - faces1 = to_local_pos + faces1 @ to_local_mat.T - normals1 = b1.face_normal @ to_local_mat.T - normals2 = b2.face_normal - - vertices1 = to_local_pos + b1.vert @ to_local_mat.T - vertices2 = b2.vert - - unique_edges1 = jp.take(vertices1, b1.edge_dir, axis=0) - unique_edges2 = jp.take(vertices2, b2.edge_dir, axis=0) - - # brute-force SAT is more performant for box-box - dist, pos, normal = _sat_bruteforce( - faces1, - faces2, - vertices1, - vertices2, - normals1, - normals2, - unique_edges1, - unique_edges2, - ) - - # Go back to world frame. - pos = b2.pos + pos @ b2.mat.T - n = normal @ b2.mat.T - dist = jp.where(jp.isinf(dist), jp.finfo(float).max, dist) - - return dist, pos, n - - def _convex_convex(c1: ConvexInfo, c2: ConvexInfo) -> Collision: """Calculates contacts between two convex meshes.""" # pad face vertices so that we can broadcast between geom1 and geom2 diff --git a/mjx/mujoco/mjx/_src/collision_driver_test.py b/mjx/mujoco/mjx/_src/collision_driver_test.py index 4b083223d4..6bff2e8c35 100644 --- a/mjx/mujoco/mjx/_src/collision_driver_test.py +++ b/mjx/mujoco/mjx/_src/collision_driver_test.py @@ -49,11 +49,15 @@ def _assert_attr_eq(mjx_d, mj_d, attr, name, atol): def _collide( - mjcf: str, assets: Optional[Dict[str, str]] = None + mjcf: str, + assets: Optional[Dict[str, str]] = None, + keyframe: Optional[int] = None, ) -> Tuple[mujoco.MjModel, mujoco.MjData, Model, Data]: m = mujoco.MjModel.from_xml_string(mjcf, assets or {}) mx = mjx.put_model(m) d = mujoco.MjData(m) + if keyframe is not None: + mujoco.mj_resetDataKeyframe(m, d, keyframe) dx = mjx.put_data(m, d) m.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_NATIVECCD @@ -649,28 +653,30 @@ def test_flat_box_plane(self): _BOX_BOX = """ - - - - - - - - + + + + + + + + + + """ def test_box_box(self): """Tests a face contact for a box-box collision.""" - d, dx = _collide(self._BOX_BOX) + d, dx = _collide(self._BOX_BOX, keyframe=0) c = dx.contact self.assertEqual(c.pos.shape[0], 4) np.testing.assert_array_less(c.dist, 0) - np.testing.assert_array_almost_equal(c.pos[:, 2], np.array([0.39] * 4), 2) + np.testing.assert_array_almost_equal(c.pos[:, 2], np.array([0.05] * 4), 2) np.testing.assert_array_almost_equal( - c.frame[:, 0, :], np.array([[0.0, 0.0, 1.0]] * 4) + c.frame[:, 0, :], np.array([[0.0, 0.0, 1.0]] * 4), decimal=2 ) np.testing.assert_array_almost_equal( c.frame.reshape((-1, 9)), d.contact.frame[:4, :] diff --git a/mjx/mujoco/mjx/_src/collision_types.py b/mjx/mujoco/mjx/_src/collision_types.py index 4495771189..d105a8d296 100644 --- a/mjx/mujoco/mjx/_src/collision_types.py +++ b/mjx/mujoco/mjx/_src/collision_types.py @@ -15,7 +15,7 @@ """Collision base types.""" import dataclasses -from typing import Optional, Tuple +from typing import Tuple import jax from mujoco.mjx._src.dataclasses import PyTreeNode # pylint: disable=g-importing-member import numpy as np @@ -46,7 +46,6 @@ class ConvexInfo(PyTreeNode): face_normal: jax.Array edge: jax.Array edge_face_normal: jax.Array - edge_dir: Optional[jax.Array] = None class HFieldInfo(PyTreeNode): diff --git a/mjx/mujoco/mjx/_src/mesh.py b/mjx/mujoco/mjx/_src/mesh.py index 4c60ee1e63..282058f64d 100644 --- a/mjx/mujoco/mjx/_src/mesh.py +++ b/mjx/mujoco/mjx/_src/mesh.py @@ -53,42 +53,6 @@ def _get_face_norm(vert: np.ndarray, face: np.ndarray) -> np.ndarray: return face_norm -def _get_unique_edge_dir(vert: np.ndarray, face: np.ndarray) -> np.ndarray: - """Returns unique edge directions. - - Args: - vert: (n_vert, 3) vertices - face: (n_face, n_vert) face index array - - Returns: - edges: tuples of vertex indexes for each edge - """ - r_face = np.roll(face, 1, axis=1) - edges = np.concatenate(np.array([face, r_face]).T) - - # do a first pass to remove duplicates - edges.sort(axis=1) - edges = np.unique(edges, axis=0) - edges = edges[edges[:, 0] != edges[:, 1]] # get rid of edges from padded face - - # get normalized edge directions - edge_vert = vert.take(edges, axis=0) - edge_dir = edge_vert[:, 0] - edge_vert[:, 1] - norms = np.sqrt(np.sum(edge_dir**2, axis=1)) - edge_dir = edge_dir / norms.reshape((-1, 1)) - - # get the first unique edge for all pairwise comparisons - diff1 = edge_dir[:, None, :] - edge_dir[None, :, :] - diff2 = edge_dir[:, None, :] + edge_dir[None, :, :] - matches = (np.linalg.norm(diff1, axis=-1) < 1e-6) | ( - np.linalg.norm(diff2, axis=-1) < 1e-6 - ) - matches = np.tril(matches).sum(axis=-1) - unique_edge_idx = np.where(matches == 1)[0] - - return edges[unique_edge_idx] - - def _get_edge_normals( face: np.ndarray, face_norm: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: @@ -221,7 +185,6 @@ def box(info: GeomInfo) -> ConvexInfo: # pyformat: enable face_normal = _get_face_norm(vert, face) edge, edge_face_normal = _get_edge_normals(face, face_normal) - edge_dir = _get_unique_edge_dir(vert, face) face = vert[face] # materialize full nface x nvert matrix c = ConvexInfo( @@ -233,7 +196,6 @@ def box(info: GeomInfo) -> ConvexInfo: face_normal, edge, edge_face_normal, - edge_dir, ) c = jax.tree_util.tree_map(jp.array, c) vert = jax.vmap(jp.multiply, in_axes=(None, 0))(c.vert, info.size) @@ -356,7 +318,6 @@ def get_face_norm(face): face_norm, edges, face_norm[edge_face_norm], - None, ) return jax.tree_util.tree_map(jp.array, c) diff --git a/mjx/mujoco/mjx/_src/mesh_test.py b/mjx/mujoco/mjx/_src/mesh_test.py index f50bffc64c..434a4d17f0 100644 --- a/mjx/mujoco/mjx/_src/mesh_test.py +++ b/mjx/mujoco/mjx/_src/mesh_test.py @@ -61,15 +61,6 @@ def test_pyramid(self): expected_face_verts, ) - # check edges - edge_dir = mesh._get_unique_edge_dir(convex_vert, convex_face) - unique_edge = np.vectorize(map_.get)(edge_dir) - unique_edge = np.array(sorted(unique_edge.tolist())) - np.testing.assert_array_equal( - unique_edge, - np.array([[0, 2], [0, 3], [0, 4], [1, 4], [2, 4], [3, 4]]), - ) - # face normals face_normal = mesh._get_face_norm(convex_vert, convex_face) self.assertEqual(face_normal.shape, (5, 3)) @@ -143,19 +134,5 @@ def test_convex_hull_2d_axis2(self): np.testing.assert_array_almost_equal(normal, expected) -class UniqueEdgesTest(absltest.TestCase): - - def test_tetrahedron_edges(self): - """Tests unique edges for a tetrahedron.""" - vert = np.array( - [[-0.1, 0.0, -0.1], [0.0, 0.1, 0.1], [0.1, 0.0, -0.1], [0.0, -0.1, 0.1]] - ) - face = np.array([[0, 1, 2], [0, 2, 3], [0, 3, 1], [2, 1, 3]]) - idx = mesh._get_unique_edge_dir(vert, face) - np.testing.assert_array_equal( - idx, np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]) - ) - - if __name__ == '__main__': absltest.main()