diff --git a/docs/api/tianshou.data.rst b/docs/api/tianshou.data.rst index eea262a76..77c69aa15 100644 --- a/docs/api/tianshou.data.rst +++ b/docs/api/tianshou.data.rst @@ -88,3 +88,30 @@ AsyncCollector :members: :undoc-members: :show-inheritance: + + +Utils +----- + +to_numpy +~~~~~~~~ + +.. autofunction:: tianshou.data.to_numpy + +to_torch +~~~~~~~~ + +.. autofunction:: tianshou.data.to_torch + +to_torch_as +~~~~~~~~~~~ + +.. autofunction:: tianshou.data.to_torch_as + +SegmentTree +~~~~~~~~~~~ + +.. autoclass:: tianshou.data.SegmentTree + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tianshou.env.rst b/docs/api/tianshou.env.rst index 04848a778..77713f411 100644 --- a/docs/api/tianshou.env.rst +++ b/docs/api/tianshou.env.rst @@ -46,6 +46,26 @@ RayVectorEnv :show-inheritance: +Wrapper +------- + +VectorEnvWrapper +~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.VectorEnvWrapper + :members: + :undoc-members: + :show-inheritance: + +VectorEnvNormObs +~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.VectorEnvNormObs + :members: + :undoc-members: + :show-inheritance: + + Worker ------ @@ -80,3 +100,15 @@ RayEnvWorker :members: :undoc-members: :show-inheritance: + + +Utils +----- + +PettingZooEnv +~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.PettingZooEnv + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 471f262b5..e5a2d5345 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -158,3 +158,4 @@ Enduro Qbert Seaquest subnets +subprocesses diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 6273c06af..08aac4451 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -123,7 +123,11 @@ EnvPool Integration `EnvPool `_ is a C++-based vectorized environment implementation and is way faster than the above solutions. The APIs are almost the same as above four classes, so that means you can directly switch the vectorized environment to envpool and get immediate speed-up. -Currently it supports Atari, VizDoom, toy_text and classic_control environments. For more information, please refer to `EnvPool's documentation `_. +Currently it supports +`Atari `_, +`Mujoco `_, +`VizDoom `_, +toy_text and classic_control environments. For more information, please refer to `EnvPool's documentation `_. :: @@ -133,7 +137,7 @@ Currently it supports Atari, VizDoom, toy_text and classic_control environments. envs = envpool.make_gym("CartPole-v0", num_envs=10) collector = Collector(policy, envs, buffer) -Here are some examples: https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples +Here are some other `examples `_. .. _preprocess_fn: @@ -177,7 +181,7 @@ For example, you can write your hook as: self.episode_log[i].append(kwargs['rew'][i]) kwargs['rew'][i] -= self.baseline for i in range(n): - if kwargs['done']: + if kwargs['done'][i]: self.main_log.append(np.mean(self.episode_log[i])) self.episode_log[i] = [] self.baseline = np.mean(self.main_log) @@ -191,6 +195,40 @@ And finally, Some examples are in `test/base/test_collector.py `_. +Another solution is to create a vector environment wrapper through :class:`~tianshou.env.VectorEnvWrapper`, e.g. +:: + + import numpy as np + from collections import deque + from tianshou.env import VectorEnvWrapper + + class MyWrapper(VectorEnvWrapper): + def __init__(self, venv, size=100): + self.episode_log = None + self.main_log = deque(maxlen=size) + self.main_log.append(0) + self.baseline = 0 + + def step(self, action, env_id): + obs, rew, done, info = self.venv.step(action, env_id) + n = len(rew) + if self.episode_log is None: + self.episode_log = [[] for i in range(n)] + for i in range(n): + self.episode_log[i].append(rew[i]) + rew[i] -= self.baseline + for i in range(n): + if done[i]: + self.main_log.append(np.mean(self.episode_log[i])) + self.episode_log[i] = [] + self.baseline = np.mean(self.main_log) + return obs, rew, done, info + + env = MyWrapper(env, size=100) + collector = Collector(policy, env, buffer) + +We provide an observation normalization vector env wrapper: :class:`~tianshou.env.VectorEnvNormObs`. + .. _rnn_training: diff --git a/examples/vizdoom/README.md b/examples/vizdoom/README.md index ab01b7a77..ca151f19b 100644 --- a/examples/vizdoom/README.md +++ b/examples/vizdoom/README.md @@ -2,12 +2,24 @@ [ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario. +## EnvPool + +We highly recommend using envpool to run the following experiments. To install, in a linux machine, type: + +```bash +pip install envpool +``` + +After that, `make_vizdoom_env` will automatically switch to envpool's ViZDoom env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation. + +For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/) and [Docs](https://envpool.readthedocs.io/en/latest/api/vizdoom.html). + ## Train To train an agent: ```bash -python3 vizdoom_c51.py --task {D1_basic|D3_battle|D4_battle2} +python3 vizdoom_c51.py --task {D1_basic|D2_navigation|D3_battle|D4_battle2} ``` D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs); diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 290cb92e5..63555f733 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -5,6 +5,13 @@ import numpy as np import vizdoom as vzd +from tianshou.env import ShmemVectorEnv + +try: + import envpool +except ImportError: + envpool = None + def normal_button_comb(): actions = [] @@ -112,6 +119,58 @@ def close(self): self.game.close() +def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num): + test_num = min(os.cpu_count() - 1, test_num) + if envpool is not None: + task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" + lmp_save_dir = "lmps/" if save_lmp else "" + reward_config = { + "KILLCOUNT": [20.0, -20.0], + "HEALTH": [1.0, 0.0], + "AMMO2": [1.0, -1.0], + } + if "battle" in task: + reward_config["HEALTH"] = [1.0, -1.0] + env = train_envs = envpool.make_gym( + task_id, + frame_skip=frame_skip, + stack_num=res[0], + seed=seed, + num_envs=training_num, + reward_config=reward_config, + use_combined_action=True, + max_episode_steps=2625, + use_inter_area_resize=False, + ) + test_envs = envpool.make_gym( + task_id, + frame_skip=frame_skip, + stack_num=res[0], + lmp_save_dir=lmp_save_dir, + seed=seed, + num_envs=test_num, + reward_config=reward_config, + use_combined_action=True, + max_episode_steps=2625, + use_inter_area_resize=False, + ) + else: + cfg_path = f"maps/{task}.cfg" + env = Env(cfg_path, frame_skip, res) + train_envs = ShmemVectorEnv( + [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)] + ) + test_envs = ShmemVectorEnv( + [ + lambda: Env(cfg_path, frame_skip, res, save_lmp) + for _ in range(test_num) + ] + ) + train_envs.seed(seed) + test_envs.seed(seed) + return env, train_envs, test_envs + + if __name__ == '__main__': # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84)) env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 9f5aab835..5b41aa775 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -1,94 +1,88 @@ import argparse +import datetime import os import pprint import numpy as np import torch -from env import Env +from env import make_vizdoom_env from network import C51 from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='D1_basic') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=2000000) - parser.add_argument('--lr', type=float, default=0.0001) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--num-atoms', type=int, default=51) - parser.add_argument('--v-min', type=float, default=-10.) - parser.add_argument('--v-max', type=float, default=10.) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=300) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="D1_basic") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=2000000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-atoms", type=int, default=51) + parser.add_argument("--v-min", type=float, default=-10.) + parser.add_argument("--v-max", type=float, default=10.) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=300) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--skip-num', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--skip-num", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only", ) parser.add_argument( - '--save-lmp', + "--save-lmp", default=False, - action='store_true', - help='save lmp file for replay whole episode' + action="store_true", + help="save lmp file for replay whole episode", ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) return parser.parse_args() def test_c51(args=get_args()): - args.cfg_path = f"maps/{args.task}.cfg" - args.wad_path = f"maps/{args.task}.wad" - args.res = (args.skip_num, 84, 84) - env = Env(args.cfg_path, args.frames_stack, args.res) - args.state_shape = args.res + # make environments + env, train_envs, test_envs = make_vizdoom_env( + args.task, args.skip_num, (args.frames_stack, 84, 84), args.save_lmp, + args.seed, args.training_num, args.test_num + ) + args.state_shape = env.observation_space.shape args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [ - lambda: Env(args.cfg_path, args.frames_stack, args.res) - for _ in range(args.training_num) - ] - ) - test_envs = ShmemVectorEnv( - [ - lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) - for _ in range(min(os.cpu_count() - 1, args.test_num)) - ] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) @@ -101,7 +95,7 @@ def test_c51(args=get_args()): args.v_min, args.v_max, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: @@ -114,25 +108,40 @@ def test_c51(args=get_args()): buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack + stack_num=args.frames_stack, ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_path = os.path.join(args.logdir, args.task, 'c51') + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "c51" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: - return mean_rewards >= 20 else: return False @@ -163,7 +172,7 @@ def watch(): buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack + stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) @@ -203,12 +212,12 @@ def watch(): save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) watch() -if __name__ == '__main__': +if __name__ == "__main__": test_c51(get_args()) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 9a219a3ee..26a5885e7 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -1,126 +1,120 @@ import argparse +import datetime import os import pprint import numpy as np import torch -from env import Env +from env import make_vizdoom_env from network import DQN from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import ShmemVectorEnv from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.trainer import onpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='D2_navigation') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.00002) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--epoch', type=int, default=300) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=1000) - parser.add_argument('--repeat-per-collect', type=int, default=4) - parser.add_argument('--batch-size', type=int, default=256) - parser.add_argument('--hidden-size', type=int, default=512) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--rew-norm', type=int, default=False) - parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.01) - parser.add_argument('--gae-lambda', type=float, default=0.95) - parser.add_argument('--lr-decay', type=int, default=True) - parser.add_argument('--max-grad-norm', type=float, default=0.5) - parser.add_argument('--eps-clip', type=float, default=0.2) - parser.add_argument('--dual-clip', type=float, default=None) - parser.add_argument('--value-clip', type=int, default=0) - parser.add_argument('--norm-adv', type=int, default=1) - parser.add_argument('--recompute-adv', type=int, default=0) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="D1_basic") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.00002) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--epoch", type=int, default=300) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=1000) + parser.add_argument("--repeat-per-collect", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--vf-coef", type=float, default=0.5) + parser.add_argument("--ent-coef", type=float, default=0.01) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--eps-clip", type=float, default=0.2) + parser.add_argument("--dual-clip", type=float, default=None) + parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--skip-num', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--skip-num", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only", ) parser.add_argument( - '--save-lmp', + "--save-lmp", default=False, - action='store_true', - help='save lmp file for replay whole episode' + action="store_true", + help="save lmp file for replay whole episode", ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) parser.add_argument( - '--icm-lr-scale', + "--icm-lr-scale", type=float, default=0., - help='use intrinsic curiosity module with this lr scale' + help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - '--icm-reward-scale', + "--icm-reward-scale", type=float, default=0.01, - help='scaling factor for intrinsic curiosity reward' + help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - '--icm-forward-loss-weight', + "--icm-forward-loss-weight", type=float, default=0.2, - help='weight for the forward model loss in ICM' + help="weight for the forward model loss in ICM", ) return parser.parse_args() def test_ppo(args=get_args()): - args.cfg_path = f"maps/{args.task}.cfg" - args.wad_path = f"maps/{args.task}.wad" - args.res = (args.skip_num, 84, 84) - env = Env(args.cfg_path, args.frames_stack, args.res) - args.state_shape = args.res + # make environments + env, train_envs, test_envs = make_vizdoom_env( + args.task, args.skip_num, (args.frames_stack, 84, 84), args.save_lmp, + args.seed, args.training_num, args.test_num + ) + args.state_shape = env.observation_space.shape args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - # make environments - train_envs = ShmemVectorEnv( - [ - lambda: Env(args.cfg_path, args.frames_stack, args.res) - for _ in range(args.training_num) - ] - ) - test_envs = ShmemVectorEnv( - [ - lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) - for _ in range(min(os.cpu_count() - 1, args.test_num)) - ] - ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) # define model net = DQN( *args.state_shape, args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size + output_dim=args.hidden_size, ) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) @@ -159,7 +153,7 @@ def dist(p): value_clip=args.value_clip, dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv + recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: feature_net = DQN( @@ -167,7 +161,7 @@ def dist(p): args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size + output_dim=args.hidden_size, ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim @@ -190,26 +184,40 @@ def dist(p): buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack + stack_num=args.frames_stack, ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' - log_path = os.path.join(args.logdir, args.task, log_name) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: - return mean_rewards >= 20 else: return False @@ -225,7 +233,7 @@ def watch(): buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack + stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) @@ -263,12 +271,12 @@ def watch(): stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - test_in_train=False + test_in_train=False, ) pprint.pprint(result) watch() -if __name__ == '__main__': +if __name__ == "__main__": test_ppo(get_args())