diff --git a/CHANGELOG.md b/CHANGELOG.md index 1806cd3c79..0c179d0eb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Add other x-axis alignment and system logs tracking to cli convert wandb (hjoonjang) - Add support for pre-binned distribution/histogram (YodaEmbedding) - Fix plotly and matplotlib compatibility (tmynn) +- Add Stable-Baselines3 integration (tmynn) ### Fixes diff --git a/aim/sb3.py b/aim/sb3.py new file mode 100644 index 0000000000..43fd7899eb --- /dev/null +++ b/aim/sb3.py @@ -0,0 +1,2 @@ +# Alias to SDK sb3 interface +from aim.sdk.adapters.sb3 import AimCallback # noqa F401 diff --git a/aim/sdk/adapters/sb3.py b/aim/sdk/adapters/sb3.py new file mode 100644 index 0000000000..821dfb8aa6 --- /dev/null +++ b/aim/sdk/adapters/sb3.py @@ -0,0 +1,134 @@ +import logging +import numpy as np +from typing import Optional, Dict, Any, Union, Tuple + +from stable_baselines3.common.logger import KVWriter, Logger +from stable_baselines3.common.callbacks import BaseCallback # type: ignore + +from aim.sdk.run import Run +from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT + + +logger = logging.getLogger(__name__) + + +class AimOutputFormat(KVWriter): + """ + Track key/value pairs into Aim run. + """ + + def __init__( + self, + aim_callback + ): + self.aim_callback = aim_callback + + def write( + self, + key_values: Dict[str, Any], + key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0, + ) -> None: + for (key, value), (_, excluded) in zip( + sorted(key_values.items()), sorted(key_excluded.items()) + ): + + if excluded is not None: + continue + + if isinstance(value, np.ScalarType): + if not isinstance(value, str): + tag, key = key.split('/') + if tag in ['train', 'valid']: + context = {'subset': tag} + else: + context = {'tag': tag} + + self.aim_callback.experiment.track(value, key, step=step, context=context) + + +class AimCallback(BaseCallback): + """ + AimCallback callback function. + + Args: + repo (:obj:`str`, optional): Aim repository path or Repo object to which Run object is bound. + If skipped, default Repo is used. + experiment_name (:obj:`str`, optional): Sets Run's `experiment` property. 'default' if not specified. + Can be used later to query runs/sequences. + system_tracking_interval (:obj:`int`, optional): Sets the tracking interval in seconds for system usage + metrics (CPU, Memory, etc.). Set to `None` to disable system metrics tracking. + log_system_params (:obj:`bool`, optional): Enable/Disable logging of system params such as installed packages, + git info, environment variables, etc. + """ + + def __init__( + self, + repo: Optional[str] = None, + experiment_name: Optional[str] = None, + system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, + log_system_params: bool = True, + verbose: int = 0, + ) -> None: + + super().__init__(verbose) + + self.repo = repo + self.experiment_name = experiment_name + self.system_tracking_interval = system_tracking_interval + self.log_system_params = log_system_params + self._run = None + self._run_hash = None + + def _init_callback(self) -> None: + args = {"algo": type(self.model).__name__} + for key in self.model.__dict__: + if type(self.model.__dict__[key]) in [float, int, str]: + args[key] = self.model.__dict__[key] + else: + args[key] = str(self.model.__dict__[key]) + + self.setup(args=args) + + loggers = Logger( + folder=None, + output_formats=[AimOutputFormat(self)], + ) + + self.model.set_logger(loggers) + + def _on_step(self) -> bool: + return True + + @property + def experiment(self): + if not self._run: + self.setup() + return self._run + + def setup(self, args=None): + if not self._run: + if self._run_hash: + self._run = Run( + self._run_hash, + repo=self.repo, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + else: + self._run = Run( + repo=self.repo, + experiment=self.experiment_name, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + self._run_hash = self._run.hash + + # Log config parameters + if args: + for key in args: + self._run.set(key, args[key], strict=False) + + def __del__(self): + if self._run and self._run.active: + self._run.close() diff --git a/docs/source/quick_start/integrations.md b/docs/source/quick_start/integrations.md index 54bb7f05f9..c7d6a10773 100644 --- a/docs/source/quick_start/integrations.md +++ b/docs/source/quick_start/integrations.md @@ -356,6 +356,28 @@ Check out a simple objective optimization example [here](https://github.com/aimh +### Integration with Stable-Baselines3 + +Aim provides a callback to easily track one of the reliable Reinforcement Learning implementations [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) trainings. +It takes two steps to integrate Aim into your training script. + +Step 1: Explicitly import the `AimCallback` for tracking training metadata. + +```python +from aim.sb3 import AimCallback +``` + +Step 2: Pass the callback to `callback` upon initiating your training. + +```python +model.learn(total_timesteps=10_000, callback=AimCallback(repo='.', experiment_name='sb3_test')) +``` + +See `AimCallback` source [here](https://github.com/aimhubio/aim/blob/main/aim/sdk/adapters/sb3.py). +Check out a simple objective optimization example [here](https://github.com/aimhubio/aim/blob/main/examples/sb3_track.py). + + + ### What's next? During the training process, you can start another terminal in the same directory, start `aim up` and you can observe diff --git a/examples/paddle_track.py b/examples/paddle_track.py index fe4174b89c..53f2f0afe5 100644 --- a/examples/paddle_track.py +++ b/examples/paddle_track.py @@ -22,5 +22,5 @@ loss=paddle.nn.CrossEntropyLoss(), metrics=paddle.metric.Accuracy()) -callback = AimCallback(repo='.', experiment='paddle_test') +callback = AimCallback(repo='.', experiment_name='paddle_test') model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback) diff --git a/examples/sb3_track.py b/examples/sb3_track.py new file mode 100644 index 0000000000..5a093ffa04 --- /dev/null +++ b/examples/sb3_track.py @@ -0,0 +1,6 @@ +from aim.sb3 import AimCallback +from stable_baselines3 import A2C + + +model = A2C("MlpPolicy", "CartPole-v1", verbose=2) +model.learn(total_timesteps=10_000, callback=AimCallback(repo='.', experiment_name='sb3_test'))