Skip to content

Commit

Permalink
Add visualization code for MPC test
Browse files Browse the repository at this point in the history
  • Loading branch information
hbuurmei committed Feb 13, 2025
1 parent aee3e4d commit 5323958
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 5 deletions.
172 changes: 172 additions & 0 deletions stack/main/scripts/visualize_mpc_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def latex_label(text):\n",
" \"\"\"\n",
" Converts regular text to a LaTeX-compatible label, handling LaTeX math mode correctly.\n",
" \"\"\"\n",
" # Split text into math ($...$) and non-math segments\n",
" parts = re.split(r'(\\$.+?\\$)', text)\n",
" formatted_parts = []\n",
"\n",
" for part in parts:\n",
" if part.startswith('$') and part.endswith('$'):\n",
" # It's a math mode segment, add it as is\n",
" formatted_parts.append(part)\n",
" else:\n",
" # Replace spaces with ~ and wrap non-math text with \\mathrm{}\n",
" clean_fragment = part.replace(' ', '~')\n",
" if clean_fragment:\n",
" latex_command = r\"\\mathrm{\" + clean_fragment + \"}\"\n",
" formatted_parts.append(f'${latex_command}$')\n",
" # formatted_parts.append(f'${r\"\\mathrm{\" + clean_fragment + \"}\"}$')\n",
"\n",
" return ''.join(formatted_parts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_mpc_trajectory(ts, z_ref, z_mpc, z_true, u_mpc, N, plot_controls=True):\n",
" \"\"\"\n",
" Plot the MPC trajectory for 2D or 3D data.\n",
" \"\"\"\n",
" # Determine the dimension (2D or 3D)\n",
" dim = z_true.shape[1]\n",
"\n",
" # Set up the plotting environment\n",
" if plot_controls:\n",
" if dim == 2:\n",
" fig, ax = plt.subplots(1, 2, figsize=(10, 4))\n",
" else:\n",
" fig = plt.figure(figsize=(12, 5))\n",
" ax = [fig.add_subplot(1, 2, 1, projection='3d'), fig.add_subplot(1, 2, 2)]\n",
" else:\n",
" if dim == 2:\n",
" fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n",
" ax = [ax]\n",
" else:\n",
" fig = plt.figure(figsize=(6, 5))\n",
" ax = [fig.add_subplot(1, 1, 1, projection='3d')]\n",
"\n",
" # Plot the MPC trajectories\n",
" for t_idx in range(len(ts) - 2 * N):\n",
" if dim == 2:\n",
" ax[0].plot(z_mpc[t_idx, :, 0], z_mpc[t_idx, :, 1], '--*', color='k', markersize=3,\n",
" label=latex_label('MPC') if t_idx == 0 else None)\n",
" else:\n",
" ax[0].plot3D(z_mpc[t_idx, :, 0], z_mpc[t_idx, :, 1], z_mpc[t_idx, :, 2], '--*', color='k', markersize=3,\n",
" label=latex_label('MPC') if t_idx == 0 else None)\n",
"\n",
" # Plot the true trajectory, start point, and reference\n",
" if dim == 2:\n",
" ax[0].plot(z_true[:-N, 0], z_true[:-N, 1], '-o', label=latex_label('True'), markersize=3)\n",
" ax[0].plot(z_true[0, 0], z_true[0, 1], 'ro', label=latex_label('Start'), markersize=6)\n",
" ax[0].plot(z_ref[:-N, 0], z_ref[:-N, 1], 'y--', label=latex_label('Reference'))\n",
" ax[0].set_xlabel(latex_label('X [m]'))\n",
" ax[0].set_ylabel(latex_label('Y [m]'))\n",
" ax[0].axis('equal')\n",
" else:\n",
" ax[0].plot3D(z_true[:-N, 0], z_true[:-N, 1], z_true[:-N, 2], '-o', label=latex_label('True'), markersize=3)\n",
" ax[0].scatter(z_true[0, 0], z_true[0, 1], z_true[0, 2], color='r', label=latex_label('Start'), s=36)\n",
" ax[0].plot3D(z_ref[:-N, 0], z_ref[:-N, 1], z_ref[:-N, 2], 'y--', label=latex_label('Reference'))\n",
" ax[0].set_xlabel(latex_label('X [m]'))\n",
" ax[0].set_ylabel(latex_label('Y [m]'))\n",
" ax[0].set_zlabel(latex_label('Z [m]'))\n",
" ax[0].set_box_aspect([1, 1, 1])\n",
" ax[0].view_init(elev=115, azim=-115, roll=-25)\n",
" ax[0].legend()\n",
"\n",
" # Plot the control inputs\n",
" if plot_controls:\n",
" for u_idx in range(u_mpc.shape[-1]):\n",
" ax[1].plot(ts, u_mpc[:, 0, u_idx], label=latex_label(f'u_{u_idx+1}(t)'))\n",
" ax[1].set_xlabel(latex_label('t [s]'))\n",
" ax[1].set_ylabel(latex_label('U'))\n",
" ax[1].legend()\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_ref_trajectory(T, dt, traj_type, size):\n",
" \"\"\"\n",
" Generate a 3D reference trajectory for the system to track.\n",
" \"\"\"\n",
" t = np.linspace(0, T, int(T/dt))\n",
" z_ref = np.zeros((len(t), 3))\n",
"\n",
" # Note that y is up\n",
" if traj_type == 'circle':\n",
" z_ref[:, 0] = size * (np.cos(2 * np.pi / T * t) - 1)\n",
" z_ref[:, 1] = size / 2 * np.ones_like(t)\n",
" z_ref[:, 2] = size * np.sin(2 * np.pi / T * t)\n",
" elif traj_type == 'figure_eight':\n",
" z_ref[:, 0] = size * np.sin(2 * np.pi / T * t)\n",
" z_ref[:, 1] = size / 2 * np.ones_like(t)\n",
" z_ref[:, 2] = size * np.sin(4 * np.pi / T * t)\n",
" elif traj_type == 'periodic_line':\n",
" m = -1\n",
" z_ref[:, 0] = size * np.sin(2 * np.pi / T * t)\n",
" z_ref[:, 1] = np.zeros_like(t)\n",
" z_ref[:, 2] = m * size * np.sin(2 * np.pi / T * t)\n",
" elif traj_type == 'arc':\n",
" m = 1\n",
" l_trunk = 0.35\n",
" R = l_trunk / 2\n",
" z_ref[:, 0] = size * np.sin(2 * np.pi / T * t)\n",
" z_ref[:, 2] = m * size * np.sin(2 * np.pi / T * t)\n",
" z_ref[:, 1] = R - np.sqrt(R**2 - z_ref[:, 0]**2 - z_ref[:, 0]**2)\n",
" else:\n",
" raise ValueError('Invalid trajectory type: ' + traj_type + '. Valid options are: \"circle\" or \"figure_eight\".')\n",
" return z_ref, t"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate reference trajectory (check that it's the same as used in mpc_initializer_node.py)\n",
"z_ref, t = generate_ref_trajectory(10, 0.01, 'figure_eight', 0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
34 changes: 29 additions & 5 deletions stack/main/src/executor/executor/test_mpc_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import csv
import jax
import jax.numpy as jnp
import logging
Expand Down Expand Up @@ -65,6 +66,10 @@ def __init__(self):
self.results_name = self.get_parameter('results_name').value
self.data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')

# Initialize the CSV file
self.results_file = os.path.join(self.data_dir, f"trajectories/test_mpc/{self.results_name}.csv")
self.initialize_csv()

# Key to control randomness in added noise
self.rnd_key = jax.random.key(seed=0)

Expand Down Expand Up @@ -118,9 +123,9 @@ def mpc_executor_callback(self):
self.future.add_done_callback(self.service_callback)
self.initialized = True
else:
t0 = self.clock.now().nanoseconds / 1e9 - self.start_time
self.update_observations(t0)
self.send_request(t0, self.latest_y, wait=False)
self.t0 = self.clock.now().nanoseconds / 1e9 - self.start_time
self.update_observations(self.t0)
self.send_request(self.t0, self.latest_y, wait=False)
self.future.add_done_callback(self.service_callback)

def send_request(self, t0, y0, wait=False):
Expand Down Expand Up @@ -150,8 +155,11 @@ def service_callback(self, async_response):
safe_control_inputs = check_control_inputs(jnp.array(response.uopt[:6]), self.uopt_previous)
self.uopt_previous = safe_control_inputs

# Save the predicted observations
self.topt, self.zopt = arr2jnp(response.t, 1, squeeze=True), arr2jnp(response.zopt, 3)
# Save the predicted observations and control inputs
topt, zopt, uopt = response.t, response.zopt, response.uopt
y0 = self.latest_y[:3].tolist()
self.save_to_csv(self.t0, y0, topt, zopt, uopt)
self.topt, self.zopt = arr2jnp(topt, 1, squeeze=True), arr2jnp(zopt, 3)

except Exception as e:
self.get_logger().error(f'Service call failed: {e}.')
Expand Down Expand Up @@ -182,6 +190,22 @@ def update_observations(self, t0, eps_noise=1e-4):
# Otherwise we concatenate the new observations with the old ones
self.latest_y = jnp.concatenate([jnp.flip(y_centered_tip.T, 1).T.flatten(), self.latest_y[:(4-N_new_obs)*3]])

def initialize_csv(self):
"""
Initialize the CSV file with headers.
"""
with open(self.results_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['t0', 'y_latest', 'topt', 'zopt', 'uopt'])

def save_to_csv(self, t0, y0, topt, zopt, uopt):
"""
Save data to the CSV file.
"""
with open(self.results_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([t0, y0, topt, zopt, uopt])


def main(args=None):
"""
Expand Down

0 comments on commit 5323958

Please sign in to comment.