From 1abbeb821b1ce4426f3ee2a552f094a2df0cff4e Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Tue, 28 May 2024 04:55:55 +0530 Subject: [PATCH 1/4] fix delayed reset bug --- envpool/sokoban/sokoban_envpool.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index c4e6affc..c7e6219b 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -142,7 +142,11 @@ void SokobanEnv::Step(const Action& action_dict) { reward_box_ * static_cast(prev_unmatched_boxes - unmatched_boxes_) + ((unmatched_boxes_ == 0) ? reward_finished_ : 0.0f); - WriteState(static_cast(reward)); + if (IsDone()) { + Reset(); + } else { + WriteState(static_cast(reward)); + } } constexpr std::array, kPlayerOnTarget + 1> kTinyColors{{ From 268c93dfdda32fedb9dc671a84617c4de1985577 Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Tue, 28 May 2024 05:49:18 +0530 Subject: [PATCH 2/4] update the reset function --- envpool/sokoban/sokoban_envpool.cc | 17 +++++++++++------ envpool/sokoban/sokoban_envpool.h | 1 + 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index c7e6219b..5e5db8c6 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -24,7 +24,7 @@ namespace sokoban { -void SokobanEnv::Reset() { +void SokobanEnv::ResetWithoutWrite() { const int max_episode_steps = spec_.config["max_episode_steps"_]; const int min_episode_steps = spec_.config["min_episode_steps"_]; current_max_episode_steps_ = @@ -52,6 +52,10 @@ void SokobanEnv::Reset() { } } current_step_ = 0; +} + +void SokobanEnv::Reset() { + ResetWithoutWrite(); WriteState(0.0f); } @@ -142,11 +146,8 @@ void SokobanEnv::Step(const Action& action_dict) { reward_box_ * static_cast(prev_unmatched_boxes - unmatched_boxes_) + ((unmatched_boxes_ == 0) ? reward_finished_ : 0.0f); - if (IsDone()) { - Reset(); - } else { - WriteState(static_cast(reward)); - } + + WriteState(static_cast(reward)); } constexpr std::array, kPlayerOnTarget + 1> kTinyColors{{ @@ -181,6 +182,10 @@ void SokobanEnv::WriteState(float reward) { throw std::runtime_error(msg.str()); } + if (IsDone()) { + ResetWithoutWrite(); + } + std::vector out(3 * world_.size()); for (int rgb = 0; rgb < 3; rgb++) { for (size_t i = 0; i < world_.size(); i++) { diff --git a/envpool/sokoban/sokoban_envpool.h b/envpool/sokoban/sokoban_envpool.h index a76951ef..f0138b20 100644 --- a/envpool/sokoban/sokoban_envpool.h +++ b/envpool/sokoban/sokoban_envpool.h @@ -115,6 +115,7 @@ class SokobanEnv : public Env { [[nodiscard]] uint8_t WorldAt(int x, int y) const; void WorldAssignAt(int x, int y, uint8_t value); + void ResetWithoutWrite(); }; using SokobanEnvPool = AsyncEnvPool; From 4098670966fa1e2b886943434c3ad119166102bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Mon, 27 May 2024 22:38:52 -0400 Subject: [PATCH 3/4] Explain why ResetWithoutWrite with comment --- envpool/sokoban/sokoban_envpool.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index 5e5db8c6..94e874b6 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -183,6 +183,7 @@ void SokobanEnv::WriteState(float reward) { } if (IsDone()) { + // If this episode truncates or terminates, the observation should be the one for the next episode. ResetWithoutWrite(); } From 6b1b577d883ec50acebda31a5166906f3b449f81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Mon, 27 May 2024 23:45:06 -0400 Subject: [PATCH 4/4] Cap line to 80 --- envpool/sokoban/sokoban_envpool.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index 94e874b6..8bbbea2c 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -183,7 +183,8 @@ void SokobanEnv::WriteState(float reward) { } if (IsDone()) { - // If this episode truncates or terminates, the observation should be the one for the next episode. + // If this episode truncates or terminates, the observation should be the + // one for the next episode. ResetWithoutWrite(); }