Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel-envs-friendly ppo_continuous_action.py #348

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

vwxyzjn
Copy link
Owner

@vwxyzjn vwxyzjn commented Jan 13, 2023

Description

This PR modifies ppo_continuous_action.py to make it more parallel-envs-friendly. CC @kevinzakka.

The version of ppo_continuous_action.py in this PR is different from that in the master branch in the following ways:

  1. use a different set of hyperparameters that leverage more simulation environments (e.g., 64 parallel environments)
    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="Ant-v4",
    help="the id of the environment")
    parser.add_argument("--total-timesteps", type=int, default=10000000,
    help="total timesteps of the experiments")
    parser.add_argument("--learning-rate", type=float, default=0.00295,
    help="the learning rate of the optimizer")
    parser.add_argument("--num-envs", type=int, default=64,
    help="the number of parallel game environments")
    parser.add_argument("--num-steps", type=int, default=64,
    help="the number of steps to run in each environment per policy rollout")
    parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
    help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument("--gamma", type=float, default=0.99,
    help="the discount factor gamma")
    parser.add_argument("--gae-lambda", type=float, default=0.95,
    help="the lambda for the general advantage estimation")
    parser.add_argument("--num-minibatches", type=int, default=4,
    help="the number of mini-batches")
    parser.add_argument("--update-epochs", type=int, default=2,
    help="the K epochs to update the policy")
    parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
    help="Toggles advantages normalization")
    parser.add_argument("--clip-coef", type=float, default=0.2,
    help="the surrogate clipping coefficient")
    parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
    help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
    parser.add_argument("--ent-coef", type=float, default=0.0,
    help="coefficient of the entropy")
    parser.add_argument("--vf-coef", type=float, default=1.3,
    help="coefficient of the value function")
    parser.add_argument("--max-grad-norm", type=float, default=3.5,
    help="the maximum norm for the gradient clipping")
    parser.add_argument("--target-kl", type=float, default=None,
    help="the target KL divergence threshold")
  2. use gym.vector.AsyncVectorEnv in favor of gym.vector.SyncVectorEnv to speed up things more
    envs = gym.vector.AsyncVectorEnv(
    [make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)]
    )
  3. apply the normalize wrappers at the parallel envs level instead of individual env level, meaning the running mean and std for the obs and returns will be calculated based on the whole batch of obs and rewards. In my experience, this is usually more preferable than maintaining the normalize wrappers at each sub-env. When N=1, it should not cause any performance difference
    envs = gym.wrappers.ClipAction(envs)
    envs = gym.wrappers.NormalizeObservation(envs)
    envs = gym.wrappers.TransformObservation(envs, lambda obs: np.clip(obs, -10, 10))
    envs = gym.wrappers.NormalizeReward(envs, gamma=args.gamma)
    envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10))
    • one thing that would be worth trying is to remove the normalize wrappers — it should improve SPS. Or in the case of JAX, maybe re-writing and jitting the normalize wrappers will improve SPS as well.

I also added a JAX variant that reached the same level of performance

image

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm variant.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Jan 13, 2023

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jan 13, 2023 at 2:25PM (UTC)

@kevinzakka
Copy link

Thank you @vwxyzjn! I'll give this a spin.

@vwxyzjn vwxyzjn mentioned this pull request Jan 31, 2023
19 tasks
@varadVaidya
Copy link

hello. Thanks alot for implementing PPO in JAX in such a clean fashion. But, while reproducing the results, i am facing the following issue.

Traceback (most recent call last):
  File "/scratch/vaidya/mujoco_sims/gym_mujoco_drones/gym_mujoco_drones/cleanrl_jax_ppo.py", line 199, in <module>
    agent_state = TrainState.create(
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/flax/training/train_state.py", line 127, in create
    params['params'] if OVERWRITE_WITH_GRADIENT in params else params
TypeError: argument of type 'AgentParams' is not iterable
Exception ignored in: <function AsyncVectorEnv.__del__ at 0x7f6aa6d89630>
Traceback (most recent call last):
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/async_vector_env.py", line 549, in __del__
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/vector_env.py", line 272, in close
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/async_vector_env.py", line 465, in close_extras
AttributeError: 'NoneType' object has no attribute 'TimeoutError'

Since i am currently new to JAX, i am unable to debug the issue of AgentParams being not iterable on my own. I understand that this is a work in progress, but i would appreciate any pointers to solve this.
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants