Skip to content

Commit

Permalink
Add new mpc results
Browse files Browse the repository at this point in the history
  • Loading branch information
hbuurmei committed Feb 5, 2025
1 parent 4c3bce4 commit 3b3b0fa
Show file tree
Hide file tree
Showing 16 changed files with 26,798 additions and 116 deletions.
3,722 changes: 3,722 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small.csv

Large diffs are not rendered by default.

1,849 changes: 1,849 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small2.csv

Large diffs are not rendered by default.

1,907 changes: 1,907 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small3.csv

Large diffs are not rendered by default.

1,864 changes: 1,864 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small4.csv

Large diffs are not rendered by default.

1,993 changes: 1,993 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small_origin.csv

Large diffs are not rendered by default.

2,000 changes: 2,000 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small_origin2.csv

Large diffs are not rendered by default.

1,972 changes: 1,972 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_circle_small_origin3.csv

Large diffs are not rendered by default.

3,827 changes: 3,827 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_figure_eight.csv

Large diffs are not rendered by default.

1,931 changes: 1,931 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_figure_eight_small.csv

Large diffs are not rendered by default.

1,666 changes: 1,666 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_small_line.csv

Large diffs are not rendered by default.

1,918 changes: 1,918 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_small_line10.csv

Large diffs are not rendered by default.

1,931 changes: 1,931 additions & 0 deletions stack/main/data/trajectories/closed_loop/test_small_line10_smallerR.csv

Large diffs are not rendered by default.

159 changes: 118 additions & 41 deletions stack/main/scripts/visualize_mpc_traj.ipynb

Large diffs are not rendered by default.

26 changes: 12 additions & 14 deletions stack/main/src/executor/executor/mpc_initializer_node.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os

import jax
import jax.numpy as jnp
import logging
logging.getLogger('jax').setLevel(logging.ERROR)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

import rclpy # type: ignore
from rclpy.node import Node # type: ignore

from controller.mpc.gusto import GuSTOConfig # type: ignore
from controller.mpc_solver_node import run_mpc_solver_node # type: ignore
from .utils.models import SSMR
Expand All @@ -28,9 +31,7 @@ def __init__(self):
self.data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')

# Generate reference trajectory
z_ref, t = self._generate_ref_trajectory(15, 0.01, 'periodic_line', 0.05)
# t = jnp.arange(0, 5, 0.01)
# z_ref = jnp.zeros((len(t), 3))
z_ref, t = self._generate_ref_trajectory(10, 0.01, 'circle', 0.03)

# Load the model
self._load_model()
Expand All @@ -40,22 +41,19 @@ def __init__(self):
Qz = Qz.at[1, 1].set(0)
Qzf = 10 * jnp.eye(self.model.n_z)
Qzf = Qzf.at[1, 1].set(0)
R_base = 0.008
R = R_base * jnp.eye(self.model.n_u)
R = R.at[1, 1].set(R_base * 10)
# R = R.at[-1, -1].set(R_base * 3 / 4)
R_tip, R_mid, R_top = 0.005, 0.005, 0.005
R = jnp.diag(jnp.array([R_tip, R_mid, R_top, R_mid, R_top, R_tip]))
gusto_config = GuSTOConfig(
Qz=Qz,
Qzf=Qzf,
R=R,
x_char=0.05*jnp.ones(self.model.n_x),
f_char=0.5*jnp.ones(self.model.n_x),
N=6
N=5
)
U = HyperRectangle([0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
[-0.4, -0.4, -0.4, -0.4, -0.4, -0.4])
# dU = HyperRectangle([0.1]*6, [-0.1]*6)
dU = None
U = HyperRectangle([0.4]*6, [-0.4]*6)
dU = HyperRectangle([0.1]*6, [-0.1]*6)
# dU = None
x0 = jnp.zeros(self.model.n_x)
self.mpc_solver_node = run_mpc_solver_node(self.model, gusto_config, x0, t=t, z=z_ref, U=U, dU=dU)

Expand All @@ -81,9 +79,9 @@ def _generate_ref_trajectory(self, T, dt, traj_type, size):
t = jnp.linspace(0, T, int(T/dt))
z_ref = jnp.zeros((len(t), 3))

# Note that y is up
# Note that y is vertically up
if traj_type == 'circle':
z_ref = z_ref.at[:, 0].set(size * jnp.cos(2 * jnp.pi / T * t))
z_ref = z_ref.at[:, 0].set(size * (jnp.cos(2 * jnp.pi / T * t) - 1))
z_ref = z_ref.at[:, 1].set(size / 2 * jnp.ones_like(t))
z_ref = z_ref.at[:, 2].set(size * jnp.sin(2 * jnp.pi / T * t))
elif traj_type == 'figure_eight':
Expand Down
148 changes: 87 additions & 61 deletions stack/main/src/executor/executor/mpc_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import time
import threading

import jax
import jax.numpy as jnp
Expand All @@ -9,16 +7,19 @@
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from rclpy.executors import MultiThreadedExecutor # type: ignore
from rclpy.qos import QoSProfile # type: ignore
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from rclpy.clock import ROSClock
from rclpy.callback_groups import ReentrantCallbackGroup # type: ignore
from rclpy.executors import MultiThreadedExecutor # type: ignore
from rclpy.qos import QoSProfile # type: ignore

from controller.mpc_solver_node import jnp2arr # type: ignore
from controller.mpc_solver_node import jnp2arr # type: ignore
from interfaces.msg import SingleMotorControl, AllMotorsControl, TrunkRigidBodies
from interfaces.srv import ControlSolver


@jax.jit
def check_control_inputs(u_opt, u_opt_previous):
"""
Check control inputs for safety constraints, rejecting vector norms that are too large.
Expand Down Expand Up @@ -52,12 +53,13 @@ def check_control_inputs(u_opt, u_opt_previous):
norm_value = jnp.linalg.norm(vector_sum)

# Check the constraint: if the constraint is met, then keep previous control command
if norm_value > 0.8:
print(f'Sample {u_opt} got rejected')
u_opt = u_opt_previous
else:
# Else the clipped command is published
u_opt = jnp.array([u1, u2, u3, u4, u5, u6])
# if norm_value > 0.8:
# print(f'Sample {u_opt} got rejected')
# u_opt = u_opt_previous
# else:
# # Else the clipped command is published
# u_opt = jnp.array([u1, u2, u3, u4, u5, u6])
u_opt = jnp.where(norm_value > 0.8, u_opt_previous, jnp.array([u1, u2, u3, u4, u5, u6]))

return u_opt

Expand All @@ -83,19 +85,23 @@ def __init__(self):
self.rest_position = jnp.array([0.10056, -0.10541, 0.10350,
0.09808, -0.20127, 0.10645,
0.09242, -0.31915, 0.09713])

self.callback_group = ReentrantCallbackGroup()

# Subscribe to current positions
self.mocap_subscription = self.create_subscription(
TrunkRigidBodies,
'/trunk_rigid_bodies',
self.mocap_listener_callback,
QoSProfile(depth=10)
QoSProfile(depth=10),
callback_group=self.callback_group
)

# Create MPC solver service client
self.mpc_client = self.create_client(
ControlSolver,
'mpc_solver'
'mpc_solver',
callback_group=self.callback_group
)
self.get_logger().info('MPC client created.')
while not self.mpc_client.wait_for_service(timeout_sec=1.0):
Expand All @@ -112,12 +118,35 @@ def __init__(self):
)

# Maintain current observations because of the delay embedding
self.lastest_y = None
self.latest_y = None

# Maintain previous control inputs
self.uopt_previous = jnp.zeros(6)

self.get_logger().info('MPC node has been started.')
self.clock = self.get_clock()

# self.start_time = self.clock.now().nanoseconds / 1e9

# Need some initialization
self.initialized = False

# Initialize by calling mpc callback function
self.mpc_executor_callback()

# JIT compile this function
check_control_inputs(jnp.zeros(6), self.uopt_previous)

self.controller_period = 0.04

self.mpc_exec_timer = self.create_timer(
self.controller_period,
self.mpc_executor_callback,
# clock=self.clock,
callback_group=self.callback_group)

self.get_logger().info(f'MPC node has been started with controller frequency: {1/self.controller_period:.2f} [Hz].')

self.start_time = self.clock.now().nanoseconds / 1e9

def mocap_listener_callback(self, msg):
"""
Expand All @@ -139,53 +168,50 @@ def mocap_listener_callback(self, msg):
self.latest_y = jnp.tile(y_centered_tip, 4)
self.start_time = self.clock.now().nanoseconds / 1e9
else:
self.latest_y = jnp.concatenate([y_centered_tip, self.y[:-3]])

def control_loop(self):
self.latest_y = jnp.concatenate([y_centered_tip, self.latest_y[:-3]])

self.t0 = self.clock.now().nanoseconds / 1e9 - self.start_time

def mpc_executor_callback(self):
if not self.initialized:
self.send_request(0.0, jnp.zeros(12), wait=True)
self.future.add_done_callback(self.service_callback)
self.initialized = True
elif self.latest_y is not None:
# self.get_logger().info(f'Sent the request at {(self.clock.now().nanoseconds / 1e9 - self.start_time):.3f}')
# self.get_logger().info(f'Sent the request at {(time.time() - self.start_time):.3f}')
self.send_request(self.t0, self.latest_y, wait=False)
self.future.add_done_callback(self.service_callback)

def send_request(self, t0, y0, wait=False):
"""
Control loop running in its own thread. It periodically sends a request
to the MPC solver and publishes control commands when the response arrives.
Send request to MPC solver.
"""
while rclpy.ok():
if self.latest_y is None:
time.sleep(0.1)
continue

# Compute the current time t0 relative to start
t0 = self.get_clock().now().nanoseconds / 1e9 - self.start_time

# Prepare the service request
self.req.t0 = t0
self.req.y0 = jnp2arr(self.latest_y)

# Call the MPC service synchronously (since the solver is fast)
future = self.mpc_client.call_async(self.req)

# Wait for the future to complete without blocking other callbacks due to multithreading
rclpy.spin_until_future_complete(self, future)
if future.done():
try:
response = future.result()

# Check if the trajectory is done
if response.done:
self.get_logger().info(f'Trajectory finished at {t0:.3f} seconds. Shutting down.')
rclpy.shutdown()

# Apply safety checks on the control inputs
safe_control = check_control_inputs(response.uopt[:6], self.uopt_previous)
self.uopt_previous = safe_control

# Publish the safe control inputs
self.publish_control_inputs(safe_control.tolist())

if self.debug:
self.get_logger().info(f'Commanded control inputs: {safe_control.tolist()}')

except Exception as e:
self.get_logger().error(f'Failed to process MPC response: {e}')
self.req.t0 = t0
self.req.y0 = jnp2arr(y0)
self.future = self.mpc_client.call_async(self.req)

if wait:
# Synchronous call, not compatible for real-time applications
rclpy.spin_until_future_complete(self, self.future)

def service_callback(self, async_response):
try:
response = async_response.result()
# self.get_logger().info(f'Received uopt at: {(self.clock.now().nanoseconds / 1e9 - self.start_time):.3f} for t0: {response.t[0]:.3f}')

if response.done:
self.get_logger().info(f'Trajectory is finished! At {(self.clock.now().nanoseconds / 1e9 - self.start_time):.3f}')
self.destroy_node()
rclpy.shutdown()
else:
self.get_logger().warn("MPC service call did not complete in time.")
safe_control_inputs = check_control_inputs(jnp.array(response.uopt[:6]), self.uopt_previous)
self.publish_control_inputs(safe_control_inputs.tolist())
# self.get_logger().info(f'We command the control inputs: {safe_control_inputs.tolist()}.')
# self.get_logger().info(f'We would command the control inputs: {response.uopt[:6]}.')
self.uopt_previous = safe_control_inputs
except Exception as e:
self.get_logger().error(f'Service call failed: {e}.')

def publish_control_inputs(self, control_inputs):
"""
Expand Down
1 change: 1 addition & 0 deletions stack/main/src/executor/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'manual_decay_node = executor.manual_decay_node:main',
'adiabatic_manual_decay_node = executor.adiabatic_manual_decay_node:main',
'mpc_initializer_node = executor.mpc_initializer_node:main',
'mpc_node = executor.mpc_node:main',
'store_observations_node = executor.store_observations_node:main',
'test_mpc_node = executor.test_mpc_node:main',
],
Expand Down

0 comments on commit 3b3b0fa

Please sign in to comment.