diff --git a/CHANGELOG.md b/CHANGELOG.md index aa29e70a6..7e12b0991 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ - Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 - Wandb logger extended #1183 -- `utils`: +- `utils`: - `net.continuous.Critic`: - Add flag `apply_preprocess_net_to_obs_only` to allow the preprocessing network to be applied to the observations only (without @@ -104,7 +104,12 @@ instead of just `nn.Module`. #1032 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 -- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 +- `utils`: + - Modules with code that was copied from sensAI have been replaced by imports from new dependency sensAI-utils: + - `tianshou.utils.logging` is replaced with `sensai.util.logging` + - `tianshou.utils.string` is replaced with `sensai.util.string` + - `tianshou.utils.pickle` is replaced with `sensai.util.pickle` + - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 - `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074 - `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074 - `highlevel`: diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index aa76983be..601481523 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -2,6 +2,9 @@ import os +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -20,8 +23,6 @@ EpochTestCallbackDQNSetEps, EpochTrainCallbackDQNEpsLinearDecay, ) -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 23df1cd25..b71b0eef3 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -3,6 +3,9 @@ import os from collections.abc import Sequence +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) @@ -17,8 +20,6 @@ EpochTestCallbackDQNSetEps, EpochTrainCallbackDQNEpsLinearDecay, ) -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 10dcd0a7e..26ebaba08 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -3,6 +3,9 @@ import os from collections.abc import Sequence +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -18,8 +21,6 @@ from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index cf09b40ea..124def768 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -3,6 +3,9 @@ import os from collections.abc import Sequence +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -18,8 +21,6 @@ from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index eacf4c78f..464d06e71 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -1,3 +1,5 @@ +from sensai.util.logging import run_main + from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ( EnvFactoryRegistered, @@ -10,7 +12,6 @@ EpochTestCallbackDQNSetEps, EpochTrainCallbackDQNSetEps, ) -from tianshou.utils.logging import run_main def main() -> None: diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index bce02e9c0..c804d6c26 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -4,6 +4,8 @@ from collections.abc import Sequence from typing import Literal +from sensai.util import logging +from sensai.util.logging import datetime_tag from torch import nn from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -15,8 +17,6 @@ from tianshou.highlevel.optim import OptimizerFactoryRMSprop from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import A2CParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index db9c4e3e2..27dbfc8d9 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -3,6 +3,9 @@ import os from collections.abc import Sequence +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( @@ -11,8 +14,6 @@ ) from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.policy_params import DDPGParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 231e735c9..387f87c6e 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -5,6 +5,8 @@ from typing import Literal import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -14,8 +16,6 @@ ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import NPGParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index af0c5ab8f..b10d4cf26 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -5,6 +5,8 @@ from typing import Literal import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -14,8 +16,6 @@ ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index c9d5a8fde..333870809 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -16,6 +16,8 @@ import warnings import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.evaluation.launcher import RegisteredExpLauncher @@ -28,8 +30,6 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag log = logging.getLogger(__name__) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index f52372906..c0c63279a 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -4,6 +4,9 @@ from collections.abc import Sequence from typing import Literal +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( @@ -12,8 +15,6 @@ ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import REDQParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 46eb64fa2..59a600568 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -5,6 +5,8 @@ from typing import Literal import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -14,8 +16,6 @@ ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PGParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 5ca731868..a150f5571 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -3,6 +3,9 @@ import os from collections.abc import Sequence +from sensai.util import logging +from sensai.util.logging import datetime_tag + from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( @@ -11,8 +14,6 @@ ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import SACParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 3a32c7f42..5ec9cc17b 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -4,6 +4,8 @@ from collections.abc import Sequence import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -16,8 +18,6 @@ MaxActionScaledGaussian, ) from tianshou.highlevel.params.policy_params import TD3Params -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 59929529e..1ec26bad2 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -5,6 +5,8 @@ from typing import Literal import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -14,8 +16,6 @@ ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import TRPOParams -from tianshou.utils import logging -from tianshou.utils.logging import datetime_tag def main( diff --git a/poetry.lock b/poetry.lock index 2c62d0461..cdae2222c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -5324,6 +5324,20 @@ nativelib = ["pyobjc-framework-Cocoa", "pywin32"] objc = ["pyobjc-framework-Cocoa"] win32 = ["pywin32"] +[[package]] +name = "sensai-utils" +version = "1.2.1" +description = "Utilities from sensAI, the Python library for sensible AI" +optional = false +python-versions = "*" +files = [ + {file = "sensai_utils-1.2.1-py3-none-any.whl", hash = "sha256:222e60d9f9d371c9d62ffcd1e6def1186f0d5243588b0b5af57e983beecc95bb"}, + {file = "sensai_utils-1.2.1.tar.gz", hash = "sha256:4d8ca94179931798cef5f920fb042cbf9e7d806c0026b02afb58d0f72211bf27"}, +] + +[package.dependencies] +typing-extensions = ">=4.6" + [[package]] name = "sentry-sdk" version = "2.8.0" @@ -6826,4 +6840,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "bb0adb17849ce0ea79ab46c0d7f1d3fbfd92cc414c3bd90ce37200413c3e74a9" +content-hash = "200077246f10046fe1d0494977e5565420e0c166ef905a1d22608e84fcfb3459" diff --git a/pyproject.toml b/pyproject.toml index 48aee5fcf..66c0740ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ overrides = "^7.4.0" packaging = "*" pandas = ">=2.0.0" pettingzoo = "^1.22" +sensai-utils = "^1.2.1" tensorboard = "^2.5.0" # Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974 torch = "^2.0.0, !=2.0.1, !=2.1.0" diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4ea99dbcb..38c7d0465 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -64,10 +64,9 @@ import pandas as pd import torch from deepdiff import DeepDiff +from sensai.util import logging from torch.distributions import Categorical, Distribution, Independent, Normal -from tianshou.utils import logging - _SingleIndexType = slice | int | EllipsisType IndexType = np.ndarray | _SingleIndexType | Sequence[_SingleIndexType] TBatch = TypeVar("TBatch", bound="BatchProtocol") diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index e5fdf8eb4..2b8ff5131 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -11,9 +11,10 @@ import scipy.stats as sst from rliable import library as rly from rliable import plot_utils +from sensai.util import logging from tianshou.highlevel.experiment import Experiment -from tianshou.utils import TensorboardLogger, logging +from tianshou.utils import TensorboardLogger from tianshou.utils.logger.base import DataScope log = logging.getLogger(__name__) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 03b4e1463..81141a8a6 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -4,6 +4,7 @@ from typing import Any, Generic, TypeVar, cast import gymnasium +from sensai.util.string import ToStringMixin from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data.collector import BaseCollector @@ -60,7 +61,6 @@ ) from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net.common import ActorCritic -from tianshou.utils.string import ToStringMixin CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 951f2f3af..1d704cfc3 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,7 +1,7 @@ import multiprocessing from dataclasses import dataclass -from tianshou.utils.string import ToStringMixin +from sensai.util.string import ToStringMixin @dataclass diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 3cb9fadb2..b69e9c0b8 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -8,6 +8,7 @@ import gymnasium as gym import gymnasium.spaces from gymnasium import Env +from sensai.util.string import ToStringMixin from tianshou.env import ( BaseVectorEnv, @@ -17,7 +18,6 @@ ) from tianshou.highlevel.persistence import Persistence from tianshou.utils.net.common import TActionShape -from tianshou.utils.string import ToStringMixin TObservationShape: TypeAlias = int | Sequence[int] diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 8a08f9a8f..74413997b 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -32,6 +32,9 @@ import numpy as np import torch +from sensai.util import logging +from sensai.util.logging import datetime_tag +from sensai.util.string import ToStringMixin from tianshou.data import Collector, InfoStats from tianshou.env import BaseVectorEnv @@ -105,10 +108,8 @@ ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.utils import LazyLogger, logging -from tianshou.utils.logging import datetime_tag +from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType -from tianshou.utils.string import ToStringMixin from tianshou.utils.warning import deprecation log = logging.getLogger(__name__) @@ -315,7 +316,7 @@ def create_experiment_world( full_config["experiment_config"] = asdict(self.config) full_config["sampling_config"] = asdict(self.sampling_config) with suppress(AttributeError): - full_config["policy_params"] = asdict(self.agent_factory.params) # type: ignore + full_config["policy_params"] = asdict(self.agent_factory.params) logger: TLogger if use_persistence: diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 39b5fc2e7..f0e3e59d9 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -2,10 +2,10 @@ from abc import ABC, abstractmethod from typing import Literal, TypeAlias +from sensai.util.string import ToStringMixin from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger -from tianshou.utils.string import ToStringMixin TLogger: TypeAlias = BaseLogger diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 5c1387e35..4a1fe5c2e 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -5,6 +5,7 @@ from typing import Protocol import torch +from sensai.util.string import ToStringMixin from torch import nn from tianshou.highlevel.env import Environments, EnvType @@ -26,7 +27,6 @@ from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import BaseActor, ModuleType, Net -from tianshou.utils.string import ToStringMixin class ContinuousActorType(Enum): diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 9e0fd070b..0352fd132 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -2,6 +2,7 @@ from collections.abc import Sequence import numpy as np +from sensai.util.string import ToStringMixin from torch import nn from tianshou.highlevel.env import Environments, EnvType @@ -11,7 +12,6 @@ from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net -from tianshou.utils.string import ToStringMixin class CriticFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/module/intermediate.py b/tianshou/highlevel/module/intermediate.py index a008935af..62bf3843f 100644 --- a/tianshou/highlevel/module/intermediate.py +++ b/tianshou/highlevel/module/intermediate.py @@ -2,10 +2,10 @@ from dataclasses import dataclass import torch +from sensai.util.string import ToStringMixin from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice -from tianshou.utils.string import ToStringMixin @dataclass diff --git a/tianshou/highlevel/module/special.py b/tianshou/highlevel/module/special.py index 8c3c568f0..de572d7a1 100644 --- a/tianshou/highlevel/module/special.py +++ b/tianshou/highlevel/module/special.py @@ -1,10 +1,11 @@ from collections.abc import Sequence +from sensai.util.string import ToStringMixin + from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.utils.net.discrete import ImplicitQuantileNetwork -from tianshou.utils.string import ToStringMixin class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index bdb01fbf0..d480978fb 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -3,10 +3,9 @@ from typing import Any, Protocol, TypeAlias import torch +from sensai.util.string import ToStringMixin from torch.optim import Adam, RMSprop -from tianshou.utils.string import ToStringMixin - TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 4e8490de8..2b662eb44 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -2,11 +2,11 @@ import numpy as np import torch +from sensai.util.string import ToStringMixin from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.optim import OptimizerFactory -from tianshou.utils.string import ToStringMixin class AutoAlphaFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index d28a2166c..6cb436185 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -3,10 +3,10 @@ from typing import Any import torch +from sensai.util.string import ToStringMixin from tianshou.highlevel.env import Environments from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont -from tianshou.utils.string import ToStringMixin class DistributionFunctionFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/env_param.py b/tianshou/highlevel/params/env_param.py index 8696518dc..2b444bbd6 100644 --- a/tianshou/highlevel/params/env_param.py +++ b/tianshou/highlevel/params/env_param.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar +from sensai.util.string import ToStringMixin + from tianshou.highlevel.env import ContinuousEnvironments, Environments -from tianshou.utils.string import ToStringMixin TValue = TypeVar("TValue") TEnvs = TypeVar("TEnvs", bound=Environments) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 0b0cf359a..6a2778bab 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -2,10 +2,10 @@ import numpy as np import torch +from sensai.util.string import ToStringMixin from torch.optim.lr_scheduler import LambdaLR, LRScheduler from tianshou.highlevel.config import SamplingConfig -from tianshou.utils.string import ToStringMixin class LRSchedulerFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/noise.py b/tianshou/highlevel/params/noise.py index ca4ca1fdd..fce3dba52 100644 --- a/tianshou/highlevel/params/noise.py +++ b/tianshou/highlevel/params/noise.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod +from sensai.util.string import ToStringMixin + from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.highlevel.env import ContinuousEnvironments, Environments -from tianshou.utils.string import ToStringMixin class NoiseFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 5f521d418..d20bbe44b 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -4,6 +4,8 @@ from typing import Any, Literal, Protocol import torch +from sensai.util.pickle import setstate +from sensai.util.string import ToStringMixin from torch.optim.lr_scheduler import LRScheduler from tianshou.exploration import BaseNoise @@ -16,8 +18,6 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory from tianshou.utils import MultipleLRSchedulers -from tianshou.utils.pickle import setstate -from tianshou.utils.string import ToStringMixin @dataclass(kw_only=True) diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index e7958224d..43cbfed1e 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -2,13 +2,14 @@ from collections.abc import Sequence from typing import Generic, TypeVar +from sensai.util.string import ToStringMixin + from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactory from tianshou.policy import BasePolicy, ICMPolicy from tianshou.utils.net.discrete import IntrinsicCuriosityModule -from tianshou.utils.string import ToStringMixin TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 1083017d4..498cc3173 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -4,10 +4,11 @@ from dataclasses import dataclass from typing import TypeVar, cast +from sensai.util.string import ToStringMixin + from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import BasePolicy, DQNPolicy -from tianshou.utils.string import ToStringMixin TPolicy = TypeVar("TPolicy", bound=BasePolicy) log = logging.getLogger(__name__) diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index b2eaf3ffc..f9ab4066e 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -1,23 +1,4 @@ -""" -Partial copy of sensai.util.logging -""" -# ruff: noqa -import atexit -import logging as lg -import sys -from collections.abc import Callable -from datetime import datetime -from io import StringIO -from logging import * -from typing import Any, TypeVar, cast - -log = getLogger(__name__) # type: ignore - -LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s" - -# Holds the log format that is configured by the user (using function `configure`), such -# that it can be reused in other places -_logFormat = LOG_DEFAULT_FORMAT +from typing import Any def set_numerical_fields_to_precision(data: dict[str, Any], precision: int = 3) -> dict[str, Any]: @@ -34,150 +15,3 @@ def set_numerical_fields_to_precision(data: dict[str, Any], precision: int = 3) v = round(v, precision) result[k] = v return result - - -def remove_log_handlers() -> None: - """Removes all current log handlers.""" - logger = getLogger() - while logger.hasHandlers(): - logger.removeHandler(logger.handlers[0]) - - -def remove_log_handler(handler: Handler) -> None: - getLogger().removeHandler(handler) - - -def is_log_handler_active(handler: Handler) -> bool: - """Checks whether the given handler is active. - - :param handler: a log handler - :return: True if the handler is active, False otherwise - """ - return handler in getLogger().handlers - - -# noinspection PyShadowingBuiltins -def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None: - """Configures logging to stdout with the given format and log level, - also configuring the default log levels of some overly verbose libraries as well as some pandas output options. - - :param format: the log format - :param level: the minimum log level - """ - global _logFormat - _logFormat = format - remove_log_handlers() - basicConfig(level=level, format=format, stream=sys.stdout) - # set log levels of third-party libraries - getLogger("numba").setLevel(INFO) - - -T = TypeVar("T") - - -# noinspection PyShadowingBuiltins -def run_main( - main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG -) -> T | None: - """Configures logging with the given parameters, ensuring that any exceptions that occur during - the execution of the given function are logged. - Logs two additional messages, one before the execution of the function, and one upon its completion. - - :param main_fn: the function to be executed - :param format: the log message format - :param level: the minimum log level - :return: the result of `main_fn` - """ - configure(format=format, level=level) - log.info("Starting") # type: ignore - try: - result = main_fn() - log.info("Done") # type: ignore - return result - except Exception as e: - log.error("Exception during script execution", exc_info=e) # type: ignore - return None - - -def run_cli( - main_fn: Callable[..., T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG -) -> T | None: - """ - Configures logging with the given parameters and runs the given main function as a - CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such - that dataclasses can be used as function arguments). - Using this function requires that `jsonargparse` and `docstring_parser` be available. - Like `run_main`, two additional log messages will be logged (at the beginning and end - of the execution), and it is ensured that all exceptions will be logged. - - :param main_fn: the function to be executed - :param format: the log message format - :param level: the minimum log level - :return: the result of `main_fn` - """ - from jsonargparse import set_docstring_parse_options, CLI - - set_docstring_parse_options(attribute_docstrings=True) - return run_main(lambda: CLI(main_fn), format=format, level=level) - - -def datetime_tag() -> str: - """:return: a string tag for use in log file names which contains the current date and time (compact but readable)""" - return datetime.now().strftime("%Y%m%d-%H%M%S") - - -_fileLoggerPaths: list[str] = [] -_isAtExitReportFileLoggerRegistered = False -_memoryLogStream: StringIO | None = None - - -def _at_exit_report_file_logger() -> None: - for path in _fileLoggerPaths: - print(f"A log file was saved to {path}") - - -def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler: - global _isAtExitReportFileLoggerRegistered - log.info(f"Logging to {path} ...") # type: ignore - handler = FileHandler(path) - handler.setFormatter(Formatter(_logFormat)) - Logger.root.addHandler(handler) - _fileLoggerPaths.append(path) - if not _isAtExitReportFileLoggerRegistered and register_atexit: - atexit.register(_at_exit_report_file_logger) - _isAtExitReportFileLoggerRegistered = True - return handler - - -def add_memory_logger() -> None: - """Enables in-memory logging (if it is not already enabled), i.e. all log statements are written to a memory buffer and can later be - read via function `get_memory_log()`. - """ - global _memoryLogStream - if _memoryLogStream is not None: - return - _memoryLogStream = StringIO() - handler = StreamHandler(_memoryLogStream) - handler.setFormatter(Formatter(_logFormat)) - Logger.root.addHandler(handler) - - -def get_memory_log() -> Any: - """:return: the in-memory log (provided that `add_memory_logger` was called beforehand)""" - assert _memoryLogStream is not None, "This should not have happened and might be a bug." - return _memoryLogStream.getvalue() - - -class FileLoggerContext: - def __init__(self, path: str, enabled: bool = True): - self.enabled = enabled - self.path = path - self._log_handler: Handler | None = None - - def __enter__(self) -> None: - if self.enabled: - self._log_handler = add_file_logger(self.path, register_atexit=False) - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - if self._log_handler is not None: - remove_log_handler(self._log_handler) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 0b28f98f9..a0b85ede2 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -5,6 +5,7 @@ import numpy as np import torch +from sensai.util.pickle import setstate from torch import nn from tianshou.utils.net.common import ( @@ -15,7 +16,6 @@ TLinearLayer, get_output_dim, ) -from tianshou.utils.pickle import setstate SIGMA_MIN = -20 SIGMA_MAX = 2 diff --git a/tianshou/utils/pickle.py b/tianshou/utils/pickle.py deleted file mode 100644 index 924716222..000000000 --- a/tianshou/utils/pickle.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Helper functions for persistence/pickling, which have been copied from sensAI (specifically `sensai.util.pickle`).""" - -from collections.abc import Iterable -from copy import copy -from typing import Any - - -def setstate( - cls: type, - obj: Any, - state: dict[str, Any], - renamed_properties: dict[str, str] | None = None, - new_optional_properties: list[str] | None = None, - new_default_properties: dict[str, Any] | None = None, - removed_properties: list[str] | None = None, -) -> None: - """Helper function for safe implementations of `__setstate__` in classes, which appropriately handles the cases where - a parent class already implements `__setstate__` and where it does not. Call this function whenever you would actually - like to call the super-class' implementation. - Unfortunately, `__setstate__` is not implemented in `object`, rendering `super().__setstate__(state)` invalid in the general case. - - :param cls: the class in which you are implementing `__setstate__` - :param obj: the instance of `cls` - :param state: the state dictionary - :param renamed_properties: a mapping from old property names to new property names - :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialized with None - :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present - :param removed_properties: a list of names of properties that are no longer being used - """ - # handle new/changed properties - if renamed_properties is not None: - for mOld, mNew in renamed_properties.items(): - if mOld in state: - state[mNew] = state[mOld] - del state[mOld] - if new_optional_properties is not None: - for mNew in new_optional_properties: - if mNew not in state: - state[mNew] = None - if new_default_properties is not None: - for mNew, mValue in new_default_properties.items(): - if mNew not in state: - state[mNew] = mValue - if removed_properties is not None: - for p in removed_properties: - if p in state: - del state[p] - # call super implementation, if any - s = super(cls, obj) - if hasattr(s, "__setstate__"): - s.__setstate__(state) - else: - obj.__dict__ = state - - -def getstate( - cls: type, - obj: Any, - transient_properties: Iterable[str] | None = None, - excluded_properties: Iterable[str] | None = None, - override_properties: dict[str, Any] | None = None, - excluded_default_properties: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Helper function for safe implementations of `__getstate__` in classes, which appropriately handles the cases where - a parent class already implements `__getstate__` and where it does not. Call this function whenever you would actually - like to call the super-class' implementation. - Unfortunately, `__getstate__` is not implemented in `object`, rendering `super().__getstate__()` invalid in the general case. - - :param cls: the class in which you are implementing `__getstate__` - :param obj: the instance of `cls` - :param transient_properties: transient properties which shall be set to None in serializations - :param excluded_properties: properties which shall be completely removed from serializations - :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set; - use this to set a fixed value for an existing property or to add a completely new property - :param excluded_default_properties: properties which shall be completely removed from serializations, if they are set - to the given default value - :return: the state dictionary, which may be modified by the receiver - """ - s = super(cls, obj) - d = s.__getstate__() if hasattr(s, "__getstate__") else obj.__dict__ - d = copy(d) - if transient_properties is not None: - for p in transient_properties: - if p in d: - d[p] = None - if excluded_properties is not None: - for p in excluded_properties: - if p in d: - del d[p] - if override_properties is not None: - for k, v in override_properties.items(): - d[k] = v - if excluded_default_properties is not None: - for p, v in excluded_default_properties.items(): - if p in d and d[p] == v: - del d[p] - return d diff --git a/tianshou/utils/space_info.py b/tianshou/utils/space_info.py index f8b99053f..6943e4f11 100644 --- a/tianshou/utils/space_info.py +++ b/tianshou/utils/space_info.py @@ -5,8 +5,7 @@ import gymnasium as gym import numpy as np from gymnasium import spaces - -from tianshou.utils.string import ToStringMixin +from sensai.util.string import ToStringMixin @dataclass(kw_only=True) diff --git a/tianshou/utils/string.py b/tianshou/utils/string.py deleted file mode 100644 index a6d479236..000000000 --- a/tianshou/utils/string.py +++ /dev/null @@ -1,536 +0,0 @@ -"""Copy of sensai.util.string from sensAI """ -# From commit commit d7b4afcc89b4d2e922a816cb07dffde27f297354 - - -import functools -import logging -import re -import sys -import types -from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Mapping, Sequence -from typing import ( - Any, - Self, -) - -reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+") - -log = logging.getLogger(__name__) - -# ruff: noqa - - -class StringConverter(ABC): - """Abstraction for a string conversion mechanism.""" - - @abstractmethod - def to_string(self, x: Any) -> str: - pass - - -def dict_string( - d: Mapping, brackets: str | None = None, converter: StringConverter | None = None -) -> str: - """Converts a dictionary to a string of the form "=, =, ...", optionally enclosed - by brackets. - - :param d: the dictionary - :param brackets: a two-character string containing the opening and closing bracket to use, e.g. ``"{}"``; - if None, do not use enclosing brackets - :param converter: the string converter to use for values - :return: the string representation - """ - s = ", ".join([f"{k}={to_string(v, converter=converter, context=k)}" for k, v in d.items()]) - if brackets is not None: - return brackets[:1] + s + brackets[-1:] - else: - return s - - -def list_string( - l: Iterable[Any], - brackets: str | None = "[]", - quote: str | None = None, - converter: StringConverter | None = None, -) -> str: - """Converts a list or any other iterable to a string of the form "[, , ...]", optionally enclosed - by different brackets or with the values quoted. - - :param l: the list - :param brackets: a two-character string containing the opening and closing bracket to use, e.g. ``"[]"``; - if None, do not use enclosing brackets - :param quote: a 1-character string defining the quote to use around each value, e.g. ``"'"``. - :param converter: the string converter to use for values - :return: the string representation - """ - - def item(x: Any) -> str: - x = to_string(x, converter=converter, context="list") - if quote is not None: - return quote + x + quote - else: - return x - - s = ", ".join(item(x) for x in l) - if brackets is not None: - return brackets[:1] + s + brackets[-1:] - else: - return s - - -def to_string( - x: Any, - converter: StringConverter | None = None, - apply_converter_to_non_complex_objects: bool = True, - context: Any = None, -) -> str: - """Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter. - The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations). - - :param x: the object to convert - :param converter: the converter with which to convert objects to strings - :param apply_converter_to_non_complex_objects: whether to apply/pass on the converter (if any) not only when converting complex objects - but also non-complex, primitive objects; use of this flag enables converters to implement their conversion functionality using this - function for complex objects without causing an infinite recursion. - :param context: context in which the object is being converted (e.g. dictionary key for case where x is the corresponding - dictionary value), only for debugging purposes (will be reported in log messages upon recursion exception) - :return: the string representation - """ - try: - if isinstance(x, list): - return list_string(x, converter=converter) - elif isinstance(x, tuple): - return list_string(x, brackets="()", converter=converter) - elif isinstance(x, dict): - return dict_string(x, brackets="{}", converter=converter) - elif isinstance(x, types.MethodType): - # could be bound method of a ToStringMixin instance (which would print the repr of the instance, which can potentially cause - # an infinite recursion) - return f"Method[{x.__name__}]" - else: - if converter and apply_converter_to_non_complex_objects: - s = converter.to_string(x) - else: - s = str(x) - - # remove any unwanted line breaks and indentation after commas (as generated, for example, by sklearn objects) - return reCommaWhitespacePotentiallyBreaks.sub(", ", s) - - except RecursionError: - log.error(f"Recursion in string conversion detected; context={context}") - raise - - -def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str: - """Creates a string representation for the given object based on the given members. - - The string takes the form "ClassName[attr1=value1, attr2=value2, ...]" - """ - if isinstance(member_names_or_dict, dict): - members_dict = member_names_or_dict - else: - members_dict = {m: to_string(getattr(obj, m)) for m in member_names_or_dict} - return f"{obj.__class__.__name__}[{dict_string(members_dict)}]" - - -def or_regex_group(allowed_names: Sequence[str]) -> str: - """:param allowed_names: strings to include as literals in the regex - :return: a regular expression string of the form `(| ...|)`, which any of the given names - """ - allowed_names = [re.escape(name) for name in allowed_names] - return r"(%s)" % "|".join(allowed_names) - - -def function_name(x: Callable) -> str: - """Attempts to retrieve the name of the given function/callable object, taking the possibility - of the function being defined via functools.partial into account. - - :param x: a callable object - :return: name of the function or str(x) as a fallback - """ - if isinstance(x, functools.partial): - return function_name(x.func) - elif hasattr(x, "__name__"): - return x.__name__ - else: - return str(x) - - -class ToStringMixin: - """Provides implementations for ``__str__`` and ``__repr__`` which are based on the format ``"[]"`` and - ``"[id=, ]"`` respectively, where ```` is usually a list of entries of the - form ``"=, ..."``. - - By default, ```` will be the qualified name of the class, and ```` will include all properties - of the class, including private ones starting with an underscore (though the underscore will be dropped in the string - representation). - - * To exclude private properties, override :meth:`_toStringExcludePrivate` to return True. If there are exceptions - (and some private properties shall be retained), additionally override :meth:`_toStringExcludeExceptions`. - * To exclude a particular set of properties, override :meth:`_toStringExcludes`. - * To include only select properties (introducing inclusion semantics), override :meth:`_toStringIncludes`. - * To add values to the properties list that aren't actually properties of the object (i.e. derived properties), - override :meth:`_toStringAdditionalEntries`. - * To define a fully custom representation for ```` which is not based on the above principles, override - :meth:`_toStringObjectInfo`. - - For well-defined string conversions within a class hierarchy, it can be a good practice to define additional - inclusions/exclusions by overriding the respective method once more and basing the return value on an extended - version of the value returned by superclass. - In some cases, the requirements of a subclass can be at odds with the definitions in the superclass: The superclass - may make use of exclusion semantics, but the subclass may want to use inclusion semantics (and include - only some of the many properties it adds). In this case, if the subclass used :meth:`_toStringInclude`, the exclusion semantics - of the superclass would be void and none of its properties would actually be included. - In such cases, override :meth:`_toStringIncludesForced` to add inclusions regardless of the semantics otherwise used along - the class hierarchy. - - """ - - _TOSTRING_INCLUDE_ALL = "__all__" - - def _tostring_class_name(self) -> str: - """:return: the string use for in the string representation ``"[ str: - """Creates a string of the class attributes, with optional exclusions/inclusions/additions. - Exclusions take precedence over inclusions. - - :param exclude: attributes to be excluded - :param include: attributes to be included; if non-empty, only the specified attributes will be printed (bar the ones - excluded by ``exclude``) - :param include_forced: additional attributes to be included - :param additional_entries: additional key-value entries to be added - :param converter: the string converter to use; if None, use default (which avoids infinite recursions) - :return: a string containing entry/property names and values - """ - - def mklist(x: Any) -> list[str]: - if x is None: - return [] - if isinstance(x, str): - return [x] - return x - - exclude = mklist(exclude) - include = mklist(include) - include_forced = mklist(include_forced) - exclude_exceptions = mklist(exclude_exceptions) - - def is_excluded(k: Any) -> bool: - if k in include_forced or k in exclude_exceptions: - return False - if k in exclude: - return True - if self._tostring_exclude_private(): - return k.startswith("_") - else: - return False - - # determine relevant attribute dictionary - if ( - len(include) == 1 and include[0] == self._TOSTRING_INCLUDE_ALL - ): # exclude semantics (include everything by default) - attribute_dict = self.__dict__ - else: # include semantics (include only inclusions) - attribute_dict = { - k: getattr(self, k) - for k in set(include + include_forced) - if hasattr(self, k) and k != self._TOSTRING_INCLUDE_ALL - } - - # apply exclusions and remove underscores from attribute names - d = {k.strip("_"): v for k, v in attribute_dict.items() if not is_excluded(k)} - - if additional_entries is not None: - d.update(additional_entries) - - if converter is None: - converter = self._StringConverterAvoidToStringMixinRecursion(self) - return dict_string(d, converter=converter) - - def _tostring_object_info(self) -> str: - """Override this method to use a fully custom definition of the ```` part in the full string - representation ``"[]"`` to be generated. - As soon as this method is overridden, any property-based exclusions, inclusions, etc. will have no effect - (unless the implementation is specifically designed to make use of them - as is the default - implementation). - NOTE: Overrides must not internally use super() because of a technical limitation in the proxy - object that is used for nested object structures. - - :return: a string containing the string to use for ```` - """ - return self._tostring_properties( - exclude=self._tostring_excludes(), - include=self._tostring_includes(), - exclude_exceptions=self._tostring_exclude_exceptions(), - include_forced=self._tostring_includes_forced(), - additional_entries=self._tostring_additional_entries(), - ) - - def _tostring_excludes(self) -> list[str]: - """Makes the string representation exclude the returned attributes. - This method can be conveniently overridden by subclasses which can call super and extend the list returned. - - This method will only have no effect if :meth:`_toStringObjectInfo` is overridden to not use its result. - - :return: a list of attribute names - """ - return [] - - def _tostring_includes(self) -> list[str]: - """Makes the string representation include only the returned attributes (i.e. introduces inclusion semantics); - By default, the list contains only a marker element, which is interpreted as "all attributes included". - - This method can be conveniently overridden by sub-classes which can call super and extend the list returned. - Note that it is not a problem for a list containing the aforementioned marker element (which stands for all attributes) - to be extended; the marker element will be ignored and only the user-added elements will be considered as included. - - Note: To add an included attribute in a sub-class, regardless of any super-classes using exclusion or inclusion semantics, - use _toStringIncludesForced instead. - - This method will have no effect if :meth:`_toStringObjectInfo` is overridden to not use its result. - - :return: a list of attribute names to be included in the string representation - """ - return [self._TOSTRING_INCLUDE_ALL] - - # noinspection PyMethodMayBeStatic - def _tostring_includes_forced(self) -> list[str]: - """Defines a list of attribute names that are required to be present in the string representation, regardless of the - instance using include semantics or exclude semantics, thus facilitating added inclusions in sub-classes. - - This method will have no effect if :meth:`_toStringObjectInfo` is overridden to not use its result. - - :return: a list of attribute names - """ - return [] - - def _tostring_additional_entries(self) -> dict[str, Any]: - """:return: a dictionary of entries to be included in the ```` part of the string representation""" - return {} - - def _tostring_exclude_private(self) -> bool: - """:return: whether to exclude properties that are private (start with an underscore); explicitly included attributes - will still be considered - as will properties exempt from the rule via :meth:`toStringExcludeException`. - """ - return False - - def _tostring_exclude_exceptions(self) -> list[str]: - """Defines attribute names which should not be excluded even though other rules (particularly the exclusion of private members - via :meth:`_toStringExcludePrivate`) would otherwise exclude them. - - :return: a list of attribute names - """ - return [] - - def __str__(self) -> str: - return f"{self._tostring_class_name()}[{self._tostring_object_info()}]" - - def __repr__(self) -> str: - info = f"id={id(self)}" - property_info = self._tostring_object_info() - if len(property_info) > 0: - info += ", " + property_info - return f"{self._tostring_class_name()}[{info}]" - - def pprint(self, file: Any = sys.stdout) -> None: - """Prints a prettily formatted string representation of the object (with line breaks and indentations) - to ``stdout`` or the given file. - - :param file: the file to print to - """ - print(self.pprints(), file=file) - - def pprints(self) -> str: - """:return: a prettily formatted string representation with line breaks and indentations""" - return pretty_string_repr(self) - - class _StringConverterAvoidToStringMixinRecursion(StringConverter): - """Avoids recursions when converting objects implementing :class:`ToStringMixin` which may contain themselves to strings. - Use of this object prevents infinite recursions caused by a :class:`ToStringMixin` instance recursively containing itself in - either a property of another :class:`ToStringMixin`, a list or a tuple. - It handles all :class:`ToStringMixin` instances recursively encountered. - - A previously handled instance is converted to a string of the form "[<<]". - """ - - def __init__(self, *handled_objects: "ToStringMixin"): - """:param handled_objects: objects which are initially assumed to have been handled already""" - self._handled_to_string_mixin_ids = {id(o) for o in handled_objects} - - def to_string(self, x: Any) -> str: - if isinstance(x, ToStringMixin): - oid = id(x) - if oid in self._handled_to_string_mixin_ids: - return f"{x._tostring_class_name()}[<<]" - self._handled_to_string_mixin_ids.add(oid) - return str(self._ToStringMixinProxy(x, self)) - else: - return to_string( - x, - converter=self, - apply_converter_to_non_complex_objects=False, - context=x.__class__, - ) - - class _ToStringMixinProxy: - """A proxy object which wraps a ToStringMixin to ensure that the converter is applied when creating the properties string. - The proxy is to achieve that all ToStringMixin methods that aren't explicitly overwritten are bound to this proxy - (rather than the original object), such that the transitive call to _toStringProperties will call the new - implementation. - """ - - # methods where we assume that they could transitively call _toStringProperties (others are assumed not to) - TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"} - - def __init__(self, x: "ToStringMixin", converter: Any) -> None: - self.x = x - self.converter = converter - - def _tostring_properties(self, *args: Any, **kwargs: Any) -> str: - return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc] - - def _tostring_class_name(self) -> str: - return self.x._tostring_class_name() - - def __getattr__(self, attr: str) -> Any: - if attr.startswith( - "_tostring", - ): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call - method = getattr(self.x.__class__, attr) - obj = ( - self - if attr in self.TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES - else self.x - ) - return lambda *args, **kwargs: method(obj, *args, **kwargs) - else: - return getattr(self.x, attr) - - def __str__(self) -> str: - return ToStringMixin.__str__(self) # type: ignore[arg-type] - - -def pretty_string_repr( - s: Any, initial_indentation_level: int = 0, indentation_string: str = " " -) -> str: - """Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via - ToStringMixin). An indentation level is added for every opening bracket. - - :param s: an object or object string representation - :param initial_indentation_level: the initial indentation level - :param indentation_string: the string which corresponds to a single indentation level - :return: a reformatted version of the input string with added indentations and line breaks - """ - if not isinstance(s, str): - s = str(s) - indent = initial_indentation_level - result = indentation_string * indent - i = 0 - - def nl() -> None: - nonlocal result - result += "\n" + (indentation_string * indent) - - def take(cnt: int = 1) -> None: - nonlocal result, i - result += s[i : i + cnt] - i += cnt - - def find_matching(j: int) -> int | None: - start = j - op = s[j] - cl = {"[": "]", "(": ")", "'": "'"}[s[j]] - is_bracket = cl != s[j] - stack = 0 - while j < len(s): - if s[j] == op and (is_bracket or j == start): - stack += 1 - elif s[j] == cl: - stack -= 1 - if stack == 0: - return j - j += 1 - return None - - brackets = "[(" - quotes = "'" - while i < len(s): - is_bracket = s[i] in brackets - is_quote = s[i] in quotes - if is_bracket or is_quote: - i_match = find_matching(i) - take_full_match_without_break = False - if i_match is not None: - k = i_match + 1 - full_match = s[i:k] - take_full_match_without_break = is_quote or not ( - "=" in full_match and "," in full_match - ) - if take_full_match_without_break: - take(k - i) - if not take_full_match_without_break: - take(1) - indent += 1 - nl() - elif s[i] in "])": - take(1) - indent -= 1 - elif s[i : i + 2] == ", ": - take(2) - nl() - else: - take(1) - - return result - - -class TagBuilder: - """Assists in building strings made up of components that are joined via a glue string.""" - - def __init__(self, *initial_components: str, glue: str = "_"): - """:param initial_components: initial components to always include at the beginning - :param glue: the glue string which joins components - """ - self.glue = glue - self.components = list(initial_components) - - def with_component(self, component: str) -> Self: - self.components.append(component) - return self - - def with_conditional(self, cond: bool, component: str) -> Self: - """Conditionally adds the given component. - - :param cond: the condition - :param component: the component to add if the condition holds - :return: the builder - """ - if cond: - self.components.append(component) - return self - - def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self: - """Adds a component depending on a condition. - - :param cond: the condition - :param true_component: the component to add if the condition holds - :param false_component: the component to add if the condition does not hold - :return: the builder - """ - self.components.append(true_component if cond else false_component) - return self - - def build(self) -> str: - """:return: the string (with all components joined)""" - return self.glue.join(self.components)