diff --git a/mjx/mujoco/mjx/_src/support_test.py b/mjx/mujoco/mjx/_src/support_test.py index 572e767984..5bba0df169 100644 --- a/mjx/mujoco/mjx/_src/support_test.py +++ b/mjx/mujoco/mjx/_src/support_test.py @@ -200,6 +200,10 @@ def test_bind(self): # test getting np.testing.assert_array_equal(mx.bind(s.bodies).pos, m.body_pos) np.testing.assert_array_equal(dx.bind(mx, s.bodies).xpos, d.xpos) + np.testing.assert_array_equal(m.bind(s.bodies[0]).mass, m.body_mass[0]) + np.testing.assert_array_equal(m.bind(s.bodies[0:1]).mass, [m.body_mass[0]]) + np.testing.assert_array_equal(mx.bind(s.bodies[0]).mass, m.body_mass[0]) + np.testing.assert_array_equal(mx.bind(s.bodies[0:1]).mass, [m.body_mass[0]]) for i in range(m.nbody): np.testing.assert_array_equal(m.bind(s.bodies[i]).pos, m.body_pos[i, :]) np.testing.assert_array_equal(mx.bind(s.bodies[i]).pos, m.body_pos[i, :]) @@ -224,6 +228,8 @@ def test_bind(self): np.testing.assert_array_equal(mx.bind(s.joints).axis, m.jnt_axis) np.testing.assert_array_equal(mx.bind(s.joints).qposadr, m.jnt_qposadr) np.testing.assert_array_equal(mx.bind(s.joints).dofadr, m.jnt_dofadr) + np.testing.assert_array_equal(dx.bind(mx, s.joints[1]).id, 1) + np.testing.assert_array_equal(dx.bind(mx, s.joints[1:2]).id, [1]) qposnum = [4, 1, 1] # one ball joint (4) and two slide joints (1) dofnum = [3, 1, 1] # one ball joint (3) and two slide joints (1) for i in range(m.njnt):