From c25428b1bed59b5617834141c6c75819431f16ec Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Fri, 28 Jun 2024 03:34:24 +0530 Subject: [PATCH] add fix in level loader and test --- envpool/sokoban/level_loader.cc | 19 +++++--- envpool/sokoban/level_loader.h | 8 ++-- envpool/sokoban/sokoban_envpool.h | 1 + envpool/sokoban/sokoban_py_envpool_test.py | 55 +++++++++++++++++++++- 4 files changed, 72 insertions(+), 11 deletions(-) diff --git a/envpool/sokoban/level_loader.cc b/envpool/sokoban/level_loader.cc index 56c6fdba..faed8197 100644 --- a/envpool/sokoban/level_loader.cc +++ b/envpool/sokoban/level_loader.cc @@ -29,10 +29,11 @@ namespace sokoban { LevelLoader::LevelLoader(const std::filesystem::path& base_path, bool load_sequentially, int n_levels_to_load, - int verbose) + int env_id, int num_envs, int verbose) : load_sequentially_(load_sequentially), n_levels_to_load_(n_levels_to_load), - cur_level_(levels_.begin()), + num_envs_(num_envs), + cur_level_(env_id), verbose(verbose) { if (std::filesystem::is_regular_file(base_path)) { level_file_paths_.push_back(base_path); @@ -49,6 +50,10 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path, }); } cur_file_ = level_file_paths_.begin(); + if (n_levels_to_load_ > 0 && n_levels_to_load_ % num_envs_ != 0) { + throw std::runtime_error( + "n_levels_to_load must be a multiple of num_envs."); + } } static const std::array kPrintLevelKey{ @@ -183,15 +188,15 @@ std::vector::iterator LevelLoader::GetLevel(std::mt19937& gen) { if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) { throw std::runtime_error("Loaded all requested levels."); } - if (cur_level_ == levels_.end()) { + while (cur_level_ >= levels_.size()) { + cur_level_ -= levels_.size(); LoadFile(gen); - cur_level_ = levels_.begin(); - if (cur_level_ == levels_.end()) { + if (levels_.size() == 0) { throw std::runtime_error("No levels loaded."); } } - auto out = cur_level_; - cur_level_++; + auto out = levels_.begin() + cur_level_; + cur_level_ += num_envs_; levels_loaded_++; return out; } diff --git a/envpool/sokoban/level_loader.h b/envpool/sokoban/level_loader.h index ced5e60a..d8a07c16 100644 --- a/envpool/sokoban/level_loader.h +++ b/envpool/sokoban/level_loader.h @@ -39,8 +39,10 @@ class LevelLoader { bool load_sequentially_; int n_levels_to_load_; int levels_loaded_{0}; + int env_id_{0}; + int num_envs_{1}; std::vector levels_{0}; - std::vector::iterator cur_level_; + int cur_level_; std::vector level_file_paths_{0}; std::vector::iterator cur_file_; void LoadFile(std::mt19937& gen); @@ -50,8 +52,8 @@ class LevelLoader { std::vector::iterator GetLevel(std::mt19937& gen); explicit LevelLoader(const std::filesystem::path& base_path, - bool load_sequentially, int n_levels_to_load, - int verbose = 0); + bool load_sequentially, int n_levels_to_load, int env_id, + int num_envs, int verbose = 0); }; void PrintLevel(std::ostream& os, const SokobanLevel& vec); diff --git a/envpool/sokoban/sokoban_envpool.h b/envpool/sokoban/sokoban_envpool.h index f0138b20..d2cd597d 100644 --- a/envpool/sokoban/sokoban_envpool.h +++ b/envpool/sokoban/sokoban_envpool.h @@ -70,6 +70,7 @@ class SokobanEnv : public Env { levels_dir_{static_cast(spec.config["levels_dir"_])}, level_loader_(levels_dir_, spec.config["load_sequentially"_], static_cast(spec.config["n_levels_to_load"_]), + env_id, static_cast(spec.config["num_envs"_]), static_cast(spec.config["verbose"_])), world_(kWall, static_cast(dim_room_ * dim_room_)), verbose_(static_cast(spec.config["verbose"_])), diff --git a/envpool/sokoban/sokoban_py_envpool_test.py b/envpool/sokoban/sokoban_py_envpool_test.py index 198ff34a..969eb871 100644 --- a/envpool/sokoban/sokoban_py_envpool_test.py +++ b/envpool/sokoban/sokoban_py_envpool_test.py @@ -25,7 +25,8 @@ import envpool # noqa: F401 import envpool.sokoban.registration from envpool.sokoban.sokoban_envpool import _SokobanEnvSpec - +from pathlib import Path +from typing import List def test_config() -> None: ref_config_keys = [ @@ -261,6 +262,58 @@ def test_solved_level_does_not_truncate(solve_on_time: bool): _, _, term, trunc, _ = env.step(make_1d_array(wrong_action)) assert not term and not trunc, "Level should reset correctly" +def read_levels_file(fpath: Path) -> List[List[str]]: + maps = [] + current_map = [] + with open(fpath, "r") as sf: + for line in sf.readlines(): + if ";" in line and current_map: + maps.append(current_map) + current_map = [] + if "#" == line[0]: + current_map.append(line.strip()) + + maps.append(current_map) + return maps + +def test_load_sequentially_with_multiple_envs() -> None: + levels_dir = "/app/envpool/sokoban/sample_levels" + files = glob.glob(f"{levels_dir}/*.txt") + levels_by_files = [] + total_levels, num_envs = 8, 2 + for file in sorted(files): + levels = read_levels_file(file) + levels_by_files.extend(levels) + assert len(levels_by_files) == total_levels, "8 levels stored in files." + + env = envpool.make( + "Sokoban-v0", + env_type="gymnasium", + num_envs=num_envs, + batch_size=num_envs, + max_episode_steps=60, + min_episode_steps=60, + levels_dir=levels_dir, + load_sequentially=True, + n_levels_to_load=total_levels, + verbose=2, + ) + dim_room = env.spec.config.dim_room + printed_obs = [] + for _ in range(total_levels // num_envs): + obs, _ = env.reset() + assert obs.shape == ( + num_envs, + 3, + dim_room, + dim_room, + ), f"obs shape: {obs.shape}" + for idx in range(num_envs): + printed_obs.append(print_obs(obs[idx])) + for i, level in enumerate(levels_by_files): + for j, line in enumerate(level): + assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly." + def test_astar_log(tmp_path) -> None: level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"