Skip to content

Commit

Permalink
Merge branch 'mdp_policy_gradient' into fabikonsti-mdp_policy_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
m-naumann committed Feb 3, 2024
2 parents 3d88f57 + 202a2f2 commit f1a9fbf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 4 additions & 8 deletions src/behavior_generation_lecture_python/mdp/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,7 @@ def policy_gradient(
buffer.states.append(deepcopy(state))

# call model to get next action
action = policy.get_action(
state=torch.as_tensor(state, dtype=torch.float32)
)
action = policy.get_action(state=torch.tensor(state, dtype=torch.float32))

# execute action in the environment
state, reward, done = mdp.execute_action(state=state, action=action)
Expand Down Expand Up @@ -632,12 +630,10 @@ def policy_gradient(

# compute the loss
logp = policy.get_log_prob(
states=torch.as_tensor(buffer.states, dtype=torch.float32),
actions=torch.as_tensor(buffer.actions, dtype=torch.int32),
states=torch.tensor(buffer.states, dtype=torch.float),
actions=torch.tensor(buffer.actions, dtype=torch.long),
)
batch_loss = -(
logp * torch.as_tensor(buffer.weights, dtype=torch.float32)
).mean()
batch_loss = -(logp * torch.tensor(buffer.weights, dtype=torch.float)).mean()

# take a single policy gradient update step
optimizer.zero_grad()
Expand Down
8 changes: 6 additions & 2 deletions src/behavior_generation_lecture_python/mdp/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from torch.distributions.categorical import Categorical


def multi_layer_perceptron(sizes, activation=nn.ReLU, output_activation=nn.Identity):
def multi_layer_perceptron(
sizes: List[int],
activation: torch.nn.Module = nn.ReLU,
output_activation: torch.nn.Module = nn.Identity,
):
"""Returns a multi-layer perceptron"""
mlp = nn.Sequential()
for i in range(len(sizes) - 1):
Expand All @@ -25,7 +29,7 @@ def __init__(self, sizes: List[int], actions: List):
torch.manual_seed(1337)
self.net = multi_layer_perceptron(sizes=sizes)
self.actions = actions
self._actions_tensor = torch.as_tensor(actions, dtype=torch.float32).view(
self._actions_tensor = torch.tensor(actions, dtype=torch.long).view(
len(actions), -1
)

Expand Down

0 comments on commit f1a9fbf

Please sign in to comment.