Skip to content

Commit

Permalink
Housekeeping and static code analysis (#15)
Browse files Browse the repository at this point in the history
* Move to pyproject.toml and update CODEOWNERS

* Increase build requirements to ensure editable build

* Add basic pylint checks

* Add mypy and make it happy

* Add and run isort
  • Loading branch information
m-naumann authored Jan 18, 2024
1 parent 24c3b13 commit 2dea7b8
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 67 deletions.
13 changes: 11 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@ jobs:
python-version: 3.7

- name: set up env
run: python -m pip install -e .
run: python -m pip install -e .[docs,dev]

- name: run black
run: black --check .

- name: run isort
run: isort .

- name: run pylint for mdp folder
run: pylint src/behavior_generation_lecture_python/mdp --errors-only

- name: run mypy for mdp folder
run: mypy src/behavior_generation_lecture_python/mdp

- name: test
run: pytest

Expand All @@ -49,7 +58,7 @@ jobs:
python-version: 3.7

- name: set up env
run: python -m pip install -e .
run: python -m pip install -e .[docs]

- name: copy notebooks to docs folder
run: cp -r notebooks/* docs/notebooks
Expand Down
6 changes: 3 additions & 3 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
* m-naumann@naumail.de
* [email protected]
* [email protected]
* @m-naumann
* @jtruetsch
* @keroe
51 changes: 51 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[project]
name = "behavior_generation_lecture_python"
version = "0.0.2"
description = "Python code for the respective lecture at KIT"
readme = "README.md"
requires-python = ">=3.7"
license = {file = "LICENSE"}
authors = [
{name = "Organizers of the lecture 'Verhaltensgenerierung für Fahrzeuge' at KIT" }
]
maintainers = [
{name = "Maximilian Naumann", email = "[email protected]" }
]

dependencies = [
"numpy",
"matplotlib>=2.2.4",
"scipy",
"jupyter",
"python-statemachine"
]

[project.optional-dependencies]
dev = [
"black[jupyter]==22.3.0",
"pytest",
"pytest-cov>=3.0.0",
"pylint",
"mypy",
"isort"
]
docs = [
"mkdocs-material",
"mkdocs-jupyter",
"mkdocstrings[python]>=0.18",
"mkdocs-gen-files",
"mkdocs-literate-nav",
"mkdocs-section-index"
]

[project.urls] # Optional
"Homepage" = "https://kit-mrt.github.io/behavior_generation_lecture_python/"
"Bug Reports" = "hhttps://github.com/KIT-MRT/behavior_generation_lecture_python/issues"
"Source" = "https://github.com/KIT-MRT/behavior_generation_lecture_python"

[build-system]
requires = ["setuptools>=64.0.0", "wheel", "pip>=21.3.0"]
build-backend = "setuptools.build_meta"

[tool.isort]
profile = "black"
14 changes: 0 additions & 14 deletions requirements.txt

This file was deleted.

2 changes: 1 addition & 1 deletion scripts/run_a_star.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from behavior_generation_lecture_python.graph_search.a_star import Node, Graph
from behavior_generation_lecture_python.graph_search.a_star import Graph, Node


def main():
Expand Down
16 changes: 0 additions & 16 deletions setup.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/behavior_generation_lecture_python/graph_search/a_star.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from typing import Dict, List, Set

import matplotlib.pyplot as plt
import numpy as np

from typing import Dict, List, Set


class Node:
def __init__(
Expand Down
51 changes: 30 additions & 21 deletions src/behavior_generation_lecture_python/mdp/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
for action in self.actions:
if (state, action) not in transition_probabilities:
continue
total_prob = 0
total_prob = 0.0
for prob, next_state in transition_probabilities[(state, action)]:
assert (
next_state in self.states
Expand Down Expand Up @@ -151,14 +151,17 @@ def sample_next_state(self, state, action) -> Any:
return prob_per_transition[choice][1]


GridState = Tuple[int, int]


class GridMDP(MDP):
def __init__(
self,
grid: List[List[Union[float, None]]],
initial_state: Tuple[int, int],
terminal_states: Set[Tuple[int, int]],
initial_state: GridState,
terminal_states: Set[GridState],
transition_probabilities_per_action: Dict[
Tuple[int, int], List[Tuple[float, Tuple[int, int]]]
GridState, List[Tuple[float, GridState]]
],
restrict_actions_to_available_states: Optional[bool] = False,
) -> None:
Expand Down Expand Up @@ -186,7 +189,9 @@ def __init__(
for y in range(rows):
if grid[y][x] is not None:
states.add((x, y))
reward[(x, y)] = grid[y][x]
reward_xy = grid[y][x]
assert reward_xy is not None
reward[(x, y)] = reward_xy

transition_probabilities = {}
for state in states:
Expand Down Expand Up @@ -260,8 +265,11 @@ def _next_state_deterministic(
return state


StateValueTable = Dict[Any, float]


def expected_utility_of_action(
mdp: MDP, state: Any, action: Any, utility_of_states: Dict[Any, float]
mdp: MDP, state: Any, action: Any, utility_of_states: StateValueTable
) -> float:
"""Compute the expected utility of taking an action in a state.
Expand All @@ -283,7 +291,7 @@ def expected_utility_of_action(
)


def derive_policy(mdp: MDP, utility_of_states: Dict[Any, float]) -> Dict[Any, Any]:
def derive_policy(mdp: MDP, utility_of_states: StateValueTable) -> Dict[Any, Any]:
"""Compute the best policy for an MDP given the utility of the states.
Args:
Expand All @@ -310,7 +318,7 @@ def value_iteration(
epsilon: float,
max_iterations: int,
return_history: Optional[bool] = False,
) -> Union[Dict[Any, float], List[Dict[Any, float]]]:
) -> Union[StateValueTable, List[StateValueTable]]:
"""Derive a utility estimate by means of value iteration.
Args:
Expand All @@ -326,11 +334,11 @@ def value_iteration(
The final utility estimate, if return_history is false. The
history of utility estimates as list, if return_history is true.
"""
utility = {state: 0 for state in mdp.get_states()}
utility = {state: 0.0 for state in mdp.get_states()}
utility_history = [utility.copy()]
for _ in range(max_iterations):
utility_old = utility.copy()
max_delta = 0
max_delta = 0.0
for state in mdp.get_states():
utility[state] = max(
expected_utility_of_action(
Expand All @@ -348,8 +356,11 @@ def value_iteration(
raise RuntimeError(f"Did not converge in {max_iterations} iterations")


QTable = Dict[Tuple[Any, Any], float]


def best_action_from_q_table(
*, state: Any, available_actions: Set[Any], q_table: Dict[Tuple[Any, Any], float]
*, state: Any, available_actions: Set[Any], q_table: QTable
) -> Any:
"""Derive the best action from a Q table.
Expand All @@ -361,9 +372,9 @@ def best_action_from_q_table(
Returns:
The best action according to the Q table.
"""
available_actions = list(available_actions)
values = np.array([q_table[(state, action)] for action in available_actions])
action = available_actions[np.argmax(values)]
available_actions_list = list(available_actions)
values = np.array([q_table[(state, action)] for action in available_actions_list])
action = available_actions_list[np.argmax(values)]
return action


Expand All @@ -376,15 +387,13 @@ def random_action(available_actions: Set[Any]) -> Any:
Returns:
A random action.
"""
available_actions = list(available_actions)
num_actions = len(available_actions)
available_actions_list = list(available_actions)
num_actions = len(available_actions_list)
choice = np.random.choice(num_actions)
return available_actions[choice]
return available_actions_list[choice]


def greedy_value_estimate_for_state(
*, q_table: Dict[Tuple[Any, Any], float], state: Any
) -> float:
def greedy_value_estimate_for_state(*, q_table: QTable, state: Any) -> float:
"""Compute the greedy (best possible) value estimate for a state from the Q table.
Args:
Expand All @@ -407,7 +416,7 @@ def q_learning(
epsilon: float,
iterations: int,
return_history: Optional[bool] = False,
) -> Dict[Tuple[Any, Any], float]:
) -> Union[QTable, List[QTable]]:
"""Derive a value estimate for state-action pairs by means of Q learning.
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion tests/test_a_star.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import matplotlib
import numpy as np

from behavior_generation_lecture_python.graph_search.a_star import Node, Graph
from behavior_generation_lecture_python.graph_search.a_star import Graph, Node


def test_example_graph():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
MDP,
SIMPLE_MDP_DICT,
GridMDP,
best_action_from_q_table,
derive_policy,
expected_utility_of_action,
value_iteration,
best_action_from_q_table,
random_action,
greedy_value_estimate_for_state,
q_learning,
random_action,
value_iteration,
)


Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_grid_plotting.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import matplotlib

from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_grid_step_function,
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GRID_MDP_DICT,
GridMDP,
derive_policy,
)
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_grid_step_function,
make_plot_policy_step_function,
)

TRUE_UTILITY_GRID_MDP = {
(0, 0): 0.705,
Expand Down

0 comments on commit 2dea7b8

Please sign in to comment.