Skip to content

Commit

Permalink
Apply sensai-utils, removing copied code (#1199)
Browse files Browse the repository at this point in the history
Remove the code copied from sensAI, using the new library sensAI-utils
instead (which does not add any additional dependencies to Tianshou)

- [X] I have added the correct label(s) to this Pull Request or linked
the relevant issue(s)
- [X] I have provided a description of the changes in this Pull Request
- [X] I have added documentation for my changes and have listed relevant
changes in CHANGELOG.md
- [X] If applicable, I have added tests to cover my changes.
- [X] I have reformatted the code using `poe format` 
- [X] I have checked style and types with `poe lint` and `poe
type-check`
- [ ] (Optional) I ran tests locally with `poe test` 
(or a subset of them with `poe test-reduced`) ,and they pass
- [ ] (Optional) I have tested that documentation builds correctly with
`poe doc-build`
  • Loading branch information
MischaPanch authored Aug 10, 2024
2 parents d330056 + fdfefb8 commit f5d2ae6
Show file tree
Hide file tree
Showing 43 changed files with 95 additions and 862 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
- Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- Wandb logger extended #1183
- `utils`:
- `utils`:
- `net.continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without
Expand Down Expand Up @@ -104,7 +104,12 @@ instead of just `nn.Module`. #1032
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- `utils`:
- Modules with code that was copied from sensAI have been replaced by imports from new dependency sensAI-utils:
- `tianshou.utils.logging` is replaced with `sensai.util.logging`
- `tianshou.utils.string` is replaced with `sensai.util.string`
- `tianshou.utils.pickle` is replaced with `sensai.util.pickle`
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074
- `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074
- `highlevel`:
Expand Down
5 changes: 3 additions & 2 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import os

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
Expand All @@ -20,8 +23,6 @@
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/atari/atari_iqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from collections.abc import Sequence

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
Expand All @@ -17,8 +20,6 @@
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/atari/atari_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from collections.abc import Sequence

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
Expand All @@ -18,8 +21,6 @@
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/atari/atari_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from collections.abc import Sequence

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
Expand All @@ -18,8 +21,6 @@
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
3 changes: 2 additions & 1 deletion examples/discrete/discrete_dqn_hl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sensai.util.logging import run_main

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import (
EnvFactoryRegistered,
Expand All @@ -10,7 +12,6 @@
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNSetEps,
)
from tianshou.utils.logging import run_main


def main() -> None:
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_a2c_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections.abc import Sequence
from typing import Literal

from sensai.util import logging
from sensai.util.logging import datetime_tag
from torch import nn

from examples.mujoco.mujoco_env import MujocoEnvFactory
Expand All @@ -15,8 +17,6 @@
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import A2CParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_ddpg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from collections.abc import Sequence

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
Expand All @@ -11,8 +14,6 @@
)
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Literal

import torch
from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
Expand All @@ -14,8 +16,6 @@
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import NPGParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Literal

import torch
from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
Expand All @@ -14,8 +16,6 @@
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import warnings

import torch
from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.evaluation.launcher import RegisteredExpLauncher
Expand All @@ -28,8 +30,6 @@
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag

log = logging.getLogger(__name__)

Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_redq_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from collections.abc import Sequence
from typing import Literal

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
Expand All @@ -12,8 +15,6 @@
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import REDQParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_reinforce_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Literal

import torch
from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
Expand All @@ -14,8 +16,6 @@
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PGParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from collections.abc import Sequence

from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
Expand All @@ -11,8 +14,6 @@
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_td3_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections.abc import Sequence

import torch
from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
Expand All @@ -16,8 +18,6 @@
MaxActionScaledGaussian,
)
from tianshou.highlevel.params.policy_params import TD3Params
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_trpo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Literal

import torch
from sensai.util import logging
from sensai.util.logging import datetime_tag

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
Expand All @@ -14,8 +16,6 @@
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import TRPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
Expand Down
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ overrides = "^7.4.0"
packaging = "*"
pandas = ">=2.0.0"
pettingzoo = "^1.22"
sensai-utils = "^1.2.1"
tensorboard = "^2.5.0"
# Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974
torch = "^2.0.0, !=2.0.1, !=2.1.0"
Expand Down
3 changes: 1 addition & 2 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@
import pandas as pd
import torch
from deepdiff import DeepDiff
from sensai.util import logging
from torch.distributions import Categorical, Distribution, Independent, Normal

from tianshou.utils import logging

_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | Sequence[_SingleIndexType]
TBatch = TypeVar("TBatch", bound="BatchProtocol")
Expand Down
3 changes: 2 additions & 1 deletion tianshou/evaluation/rliable_evaluation_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils
from sensai.util import logging

from tianshou.highlevel.experiment import Experiment
from tianshou.utils import TensorboardLogger, logging
from tianshou.utils import TensorboardLogger
from tianshou.utils.logger.base import DataScope

log = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Generic, TypeVar, cast

import gymnasium
from sensai.util.string import ToStringMixin

from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.data.collector import BaseCollector
Expand Down Expand Up @@ -60,7 +61,6 @@
)
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin

CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
Expand Down
2 changes: 1 addition & 1 deletion tianshou/highlevel/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import multiprocessing
from dataclasses import dataclass

from tianshou.utils.string import ToStringMixin
from sensai.util.string import ToStringMixin


@dataclass
Expand Down
Loading

0 comments on commit f5d2ae6

Please sign in to comment.