diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 29f12588d..2ea47a77d 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -325,7 +325,7 @@ def __init__( self._demo_data_loader: Optional[Iterable[types.TransitionMapping]] = None self.batch_size = batch_size self.minibatch_size = minibatch_size or batch_size - if self.batch_size % self.minibatch_size != 0: + if self.batch_size % self.minibatch_size != 0: # pragma: no cover raise ValueError("Batch size must be a multiple of minibatch size.") super().__init__( demonstrations=demonstrations, @@ -358,7 +358,7 @@ def __init__( assert self.policy.action_space == self.action_space if optimizer_kwargs: - if "weight_decay" in optimizer_kwargs: + if "weight_decay" in optimizer_kwargs: # pragma: no cover raise ValueError("Use the parameter l2_weight instead of weight_decay.") optimizer_kwargs = optimizer_kwargs or {} self.optimizer = optimizer_cls( diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 8de49c66e..21349d72a 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -78,7 +78,6 @@ def make_bc_train_args( rng=rngs, ) batch_sizes = st.integers(min_value=1, max_value=50) -loggers = st.sampled_from([None, logger.configure()]) expert_data_types = st.sampled_from( ["data_loader", "ducktyped_data_loader", "transitions"], ) @@ -95,17 +94,17 @@ def make_bc_train_args( log_rollouts_venv=st.one_of(rollout_envs, st.just(None)), ) bc_args = st.builds( - lambda env, batch_size, custom_logger, rng: dict( + lambda env, minibatch_size, rng, minibatch_fraction: dict( observation_space=env.observation_space, action_space=env.action_space, - batch_size=batch_size, - custom_logger=custom_logger, + batch_size=minibatch_size * minibatch_fraction, + minibatch_size=minibatch_size, rng=rng, ), env=envs, - batch_size=batch_sizes, - custom_logger=loggers, + minibatch_size=batch_sizes, rng=rngs, + minibatch_fraction=st.integers(1, 10), ) @@ -136,7 +135,7 @@ def test_smoke_bc_creation( **bc_args, demonstrations=make_expert_transition_loader( cache_dir=cache.mkdir("experts"), - batch_size=bc_args["batch_size"], + batch_size=bc_args["minibatch_size"], expert_data_type=expert_data_type, env_name=env_name, rng=rng, @@ -152,7 +151,21 @@ def test_smoke_bc_creation( expert_data_type=expert_data_types, rng=rngs, ) -@hypothesis.settings(deadline=20000, max_examples=15) +@hypothesis.settings( + deadline=20000, + max_examples=15, + # TODO: one day consider removing this. For now we are good. + # Note: Hypothesis automatically generates input examples. The "size" of + # the examples is determined by the number of decisions it has to make when + # generating each example. E.g. a list of 100 random integers has a size of 100 but + # choosing between one of three different lists of length 100 has a size of 1. + # If the number of choices becomes too large we risk not properly covering the + # search space and hypothesis will complain. In this particular case we are not + # too concerned with covering the entire search space so we suppress the warning. + # Read me for more info: + # https://hypothesis.readthedocs.io/en/latest/settings.html#hypothesis.HealthCheck.data_too_large + suppress_health_check=[hypothesis.HealthCheck.data_too_large], +) def test_smoke_bc_training( env_name: str, bc_args: dict, @@ -168,7 +181,7 @@ def test_smoke_bc_training( **bc_args, demonstrations=make_expert_transition_loader( cache_dir=cache.mkdir("experts"), - batch_size=bc_args["batch_size"], + batch_size=bc_args["minibatch_size"], expert_data_type=expert_data_type, env_name=env_name, rng=rng, @@ -246,7 +259,6 @@ def make_trainer(**kwargs: Any) -> bc.BC: action_space=cartpole_venv.action_space, batch_size=batch_size, demonstrations=demonstrations, - custom_logger=None, rng=rng, **kwargs, )