Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RL platoon example #2141

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/e10_drive/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Tuple

import gymnasium as gym
import numpy as np
Expand Down
124 changes: 14 additions & 110 deletions examples/e11_platoon/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Tuple

import gymnasium as gym
import numpy as np

from smarts.core.agent_interface import RGB
from smarts.core.colors import Colors, SceneColors
from smarts.core.utils.observations import points_to_pixels, replace_rgb_image_color


class FilterObs:
Expand Down Expand Up @@ -72,19 +73,19 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
# Get rgb image, remove road, and replace other egos (if any) as background vehicles
rgb = obs["top_down_rgb"]
h, w, _ = rgb.shape
rgb_noroad = replace_color(rgb=rgb, old_color=[self._road_color, self._lane_divider_color, self._edge_divider_color], new_color=self._no_color)
rgb_ego = replace_color(rgb=rgb_noroad, old_color=[self._ego_color], new_color=self._traffic_color, mask=self._rgb_mask)
rgb_noroad = replace_rgb_image_color(rgb=rgb, old_color=[self._road_color, self._lane_divider_color, self._edge_divider_color], new_color=self._no_color)
rgb_ego = replace_rgb_image_color(rgb=rgb_noroad, old_color=[self._ego_color], new_color=self._traffic_color, mask=self._rgb_mask)

# Superimpose waypoints onto rgb image
wps = obs["waypoint_paths"]["position"][0:11, 3:, 0:3]
for path in wps[:]:
wps_valid = points_to_pixels(
points=path,
ego_pos=ego_pos,
ego_heading=ego_heading,
w=w,
h=h,
res=self._res,
center_position=ego_pos,
heading=ego_heading,
width=w,
height=h,
resolution=self._res,
)
for point in wps_valid:
img_x, img_y = point[0], point[1]
Expand All @@ -95,11 +96,11 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
if not all((goal:=obs["ego_vehicle_state"]["mission"]["goal_position"]) == np.zeros((3,))):
goal_pixel = points_to_pixels(
points=np.expand_dims(goal,axis=0),
ego_pos=ego_pos,
ego_heading=ego_heading,
w=w,
h=h,
res=self._res,
center_position=ego_pos,
heading=ego_heading,
width=w,
height=h,
resolution=self._res,
)
if len(goal_pixel) != 0:
img_x, img_y = goal_pixel[0][0], goal_pixel[0][1]
Expand All @@ -121,100 +122,3 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
return filtered_obs
# fmt: on


def replace_color(
rgb: np.ndarray,
old_color: Sequence[np.ndarray],
new_color: np.ndarray,
mask: np.ndarray = np.ma.nomask,
) -> np.ndarray:
"""Convert pixels of value `old_color` to `new_color` within the masked
region in the received RGB image.

Args:
rgb (np.ndarray): RGB image. Shape = (m,n,3).
old_color (Sequence[np.ndarray]): List of old colors to be removed from the RGB image. Shape = (3,).
new_color (np.ndarray): New color to be added to the RGB image. Shape = (3,).
mask (np.ndarray, optional): Valid regions for color replacement. Shape = (m,n,3).
Defaults to np.ma.nomask .

Returns:
np.ndarray: RGB image with `old_color` pixels changed to `new_color`
within the masked region. Shape = (m,n,3).
"""
# fmt: off
assert all(color.shape == (3,) for color in old_color), (
f"Expected old_color to be of shape (3,), but got {[color.shape for color in old_color]}.")
assert new_color.shape == (3,), (
f"Expected new_color to be of shape (3,), but got {new_color.shape}.")

nc = new_color.reshape((1, 1, 3))
nc_array = np.full_like(rgb, nc)
rgb_masked = np.ma.MaskedArray(data=rgb, mask=mask)

rgb_condition = rgb_masked
result = rgb
for color in old_color:
result = np.ma.where((rgb_condition == color.reshape((1, 1, 3))).all(axis=-1)[..., None], nc_array, result)

return result
# fmt: on


def points_to_pixels(
points: np.ndarray,
ego_pos: np.ndarray,
ego_heading: float,
w: int,
h: int,
res: float,
) -> np.ndarray:
"""Converts points into pixel coordinates in order to superimpose the
points onto the RGB image.

Args:
points (np.ndarray): Array of points. Shape (n,3).
ego_pos (np.ndarray): Ego position. Shape = (3,).
ego_heading (float): Ego heading in radians.
w (int): Width of RGB image
h (int): Height of RGB image.
res (float): Resolution of RGB image in meters/pixels. Computed as
ground_size/image_size.

Returns:
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
"""
# fmt: off
mask = [False if all(point == np.zeros(3,)) else True for point in points]
points_nonzero = points[mask]
points_delta = points_nonzero - ego_pos
points_rotated = rotate_axes(points_delta, theta=ego_heading)
points_pixels = points_rotated / np.array([res, res, res])
points_overlay = np.array([w / 2, h / 2, 0]) + points_pixels * np.array([1, -1, 1])
points_rfloat = np.rint(points_overlay)
points_valid = points_rfloat[(points_rfloat[:,0] >= 0) & (points_rfloat[:,0] < w) & (points_rfloat[:,1] >= 0) & (points_rfloat[:,1] < h)]
points_rint = points_valid.astype(int)
return points_rint
# fmt: on


def rotate_axes(points: np.ndarray, theta: float) -> np.ndarray:
"""A counterclockwise rotation of the x-y axes by an angle theta θ about
the z-axis.

Args:
points (np.ndarray): x,y,z coordinates in original axes. Shape = (n,3).
theta (np.float): Axes rotation angle in radians.

Returns:
np.ndarray: x,y,z coordinates in rotated axes. Shape = (n,3).
"""
# fmt: off
theta = (theta + np.pi) % (2 * np.pi) - np.pi
ct, st = np.cos(theta), np.sin(theta)
R = np.array([[ ct, st, 0],
[-st, ct, 0],
[ 0, 0, 1]])
rotated_points = (R.dot(points.T)).T
return rotated_points
# fmt: on
14 changes: 8 additions & 6 deletions smarts/core/utils/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def points_to_pixels(

Args:
points (np.ndarray): Array of points. Shape (n,3).
ego_pos (np.ndarray): Ego position. Shape = (3,).
ego_heading (float): Ego heading in radians.
w (int): Width of RGB image
h (int): Height of RGB image.
res (float): Resolution of RGB image in meters/pixels. Computed as
ground_size/image_size.
center_position (np.ndarray): Center position of image. Generally, this
is equivalent to ego position. Shape = (3,).
heading (float): Heading of image in radians. Generally, this is
equivalent to ego heading.
width (int): Width of RGB image
height (int): Height of RGB image.
resolution (float): Resolution of RGB image in meters/pixels. Computed
as ground_size/image_size.

Returns:
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
Expand Down
Loading