Skip to content

Commit

Permalink
High-level API: Fix number of test episodes being incorrectly scaled …
Browse files Browse the repository at this point in the history
…by number of envs (#1071)
  • Loading branch information
opcode81 authored Mar 7, 2024
1 parent 6746a80 commit 1714c7f
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/04_contributing/05_contributors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Contributors
============

We always welcome contributions to help make Tianshou better!
Tiashou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
Tianshou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.

Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_,
which is committed to making Tianshou the go-to resource for reinforcement learning research and development,
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,4 @@ Dominik
Tsinghua
Tianshou
appliedAI
Panchenko
4 changes: 2 additions & 2 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_trainer(
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
episode_per_test=sampling_config.num_test_episodes,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
Expand Down Expand Up @@ -228,7 +228,7 @@ def create_trainer(
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
episode_per_test=sampling_config.num_test_episodes,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger,
Expand Down
9 changes: 0 additions & 9 deletions tianshou/highlevel/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import multiprocessing
from dataclasses import dataclass

Expand All @@ -9,7 +8,6 @@
class SamplingConfig(ToStringMixin):
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""

# TODO: What are the most reasonable defaults?
num_epochs: int = 100
"""
the number of epochs to run training for. An epoch is the outermost iteration level and each
Expand Down Expand Up @@ -55,8 +53,6 @@ class SamplingConfig(ToStringMixin):

num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""

buffer_size: int = 4096
Expand Down Expand Up @@ -129,8 +125,3 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()

@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)

0 comments on commit 1714c7f

Please sign in to comment.