Skip to content

Commit

Permalink
[feat] Add Acme integration (#2453)
Browse files Browse the repository at this point in the history
* [feat] Add Acme integration

* [fix] Change track aparam to dict. Apply flake8 formatting

* [fix] Fix typos

* [fix] Make track method's context argument none by default
  • Loading branch information
tamohannes authored Jan 13, 2023
1 parent 21bbc49 commit 81b4ece
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Add support for pre-binned distribution/histogram (YodaEmbedding)
- Fix plotly and matplotlib compatibility (tmynn)
- Add Stable-Baselines3 integration (tmynn)
- Add Acme integration (tmynn)

### Fixes

Expand Down
2 changes: 2 additions & 0 deletions aim/acme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Alias to SDK acme interface
from aim.sdk.adapters.acme import AimCallback, AimWriter # noqa F401
73 changes: 73 additions & 0 deletions aim/sdk/adapters/acme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Optional, Dict
from aim.sdk.run import Run
from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT
from acme.utils.loggers.base import Logger, LoggingData


class AimCallback:
def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: bool = True,
args: Optional[Dict] = None,
):
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = None
self._run_hash = None

self.setup(args)

@property
def experiment(self):
if not self._run:
self.setup()
return self._run

def setup(self, args=None):
if not self._run:
if self._run_hash:
self._run = Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
else:
self._run = Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash

if args:
for key, value in args.items():
self._run.set(key, value, strict=False)

def track(self, logs, step=None, context=None):
self._run.track(logs, step=step, context=context)

def close(self):
if self._run and self._run.active:
self._run.close()


class AimWriter(Logger):
def __init__(self, aim_run, logger_label, steps_key, task_id):
self.aim_run = aim_run
self.logger_label = logger_label
self.steps_key = steps_key
self.task_id = task_id

def write(self, values: LoggingData):
self.aim_run.track(values, context={"logger_label": self.logger_label})

def close(self):
if self.aim_run and self.aim_run.active:
self.aim_run.close()
39 changes: 39 additions & 0 deletions docs/source/quick_start/integrations.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,45 @@ Check out a simple objective optimization example [here](https://github.com/aimh



### Integration with Acme

Aim provides a built in callback to easily track [Acme](https://dm-acme.readthedocs.io/en/latest/) trainings.
It takes few simple steps to integrate Aim into your training script.

Step 1: Explicitly import the `AimCallback` and `AimWriter` for tracking training metadata.

```python
from aim.sdk.acme import AimCallback, AimWriter
```

Step 2: Initialize an Aim Run via `AimCallback`, and create a log factory using the Run.

```python
aim_run = AimCallback(repo=".", experiment_name="acme_test")
def logger_factory(
name: str,
steps_key: Optional[str] = None,
task_id: Optional[int] = None,
) -> loggers.Logger:
return AimWriter(aim_run, name, steps_key, task_id)
```

Step 3: Pass the logger factory to `logger_factory` upon initiating your training.

```python
experiment_config = experiments.ExperimentConfig(
builder=d4pg_builder,
environment_factory=make_environment,
network_factory=network_factory,
logger_factory=logger_factory,
seed=0,
max_num_actor_steps=5000)
```

See `AimCallback` source [here](https://github.com/aimhubio/aim/blob/main/aim/sdk/adapters/acme.py).
Check out a simple objective optimization example [here](https://github.com/aimhubio/aim/blob/main/examples/acme_track.py).


### What's next?

During the training process, you can start another terminal in the same directory, start `aim up` and you can observe
Expand Down
66 changes: 66 additions & 0 deletions examples/acme_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Optional

from dm_control import suite as dm_suite
import dm_env

from acme import specs
from acme import wrappers
from acme.agents.jax import d4pg
from acme.jax import experiments
from acme.utils import loggers

from aim.sdk.acme import AimCallback, AimWriter


def make_environment(seed: int) -> dm_env.Environment:
environment = dm_suite.load('cartpole', 'balance')

# Make the observations be a flat vector of all concatenated features.
environment = wrappers.ConcatObservationWrapper(environment)

# Wrap the environment so the expected continuous action spec is [-1, 1].
# Note: this is a no-op on 'control' tasks.
environment = wrappers.CanonicalSpecWrapper(environment, clip=True)

# Make sure the environment outputs single-precision floats.
environment = wrappers.SinglePrecisionWrapper(environment)

return environment


def network_factory(spec: specs.EnvironmentSpec) -> d4pg.D4PGNetworks:
return d4pg.make_networks(
spec,
# These correspond to sizes of the hidden layers of an MLP.
policy_layer_sizes=(256, 256),
critic_layer_sizes=(256, 256),
)


d4pg_config = d4pg.D4PGConfig(learning_rate=3e-4, sigma=0.2)
d4pg_builder = d4pg.D4PGBuilder(d4pg_config)


aim_run = AimCallback(repo=".", experiment_name="acme_test")

def logger_factory(
name: str,
steps_key: Optional[str] = None,
task_id: Optional[int] = None,
) -> loggers.Logger:
return AimWriter(aim_run, name, steps_key, task_id)


experiment_config = experiments.ExperimentConfig(
builder=d4pg_builder,
environment_factory=make_environment,
network_factory=network_factory,
logger_factory=logger_factory,
seed=0,
max_num_actor_steps=5000) # Each episode is 1000 steps.


experiments.run_experiment(
experiment=experiment_config,
eval_every=1000,
num_eval_episodes=1)

0 comments on commit 81b4ece

Please sign in to comment.