Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Real2Sim Eval Digital Twins #536

Merged
merged 56 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
87f0a88
work
StoneT2000 May 4, 2024
d3f0514
work
StoneT2000 May 10, 2024
de231a7
Update base_env.py
StoneT2000 May 10, 2024
b6e022f
greenscreen trick added
StoneT2000 May 10, 2024
35bd2df
code refactors
StoneT2000 May 16, 2024
e217661
work
StoneT2000 May 16, 2024
134e748
Update widowx.py
StoneT2000 May 16, 2024
cb47d84
fixes
StoneT2000 May 16, 2024
1154b60
fixes
StoneT2000 May 16, 2024
42e9b99
Merge branch 'main' into simplerenv-port
StoneT2000 Aug 19, 2024
c86d3bb
Merge branch 'main' into simplerenv-port
StoneT2000 Aug 22, 2024
54a5c97
updates
StoneT2000 Aug 22, 2024
cbbc49c
bug fixes
StoneT2000 Aug 22, 2024
d92a62b
align sim configs
StoneT2000 Aug 22, 2024
3002995
fixes
StoneT2000 Aug 22, 2024
97a9509
Update demo_octo_eval.py
StoneT2000 Aug 22, 2024
70f63a3
debugged
StoneT2000 Aug 23, 2024
94a09e8
work
StoneT2000 Aug 23, 2024
8b7b254
bug fixes
StoneT2000 Aug 23, 2024
1183a87
attempt to support IK
StoneT2000 Aug 23, 2024
a58f73b
work
StoneT2000 Aug 27, 2024
039f80d
cleanup
StoneT2000 Aug 27, 2024
6f44d5e
work
StoneT2000 Aug 27, 2024
bb9817f
work
StoneT2000 Aug 29, 2024
a0cb03b
Merge branch 'main' into simplerenv-port
StoneT2000 Sep 6, 2024
62d4ec6
cleaned up code
StoneT2000 Sep 6, 2024
e7485f2
evals
StoneT2000 Sep 6, 2024
30eec6b
fixes
StoneT2000 Sep 6, 2024
54eddce
spoon task
StoneT2000 Sep 6, 2024
bfa607a
Update demo_octo_eval.py
StoneT2000 Sep 6, 2024
1033916
work
StoneT2000 Sep 7, 2024
33cc400
update widowx model download link and cleanup code
StoneT2000 Sep 10, 2024
a3ca0a4
fixes
StoneT2000 Sep 10, 2024
ebf1eae
work
StoneT2000 Sep 10, 2024
8dc56bd
bug fixes
StoneT2000 Sep 10, 2024
878a890
rt1 inference example
StoneT2000 Sep 10, 2024
fa1ca12
bug fixes
StoneT2000 Sep 11, 2024
d1f0893
less eggplant rolling
StoneT2000 Sep 11, 2024
f424c1d
code cleanup
StoneT2000 Sep 12, 2024
fa4fe75
Merge branch 'main' into simplerenv-port
StoneT2000 Sep 12, 2024
cd1aebd
GPU IK no delta controller implemented
StoneT2000 Sep 12, 2024
62e8b7a
gpu fixes
StoneT2000 Sep 12, 2024
9d39194
bug fixes
StoneT2000 Sep 12, 2024
00e471e
work
StoneT2000 Sep 12, 2024
effc8dc
fixes
StoneT2000 Sep 12, 2024
35e0a9e
work
StoneT2000 Sep 12, 2024
e64acf5
w
StoneT2000 Sep 12, 2024
64aab7f
cleanup
StoneT2000 Sep 12, 2024
0f926cd
code cleanup, assets added
StoneT2000 Sep 12, 2024
5ff9f75
docs
StoneT2000 Sep 13, 2024
365ef84
Delete demo_real2sim_eval.py
StoneT2000 Sep 13, 2024
607b836
f
StoneT2000 Sep 13, 2024
8eef346
Update base_env.py
StoneT2000 Sep 13, 2024
343f75c
Update base_env.py
StoneT2000 Sep 13, 2024
5e65ba8
Delete README.md
StoneT2000 Sep 13, 2024
2f2d077
Update index.md
StoneT2000 Sep 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion docs/source/tasks/digital_twins/index.md
Original file line number Diff line number Diff line change
@@ -1 +1,47 @@
# Digital Twins (WIP)
# Digital Twins

ManiSkill supports both training and evaluation types of digital twins and provides a simple framework for building them. Training digital twins are tasks designed to train a robot in simulation to then be deployed in the real world (sim2real). Evaluation digital twins are tasks designed to evaluate the performance of a robot trained on real world data (real2sim) and not for training.


## Training Digital Twins (WIP)

Coming soon.

## BridgeData v2 (Evaluation)

We currently support evaluation digital twins of some tasks in the [BridgeData v2](https://rail-berkeley.github.io/bridgedata/) environments in simulation based on [SimplerEnv](https://simpler-env.github.io/) by Xuanlin Li, Kyle Hsu, Jiayuan Gu et al. These digital twins are also GPU parallelized enabling large-scale, fast, evaluation of real-world generalist robotics policies. GPU simulation + rendering enables evaluating up to 60x faster than the real-world and 10x faster than CPU simulation, all without human supervision. ManiSkill only provides the environments, to run policy inference of models like Octo and RT see https://github.com/simpler-env/SimplerEnv/tree/maniskill3

If you use the BridgeData v2 digital twins please cite the following in addition to ManiSkill 3:

```
@article{li24simpler,
title={Evaluating Real-World Robot Manipulation Policies in Simulation},
author={Xuanlin Li and Kyle Hsu and Jiayuan Gu and Karl Pertsch and Oier Mees and Homer Rich Walke and Chuyuan Fu and Ishikaa Lunawat and Isabel Sieh and Sean Kirmani and Sergey Levine and Jiajun Wu and Chelsea Finn and Hao Su and Quan Vuong and Ted Xiao},
journal = {arXiv preprint arXiv:2405.05941},
year={2024},
}
```

### PutCarrotOnPlateInScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/PutCarrotOnPlateInScene-v1.mp4" type="video/mp4">
</video>

### PutSpoonOnTableClothInScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/PutSpoonOnTableClothInScene-v1.mp4" type="video/mp4">
</video>

### StackGreenCubeOnYellowCubeBakedTexInScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/StackGreenCubeOnYellowCubeBakedTexInScene-v1.mp4" type="video/mp4">
</video>

### PutEggplantInBasketScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/PutEggplantInBasketScene-v1.mp4" type="video/mp4">
</video>
4 changes: 4 additions & 0 deletions docs/source/user_guide/demos/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ Example below shows what it looks like with the GUI:

For more details check out the [motion planning page](../data_collection/motionplanning.md)

## Real2Sim Evaluation

ManiSkill3 supports extremely fast real2sim evaluation via GPU simulation + rendering of policies like RT-1 and Octo. See [this page](../../tasks/digital_twins/index.md) for more details on which environments are supported. To run inference of RT-1 and Octo, see the `maniskill3` branch of the [SimplerEnv Project](https://github.com/simpler-env/SimplerEnv/tree/maniskill3).

## Visualize Pointcloud Data

You can run the following to visualize the pointcloud observations (require's a display to work)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 7 additions & 6 deletions mani_skill/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _load_articulation(self):

if not os.path.exists(asset_path):
print(f"Robot {self.uid} definition file not found at {asset_path}")
if len(assets.DATA_GROUPS[self.uid]) > 0:
if self.uid in assets.DATA_GROUPS or len(assets.DATA_GROUPS[self.uid]) > 0:
response = download_asset.prompt_yes_no(
f"Robot {self.uid} has assets available for download. Would you like to download them now?"
)
Expand All @@ -181,13 +181,15 @@ def _load_articulation(self):
print(f"Exiting as assets for robot {self.uid} are not downloaded")
exit()
else:
print(f"Exiting as assets for robot {self.uid} are not found")
print(
f"Exiting as assets for robot {self.uid} are not found. Check that this agent is properly registered with the appropriate download asset ids"
)
exit()
self.robot: Articulation = loader.load(asset_path)
assert self.robot is not None, f"Fail to load URDF/MJCF from {asset_path}"

# Cache robot link ids
self.robot_link_ids = [link.name for link in self.robot.get_links()]
# Cache robot link names
self.robot_link_names = [link.name for link in self.robot.get_links()]

def _after_loading_articulation(self):
"""Called after loading articulation and before setting up any controllers. By default this is empty."""
Expand Down Expand Up @@ -337,7 +339,7 @@ def set_state(self, state: Dict, ignore_controller=False):
# -------------------------------------------------------------------------- #
def reset(self, init_qpos: torch.Tensor = None):
"""
Reset the robot to a clean state with zero velocity and forces. Furthermore it resets the current active controller.
Reset the robot to a clean state with zero velocity and forces.

Args:
init_qpos (torch.Tensor): The initial qpos to set the robot to. If None, the robot's qpos is not changed.
Expand All @@ -346,7 +348,6 @@ def reset(self, init_qpos: torch.Tensor = None):
self.robot.set_qpos(init_qpos)
self.robot.set_qvel(torch.zeros(self.robot.max_dof, device=self.device))
self.robot.set_qf(torch.zeros(self.robot.max_dof, device=self.device))
self.controller.reset()

# -------------------------------------------------------------------------- #
# Optional per-agent APIs, implemented depending on agent affordances
Expand Down
18 changes: 1 addition & 17 deletions mani_skill/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ def _check_gpu_sim_works(self):
assert (
self.config.frame == "root_translation"
), "currently only translation in the root frame for EE control is supported in GPU sim"
assert (
self.config.use_delta == True
), "currently only delta EE control is supported in GPU sim"
assert (
self.config.use_target == False
), "Currently cannot take actions relative to last target pose in GPU sim"

def _initialize_joints(self):
self.initial_qpos = None
Expand Down Expand Up @@ -111,6 +105,7 @@ def set_action(self, action: Array):
self.articulation.get_qpos(),
pos_only=pos_only,
action=action,
use_delta_ik_solver=self.config.use_delta and not self.config.use_target,
)
if self._target_qpos is None:
self._target_qpos = self._start_qpos
Expand Down Expand Up @@ -179,12 +174,6 @@ def _check_gpu_sim_works(self):
assert (
self.config.frame == "root_translation:root_aligned_body_rotation"
), "currently only translation in the root frame for EE control is supported in GPU sim"
assert (
self.config.use_delta == True
), "currently only delta EE control is supported in GPU sim"
assert (
self.config.use_target == False
), "Currently cannot take actions relative to last target pose in GPU sim"

def _initialize_action_space(self):
low = np.float32(
Expand Down Expand Up @@ -219,11 +208,6 @@ def _clip_and_scale_action(self, action):
rot_action = rot_action * self.config.rot_lower
return torch.hstack([pos_action, rot_action])

def compute_ik(self, target_pose: Pose, action: Array, max_iterations=100):
return super().compute_ik(
target_pose, action, pos_only=False, max_iterations=max_iterations
)

def compute_target_pose(self, prev_ee_pose_at_base: Pose, action):
if self.config.use_delta:
delta_pos, delta_rot = action[:, 0:3], action[:, 3:6]
Expand Down
71 changes: 53 additions & 18 deletions mani_skill/agents/controllers/utils/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _setup_cpu(self):
def _setup_gpu(self):
"""setup the kinematics solvers on the GPU"""
self.use_gpu_ik = True
with open(self.urdf_path, "r") as f:
with open(self.urdf_path, "rb") as f:
urdf_str = f.read()

# NOTE (stao): it seems that the pk library currently always outputs some complaints if there are unknown attributes in a URDF. Hide it with this contextmanager here
Expand All @@ -107,34 +107,69 @@ def suppress_stdout_stderr():
urdf_str,
end_link_name=self.end_link.name,
).to(device=self.device)
lim = torch.tensor(self.pk_chain.get_joint_limits(), device=self.device)
self.pik = pk.PseudoInverseIK(
self.pk_chain,
joint_limits=lim.T,
early_stopping_any_converged=True,
max_iterations=200,
num_retries=1,
)

self.qmask = torch.zeros(
len(self.active_ancestor_joints), dtype=bool, device=self.device
)
self.qmask[self.controlled_joints_idx_in_qmask] = 1

def compute_ik(
self, target_pose: Pose, q0: torch.Tensor, pos_only: bool = False, action=None
self,
target_pose: Pose,
q0: torch.Tensor,
pos_only: bool = False,
action=None,
use_delta_ik_solver: bool = False,
):
"""Given a target pose, via inverse kinematics compute the target joint positions that will achieve the target pose"""
"""Given a target pose, via inverse kinematics compute the target joint positions that will achieve the target pose

Args:
target_pose (Pose): target pose of the end effector in the world frame. note this is not relative to the robot base frame!
q0 (torch.Tensor): initial joint positions of every active joint in the articulation
pos_only (bool): if True, only the position of the end link is considered in the IK computation
action (torch.Tensor): delta action to be applied to the articulation. Used for fast delta IK solutions on the GPU.
use_delta_ik_solver (bool): If true, returns the target joint positions that correspond with a delta IK solution. This is specifically
used for GPU simulation to determine which GPU IK algorithm to use.
"""
if self.use_gpu_ik:
q0 = q0[:, self.active_ancestor_joint_idxs]
jacobian = self.pk_chain.jacobian(q0)
# code commented out below is the fast kinematics method
# jacobian = (
# self.fast_kinematics_model.jacobian_mixed_frame_pytorch(
# self.articulation.get_qpos()[:, self.active_ancestor_joint_idxs]
# )
# .view(-1, len(self.active_ancestor_joints), 6)
# .permute(0, 2, 1)
# )
# jacobian = jacobian[:, :, self.qmask]
if pos_only:
jacobian = jacobian[:, 0:3]
if not use_delta_ik_solver:
tf = pk.Transform3d(
pos=target_pose.p,
rot=target_pose.q,
device=self.device,
)
self.pik.initial_config = q0 # shape (num_retries, active_ancestor_dof)
result = self.pik.solve(
tf
) # produce solutions in shape (B, num_retries/initial_configs, active_ancestor_dof)
# TODO return mask for invalid solutions. CPU returns None at the moment
return result.solutions[:, 0, :]
else:
jacobian = self.pk_chain.jacobian(q0)
# code commented out below is the fast kinematics method
# jacobian = (
# self.fast_kinematics_model.jacobian_mixed_frame_pytorch(
# self.articulation.get_qpos()[:, self.active_ancestor_joint_idxs]
# )
# .view(-1, len(self.active_ancestor_joints), 6)
# .permute(0, 2, 1)
# )
# jacobian = jacobian[:, :, self.qmask]
if pos_only:
jacobian = jacobian[:, 0:3]

# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return q0 + delta_joint_pos.squeeze(-1)
# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return q0 + delta_joint_pos.squeeze(-1)
else:
result, success, error = self.pmodel.compute_inverse_kinematics(
self.end_link_idx,
Expand Down
13 changes: 13 additions & 0 deletions mani_skill/agents/robots/panda/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def _controller_configs(self):
ee_link=self.ee_link_name,
urdf_path=self.urdf_path,
)
arm_pd_ee_pose = PDEEPoseControllerConfig(
joint_names=self.arm_joint_names,
pos_lower=None,
pos_upper=None,
stiffness=self.arm_stiffness,
damping=self.arm_damping,
force_limit=self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path,
use_delta=False,
normalize_action=False,
)

arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos)
arm_pd_ee_target_delta_pos.use_target = True
Expand Down Expand Up @@ -180,6 +192,7 @@ def _controller_configs(self):
pd_ee_delta_pose=dict(
arm=arm_pd_ee_delta_pose, gripper=gripper_pd_joint_pos
),
pd_ee_pose=dict(arm=arm_pd_ee_pose, gripper=gripper_pd_joint_pos),
# TODO(jigu): how to add boundaries for the following controllers
pd_joint_target_delta_pos=dict(
arm=arm_pd_joint_target_delta_pos, gripper=gripper_pd_joint_pos
Expand Down
50 changes: 48 additions & 2 deletions mani_skill/agents/robots/widowx/widowx.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,62 @@
import numpy as np
import sapien
import torch

from mani_skill import ASSET_DIR
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.agents.controllers import *
from mani_skill.agents.registration import register_agent
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common
from mani_skill.utils.structs.actor import Actor


# TODO (stao) (xuanlin): model it properly based on real2sim
@register_agent(asset_download_ids=["widowx250s"])
class WidowX250S(BaseAgent):
uid = "widowx250s"
urdf_path = f"{ASSET_DIR}/robots/widowx250s/wx250s.urdf"
urdf_path = f"{ASSET_DIR}/robots/widowx/wx250s.urdf"
urdf_config = dict()

arm_joint_names = [
"waist",
"shoulder",
"elbow",
"forearm_roll",
"wrist_angle",
"wrist_rotate",
]
gripper_joint_names = ["left_finger", "right_finger"]

def _after_loading_articulation(self):
self.finger1_link = self.robot.links_map["left_finger_link"]
self.finger2_link = self.robot.links_map["right_finger_link"]

def is_grasping(self, object: Actor, min_force=0.5, max_angle=85):
"""Check if the robot is grasping an object

Args:
object (Actor): The object to check if the robot is grasping
min_force (float, optional): Minimum force before the robot is considered to be grasping the object in Newtons. Defaults to 0.5.
max_angle (int, optional): Maximum angle of contact to consider grasping. Defaults to 85.
"""
l_contact_forces = self.scene.get_pairwise_contact_forces(
self.finger1_link, object
)
r_contact_forces = self.scene.get_pairwise_contact_forces(
self.finger2_link, object
)
lforce = torch.linalg.norm(l_contact_forces, axis=1)
rforce = torch.linalg.norm(r_contact_forces, axis=1)

# direction to open the gripper
ldirection = self.finger1_link.pose.to_transformation_matrix()[..., :3, 1]
rdirection = -self.finger2_link.pose.to_transformation_matrix()[..., :3, 1]
langle = common.compute_angle_between(ldirection, l_contact_forces)
rangle = common.compute_angle_between(rdirection, r_contact_forces)
lflag = torch.logical_and(
lforce >= min_force, torch.rad2deg(langle) <= max_angle
)
rflag = torch.logical_and(
rforce >= min_force, torch.rad2deg(rangle) <= max_angle
)
return torch.logical_and(lflag, rflag)
9 changes: 7 additions & 2 deletions mani_skill/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def __init__(
torch.zeros(self.num_envs, device=self.device, dtype=torch.int32)
)
obs, _ = self.reset(seed=2022, options=dict(reconfigure=True))

self._init_raw_obs = common.to_cpu_tensor(obs)
"""the raw observation returned by the env.reset (a cpu torch tensor/dict of tensors). Useful for future observation wrappers to use to auto generate observation spaces"""
self._init_raw_state = common.to_cpu_tensor(self.get_state_dict())
Expand Down Expand Up @@ -549,9 +550,9 @@ def _get_obs_with_sensor_data(self, info: Dict, apply_texture_transforms: bool =
)

@property
def robot_link_ids(self):
def robot_link_names(self):
"""Get link ids for the robot. This is used for segmentation observations."""
return self.agent.robot_link_ids
return self.agent.robot_link_names

# -------------------------------------------------------------------------- #
# Reward mode
Expand Down Expand Up @@ -810,6 +811,10 @@ def reset(self, seed=None, options=None):
self.scene._gpu_apply_all()
self.scene.px.gpu_update_articulation_kinematics()
self.scene._gpu_fetch_all()

# we reset controllers here because some controllers depend on the agent/articulation qpos/poses
self.agent.controller.reset()

obs = self.get_obs()

return obs, dict(reconfigure=reconfigure)
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _load_articulation(self):
assert self.robot is not None, f"Fail to load URDF/MJCF from {asset_path}"

# Cache robot link ids
self.robot_link_ids = [link.name for link in self.robot.get_links()]
self.robot_link_names = [link.name for link in self.robot.get_links()]


# @register_env("MS-CartPole-v1", max_episode_steps=500)
Expand Down
Loading