From a376a86a94c075e62db8e32039b439d6e984c04b Mon Sep 17 00:00:00 2001 From: Fabian Konstantinidis Date: Fri, 2 Feb 2024 14:10:52 +0100 Subject: [PATCH] added typehints --- notebooks/mdp_policy_gradient.ipynb | 3 ++- src/behavior_generation_lecture_python/mdp/policy.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/notebooks/mdp_policy_gradient.ipynb b/notebooks/mdp_policy_gradient.ipynb index 137a844..0759ff9 100644 --- a/notebooks/mdp_policy_gradient.ipynb +++ b/notebooks/mdp_policy_gradient.ipynb @@ -72,7 +72,8 @@ "outputs": [], "source": [ "policy_array = [\n", - " derive_deterministic_policy(mdp=grid_mdp, policy=model) for model in model_checkpoints\n", + " derive_deterministic_policy(mdp=grid_mdp, policy=model)\n", + " for model in model_checkpoints\n", "]" ] }, diff --git a/src/behavior_generation_lecture_python/mdp/policy.py b/src/behavior_generation_lecture_python/mdp/policy.py index d85ecde..19b712e 100644 --- a/src/behavior_generation_lecture_python/mdp/policy.py +++ b/src/behavior_generation_lecture_python/mdp/policy.py @@ -7,7 +7,11 @@ from torch.distributions.categorical import Categorical -def multi_layer_perceptron(sizes, activation=nn.ReLU, output_activation=nn.Identity): +def multi_layer_perceptron( + sizes: List[int], + activation: torch.nn.Module = nn.ReLU, + output_activation: torch.nn.Module = nn.Identity, +): """Returns a multi-layer perceptron""" mlp = nn.Sequential() for i in range(len(sizes) - 1):