Skip to content

Commit

Permalink
Fix bug with box-box #2356.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721138760
Change-Id: Ib9548b0bc505a9856176188c91858cff7c9accdf
  • Loading branch information
btaba authored and copybara-github committed Jan 30, 2025
1 parent 5b245e5 commit e0664b1
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 135 deletions.
1 change: 1 addition & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ General
MJX
^^^
- Added support for spatial tendons with internal sphere and cylinder wrapping.
- Fix a bug with box-box collisions :github:issue:`2356`.

Version 3.2.7 (Jan 14, 2025)
----------------------------
Expand Down
106 changes: 47 additions & 59 deletions mjx/mujoco/mjx/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def _create_contact_manifold(
return dist, pos, normal


def _sat_bruteforce(
def _box_box_impl(
faces_a: jax.Array,
faces_b: jax.Array,
vertices_a: jax.Array,
Expand All @@ -688,17 +688,7 @@ def _sat_bruteforce(
unique_edges_a: jax.Array,
unique_edges_b: jax.Array,
) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""Runs the Separating Axis Test for a pair of hulls.
Given two convex hulls, the Separating Axis Test finds a separating axis
between all edge pairs and face pairs. Edge pairs create a single contact
point and face pairs create a contact manifold (up to four contact points).
We return both the edge and face contacts. Valid contacts can be checked with
dist < 0. Resulting edge contacts should be preferred over face contacts.
This method checks all separating axes via a brute force support function, and
is thus costly to run over large meshes, but is more performant for smaller
meshes (boxes, tetrahedra, etc.).
"""Runs the Separating Axis Test for two boxes.
Args:
faces_a: Faces for hull A.
Expand All @@ -713,18 +703,15 @@ def _sat_bruteforce(
Returns:
tuple of dist, pos, and normal
"""
# get the separating axes
v_norm = jax.vmap(math.normalize)
edge_dir_a = v_norm(unique_edges_a[:, 0] - unique_edges_a[:, 1])
edge_dir_b = v_norm(unique_edges_b[:, 0] - unique_edges_b[:, 1])
edge_dir_a_r = jp.tile(edge_dir_a, reps=(unique_edges_b.shape[0], 1))
edge_dir_b_r = jp.repeat(edge_dir_b, repeats=unique_edges_a.shape[0], axis=0)
edge_dir_a, edge_dir_b = unique_edges_a, unique_edges_b
edge_dir_a_r = jp.tile(edge_dir_a, reps=(edge_dir_b.shape[0], 1))
edge_dir_b_r = jp.repeat(edge_dir_b, repeats=edge_dir_a.shape[0], axis=0)
edge_axes = jax.vmap(jp.cross)(edge_dir_a_r, edge_dir_b_r)
degenerate_edge_axes = (edge_axes**2).sum(axis=1) < 1e-6
edge_axes = jax.vmap(lambda x: math.normalize(x, axis=0))(edge_axes)
n_norm = normals_a.shape[0] + normals_b.shape[0]
n_face_axes = normals_a.shape[0] + normals_b.shape[0]
degenerate_axes = jp.concatenate(
[jp.array([False] * n_norm), degenerate_edge_axes]
[jp.array([False] * n_face_axes), degenerate_edge_axes]
)

axes = jp.concatenate([normals_a, normals_b, edge_axes])
Expand All @@ -745,11 +732,16 @@ def get_support(axis, is_degenerate):

support, sign = get_support(axes, degenerate_axes)

# get the best face axis
best_face_idx = jp.argmin(support[:n_face_axes])
best_face_axis = axes[best_face_idx]

# choose the best separating axis
best_idx = jp.argmin(support)
best_sign = sign[best_idx]
best_axis = axes[best_idx]
is_edge_contact = best_idx >= (normals_a.shape[0] + normals_b.shape[0])
is_edge_contact = best_idx >= n_face_axes
is_edge_contact &= jp.abs(best_face_axis.dot(best_axis)) < 0.99 # prefer face

# get the (reference) face most aligned with the separating axis
dist_a = normals_a @ best_axis
Expand Down Expand Up @@ -788,6 +780,40 @@ def get_support(axis, is_degenerate):
return dist, pos, normal


def _box_box(b1: ConvexInfo, b2: ConvexInfo) -> Collision:
"""Calculates contacts between two boxes."""
faces1 = b1.face
faces2 = b2.face

to_local_pos = b2.mat.T @ (b1.pos - b2.pos)
to_local_mat = b2.mat.T @ b1.mat

faces1 = to_local_pos + faces1 @ to_local_mat.T
normals1 = b1.face_normal @ to_local_mat.T
normals2 = b2.face_normal

vertices1 = to_local_pos + b1.vert @ to_local_mat.T
vertices2 = b2.vert

dist, pos, normal = _box_box_impl(
faces1,
faces2,
vertices1,
vertices2,
normals1,
normals2,
to_local_mat.T,
jp.eye(3, dtype=float),
)

# Go back to world frame.
pos = b2.pos + pos @ b2.mat.T
n = normal @ b2.mat.T
dist = jp.where(jp.isinf(dist), jp.finfo(float).max, dist)

return dist, pos, n


def _arcs_intersect(
a: jax.Array, b: jax.Array, c: jax.Array, d: jax.Array
) -> jax.Array:
Expand Down Expand Up @@ -941,44 +967,6 @@ def get_normals(a_dir, a_pt, b_dir):
return dist, pos, normal


def _box_box(b1: ConvexInfo, b2: ConvexInfo) -> Collision:
"""Calculates contacts between two boxes."""
faces1 = b1.face
faces2 = b2.face

to_local_pos = b2.mat.T @ (b1.pos - b2.pos)
to_local_mat = b2.mat.T @ b1.mat

faces1 = to_local_pos + faces1 @ to_local_mat.T
normals1 = b1.face_normal @ to_local_mat.T
normals2 = b2.face_normal

vertices1 = to_local_pos + b1.vert @ to_local_mat.T
vertices2 = b2.vert

unique_edges1 = jp.take(vertices1, b1.edge_dir, axis=0)
unique_edges2 = jp.take(vertices2, b2.edge_dir, axis=0)

# brute-force SAT is more performant for box-box
dist, pos, normal = _sat_bruteforce(
faces1,
faces2,
vertices1,
vertices2,
normals1,
normals2,
unique_edges1,
unique_edges2,
)

# Go back to world frame.
pos = b2.pos + pos @ b2.mat.T
n = normal @ b2.mat.T
dist = jp.where(jp.isinf(dist), jp.finfo(float).max, dist)

return dist, pos, n


def _convex_convex(c1: ConvexInfo, c2: ConvexInfo) -> Collision:
"""Calculates contacts between two convex meshes."""
# pad face vertices so that we can broadcast between geom1 and geom2
Expand Down
30 changes: 18 additions & 12 deletions mjx/mujoco/mjx/_src/collision_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ def _assert_attr_eq(mjx_d, mj_d, attr, name, atol):


def _collide(
mjcf: str, assets: Optional[Dict[str, str]] = None
mjcf: str,
assets: Optional[Dict[str, str]] = None,
keyframe: Optional[int] = None,
) -> Tuple[mujoco.MjModel, mujoco.MjData, Model, Data]:
m = mujoco.MjModel.from_xml_string(mjcf, assets or {})
mx = mjx.put_model(m)
d = mujoco.MjData(m)
if keyframe is not None:
mujoco.mj_resetDataKeyframe(m, d, keyframe)
dx = mjx.put_data(m, d)

m.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_NATIVECCD
Expand Down Expand Up @@ -649,28 +653,30 @@ def test_flat_box_plane(self):
_BOX_BOX = """
<mujoco>
<worldbody>
<body pos="0.0 1.0 0.2">
<joint axis="1 0 0" type="free"/>
<geom size="0.2 0.2 0.2" type="box"/>
</body>
<body pos="0.1 1.0 0.495" euler="0.1 -0.1 0">
<joint axis="1 0 0" type="free"/>
<geom size="0.1 0.1 0.1" type="box"/>
</body>
<light name="top" pos="0 0 1"/>
<geom type="box" size="0.025 0.025 0.025" pos="0 0 0.025"/>
<body name="peg" pos="0 0 0.06">
<freejoint/>
<geom name="peg" size="0.048 0.01 0.01" type="box"/>
</body>
</worldbody>
<keyframe>
<!-- Boxes are penetrating with a slightly off-axis face contact -->
<key qpos='-0.00234853 0.0112999 0.0533649 0.474162 0.472141 0.524886 0.526069'/>
</keyframe>
</mujoco>
"""

def test_box_box(self):
"""Tests a face contact for a box-box collision."""
d, dx = _collide(self._BOX_BOX)
d, dx = _collide(self._BOX_BOX, keyframe=0)
c = dx.contact

self.assertEqual(c.pos.shape[0], 4)
np.testing.assert_array_less(c.dist, 0)
np.testing.assert_array_almost_equal(c.pos[:, 2], np.array([0.39] * 4), 2)
np.testing.assert_array_almost_equal(c.pos[:, 2], np.array([0.05] * 4), 2)
np.testing.assert_array_almost_equal(
c.frame[:, 0, :], np.array([[0.0, 0.0, 1.0]] * 4)
c.frame[:, 0, :], np.array([[0.0, 0.0, 1.0]] * 4), decimal=2
)
np.testing.assert_array_almost_equal(
c.frame.reshape((-1, 9)), d.contact.frame[:4, :]
Expand Down
3 changes: 1 addition & 2 deletions mjx/mujoco/mjx/_src/collision_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Collision base types."""

import dataclasses
from typing import Optional, Tuple
from typing import Tuple
import jax
from mujoco.mjx._src.dataclasses import PyTreeNode # pylint: disable=g-importing-member
import numpy as np
Expand Down Expand Up @@ -46,7 +46,6 @@ class ConvexInfo(PyTreeNode):
face_normal: jax.Array
edge: jax.Array
edge_face_normal: jax.Array
edge_dir: Optional[jax.Array] = None


class HFieldInfo(PyTreeNode):
Expand Down
39 changes: 0 additions & 39 deletions mjx/mujoco/mjx/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,42 +53,6 @@ def _get_face_norm(vert: np.ndarray, face: np.ndarray) -> np.ndarray:
return face_norm


def _get_unique_edge_dir(vert: np.ndarray, face: np.ndarray) -> np.ndarray:
"""Returns unique edge directions.
Args:
vert: (n_vert, 3) vertices
face: (n_face, n_vert) face index array
Returns:
edges: tuples of vertex indexes for each edge
"""
r_face = np.roll(face, 1, axis=1)
edges = np.concatenate(np.array([face, r_face]).T)

# do a first pass to remove duplicates
edges.sort(axis=1)
edges = np.unique(edges, axis=0)
edges = edges[edges[:, 0] != edges[:, 1]] # get rid of edges from padded face

# get normalized edge directions
edge_vert = vert.take(edges, axis=0)
edge_dir = edge_vert[:, 0] - edge_vert[:, 1]
norms = np.sqrt(np.sum(edge_dir**2, axis=1))
edge_dir = edge_dir / norms.reshape((-1, 1))

# get the first unique edge for all pairwise comparisons
diff1 = edge_dir[:, None, :] - edge_dir[None, :, :]
diff2 = edge_dir[:, None, :] + edge_dir[None, :, :]
matches = (np.linalg.norm(diff1, axis=-1) < 1e-6) | (
np.linalg.norm(diff2, axis=-1) < 1e-6
)
matches = np.tril(matches).sum(axis=-1)
unique_edge_idx = np.where(matches == 1)[0]

return edges[unique_edge_idx]


def _get_edge_normals(
face: np.ndarray, face_norm: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -221,7 +185,6 @@ def box(info: GeomInfo) -> ConvexInfo:
# pyformat: enable
face_normal = _get_face_norm(vert, face)
edge, edge_face_normal = _get_edge_normals(face, face_normal)
edge_dir = _get_unique_edge_dir(vert, face)
face = vert[face] # materialize full nface x nvert matrix

c = ConvexInfo(
Expand All @@ -233,7 +196,6 @@ def box(info: GeomInfo) -> ConvexInfo:
face_normal,
edge,
edge_face_normal,
edge_dir,
)
c = jax.tree_util.tree_map(jp.array, c)
vert = jax.vmap(jp.multiply, in_axes=(None, 0))(c.vert, info.size)
Expand Down Expand Up @@ -356,7 +318,6 @@ def get_face_norm(face):
face_norm,
edges,
face_norm[edge_face_norm],
None,
)

return jax.tree_util.tree_map(jp.array, c)
Expand Down
23 changes: 0 additions & 23 deletions mjx/mujoco/mjx/_src/mesh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,6 @@ def test_pyramid(self):
expected_face_verts,
)

# check edges
edge_dir = mesh._get_unique_edge_dir(convex_vert, convex_face)
unique_edge = np.vectorize(map_.get)(edge_dir)
unique_edge = np.array(sorted(unique_edge.tolist()))
np.testing.assert_array_equal(
unique_edge,
np.array([[0, 2], [0, 3], [0, 4], [1, 4], [2, 4], [3, 4]]),
)

# face normals
face_normal = mesh._get_face_norm(convex_vert, convex_face)
self.assertEqual(face_normal.shape, (5, 3))
Expand Down Expand Up @@ -143,19 +134,5 @@ def test_convex_hull_2d_axis2(self):
np.testing.assert_array_almost_equal(normal, expected)


class UniqueEdgesTest(absltest.TestCase):

def test_tetrahedron_edges(self):
"""Tests unique edges for a tetrahedron."""
vert = np.array(
[[-0.1, 0.0, -0.1], [0.0, 0.1, 0.1], [0.1, 0.0, -0.1], [0.0, -0.1, 0.1]]
)
face = np.array([[0, 1, 2], [0, 2, 3], [0, 3, 1], [2, 1, 3]])
idx = mesh._get_unique_edge_dir(vert, face)
np.testing.assert_array_equal(
idx, np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]])
)


if __name__ == '__main__':
absltest.main()

0 comments on commit e0664b1

Please sign in to comment.