Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expanded reset logic for Simulator #2451

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/esp/physics/PhysicsManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,17 @@ class PhysicsManager : public std::enable_shared_from_this<PhysicsManager> {

/**
* @brief Reset the simulation and physical world.
* Sets the @ref worldTime_ to 0.0, does not change physical state.
* Sets the @ref worldTime_ to 0.0, changes the physical state of all objects back to their initial states. Only changes motion_type when scene_instance specified a motion type.
*/
virtual void reset() {
/* TODO: reset object states or clear them? Other? */
// reset object states from initial values (e.g. from scene instance)
worldTime_ = 0.0;
for (const auto& bro : existingObjects_) {
bro.second->resetStateFromSceneInstanceAttr();
}
for (const auto& bao : existingArticulatedObjects_) {
bao.second->resetStateFromSceneInstanceAttr();
}
}

/** @brief Stores references to a set of drawable elements. */
Expand Down
3 changes: 1 addition & 2 deletions src/esp/physics/RigidObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ class RigidObject : public RigidBase {
VelocityControl::ptr getVelocityControl() { return velControl_; };

/**
* @brief Set the object's state from a @ref
* esp::metadata::attributes::SceneObjectInstanceAttributes
* @brief Set the object's state from a @ref esp::metadata::attributes::SceneObjectInstanceAttributes
*/
void resetStateFromSceneInstanceAttr() override;

Expand Down
27 changes: 23 additions & 4 deletions src/esp/physics/bullet/BulletArticulatedObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,24 @@ void BulletArticulatedObject::resetStateFromSceneInstanceAttr() {
if (attrObjMotionType != physics::MotionType::UNDEFINED) {
setMotionType(attrObjMotionType);
}
// set initial joint positions
// get array of existing joint dofs

// first clear all joint positions
setJointPositions(std::vector<float>(size_t(btMultiBody_->getNumPosVars())));
// then add back the ones for unit quaternions
int posCount = 0;
float quat_init[] = {0, 0, 0, 1};
for (int i = 0; i < btMultiBody_->getNumLinks(); ++i) {
auto& link = btMultiBody_->getLink(i);
if (link.m_posVarCount == 4) {
// special handling for quaternions in spherical joints
btMultiBody_->setJointPosMultiDof(i, const_cast<float*>(quat_init));
posCount += link.m_posVarCount;
}
}

// set initial joint positions from instance config if applicable
std::vector<float> aoJointPose = getJointPositions();
// get instance-specified initial joint positions
// get instance-specified initial joint velocities
const auto& initJointPos = sceneObjInstanceAttr->getInitJointPose();
for (size_t i = 0; i < initJointPos.size(); ++i) {
if (i >= aoJointPose.size()) {
Expand All @@ -295,8 +309,9 @@ void BulletArticulatedObject::resetStateFromSceneInstanceAttr() {
}
setJointPositions(aoJointPose);

// first clear all joint vels
setJointVelocities(std::vector<float>(size_t(btMultiBody_->getNumDofs())));
// set initial joint velocities
// get array of existing joint vel dofs
std::vector<float> aoJointVels = getJointVelocities();
// get instance-specified initial joint velocities
std::vector<float> initJointVels =
Expand All @@ -311,6 +326,10 @@ void BulletArticulatedObject::resetStateFromSceneInstanceAttr() {
}
aoJointVels[i] = initJointVels[i];
}
setJointVelocities(aoJointVels);

// clear any forces
setJointForces(std::vector<float>(size_t(btMultiBody_->getNumDofs())));

} // BulletArticulatedObject::resetStateFromSceneInstanceAttr

Expand Down
4 changes: 3 additions & 1 deletion src/esp/sim/Simulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,9 @@ bool Simulator::instanceArticulatedObjectsForSceneAttributes(

void Simulator::reset() {
if (physicsManager_ != nullptr) {
// Note: only resets time to 0 by default.
// Note: resets time to 0 and all existing objects set back to initial
// states. Does not add back deleted objects or delete added objects. Does
// not break ManagedObject pointers.
physicsManager_->reset();
}

Expand Down
8 changes: 8 additions & 0 deletions src/esp/sim/Simulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ class Simulator {

void reconfigure(const SimulatorConfiguration& cfg);

/**
* @brief Reset the simulation state including the state of all physics
* objects, agents, and the default light setup.
* Sets the @ref worldTime_ to 0.0, changes the physical state of all objects back to their initial states.
* Does not invalidate existing ManagedObject wrappers.
* Does not add or remove object instances.
* Only changes motion_type when scene_instance specified a motion type.
*/
void reset();

void seed(uint32_t newSeed);
Expand Down
15 changes: 7 additions & 8 deletions src/tests/PhysicsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,13 +648,12 @@ void PhysicsTest::testRemoveSleepingSupport() {
cubes[0]->setMotionType(esp::physics::MotionType::STATIC);

for (int testCase = 0; testCase < 2; ++testCase) {
// reset time to 0, should not otherwise modify state
physicsManager_->reset();
float currentTime = physicsManager_->getWorldTime();
CORRADE_COMPARE_AS(physicsManager_->getNumRigidObjects(), 0,
Cr::TestSuite::Compare::Greater);

// simulate to stabilize the stack and populate collision islands
while (physicsManager_->getWorldTime() < 4.0) {
while (physicsManager_->getWorldTime() < currentTime + 4.0) {
physicsManager_->stepPhysics(0.1);
}

Expand Down Expand Up @@ -924,15 +923,15 @@ void PhysicsTest::testVelocityControl() {
objectWrapper->resetTransformation();
objectWrapper->setTranslation(Magnum::Vector3{0, 2.0, 0});
physicsManager_->setGravity({}); // 0 gravity interference
physicsManager_->reset(); // reset time to 0

// should closely follow kinematic result while uninhibited in 0 gravity
float targetTime = 0.5;
float currentTime = physicsManager_->getWorldTime();
esp::core::RigidState initialObjectState(objectWrapper->getRotation(),
objectWrapper->getTranslation());
esp::core::RigidState kinematicResult =
velControl->integrateTransform(targetTime, initialObjectState);
while (physicsManager_->getWorldTime() < targetTime) {
while (physicsManager_->getWorldTime() < currentTime + targetTime) {
physicsManager_->stepPhysics(physicsManager_->getTimestep());
}
CORRADE_COMPARE_AS(
Expand All @@ -946,7 +945,7 @@ void PhysicsTest::testVelocityControl() {

// should then get blocked by ground plane collision
targetTime = 2.0;
while (physicsManager_->getWorldTime() < targetTime) {
while (physicsManager_->getWorldTime() < currentTime + targetTime) {
physicsManager_->stepPhysics(physicsManager_->getTimestep());
}
CORRADE_COMPARE_AS(objectWrapper->getTranslation()[1], 1.0 - errorEps,
Expand All @@ -964,8 +963,8 @@ void PhysicsTest::testVelocityControl() {
velControl->linVelIsLocal = true;

targetTime = 10.0;
physicsManager_->reset(); // reset time to 0
while (physicsManager_->getWorldTime() < targetTime) {
float currentTime = physicsManager_->getWorldTime();
while (physicsManager_->getWorldTime() < currentTime + targetTime) {
physicsManager_->stepPhysics(physicsManager_->getTimestep());
}

Expand Down
11 changes: 11 additions & 0 deletions src_python/habitat_sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ def reset(self, agent_ids: Optional[int] = None) -> ObservationDict:
def reset(
self, agent_ids: Union[Optional[int], List[int]] = None
) -> Union[ObservationDict, Dict[int, ObservationDict],]:
"""
Reset the simulation state including the state of all physics objects, agents, and the default light setup.
Sets the world time to 0.0, changes the physical state of all objects back to their initial states.
Does not invalidate existing ManagedObject wrappers.
Does not add or remove object instances.
Only changes motion_type when scene_instance specified a motion type.

:param agent_ids: An optional list of agent ids for which to return the sensor observations. If none is provide, default agent is used.

:return: Sensor observations in the reset state.
"""
super().reset()
for i in range(len(self.agents)):
self.reset_agent(i)
Expand Down
78 changes: 73 additions & 5 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,92 @@ def test_sim_reset(make_cfg_settings):
mm = habitat_sim.metadata.MetadataMediator(hab_cfg.sim_cfg)
hab_cfg_mm.metadata_mediator = mm

def check_isclose(val1, val2):
return np.isclose(val1, val2, rtol=1e-4).all()

test_list = [hab_cfg, hab_cfg_mm]
for ctor_arg in test_list:
with habitat_sim.Simulator(ctor_arg) as sim:
agent_config = sim.config.agents[0]
sim.initialize_agent(0)
# cache agent initial state
initial_state = sim.agents[0].initial_state
# Take random steps in the environment
for _ in range(10):
action = random.choice(list(agent_config.action_space.keys()))
sim.step(action)

# add rigid and articulated objects
sim.metadata_mediator.ao_template_manager.load_configs(
"data/test_assets/urdf/"
)
sim.metadata_mediator.object_template_manager.load_configs(
"data/test_assets/objects/"
)

chair_handle = (
sim.metadata_mediator.object_template_manager.get_template_handles(
"chair"
)[0]
)
ao_handle = sim.metadata_mediator.ao_template_manager.get_template_handles(
"prism"
)[0]

ro = sim.get_rigid_object_manager().add_object_by_template_handle(
chair_handle
)
ao = sim.get_articulated_object_manager().add_articulated_object_by_template_handle(
ao_handle
)

assert ro is not None
assert ao is not None

# cache the initial state for verification
ao_initial_state = (
ao.transformation,
ao.joint_positions,
ao.joint_velocities,
)
ro_initial_state = ro.transformation

assert check_isclose(ro.transformation, ro_initial_state)
assert check_isclose(ao.transformation, ao_initial_state[0])
assert check_isclose(ao.joint_positions, ao_initial_state[1])
assert check_isclose(ao.joint_velocities, ao_initial_state[2])

ro.translation = mn.Vector3(1, 2, 3)
ro.rotation = mn.Quaternion.rotation(
mn.Rad(0.123), mn.Vector3(0.1, 0.2, 0.3).normalized()
)
ao.translation = mn.Vector3(3, 2, 1)
ao.rotation = mn.Quaternion.rotation(
mn.Rad(0.321), mn.Vector3(0.3, 0.2, 0.1).normalized()
)
ao.joint_positions = np.array(ao_initial_state[1]) * 0.2
ao.joint_velocities = np.ones(len(ao.joint_velocities))

assert not check_isclose(ro.transformation, ro_initial_state)
assert not check_isclose(ao.transformation, ao_initial_state[0])
assert not check_isclose(ao.joint_positions, ao_initial_state[1])
assert not check_isclose(ao.joint_velocities, ao_initial_state[2])

# do the reset
sim.reset()

# validate agent state reset
new_state = sim.agents[0].get_state()
same_position = all(initial_state.position == new_state.position)
same_rotation = np.isclose(
initial_state.rotation, new_state.rotation, rtol=1e-4
) # Numerical error can cause slight deviations
assert same_position and same_rotation
# NOTE: Numerical error can cause slight deviations, use isclose
# assert same agent position and rotation
assert check_isclose(initial_state.position, new_state.position)
assert check_isclose(initial_state.rotation, new_state.rotation)

# validate object state resets
assert check_isclose(ro.transformation, ro_initial_state)
assert check_isclose(ao.transformation, ao_initial_state[0])
assert check_isclose(ao.joint_positions, ao_initial_state[1])
assert check_isclose(ao.joint_velocities, ao_initial_state[2])


def test_sim_multiagent_move_and_reset(make_cfg_settings, num_agents=10):
Expand Down