Skip to content

Commit

Permalink
added typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabian Konstantinidis authored and Fabian Konstantinidis committed Feb 2, 2024
1 parent f98c6c2 commit a376a86
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion notebooks/mdp_policy_gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
"outputs": [],
"source": [
"policy_array = [\n",
" derive_deterministic_policy(mdp=grid_mdp, policy=model) for model in model_checkpoints\n",
" derive_deterministic_policy(mdp=grid_mdp, policy=model)\n",
" for model in model_checkpoints\n",
"]"
]
},
Expand Down
6 changes: 5 additions & 1 deletion src/behavior_generation_lecture_python/mdp/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from torch.distributions.categorical import Categorical


def multi_layer_perceptron(sizes, activation=nn.ReLU, output_activation=nn.Identity):
def multi_layer_perceptron(
sizes: List[int],
activation: torch.nn.Module = nn.ReLU,
output_activation: torch.nn.Module = nn.Identity,
):
"""Returns a multi-layer perceptron"""
mlp = nn.Sequential()
for i in range(len(sizes) - 1):
Expand Down

0 comments on commit a376a86

Please sign in to comment.