diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index c4e6affc..8bbbea2c 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -24,7 +24,7 @@ namespace sokoban { -void SokobanEnv::Reset() { +void SokobanEnv::ResetWithoutWrite() { const int max_episode_steps = spec_.config["max_episode_steps"_]; const int min_episode_steps = spec_.config["min_episode_steps"_]; current_max_episode_steps_ = @@ -52,6 +52,10 @@ void SokobanEnv::Reset() { } } current_step_ = 0; +} + +void SokobanEnv::Reset() { + ResetWithoutWrite(); WriteState(0.0f); } @@ -142,6 +146,7 @@ void SokobanEnv::Step(const Action& action_dict) { reward_box_ * static_cast(prev_unmatched_boxes - unmatched_boxes_) + ((unmatched_boxes_ == 0) ? reward_finished_ : 0.0f); + WriteState(static_cast(reward)); } @@ -177,6 +182,12 @@ void SokobanEnv::WriteState(float reward) { throw std::runtime_error(msg.str()); } + if (IsDone()) { + // If this episode truncates or terminates, the observation should be the + // one for the next episode. + ResetWithoutWrite(); + } + std::vector out(3 * world_.size()); for (int rgb = 0; rgb < 3; rgb++) { for (size_t i = 0; i < world_.size(); i++) { diff --git a/envpool/sokoban/sokoban_envpool.h b/envpool/sokoban/sokoban_envpool.h index a76951ef..f0138b20 100644 --- a/envpool/sokoban/sokoban_envpool.h +++ b/envpool/sokoban/sokoban_envpool.h @@ -115,6 +115,7 @@ class SokobanEnv : public Env { [[nodiscard]] uint8_t WorldAt(int x, int y) const; void WorldAssignAt(int x, int y, uint8_t value); + void ResetWithoutWrite(); }; using SokobanEnvPool = AsyncEnvPool;