Skip to content

Commit

Permalink
Update Value-Iteration and Q-Learning to be in line with latest versi…
Browse files Browse the repository at this point in the history
…on of the script. (#13)

Adapted value iteration and q learning to latest version of Russel and Norvig, Artificial Intelligence: A Modern Approach 4th Edition (2020)

Co-authored-by: Fabian Konstantinidis <[email protected]>
  • Loading branch information
fabikonsti and Fabian Konstantinidis authored Jan 11, 2024
1 parent a849059 commit 24c3b13
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 100 deletions.
12 changes: 6 additions & 6 deletions src/behavior_generation_lecture_python/mdp/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def expected_utility_of_action(
Expected utility
"""
return sum(
p * utility_of_states[next_state]
p * (mdp.get_reward(next_state) + utility_of_states[next_state])
for (p, next_state) in mdp.get_transitions_with_probabilities(
state=state, action=action
)
Expand Down Expand Up @@ -327,12 +327,12 @@ def value_iteration(
history of utility estimates as list, if return_history is true.
"""
utility = {state: 0 for state in mdp.get_states()}
utility_history = []
utility_history = [utility.copy()]
for _ in range(max_iterations):
utility_old = utility.copy()
max_delta = 0
for state in mdp.get_states():
utility[state] = mdp.get_reward(state) + max(
utility[state] = max(
expected_utility_of_action(
mdp, state=state, action=action, utility_of_states=utility_old
)
Expand Down Expand Up @@ -426,8 +426,8 @@ def q_learning(
q_table = {}
for state in mdp.get_states():
for action in mdp.get_actions(state):
q_table[(state, action)] = mdp.get_reward(state)
q_table_history = []
q_table[(state, action)] = 0.0
q_table_history = [q_table.copy()]
state = mdp.initial_state

np.random.seed(1337)
Expand Down Expand Up @@ -455,7 +455,7 @@ def q_learning(
)
q_table[(state, chosen_action)] = (1 - alpha) * q_table[
(state, chosen_action)
] + alpha * (mdp.get_reward(state) + greedy_value_estimate_next_state)
] + alpha * (mdp.get_reward(next_state) + greedy_value_estimate_next_state)

if return_history:
q_table_history.append(q_table.copy())
Expand Down
62 changes: 39 additions & 23 deletions tests/test_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_init_grid_mdp():

def test_expected_utility():
mdp = MDP(**SIMPLE_MDP_DICT)
assert 0.8 * 1 + 0.2 * 0.01 == expected_utility_of_action(
assert 0.8 * (-0.5 + 1) + 0.2 * (-0.1 + 0.01) == expected_utility_of_action(
mdp=mdp, state=1, action="A", utility_of_states={1: 0.01, 2: 1}
)

Expand All @@ -44,19 +44,19 @@ def test_derive_policy():

def test_value_iteration():
grid_mdp = GridMDP(**GRID_MDP_DICT)
epsilon = 0.001
epsilon = 0.0005
true_utility = {
(0, 0): 0.705,
(0, 1): 0.762,
(0, 2): 0.812,
(1, 0): 0.655,
(1, 2): 0.868,
(2, 0): 0.611,
(2, 1): 0.660,
(2, 2): 0.918,
(3, 0): 0.388,
(3, 1): -1.0,
(3, 2): 1.0,
(0, 0): 0.745,
(0, 1): 0.802,
(0, 2): 0.852,
(1, 0): 0.695,
(1, 2): 0.908,
(2, 0): 0.651,
(2, 1): 0.700,
(2, 2): 0.958,
(3, 0): 0.428,
(3, 1): 0,
(3, 2): 0,
}

computed_utility = value_iteration(mdp=grid_mdp, epsilon=epsilon, max_iterations=30)
Expand All @@ -66,32 +66,45 @@ def test_value_iteration():

def test_value_iteration_history():
grid_mdp = GridMDP(**GRID_MDP_DICT)
epsilon = 0.001
epsilon = 0.0005
true_utility_0 = {
(0, 0): 0,
(0, 1): 0,
(0, 2): 0,
(1, 0): 0,
(1, 2): 0,
(2, 0): 0,
(2, 1): 0,
(2, 2): 0,
(3, 0): 0,
(3, 1): 0,
(3, 2): 0,
}
true_utility_1 = {
(0, 0): -0.04,
(0, 1): -0.04,
(0, 2): -0.04,
(1, 0): -0.04,
(1, 2): -0.04,
(2, 0): -0.04,
(2, 1): -0.04,
(2, 2): -0.04,
(2, 2): 0.792,
(3, 0): -0.04,
(3, 1): -1.0,
(3, 2): 1.0,
(3, 1): 0,
(3, 2): 0,
}
true_utility_1 = {
true_utility_2 = {
(0, 0): -0.08,
(0, 1): -0.08,
(0, 2): -0.08,
(1, 0): -0.08,
(1, 2): -0.08,
(1, 2): 0.586,
(2, 0): -0.08,
(2, 1): -0.08,
(2, 2): 0.752,
(2, 1): 0.494,
(2, 2): 0.867,
(3, 0): -0.08,
(3, 1): -1.0,
(3, 2): 1.0,
(3, 1): 0,
(3, 2): 0,
}
computed_utility_history = value_iteration(
mdp=grid_mdp, epsilon=epsilon, max_iterations=30, return_history=True
Expand All @@ -102,6 +115,9 @@ def test_value_iteration_history():
for state in true_utility_1.keys():
assert abs(true_utility_1[state] - computed_utility_history[1][state]) < epsilon

for state in true_utility_2.keys():
assert abs(true_utility_2[state] - computed_utility_history[2][state]) < epsilon


def test_best_action_from_q_table():
q_table = {("A", 1): 0.5, ("A", 2): 0.6, ("B", 1): 0.7, ("B", 2): 0.8}
Expand Down
132 changes: 61 additions & 71 deletions tests/test_mdp_lecture_tex.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,15 @@ def best_policy_to_tex_arrows(mdp: MDP, utility: dict):
\node at (3.5, 1.5) {$-1$};
\node at (3.5, 2.5) {$1$};"""

GRID_MDP_TRUE_UTILITY_TEX = r"""\node at (0.5, 0.5) {$0.705$};
\node at (0.5, 1.5) {$0.762$};
\node at (0.5, 2.5) {$0.812$};
\node at (1.5, 0.5) {$0.655$};
\node at (1.5, 2.5) {$0.868$};
\node at (2.5, 0.5) {$0.611$};
\node at (2.5, 1.5) {$0.66$};
\node at (2.5, 2.5) {$0.918$};
\node at (3.5, 0.5) {$0.388$};
\node at (3.5, 1.5) {$-1$};
\node at (3.5, 2.5) {$1$};"""
GRID_MDP_TRUE_UTILITY_TEX = r"""\node at (0.5, 0.5) {$0.745$};
\node at (0.5, 1.5) {$0.802$};
\node at (0.5, 2.5) {$0.852$};
\node at (1.5, 0.5) {$0.695$};
\node at (1.5, 2.5) {$0.908$};
\node at (2.5, 0.5) {$0.651$};
\node at (2.5, 1.5) {$0.7$};
\node at (2.5, 2.5) {$0.958$};
\node at (3.5, 0.5) {$0.428$};"""


def test_latex_value_over_time():
Expand Down Expand Up @@ -228,36 +226,32 @@ def test_latex_policy_as_arrows():
HIGHWAY_MDP_UTILITY_TEX = r"""\node at (0.5, 1.5) {$-18.6$};
\node at (0.5, 2.5) {$-16.6$};
\node at (0.5, 3.5) {$-15.8$};
\node at (1.5, 1.5) {$-19.8$};
\node at (1.5, 2.5) {$-17.4$};
\node at (1.5, 3.5) {$-15.8$};
\node at (2.5, 1.5) {$-17.7$};
\node at (2.5, 2.5) {$-16$};
\node at (2.5, 3.5) {$-14.8$};
\node at (3.5, 1.5) {$-15.3$};
\node at (3.5, 2.5) {$-14.1$};
\node at (3.5, 3.5) {$-13.8$};
\node at (4.5, 1.5) {$-12.5$};
\node at (4.5, 2.5) {$-12.1$};
\node at (4.5, 3.5) {$-15$};
\node at (5.5, 1.5) {$-9.52$};
\node at (5.5, 2.5) {$-11.8$};
\node at (5.5, 3.5) {$-20.7$};
\node at (6.5, 0.5) {$-6$};
\node at (6.5, 1.5) {$-8.09$};
\node at (6.5, 2.5) {$-14.9$};
\node at (6.5, 3.5) {$-34$};
\node at (7.5, 0.5) {$-4$};
\node at (7.5, 1.5) {$-8.38$};
\node at (7.5, 2.5) {$-26.6$};
\node at (7.5, 3.5) {$-52$};
\node at (8.5, 0.5) {$-2$};
\node at (8.5, 1.5) {$-15.5$};
\node at (8.5, 2.5) {$-52$};
\node at (8.5, 3.5) {$-51$};
\node at (9.5, 1.5) {$-50$};
\node at (9.5, 2.5) {$-50$};
\node at (9.5, 3.5) {$-50$};"""
\node at (1.5, 1.5) {$-16.8$};
\node at (1.5, 2.5) {$-15.4$};
\node at (1.5, 3.5) {$-14.8$};
\node at (2.5, 1.5) {$-14.7$};
\node at (2.5, 2.5) {$-14$};
\node at (2.5, 3.5) {$-13.8$};
\node at (3.5, 1.5) {$-12.3$};
\node at (3.5, 2.5) {$-12.1$};
\node at (3.5, 3.5) {$-12.8$};
\node at (4.5, 1.5) {$-9.52$};
\node at (4.5, 2.5) {$-10.1$};
\node at (4.5, 3.5) {$-14$};
\node at (5.5, 1.5) {$-6.52$};
\node at (5.5, 2.5) {$-9.8$};
\node at (5.5, 3.5) {$-19.7$};
\node at (6.5, 0.5) {$-4$};
\node at (6.5, 1.5) {$-5.09$};
\node at (6.5, 2.5) {$-12.9$};
\node at (6.5, 3.5) {$-33$};
\node at (7.5, 0.5) {$-2$};
\node at (7.5, 1.5) {$-5.38$};
\node at (7.5, 2.5) {$-24.6$};
\node at (7.5, 3.5) {$-51$};
\node at (8.5, 1.5) {$-12.5$};
\node at (8.5, 2.5) {$-50$};
\node at (8.5, 3.5) {$-50$};"""

HIGHWAY_MDP_OPTIMAL_POLICY_ARROWS_TEX = r"""\draw[->,color=blue] (0.3, 1.3) -- (0.7, 1.7);
\draw[->,color=blue] (0.3, 2.3) -- (0.7, 2.7);
Expand Down Expand Up @@ -293,36 +287,32 @@ def test_latex_policy_as_arrows():
HIGHWAY_MDP_LC_R_0DOT4_UTILITY_TEX = r"""\node at (0.5, 1.5) {$-28.6$};
\node at (0.5, 2.5) {$-28.6$};
\node at (0.5, 3.5) {$-31$};
\node at (1.5, 1.5) {$-28.7$};
\node at (1.5, 2.5) {$-28.6$};
\node at (1.5, 3.5) {$-32.6$};
\node at (2.5, 1.5) {$-25.7$};
\node at (2.5, 2.5) {$-27.2$};
\node at (2.5, 3.5) {$-34.5$};
\node at (3.5, 1.5) {$-22.7$};
\node at (3.5, 2.5) {$-27$};
\node at (3.5, 3.5) {$-37.9$};
\node at (4.5, 1.5) {$-19.7$};
\node at (4.5, 2.5) {$-28.5$};
\node at (4.5, 3.5) {$-42.5$};
\node at (5.5, 1.5) {$-16.7$};
\node at (5.5, 2.5) {$-33.1$};
\node at (5.5, 3.5) {$-47.2$};
\node at (6.5, 0.5) {$-6$};
\node at (6.5, 1.5) {$-18.8$};
\node at (6.5, 2.5) {$-39.3$};
\node at (6.5, 3.5) {$-50.8$};
\node at (7.5, 0.5) {$-4$};
\node at (7.5, 1.5) {$-23.6$};
\node at (7.5, 2.5) {$-46.4$};
\node at (7.5, 3.5) {$-52$};
\node at (8.5, 0.5) {$-2$};
\node at (8.5, 1.5) {$-33$};
\node at (8.5, 2.5) {$-52$};
\node at (8.5, 3.5) {$-51$};
\node at (9.5, 1.5) {$-50$};
\node at (9.5, 2.5) {$-50$};
\node at (9.5, 3.5) {$-50$};"""
\node at (1.5, 1.5) {$-25.7$};
\node at (1.5, 2.5) {$-26.6$};
\node at (1.5, 3.5) {$-31.6$};
\node at (2.5, 1.5) {$-22.7$};
\node at (2.5, 2.5) {$-25.2$};
\node at (2.5, 3.5) {$-33.5$};
\node at (3.5, 1.5) {$-19.7$};
\node at (3.5, 2.5) {$-25$};
\node at (3.5, 3.5) {$-36.9$};
\node at (4.5, 1.5) {$-16.7$};
\node at (4.5, 2.5) {$-26.5$};
\node at (4.5, 3.5) {$-41.5$};
\node at (5.5, 1.5) {$-13.7$};
\node at (5.5, 2.5) {$-31.1$};
\node at (5.5, 3.5) {$-46.2$};
\node at (6.5, 0.5) {$-4$};
\node at (6.5, 1.5) {$-15.8$};
\node at (6.5, 2.5) {$-37.3$};
\node at (6.5, 3.5) {$-49.8$};
\node at (7.5, 0.5) {$-2$};
\node at (7.5, 1.5) {$-20.6$};
\node at (7.5, 2.5) {$-44.4$};
\node at (7.5, 3.5) {$-51$};
\node at (8.5, 1.5) {$-30$};
\node at (8.5, 2.5) {$-50$};
\node at (8.5, 3.5) {$-50$};"""

HIGHWAY_MDP_LC_R_0DOT4_OPTIMAL_POLICY_ARROWS_TEX = r"""\draw[->,color=blue] (0.3, 1.3) -- (0.7, 1.7);
\draw[->,color=blue] (0.3, 2.5) -- (0.7, 2.5);
Expand Down

0 comments on commit 24c3b13

Please sign in to comment.