Skip to content

Commit

Permalink
Vanilla policy gradient: no more optional batch size
Browse files Browse the repository at this point in the history
+ uniformize names with the solution
+ removed whitespace at the end of the solution (apparently)


Former-commit-id: 9dc4aea
  • Loading branch information
ddorn committed Aug 28, 2023
1 parent 186bfd9 commit 201be50
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
17 changes: 8 additions & 9 deletions days/w1d5/1_simple_pg_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,24 @@ def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2,
# make function to compute action distribution
# What is the shape of obs?
@typechecked
def get_policy(obs: TensorType[..., obs_dim]):
# Warning: obs has not always the same shape.
def get_policy(obs: TensorType["b", obs_dim]):
logits = logits_net(obs)
return Categorical(logits=logits)

# make action selection function (outputs int actions, sampled from policy)
# What is the shape of obs?
@typechecked
def get_action(obs: TensorType[obs_dim]) -> int:
return get_policy(obs).sample().item()
return get_policy(obs.unsquezze()).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
# What does the weights parameter represents here?
# What is the shape of obs?
# Answer: b here is the sum of the len of each episode.
@typechecked
def compute_loss(obs: TensorType["b", obs_dim], act: TensorType["b"], weights: TensorType["b"]):
logp : TensorType["b"] = get_policy(obs).log_prob(act)
return -(logp * weights).mean()
def compute_loss(obs: TensorType["b", obs_dim], acts: TensorType["b"], rewards: TensorType["b"]):
logp : TensorType["b"] = get_policy(obs).log_prob(acts)
return -(logp * rewards).mean()

# make optimizer
optimizer = Adam(logits_net.parameters(), lr=lr)
Expand All @@ -91,7 +90,7 @@ def train_one_epoch():
batch_lens = [] # for measuring episode lengths

# reset episode-specific variables
obs = env.reset() # first obs comes from starting distribution
obs = env.reset() # first obs comes from starting distribution
done = False # signal from environment that episode is over
ep_rews = [] # list for rewards accrued throughout ep

Expand Down Expand Up @@ -139,8 +138,8 @@ def train_one_epoch():
# take a single policy gradient update step
optimizer.zero_grad()
batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
act=torch.as_tensor(batch_acts, dtype=torch.int32),
weights=torch.as_tensor(batch_weights, dtype=torch.float32)
acts=torch.as_tensor(batch_acts, dtype=torch.int32),
rewards=torch.as_tensor(batch_weights, dtype=torch.float32)
)
batch_loss.backward()
optimizer.step()
Expand Down
3 changes: 1 addition & 2 deletions days/w1d5/vanilla_policy_gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
" # What is the shape of obs?\n",
" @typechecked # To be typed\n",
" def get_policy(obs):\n",
" # Warning: obs sometimes has a batch dimension, sometimes there is no such dimension\n",
" logits = logits_net(obs)\n",
" # Tip: Categorical is a convenient pytorch object which enable register logits (or a batch of logits)\n",
" # and then being able to sample from this pseudo-probability distribution with the \".sample()\" method.\n",
Expand All @@ -98,7 +97,7 @@
" # What is the shape of obs?\n",
" @typechecked # To be typed\n",
" def get_action(obs):\n",
" return get_policy(obs).sample().item()\n",
" return get_policy(obs.unsquezze(0)).sample().item()\n",
"\n",
" # make loss function whose gradient, for the right data, is policy gradient\n",
" # What is the shape of obs?\n",
Expand Down

0 comments on commit 201be50

Please sign in to comment.