Skip to content

Commit

Permalink
vanilla_policy_gradient: fix typo
Browse files Browse the repository at this point in the history
Former-commit-id: fc5e3eb
  • Loading branch information
ddorn committed Aug 28, 2023
1 parent 201be50 commit ab14e18
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion days/w1d5/1_simple_pg_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
return nn.Sequential(*layers)

def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2,
def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2,
epochs=50, batch_size=5000, render=False):

# make environment, check spaces, get obs / act dims
Expand Down
2 changes: 1 addition & 1 deletion days/w1d5/vanilla_policy_gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
" # What is the shape of obs?\n",
" @typechecked # To be typed\n",
" def get_action(obs):\n",
" return get_policy(obs.unsquezze(0)).sample().item()\n",
" return get_policy(obs.unsqueeze(0)).sample().item()\n",
"\n",
" # make loss function whose gradient, for the right data, is policy gradient\n",
" # What is the shape of obs?\n",
Expand Down

0 comments on commit ab14e18

Please sign in to comment.