Skip to content

Commit

Permalink
[feat] Add Stable-Baselines3 integration (#2452)
Browse files Browse the repository at this point in the history
* [feat] Add Stable-Baselines3 integration

* [fix] Remove redundant if statement
  • Loading branch information
tamohannes authored Jan 13, 2023
1 parent e46afb3 commit 21bbc49
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Add other x-axis alignment and system logs tracking to cli convert wandb (hjoonjang)
- Add support for pre-binned distribution/histogram (YodaEmbedding)
- Fix plotly and matplotlib compatibility (tmynn)
- Add Stable-Baselines3 integration (tmynn)

### Fixes

Expand Down
2 changes: 2 additions & 0 deletions aim/sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Alias to SDK sb3 interface
from aim.sdk.adapters.sb3 import AimCallback # noqa F401
134 changes: 134 additions & 0 deletions aim/sdk/adapters/sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import logging
import numpy as np
from typing import Optional, Dict, Any, Union, Tuple

from stable_baselines3.common.logger import KVWriter, Logger
from stable_baselines3.common.callbacks import BaseCallback # type: ignore

from aim.sdk.run import Run
from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT


logger = logging.getLogger(__name__)


class AimOutputFormat(KVWriter):
"""
Track key/value pairs into Aim run.
"""

def __init__(
self,
aim_callback
):
self.aim_callback = aim_callback

def write(
self,
key_values: Dict[str, Any],
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
step: int = 0,
) -> None:
for (key, value), (_, excluded) in zip(
sorted(key_values.items()), sorted(key_excluded.items())
):

if excluded is not None:
continue

if isinstance(value, np.ScalarType):
if not isinstance(value, str):
tag, key = key.split('/')
if tag in ['train', 'valid']:
context = {'subset': tag}
else:
context = {'tag': tag}

self.aim_callback.experiment.track(value, key, step=step, context=context)


class AimCallback(BaseCallback):
"""
AimCallback callback function.
Args:
repo (:obj:`str`, optional): Aim repository path or Repo object to which Run object is bound.
If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property. 'default' if not specified.
Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval in seconds for system usage
metrics (CPU, Memory, etc.). Set to `None` to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system params such as installed packages,
git info, environment variables, etc.
"""

def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: bool = True,
verbose: int = 0,
) -> None:

super().__init__(verbose)

self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = None
self._run_hash = None

def _init_callback(self) -> None:
args = {"algo": type(self.model).__name__}
for key in self.model.__dict__:
if type(self.model.__dict__[key]) in [float, int, str]:
args[key] = self.model.__dict__[key]
else:
args[key] = str(self.model.__dict__[key])

self.setup(args=args)

loggers = Logger(
folder=None,
output_formats=[AimOutputFormat(self)],
)

self.model.set_logger(loggers)

def _on_step(self) -> bool:
return True

@property
def experiment(self):
if not self._run:
self.setup()
return self._run

def setup(self, args=None):
if not self._run:
if self._run_hash:
self._run = Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
else:
self._run = Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash

# Log config parameters
if args:
for key in args:
self._run.set(key, args[key], strict=False)

def __del__(self):
if self._run and self._run.active:
self._run.close()
22 changes: 22 additions & 0 deletions docs/source/quick_start/integrations.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,28 @@ Check out a simple objective optimization example [here](https://github.com/aimh



### Integration with Stable-Baselines3

Aim provides a callback to easily track one of the reliable Reinforcement Learning implementations [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) trainings.
It takes two steps to integrate Aim into your training script.

Step 1: Explicitly import the `AimCallback` for tracking training metadata.

```python
from aim.sb3 import AimCallback
```

Step 2: Pass the callback to `callback` upon initiating your training.

```python
model.learn(total_timesteps=10_000, callback=AimCallback(repo='.', experiment_name='sb3_test'))
```

See `AimCallback` source [here](https://github.com/aimhubio/aim/blob/main/aim/sdk/adapters/sb3.py).
Check out a simple objective optimization example [here](https://github.com/aimhubio/aim/blob/main/examples/sb3_track.py).



### What's next?

During the training process, you can start another terminal in the same directory, start `aim up` and you can observe
Expand Down
2 changes: 1 addition & 1 deletion examples/paddle_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())

callback = AimCallback(repo='.', experiment='paddle_test')
callback = AimCallback(repo='.', experiment_name='paddle_test')
model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)
6 changes: 6 additions & 0 deletions examples/sb3_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from aim.sb3 import AimCallback
from stable_baselines3 import A2C


model = A2C("MlpPolicy", "CartPole-v1", verbose=2)
model.learn(total_timesteps=10_000, callback=AimCallback(repo='.', experiment_name='sb3_test'))

0 comments on commit 21bbc49

Please sign in to comment.