Skip to content

Commit

Permalink
add fix in level loader and test
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 committed Jun 27, 2024
1 parent ce439db commit c25428b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 11 deletions.
19 changes: 12 additions & 7 deletions envpool/sokoban/level_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ namespace sokoban {

LevelLoader::LevelLoader(const std::filesystem::path& base_path,
bool load_sequentially, int n_levels_to_load,
int verbose)
int env_id, int num_envs, int verbose)
: load_sequentially_(load_sequentially),
n_levels_to_load_(n_levels_to_load),
cur_level_(levels_.begin()),
num_envs_(num_envs),
cur_level_(env_id),
verbose(verbose) {
if (std::filesystem::is_regular_file(base_path)) {
level_file_paths_.push_back(base_path);
Expand All @@ -49,6 +50,10 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path,
});
}
cur_file_ = level_file_paths_.begin();
if (n_levels_to_load_ > 0 && n_levels_to_load_ % num_envs_ != 0) {
throw std::runtime_error(
"n_levels_to_load must be a multiple of num_envs.");
}
}

static const std::array<char, kMaxLevelObject + 1> kPrintLevelKey{
Expand Down Expand Up @@ -183,15 +188,15 @@ std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) {
throw std::runtime_error("Loaded all requested levels.");
}
if (cur_level_ == levels_.end()) {
while (cur_level_ >= levels_.size()) {
cur_level_ -= levels_.size();
LoadFile(gen);
cur_level_ = levels_.begin();
if (cur_level_ == levels_.end()) {
if (levels_.size() == 0) {
throw std::runtime_error("No levels loaded.");
}
}
auto out = cur_level_;
cur_level_++;
auto out = levels_.begin() + cur_level_;
cur_level_ += num_envs_;
levels_loaded_++;
return out;
}
Expand Down
8 changes: 5 additions & 3 deletions envpool/sokoban/level_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class LevelLoader {
bool load_sequentially_;
int n_levels_to_load_;
int levels_loaded_{0};
int env_id_{0};
int num_envs_{1};
std::vector<SokobanLevel> levels_{0};
std::vector<SokobanLevel>::iterator cur_level_;
int cur_level_;
std::vector<std::filesystem::path> level_file_paths_{0};
std::vector<std::filesystem::path>::iterator cur_file_;
void LoadFile(std::mt19937& gen);
Expand All @@ -50,8 +52,8 @@ class LevelLoader {

std::vector<SokobanLevel>::iterator GetLevel(std::mt19937& gen);
explicit LevelLoader(const std::filesystem::path& base_path,
bool load_sequentially, int n_levels_to_load,
int verbose = 0);
bool load_sequentially, int n_levels_to_load, int env_id,
int num_envs, int verbose = 0);
};

void PrintLevel(std::ostream& os, const SokobanLevel& vec);
Expand Down
1 change: 1 addition & 0 deletions envpool/sokoban/sokoban_envpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
levels_dir_{static_cast<std::string>(spec.config["levels_dir"_])},
level_loader_(levels_dir_, spec.config["load_sequentially"_],
static_cast<int>(spec.config["n_levels_to_load"_]),
env_id, static_cast<int>(spec.config["num_envs"_]),
static_cast<int>(spec.config["verbose"_])),
world_(kWall, static_cast<std::size_t>(dim_room_ * dim_room_)),
verbose_(static_cast<int>(spec.config["verbose"_])),
Expand Down
55 changes: 54 additions & 1 deletion envpool/sokoban/sokoban_py_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import envpool # noqa: F401
import envpool.sokoban.registration
from envpool.sokoban.sokoban_envpool import _SokobanEnvSpec

from pathlib import Path
from typing import List

def test_config() -> None:
ref_config_keys = [
Expand Down Expand Up @@ -261,6 +262,58 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
_, _, term, trunc, _ = env.step(make_1d_array(wrong_action))
assert not term and not trunc, "Level should reset correctly"

def read_levels_file(fpath: Path) -> List[List[str]]:
maps = []
current_map = []
with open(fpath, "r") as sf:
for line in sf.readlines():
if ";" in line and current_map:
maps.append(current_map)
current_map = []
if "#" == line[0]:
current_map.append(line.strip())

maps.append(current_map)
return maps

def test_load_sequentially_with_multiple_envs() -> None:
levels_dir = "/app/envpool/sokoban/sample_levels"
files = glob.glob(f"{levels_dir}/*.txt")
levels_by_files = []
total_levels, num_envs = 8, 2
for file in sorted(files):
levels = read_levels_file(file)
levels_by_files.extend(levels)
assert len(levels_by_files) == total_levels, "8 levels stored in files."

env = envpool.make(
"Sokoban-v0",
env_type="gymnasium",
num_envs=num_envs,
batch_size=num_envs,
max_episode_steps=60,
min_episode_steps=60,
levels_dir=levels_dir,
load_sequentially=True,
n_levels_to_load=total_levels,
verbose=2,
)
dim_room = env.spec.config.dim_room
printed_obs = []
for _ in range(total_levels // num_envs):
obs, _ = env.reset()
assert obs.shape == (
num_envs,
3,
dim_room,
dim_room,
), f"obs shape: {obs.shape}"
for idx in range(num_envs):
printed_obs.append(print_obs(obs[idx]))
for i, level in enumerate(levels_by_files):
for j, line in enumerate(level):
assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly."


def test_astar_log(tmp_path) -> None:
level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"
Expand Down

0 comments on commit c25428b

Please sign in to comment.