-
Notifications
You must be signed in to change notification settings - Fork 316
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [feat] Add Acme integration * [fix] Change track aparam to dict. Apply flake8 formatting * [fix] Fix typos * [fix] Make track method's context argument none by default
- Loading branch information
1 parent
21bbc49
commit 81b4ece
Showing
5 changed files
with
181 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Alias to SDK acme interface | ||
from aim.sdk.adapters.acme import AimCallback, AimWriter # noqa F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Optional, Dict | ||
from aim.sdk.run import Run | ||
from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT | ||
from acme.utils.loggers.base import Logger, LoggingData | ||
|
||
|
||
class AimCallback: | ||
def __init__( | ||
self, | ||
repo: Optional[str] = None, | ||
experiment_name: Optional[str] = None, | ||
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, | ||
log_system_params: bool = True, | ||
args: Optional[Dict] = None, | ||
): | ||
self.repo = repo | ||
self.experiment_name = experiment_name | ||
self.system_tracking_interval = system_tracking_interval | ||
self.log_system_params = log_system_params | ||
self._run = None | ||
self._run_hash = None | ||
|
||
self.setup(args) | ||
|
||
@property | ||
def experiment(self): | ||
if not self._run: | ||
self.setup() | ||
return self._run | ||
|
||
def setup(self, args=None): | ||
if not self._run: | ||
if self._run_hash: | ||
self._run = Run( | ||
self._run_hash, | ||
repo=self.repo, | ||
system_tracking_interval=self.system_tracking_interval, | ||
log_system_params=self.log_system_params, | ||
) | ||
else: | ||
self._run = Run( | ||
repo=self.repo, | ||
experiment=self.experiment_name, | ||
system_tracking_interval=self.system_tracking_interval, | ||
log_system_params=self.log_system_params, | ||
) | ||
self._run_hash = self._run.hash | ||
|
||
if args: | ||
for key, value in args.items(): | ||
self._run.set(key, value, strict=False) | ||
|
||
def track(self, logs, step=None, context=None): | ||
self._run.track(logs, step=step, context=context) | ||
|
||
def close(self): | ||
if self._run and self._run.active: | ||
self._run.close() | ||
|
||
|
||
class AimWriter(Logger): | ||
def __init__(self, aim_run, logger_label, steps_key, task_id): | ||
self.aim_run = aim_run | ||
self.logger_label = logger_label | ||
self.steps_key = steps_key | ||
self.task_id = task_id | ||
|
||
def write(self, values: LoggingData): | ||
self.aim_run.track(values, context={"logger_label": self.logger_label}) | ||
|
||
def close(self): | ||
if self.aim_run and self.aim_run.active: | ||
self.aim_run.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import Optional | ||
|
||
from dm_control import suite as dm_suite | ||
import dm_env | ||
|
||
from acme import specs | ||
from acme import wrappers | ||
from acme.agents.jax import d4pg | ||
from acme.jax import experiments | ||
from acme.utils import loggers | ||
|
||
from aim.sdk.acme import AimCallback, AimWriter | ||
|
||
|
||
def make_environment(seed: int) -> dm_env.Environment: | ||
environment = dm_suite.load('cartpole', 'balance') | ||
|
||
# Make the observations be a flat vector of all concatenated features. | ||
environment = wrappers.ConcatObservationWrapper(environment) | ||
|
||
# Wrap the environment so the expected continuous action spec is [-1, 1]. | ||
# Note: this is a no-op on 'control' tasks. | ||
environment = wrappers.CanonicalSpecWrapper(environment, clip=True) | ||
|
||
# Make sure the environment outputs single-precision floats. | ||
environment = wrappers.SinglePrecisionWrapper(environment) | ||
|
||
return environment | ||
|
||
|
||
def network_factory(spec: specs.EnvironmentSpec) -> d4pg.D4PGNetworks: | ||
return d4pg.make_networks( | ||
spec, | ||
# These correspond to sizes of the hidden layers of an MLP. | ||
policy_layer_sizes=(256, 256), | ||
critic_layer_sizes=(256, 256), | ||
) | ||
|
||
|
||
d4pg_config = d4pg.D4PGConfig(learning_rate=3e-4, sigma=0.2) | ||
d4pg_builder = d4pg.D4PGBuilder(d4pg_config) | ||
|
||
|
||
aim_run = AimCallback(repo=".", experiment_name="acme_test") | ||
|
||
def logger_factory( | ||
name: str, | ||
steps_key: Optional[str] = None, | ||
task_id: Optional[int] = None, | ||
) -> loggers.Logger: | ||
return AimWriter(aim_run, name, steps_key, task_id) | ||
|
||
|
||
experiment_config = experiments.ExperimentConfig( | ||
builder=d4pg_builder, | ||
environment_factory=make_environment, | ||
network_factory=network_factory, | ||
logger_factory=logger_factory, | ||
seed=0, | ||
max_num_actor_steps=5000) # Each episode is 1000 steps. | ||
|
||
|
||
experiments.run_experiment( | ||
experiment=experiment_config, | ||
eval_every=1000, | ||
num_eval_episodes=1) |