Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch between data.base_transform and jaxsim.api.model.forward_kinematics for fixed-base models with non trivial base-to-world transform #337

Open
xela-95 opened this issue Jan 11, 2025 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@xela-95
Copy link
Member

xela-95 commented Jan 11, 2025

I've encountered a strange behavior when trying to simulate a fixed base model, i.e. a model with a description in which it is present a fixed joint between a world link and the model base link.

In particular when in the model URDF I set a base pose ${} ^W H _B \neq I _4$ through the <origin> tag of the fixed joint, in Jaxsim I get different results by calling data.base_transform and jaxsim.api.model.forward_kinematics for the base.

Repro steps

Using a modified version of single_pendulum model defined in unit tests, in which a non trivial base to world pose has been defined:

import jax.numpy as jnp
import jaxsim.api as js
import numpy as np
import rod
from jaxsim import VelRepr
from jaxsim.integrators.fixed_step import ForwardEuler
from rod.urdf.exporter import UrdfExporter


def jaxsim_model_single_pendulum() -> js.model.JaxSimModel:
    """
    Fixture providing the JaxSim model of a single pendulum.

    Returns:
        The JaxSim model of a single pendulum.
    """

    import rod.builder.primitives

    base_height = 2.15
    upper_height = 1.0

    # ===================
    # Create the builders
    # ===================

    base_builder = rod.builder.primitives.BoxBuilder(
        name="base",
        mass=1.0,
        x=0.15,
        y=0.15,
        z=base_height,
    )

    upper_builder = rod.builder.primitives.BoxBuilder(
        name="upper",
        mass=0.5,
        x=0.15,
        y=0.15,
        z=upper_height,
    )

    # =================
    # Create the joints
    # =================

    fixed = rod.Joint(
        name="fixed_joint",
        type="fixed",
        parent="world",
        child=base_builder.name,
        pose=rod.Pose(pose=[0, 0, 0, 1.075, 1.57, 0]),
    )

    pivot = rod.Joint(
        name="upper_joint",
        type="continuous",
        parent=base_builder.name,
        child=upper_builder.name,
        axis=rod.Axis(
            xyz=rod.Xyz([1, 0, 0]),
            limit=rod.Limit(),
        ),
    )

    # ================
    # Create the links
    # ================

    base = (
        base_builder.build_link(
            name=base_builder.name,
            pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
                pos=np.array([0, 0, base_height / 2])
            ),
        )
        .add_inertial()
        .add_visual()
        .add_collision()
        .build()
    )

    upper_pose = rod.builder.primitives.PrimitiveBuilder.build_pose(
        pos=np.array([0, 0, upper_height / 2])
    )

    upper = (
        upper_builder.build_link(
            name=upper_builder.name,
            pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
                relative_to=base.name, pos=np.array([0, 0, upper_height])
            ),
        )
        .add_inertial(pose=upper_pose)
        .add_visual(pose=upper_pose)
        .add_collision(pose=upper_pose)
        .build()
    )

    rod_model = rod.Sdf(
        version="1.10",
        model=rod.Model(
            name="single_pendulum",
            link=[base, upper],
            joint=[fixed, pivot],
        ),
    )

    rod_model.model.resolve_frames()

    urdf_string = UrdfExporter(pretty=True).to_urdf_string(sdf=rod_model.models()[0])

    model = js.model.JaxSimModel.build_from_model_description(
        model_description=urdf_string,
        integrator=ForwardEuler,
        time_step=0.001,
        contact_model=js.contact.contacts.SoftContacts.build(),
    )

    return model


model = jaxsim_model_single_pendulum()
data: js.data.JaxSimModelData = js.data.JaxSimModelData.build(
    model=model,
    velocity_representation=VelRepr.Inertial,
    joint_positions=jnp.array([-np.pi / 2]),
)

W_H_B_from_data = data.base_transform()
W_H_B_from_fk = js.model.forward_kinematics(model, data)

print(W_H_B_from_data)

print(W_H_B_from_fk[0])

assert jnp.allclose(W_H_B_from_data, W_H_B_from_fk[0])

Which outputs

jaxsim[41566] INFO Enabling JAX to use 64-bit precision
rod[41566] INFO Calling sdformat through '/home/acroci/repos/component_darwin/.pixi/envs/default/bin/gz sdf'
rod[41566] DEBUG Converting model 'single_pendulum' to URDF
rod[41566] DEBUG Detected 'base' as root link
rod[41566] DEBUG Building kinematic tree of model 'single_pendulum'
rod[41566] DEBUG Selecting 'base' as canonical link
rod[41566] DEBUG Edge 'fixed_joint' became a frame attached to 'base'
rod[41566] DEBUG Node 'world' became a frame attached to 'fixed_joint'
rod[41566] WARNING Ignoring non-trivial pose of link 'base'
jaxsim[41566] DEBUG Found model 'single_pendulum' in SDF resource
rod[41566] DEBUG Building kinematic tree of model 'single_pendulum'
rod[41566] DEBUG Selecting 'base' as canonical link
rod[41566] DEBUG Edge 'fixed_joint' became a frame attached to 'base'
rod[41566] DEBUG Node 'world' became a frame attached to 'fixed_joint'
jaxsim[41566] DEBUG Model 'single_pendulum' is fixed-base
jaxsim[41566] DEBUG Considering 'base' as base link
jaxsim[41566] DEBUG Found joints connecting to world: ['fixed_joint']
jaxsim[41566] INFO Combining the pose of base link 'base' with the pose of joint 'fixed_joint'
jaxsim[41566] INFO The kinematic graph doesn't need to be reduced
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
[[-0.47573  0.41948 -0.77312  0.51141]
 [ 0.87959  0.22571 -0.41878 -0.94556]
 [-0.00118 -0.87926 -0.47635  1.07541]
 [ 0.       0.       0.       1.     ]]
Traceback (most recent call last):
  File "/home/acroci/repos/component_darwin/src/test_base_transform.py", line 137, in <module>
    assert jnp.allclose(W_H_B_from_data, W_H_B_from_fk[0])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError


Instead the two quantities match if the base link is defined as:

    base = (
        base_builder.build_link(
            name=base_builder.name,
            pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
                # pos=np.array([0, 0, base_height / 2])
            ),
        )
        .add_inertial()
        .add_visual()
        .add_collision()
        .build()
    )

Output:

[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
@xela-95 xela-95 added the bug Something isn't working label Jan 11, 2025
@xela-95 xela-95 changed the title Mismatch between data.base_transform and modelforward_kinematics for fixed-base models with non trivial base-to-world transform Mismatch between data.base_transform and jaxsim.api.model.forward_kinematics for fixed-base models with non trivial base-to-world transform Jan 11, 2025
@xela-95
Copy link
Member Author

xela-95 commented Jan 13, 2025

Ok maybe I found something.

Concerning the forward kinematics, it is computed though an recursive method which propagates the kinematics using the joint model, taking into account the parent-to-child transforms and considering also the world-to-base transform:

def forward_kinematics_model(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
) -> jtp.Array:
"""
Compute the forward kinematics.
Args:
model: The model to consider.
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
Returns:
A 3D array containing the SE(3) transforms of all links belonging to the model.
"""
W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=W_H_B.as_matrix()
)
# Allocate the buffer of transforms world -> link and initialize the base pose.
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))

The base position and orientation that are obtained through data.base_transform are extracted from the state.physics_model attribute of type PhysicsModelState.

This is initialized in JaxSimData.build method, in which, if no base position or orientation are explicitly passed, it is default initialized to (0,0,0) and (1,0,0,0) respectively:

def build(
model: js.model.JaxSimModel,
base_position: jtp.VectorLike | None = None,
base_quaternion: jtp.VectorLike | None = None,
joint_positions: jtp.VectorLike | None = None,
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
extended_ode_state: dict[str, jtp.PyTree] | None = None,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Args:
model: The model for which to create the state.
base_position: The base position.
base_quaternion: The base orientation as a quaternion.
joint_positions: The joint positions.
base_linear_velocity:
The base linear velocity in the selected representation.
base_angular_velocity:
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
standard_gravity: The standard gravity constant.
contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
extended_ode_state:
Additional user-defined state variables that are not part of the
standard `ODEState` object. Useful to extend the system dynamics
considered by default in JaxSim.
Returns:
A `JaxSimModelData` initialized with the given state.
"""
base_position = jnp.array(
base_position if base_position is not None else jnp.zeros(3),
dtype=float,
).squeeze()
base_quaternion = jnp.array(
(
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
),
dtype=float,
).squeeze()

This does not take into account the world-to-base position specified in the model description and this creates the "corrupted" state.

@xela-95
Copy link
Member Author

xela-95 commented Jan 13, 2025

I think the mismatch in the end arises due to some attribute(s) between data.state.physics_model and model.kin_dyn_parameters not being coherent.

This means that also calling data.reset* methods doesn't allow to work around this issue.

CC @CarlottaSartore

@xela-95 xela-95 self-assigned this Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant