Skip to content

Commit

Permalink
Feature/refactor a star (#5)
Browse files Browse the repository at this point in the history
* Refactor a star

Co-authored-by: Kevin Rösch <[email protected]>
  • Loading branch information
m-naumann and keroe authored Nov 15, 2022
1 parent 01473e7 commit f80cced
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 160 deletions.
2 changes: 1 addition & 1 deletion docs/gen_ref_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
full_doc_path = Path("reference", doc_path)

parts = tuple(module_path.parts)
if not "mdp" in parts:
if not "mdp" in parts and not "graph_search" in parts:
continue # todo: add other modules here once docstrings added

if parts[-1] == "__init__":
Expand Down
31 changes: 23 additions & 8 deletions notebooks/a_star_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
"metadata": {},
"outputs": [],
"source": [
"from behavior_generation_lecture_python.graph_search.a_star import ExampleGraph\n",
"import numpy as np\n",
"from behavior_generation_lecture_python.graph_search.a_star import Node, Graph\n",
"\n",
"%matplotlib inline"
]
Expand All @@ -18,16 +19,30 @@
"outputs": [],
"source": [
"def main():\n",
" graph = ExampleGraph()\n",
" graph.draw()\n",
"\n",
" HH = graph.nodes[0]\n",
" M = graph.nodes[6]\n",
" nodes_list = [\n",
" [\"HH\", 170, 620, [\"H\", \"B\"]],\n",
" [\"H\", 150, 520, [\"B\", \"L\", \"F\", \"HH\"]],\n",
" [\"B\", 330, 540, [\"HH\", \"H\", \"L\"]],\n",
" [\"L\", 290, 420, [\"B\", \"H\", \"S\", \"M\"]],\n",
" [\"F\", 60, 270, [\"H\", \"S\"]],\n",
" [\"S\", 80, 120, [\"F\", \"L\", \"M\"]],\n",
" [\"M\", 220, 20, [\"S\", \"L\"]],\n",
" ]\n",
" nodes_dict = {}\n",
" for entry in nodes_list:\n",
" nodes_dict[entry[0]] = Node(\n",
" name=entry[0],\n",
" position=np.array([entry[1], entry[2]]),\n",
" connected_to=entry[3],\n",
" )\n",
"\n",
" result = graph.a_star(M, HH)\n",
" print([x.node.name for x in result])\n",
" graph = Graph(nodes_dict=nodes_dict)\n",
" graph.draw_graph()\n",
"\n",
" graph.draw_result(result)"
" graph.a_star(start=\"M\", end=\"HH\")\n",
"\n",
" graph.draw_result()"
]
},
{
Expand Down
34 changes: 26 additions & 8 deletions scripts/run_a_star.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
from behavior_generation_lecture_python.graph_search.a_star import ExampleGraph
import numpy as np

from behavior_generation_lecture_python.graph_search.a_star import Node, Graph


def main():
graph = ExampleGraph()
graph.draw()

HH = graph.nodes[0]
M = graph.nodes[6]
nodes_list = [
["HH", 170, 620, ["H", "B"]],
["H", 150, 520, ["B", "L", "F", "HH"]],
["B", 330, 540, ["HH", "H", "L"]],
["L", 290, 420, ["B", "H", "S", "M"]],
["F", 60, 270, ["H", "S"]],
["S", 80, 120, ["F", "L", "M"]],
["M", 220, 20, ["S", "L"]],
]
nodes_dict = {}
for entry in nodes_list:
nodes_dict[entry[0]] = Node(
name=entry[0],
position=np.array([entry[1], entry[2]]),
connected_to=entry[3],
)

graph = Graph(nodes_dict=nodes_dict)
graph.draw_graph()

result = graph.a_star(M, HH)
print([x.node.name for x in result])
success = graph.a_star(start="M", end="HH")
if not success:
raise RuntimeError("A star algorithm did not find a path.")

graph.draw_result(result)
graph.draw_result()


if __name__ == "__main__":
Expand Down
249 changes: 132 additions & 117 deletions src/behavior_generation_lecture_python/graph_search/a_star.py
Original file line number Diff line number Diff line change
@@ -1,160 +1,175 @@
import math
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np

from typing import Dict, List, Set

class GraphNode:
def __init__(self, name, x, y):
self.name = name
self.x = x
self.y = y
self.connected_to = []

def add_connected_to(self, connected_to):
self.connected_to.append(connected_to)
class Node:
def __init__(
self, name: str, position: np.ndarray, connected_to: List[str]
) -> None:
"""
Node in a graph for A* computation.
def distance_to(self, node):
return math.sqrt((self.x - node.x) ** 2 + (self.y - node.y) ** 2)
:param name: Name of the node.
:param position: Position of the node (x,y).
:param connected_to: List of the names of nodes, that this node is connected to.
"""
self.name = name
self.position = position
self.connected_to = connected_to
self.predecessor = None
self.cost_to_come = None
self.heuristic_cost_to_go = None

def compute_heuristic_cost_to_go(self, goal_node: Node) -> None:
"""
Computes the heuristic cost to go to the goal node based on the distance and assigns it to the node object.
class AStarNode:
def __init__(self, node, end_node):
self.C = 0
self.G = node.distance_to(end_node)
self.J = 0
self.node = node
self.predecessor = None
:param goal_node: The goal node.
:return:
"""
self.heuristic_cost_to_go = np.linalg.norm(goal_node.position - self.position)

def total_cost(self) -> float:
"""
Computes the expected total cost to reach the goal node as sum of cost to come and heuristic cost to go.
class ExampleGraph:
def __init__(self):
HH = GraphNode("HH", 170, 620)
H = GraphNode("H", 150, 520)
B = GraphNode("B", 330, 540)
L = GraphNode("L", 290, 420)
F = GraphNode("F", 60, 270)
S = GraphNode("S", 80, 120)
M = GraphNode("M", 220, 20)
:return: The expected total cost to reach the goal node.
"""
return self.cost_to_come + self.heuristic_cost_to_go

self.nodes = [HH, H, B, L, F, S, M]

HH.add_connected_to(H)
H.add_connected_to(HH)
HH.add_connected_to(B)
B.add_connected_to(HH)
def extract_min(node_set: Set[str], node_dict: Dict[str, Node]) -> str:
"""
Extract the node with minimal total cost from a set.
H.add_connected_to(B)
B.add_connected_to(H)
H.add_connected_to(L)
L.add_connected_to(H)
H.add_connected_to(F)
F.add_connected_to(H)
:param node_set: The set of node names to be considered.
:param node_dict: The node dict, containing the node information.
:return: The name of the node with minimal total cost.
"""
min_node = min(node_set, key=lambda x: node_dict[x].total_cost())
node_set.remove(min_node)
return min_node

B.add_connected_to(L)
L.add_connected_to(B)

L.add_connected_to(S)
S.add_connected_to(L)
L.add_connected_to(M)
M.add_connected_to(L)
class Graph:
def __init__(self, nodes_dict: Dict[str, Node]) -> None:
"""
A graph for A* computation.
F.add_connected_to(S)
S.add_connected_to(F)
:param nodes_dict: The dictionary containing the nodes of the graph.
"""
self.nodes_dict = nodes_dict

S.add_connected_to(M)
M.add_connected_to(S)
self._end_node = None

fig, ax = plt.subplots()
self.fig = fig
self.ax = ax

def draw(self):
def draw_graph(self) -> None:
"""
Draw all nodes and their connections in the graph.
:return:
"""
self.ax.set_xlim([0, 700])
self.ax.set_ylim([0, 700])

for node in self.nodes:
self.ax.plot(node.x, node.y, marker="o", markersize=6, color="k")
self.ax.annotate(node.name, (node.x + 10, node.y + 10))
for node in self.nodes_dict.values():
self.ax.plot(
node.position[0], node.position[1], marker="o", markersize=6, color="k"
)
self.ax.annotate(node.name, (node.position[0] + 10, node.position[1] + 10))
for connected_to in node.connected_to:
connected_node = self.nodes_dict[connected_to]
self.ax.plot(
[node.x, connected_to.x],
[node.y, connected_to.y],
[node.position[0], connected_node.position[0]],
[node.position[1], connected_node.position[1]],
linewidth=1,
color="k",
)

def a_star(self, start, end):
open_set = []
closed_set = []
def draw_result(self) -> None:
"""
Draw the solution to the shortest path problem.
:return:
"""
assert (
self._end_node
), "End node not defined, run a_star() before drawing the result."
current_node = self._end_node
while self.nodes_dict[current_node].predecessor:
curr_node = self.nodes_dict[current_node]
predecessor = curr_node.predecessor
pred_node = self.nodes_dict[predecessor]

x_0 = AStarNode(start, end)
self.ax.plot(
[curr_node.position[0], pred_node.position[0]],
[curr_node.position[1], pred_node.position[1]],
linewidth=2,
color="b",
)

x_0.J = x_0.G
distance = np.linalg.norm(curr_node.position - pred_node.position)
x_mid = (curr_node.position[0] + pred_node.position[0]) / 2.0
y_mid = (curr_node.position[1] + pred_node.position[1]) / 2.0
self.ax.annotate(f"{distance:.2f}", (x_mid + 10, y_mid + 10))
current_node = predecessor

open_set.append(x_0)
plt.show()

def extract_min(open_set):
if open_set:
result = open_set[0]
min_J = result.J
for node in open_set:
if node.J < min_J:
result = node
min_J = result.J
def a_star(self, start: str, end: str) -> bool:
"""
Compute the shortest path through the graph with the A* algorithm.
open_set.remove(result)
return result
:param start: Name of the start node.
:param end: Name of the end node.
:return: True if shortest path found, False otherwise.
"""
assert start in self.nodes_dict, f"Start node '{start}' must be in graph"
assert end in self.nodes_dict, f"End node '{end}' must be in graph"

return None
self._end_node = end
open_set = set()
closed_set = set()
for node in self.nodes_dict.values():
node.compute_heuristic_cost_to_go(self.nodes_dict[end])

def retrace_path(node):
path = [node]
open_set.add(start)
self.nodes_dict[start].cost_to_come = 0

while node.predecessor is not None:
node = node.predecessor
path.append(node)
while open_set:
current_node = extract_min(node_set=open_set, node_dict=self.nodes_dict)

path.reverse()
return path
if current_node == end:
return True

while open_set:
x = extract_min(open_set)
closed_set.append(x)

if x.node == end:
return retrace_path(x)
else:
for node in x.node.connected_to:
x_tilde = AStarNode(node, end)

if x_tilde.node not in closed_set:
cost = x.C + x_tilde.node.distance_to(x.node)

if (
not x_tilde.node in [x.node for x in open_set]
or cost < x_tilde.G
):
x_tilde.predecessor = x
x_tilde.C = cost
x_tilde.J = x_tilde.C + x_tilde.G

if not x_tilde.node in [x.node for x in open_set]:
open_set.append(x_tilde)

def draw_result(self, result):
for i in range(len(result) - 1):
node_from = result[i].node
node_to = result[i + 1].node
self.ax.plot(
[node_from.x, node_to.x],
[node_from.y, node_to.y],
linewidth=2,
color="b",
)
closed_set.add(current_node)

distance = node_from.distance_to(node_to)
x_mid = (node_from.x + node_to.x) / 2
y_mid = (node_from.y + node_to.y) / 2
self.ax.annotate(f"{distance:.2f}", (x_mid + 10, y_mid + 10))
for successor_node in self.nodes_dict[current_node].connected_to:
if successor_node in closed_set:
continue

plt.show()
tentative_cost_to_come = self.nodes_dict[
current_node
].cost_to_come + np.linalg.norm(
self.nodes_dict[current_node].position
- self.nodes_dict[successor_node].position
)
if (
successor_node in open_set
and tentative_cost_to_come
>= self.nodes_dict[successor_node].cost_to_come
):
continue

self.nodes_dict[successor_node].predecessor = current_node
self.nodes_dict[successor_node].cost_to_come = tentative_cost_to_come
open_set.add(successor_node)

return False
Loading

0 comments on commit f80cced

Please sign in to comment.