From 81b4ece34fea970f3f620265afb318c613907c87 Mon Sep 17 00:00:00 2001 From: Hovhannes Tamoyan Date: Fri, 13 Jan 2023 15:02:01 +0400 Subject: [PATCH] [feat] Add Acme integration (#2453) * [feat] Add Acme integration * [fix] Change track aparam to dict. Apply flake8 formatting * [fix] Fix typos * [fix] Make track method's context argument none by default --- CHANGELOG.md | 1 + aim/acme.py | 2 + aim/sdk/adapters/acme.py | 73 +++++++++++++++++++++++++ docs/source/quick_start/integrations.md | 39 +++++++++++++ examples/acme_track.py | 66 ++++++++++++++++++++++ 5 files changed, 181 insertions(+) create mode 100644 aim/acme.py create mode 100644 aim/sdk/adapters/acme.py create mode 100644 examples/acme_track.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c179d0eb0..2e2abe0770 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Add support for pre-binned distribution/histogram (YodaEmbedding) - Fix plotly and matplotlib compatibility (tmynn) - Add Stable-Baselines3 integration (tmynn) +- Add Acme integration (tmynn) ### Fixes diff --git a/aim/acme.py b/aim/acme.py new file mode 100644 index 0000000000..44884fd7d1 --- /dev/null +++ b/aim/acme.py @@ -0,0 +1,2 @@ +# Alias to SDK acme interface +from aim.sdk.adapters.acme import AimCallback, AimWriter # noqa F401 diff --git a/aim/sdk/adapters/acme.py b/aim/sdk/adapters/acme.py new file mode 100644 index 0000000000..0529415bfa --- /dev/null +++ b/aim/sdk/adapters/acme.py @@ -0,0 +1,73 @@ +from typing import Optional, Dict +from aim.sdk.run import Run +from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT +from acme.utils.loggers.base import Logger, LoggingData + + +class AimCallback: + def __init__( + self, + repo: Optional[str] = None, + experiment_name: Optional[str] = None, + system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, + log_system_params: bool = True, + args: Optional[Dict] = None, + ): + self.repo = repo + self.experiment_name = experiment_name + self.system_tracking_interval = system_tracking_interval + self.log_system_params = log_system_params + self._run = None + self._run_hash = None + + self.setup(args) + + @property + def experiment(self): + if not self._run: + self.setup() + return self._run + + def setup(self, args=None): + if not self._run: + if self._run_hash: + self._run = Run( + self._run_hash, + repo=self.repo, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + else: + self._run = Run( + repo=self.repo, + experiment=self.experiment_name, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + self._run_hash = self._run.hash + + if args: + for key, value in args.items(): + self._run.set(key, value, strict=False) + + def track(self, logs, step=None, context=None): + self._run.track(logs, step=step, context=context) + + def close(self): + if self._run and self._run.active: + self._run.close() + + +class AimWriter(Logger): + def __init__(self, aim_run, logger_label, steps_key, task_id): + self.aim_run = aim_run + self.logger_label = logger_label + self.steps_key = steps_key + self.task_id = task_id + + def write(self, values: LoggingData): + self.aim_run.track(values, context={"logger_label": self.logger_label}) + + def close(self): + if self.aim_run and self.aim_run.active: + self.aim_run.close() diff --git a/docs/source/quick_start/integrations.md b/docs/source/quick_start/integrations.md index c7d6a10773..4c23484921 100644 --- a/docs/source/quick_start/integrations.md +++ b/docs/source/quick_start/integrations.md @@ -378,6 +378,45 @@ Check out a simple objective optimization example [here](https://github.com/aimh +### Integration with Acme + +Aim provides a built in callback to easily track [Acme](https://dm-acme.readthedocs.io/en/latest/) trainings. +It takes few simple steps to integrate Aim into your training script. + +Step 1: Explicitly import the `AimCallback` and `AimWriter` for tracking training metadata. + +```python +from aim.sdk.acme import AimCallback, AimWriter +``` + +Step 2: Initialize an Aim Run via `AimCallback`, and create a log factory using the Run. + +```python +aim_run = AimCallback(repo=".", experiment_name="acme_test") +def logger_factory( + name: str, + steps_key: Optional[str] = None, + task_id: Optional[int] = None, +) -> loggers.Logger: + return AimWriter(aim_run, name, steps_key, task_id) +``` + +Step 3: Pass the logger factory to `logger_factory` upon initiating your training. + +```python +experiment_config = experiments.ExperimentConfig( + builder=d4pg_builder, + environment_factory=make_environment, + network_factory=network_factory, + logger_factory=logger_factory, + seed=0, + max_num_actor_steps=5000) +``` + +See `AimCallback` source [here](https://github.com/aimhubio/aim/blob/main/aim/sdk/adapters/acme.py). +Check out a simple objective optimization example [here](https://github.com/aimhubio/aim/blob/main/examples/acme_track.py). + + ### What's next? During the training process, you can start another terminal in the same directory, start `aim up` and you can observe diff --git a/examples/acme_track.py b/examples/acme_track.py new file mode 100644 index 0000000000..1eb0e050ef --- /dev/null +++ b/examples/acme_track.py @@ -0,0 +1,66 @@ +from typing import Optional + +from dm_control import suite as dm_suite +import dm_env + +from acme import specs +from acme import wrappers +from acme.agents.jax import d4pg +from acme.jax import experiments +from acme.utils import loggers + +from aim.sdk.acme import AimCallback, AimWriter + + +def make_environment(seed: int) -> dm_env.Environment: + environment = dm_suite.load('cartpole', 'balance') + + # Make the observations be a flat vector of all concatenated features. + environment = wrappers.ConcatObservationWrapper(environment) + + # Wrap the environment so the expected continuous action spec is [-1, 1]. + # Note: this is a no-op on 'control' tasks. + environment = wrappers.CanonicalSpecWrapper(environment, clip=True) + + # Make sure the environment outputs single-precision floats. + environment = wrappers.SinglePrecisionWrapper(environment) + + return environment + + +def network_factory(spec: specs.EnvironmentSpec) -> d4pg.D4PGNetworks: + return d4pg.make_networks( + spec, + # These correspond to sizes of the hidden layers of an MLP. + policy_layer_sizes=(256, 256), + critic_layer_sizes=(256, 256), + ) + + +d4pg_config = d4pg.D4PGConfig(learning_rate=3e-4, sigma=0.2) +d4pg_builder = d4pg.D4PGBuilder(d4pg_config) + + +aim_run = AimCallback(repo=".", experiment_name="acme_test") + +def logger_factory( + name: str, + steps_key: Optional[str] = None, + task_id: Optional[int] = None, +) -> loggers.Logger: + return AimWriter(aim_run, name, steps_key, task_id) + + +experiment_config = experiments.ExperimentConfig( + builder=d4pg_builder, + environment_factory=make_environment, + network_factory=network_factory, + logger_factory=logger_factory, + seed=0, + max_num_actor_steps=5000) # Each episode is 1000 steps. + + +experiments.run_experiment( + experiment=experiment_config, + eval_every=1000, + num_eval_episodes=1)