Skip to content

Commit

Permalink
Merge pull request #275 from ami-iit/feature/joint_limits
Browse files Browse the repository at this point in the history
Enforce joint position limits
  • Loading branch information
flferretti authored Nov 5, 2024
2 parents 35cf336 + f46e618 commit b274aab
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 5 deletions.
7 changes: 5 additions & 2 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def position_limit(
The position limits of the joint.
"""

if model.number_of_joints() <= 1:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
if model.number_of_joints() == 0:
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max

return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)

exceptions.raise_value_error_if(
condition=jnp.array(
Expand Down
33 changes: 32 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,40 @@ def system_acceleration(
# Enforce joint limits
# ====================

# TODO: enforce joint limits
τ_position_limit = jnp.zeros_like(τ_references).astype(float)

if model.dofs() > 0:

# Stiffness and damper parameters for the joint position limits.
k_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_spring
).astype(float)
d_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_damper
).astype(float)

# Compute the joint position limit violations.
lower_violation = jnp.clip(
data.state.physics_model.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
max=0.0,
)

upper_violation = jnp.clip(
data.state.physics_model.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
min=0.0,
)

# Compute the joint position limit torque.
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)

τ_position_limit -= (
jnp.positive(τ_position_limit)
* jnp.diag(d_j)
@ data.state.physics_model.joint_velocities
)

# ====================
# Joint friction model
# ====================
Expand Down
5 changes: 3 additions & 2 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import os
import pathlib
from typing import NamedTuple

Expand Down Expand Up @@ -273,14 +274,14 @@ def extract_model_data(
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.dissipation is not None
else 0.0
else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_DAMPER", 0.0)
),
position_limit_spring=float(
j.axis.limit.stiffness
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.stiffness is not None
else 0.0
else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_SPRING", 0.0)
),
)
for j in sdf_model.joints()
Expand Down
112 changes: 112 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,116 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel:
return build_jaxsim_model(model_description=model_urdf_path)


@pytest.fixture(scope="session")
def jaxsim_model_single_pendulum() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a single pendulum.
Returns:
The JaxSim model of a single pendulum.
"""

import numpy as np
import rod.builder.primitives

base_height = 2.15
upper_height = 1.0

# ===================
# Create the builders
# ===================

base_builder = rod.builder.primitives.BoxBuilder(
name="base",
mass=1.0,
x=0.15,
y=0.15,
z=base_height,
)

upper_builder = rod.builder.primitives.BoxBuilder(
name="upper",
mass=0.5,
x=0.15,
y=0.15,
z=upper_height,
)

# =================
# Create the joints
# =================

fixed = rod.Joint(
name="fixed_joint",
type="fixed",
parent="world",
child=base_builder.name,
)

pivot = rod.Joint(
name="upper_joint",
type="continuous",
parent=base_builder.name,
child=upper_builder.name,
axis=rod.Axis(
xyz=rod.Xyz([1, 0, 0]),
limit=rod.Limit(),
),
)

# ================
# Create the links
# ================

base = (
base_builder.build_link(
name=base_builder.name,
pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, base_height / 2])
),
)
.add_inertial()
.add_visual()
.add_collision()
.build()
)

upper_pose = rod.builder.primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, upper_height / 2])
)

upper = (
upper_builder.build_link(
name=upper_builder.name,
pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
relative_to=base.name, pos=np.array([0, 0, upper_height])
),
)
.add_inertial(pose=upper_pose)
.add_visual(pose=upper_pose)
.add_collision(pose=upper_pose)
.build()
)

rod_model = rod.Sdf(
version="1.10",
model=rod.Model(
name="single_pendulum",
link=[base, upper],
joint=[fixed, pivot],
),
)

rod_model.model.resolve_frames()

urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
sdf=rod_model.models()[0]
)

model = build_jaxsim_model(model_description=urdf_string)

return model


# ============================
# Collections of JaxSim models
# ============================
Expand Down Expand Up @@ -280,6 +390,8 @@ def get_jaxsim_model_fixture(
return request.getfixturevalue(jaxsim_model_ergocub_reduced.__name__)
case "ur10":
return request.getfixturevalue(jaxsim_model_ur10.__name__)
case "single_pendulum":
return request.getfixturevalue(jaxsim_model_single_pendulum.__name__)
case _:
raise ValueError(model_name)

Expand Down
56 changes: 56 additions & 0 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest

import jaxsim.api as js
Expand Down Expand Up @@ -370,3 +371,58 @@ def test_simulation_with_relaxed_rigid_contacts(
assert data_tf.base_position()[2] + max_penetration == pytest.approx(
box_height / 2, abs=0.000_100
)


def test_joint_limits(
jaxsim_model_single_pendulum: js.model.JaxSimModel,
):

model = jaxsim_model_single_pendulum

with model.editable(validate=False) as model:
model.kin_dyn_parameters.joint_parameters.position_limits_max = jnp.atleast_1d(
jnp.array(1.5708)
)
model.kin_dyn_parameters.joint_parameters.position_limits_min = jnp.atleast_1d(
jnp.array(-1.5708)
)
model.kin_dyn_parameters.joint_parameters.position_limit_spring = (
jnp.atleast_1d(jnp.array(75.0))
)
model.kin_dyn_parameters.joint_parameters.position_limit_damper = (
jnp.atleast_1d(jnp.array(0.1))
)

position_limits_min, position_limits_max = js.joint.position_limits(model=model)

data = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
)

theta = 10 * np.pi / 180

# Define a tolerance since the spring-damper model does
# not guarantee that the joint position will be exactly
# below the limit.
tolerance = theta * 0.10

# Test minimum joint position limits.
data_t0 = data.reset_joint_positions(positions=position_limits_min - theta)

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.005, tf=3.0)

assert (
np.min(np.array(data_tf.joint_positions()), axis=0) + tolerance
>= position_limits_min
)

# Test maximum joint position limits.
data_t0 = data.reset_joint_positions(positions=position_limits_max - theta)

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=3.0)

assert (
np.max(np.array(data_tf.joint_positions()), axis=0) - tolerance
<= position_limits_max
)

0 comments on commit b274aab

Please sign in to comment.