diff --git a/workshops/vanilla_policy_gradient/vanilla_policy_gradient.ipynb b/workshops/vanilla_policy_gradient/vanilla_policy_gradient.ipynb new file mode 100644 index 00000000..c604be5d --- /dev/null +++ b/workshops/vanilla_policy_gradient/vanilla_policy_gradient.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vanilla Policy Optimisation\n", + "\n", + "\"Open\n", + "\n", + "Preliminary questions:\n", + "- Run the script with the defaults parameters on the terminal\n", + "- Explain from torch.distributions.categorical import Categorical\n", + "- google gym python, why is it useful?\n", + "- Policy gradient is model based or model free?\n", + "- Is policy gradient on-policy or off-policy?\n", + "\n", + "Read all the code, then:\n", + "- Complete the ... in the compute_loss function.\n", + "- Use https://github.com/google/jaxtyping to type the functions get_policy, get_action. You can draw inspiration from the compute_loss function.\n", + "- Answer the questions\n", + "\n", + "\n", + "Don't begin working on this algorithms if you don't understand the blog: https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html\n", + "\n", + "This exercise is short, but you should aim to understand everything in this code. Simply completing the types is not sufficient. The important thing here is to have a good understanding of each line of code, as well as the policy gradient theorem that we are using." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install jaxtyping typeguard==2.13.3 gym==0.25.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch import Tensor\n", + "from torch.distributions.categorical import Categorical\n", + "from torch.optim import Adam\n", + "import numpy as np\n", + "import gymnasium as gym\n", + "from gymnasium.spaces import Discrete, Box\n", + "\n", + "from jaxtyping import Float, Int, jaxtyped\n", + "from typeguard import typechecked\n", + "\n", + "\n", + "def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):\n", + " # Build a feedforward neural network.\n", + " layers = []\n", + " for j in range(len(sizes)-1):\n", + " act = activation if j < len(sizes)-2 else output_activation\n", + " layers += [nn.Linear(sizes[j], sizes[j+1]), act()]\n", + " \n", + " # What does * mean here? Search for unpacking in python\n", + " return nn.Sequential(*layers)\n", + "\n", + "def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, \n", + " epochs=50, batch_size=5000, render=False):\n", + "\n", + " # make environment, check spaces, get obs / act dims\n", + " env = gym.make(env_name)\n", + " assert isinstance(env.observation_space, Box), \\\n", + " \"This example only works for envs with continuous state spaces.\"\n", + " assert isinstance(env.action_space, Discrete), \\\n", + " \"This example only works for envs with discrete action spaces.\"\n", + "\n", + " obs_dim = env.observation_space.shape[0]\n", + " n_acts = env.action_space.n\n", + "\n", + " # Core of policy network\n", + " # What should be the sizes of the layers of the policy network?\n", + " logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])\n", + "\n", + " # make function to compute action distribution\n", + " @jaxtyped # Checks that the sizes are consistent between tensors\n", + " @typechecked # TODO: What is the shape of obs?\n", + " def get_policy(obs: Float[Tensor, \"batch obs_dim\"]) -> Categorical:\n", + " logits = logits_net(obs)\n", + " # Tip: Categorical is a convenient pytorch object which enable register logits (or a batch of logits)\n", + " # and then being able to sample from this pseudo-probability distribution with the \".sample()\" method.\n", + " return Categorical(logits=logits)\n", + "\n", + " # make action selection function (outputs int actions, sampled from policy)\n", + " # What is the shape of obs?\n", + " @jaxtyped\n", + " @typechecked # TODO: What is the shape of obs?\n", + " def get_action(obs: Float[Tensor, \"obs_dim\"]) -> int:\n", + " return get_policy(obs.unsqueeze(0)).sample().item()\n", + "\n", + " # make loss function whose gradient, for the right data, is policy gradient\n", + " @jaxtyped\n", + " @typechecked\n", + " def compute_loss(obs: Float[Tensor, \"batch obs_dim\"], acts: Int[Tensor, \"batch\"], rewards: Float[Tensor, \"batch\"]) -> Float[Tensor, \"\"]:\n", + " # TODO: \n", + " # rewards: a piecewise constant vector containing the total reward of each episode.\n", + " \n", + " # Use the get_policy function to get the categorical object, then sample from it with the 'log_prob' method.‹\n", + " log_probs = get_policy(obs).log_prob(acts)\n", + " return -(log_probs * rewards).mean()\n", + "\n", + "\n", + " # make optimizer\n", + " optimizer = Adam(logits_net.parameters(), lr=lr)\n", + "\n", + " # for training policy\n", + " def train_one_epoch():\n", + " # make some empty lists for logging.\n", + " batch_obs = [] # for observations\n", + " batch_acts = [] # for actions\n", + " batch_weights = [] # for R(tau) weighting in policy gradient\n", + " batch_rets = [] # for measuring episode returns # What is the return?\n", + " batch_lens = [] # for measuring episode lengths\n", + "\n", + " # reset episode-specific variables\n", + " obs, _ = env.reset() # first obs comes from starting distribution \n", + " ep_rews = [] # list for rewards accrued throughout ep\n", + "\n", + " # render first episode of each epoch\n", + " finished_rendering_this_epoch = False\n", + "\n", + " # collect experience by acting in the environment with current policy\n", + " while True:\n", + "\n", + " # rendering\n", + " if (not finished_rendering_this_epoch) and render:\n", + " env.render()\n", + "\n", + " # save obs\n", + " batch_obs.append(obs.copy())\n", + "\n", + " # act in the environment\n", + " act = get_action(torch.as_tensor(obs, dtype=torch.float32))\n", + " obs, rew, terminated, truncated, _ = env.step(act)\n", + "\n", + " # save action, reward\n", + " batch_acts.append(act)\n", + " ep_rews.append(rew)\n", + "\n", + " if terminated or truncated:\n", + " # if episode is over, record info about episode\n", + " # Is the reward discounted?\n", + " ep_ret, ep_len = sum(ep_rews), len(ep_rews)\n", + " batch_rets.append(ep_ret)\n", + " batch_lens.append(ep_len)\n", + "\n", + " # the weight for each logprob(a|s) is R(tau)\n", + " # Why do we use a constant vector here?\n", + " batch_weights += [ep_ret] * ep_len\n", + "\n", + " # reset episode-specific variables\n", + " obs, _ = env.reset()\n", + " ep_rews = []\n", + "\n", + " # won't render again this epoch\n", + " finished_rendering_this_epoch = True\n", + "\n", + " # end experience loop if we have enough of it\n", + " if len(batch_obs) > batch_size:\n", + " break\n", + "\n", + " # take a single policy gradient update step\n", + " optimizer.zero_grad()\n", + "\n", + " batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),\n", + " acts=torch.as_tensor(batch_acts, dtype=torch.int32),\n", + " rewards=torch.as_tensor(batch_weights, dtype=torch.float32)\n", + " )\n", + " batch_loss.backward()\n", + " optimizer.step()\n", + " return batch_loss, batch_rets, batch_lens\n", + "\n", + " # training loop\n", + " for i in range(epochs):\n", + " batch_loss, batch_rets, batch_lens = train_one_epoch()\n", + " print('epoch: %3d \\t loss: %.3f \\t return: %.3f \\t ep_len: %.3f'% \n", + " (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))\n", + "\n", + "\n", + "\n", + "train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, \n", + " epochs=50, batch_size=50, render=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Original algo here: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/vpg/vpg.py" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}