-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Refactor a star Co-authored-by: Kevin Rösch <[email protected]>
- Loading branch information
Showing
5 changed files
with
223 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
249 changes: 132 additions & 117 deletions
249
src/behavior_generation_lecture_python/graph_search/a_star.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,160 +1,175 @@ | ||
import math | ||
from __future__ import annotations | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from typing import Dict, List, Set | ||
|
||
class GraphNode: | ||
def __init__(self, name, x, y): | ||
self.name = name | ||
self.x = x | ||
self.y = y | ||
self.connected_to = [] | ||
|
||
def add_connected_to(self, connected_to): | ||
self.connected_to.append(connected_to) | ||
class Node: | ||
def __init__( | ||
self, name: str, position: np.ndarray, connected_to: List[str] | ||
) -> None: | ||
""" | ||
Node in a graph for A* computation. | ||
def distance_to(self, node): | ||
return math.sqrt((self.x - node.x) ** 2 + (self.y - node.y) ** 2) | ||
:param name: Name of the node. | ||
:param position: Position of the node (x,y). | ||
:param connected_to: List of the names of nodes, that this node is connected to. | ||
""" | ||
self.name = name | ||
self.position = position | ||
self.connected_to = connected_to | ||
self.predecessor = None | ||
self.cost_to_come = None | ||
self.heuristic_cost_to_go = None | ||
|
||
def compute_heuristic_cost_to_go(self, goal_node: Node) -> None: | ||
""" | ||
Computes the heuristic cost to go to the goal node based on the distance and assigns it to the node object. | ||
class AStarNode: | ||
def __init__(self, node, end_node): | ||
self.C = 0 | ||
self.G = node.distance_to(end_node) | ||
self.J = 0 | ||
self.node = node | ||
self.predecessor = None | ||
:param goal_node: The goal node. | ||
:return: | ||
""" | ||
self.heuristic_cost_to_go = np.linalg.norm(goal_node.position - self.position) | ||
|
||
def total_cost(self) -> float: | ||
""" | ||
Computes the expected total cost to reach the goal node as sum of cost to come and heuristic cost to go. | ||
class ExampleGraph: | ||
def __init__(self): | ||
HH = GraphNode("HH", 170, 620) | ||
H = GraphNode("H", 150, 520) | ||
B = GraphNode("B", 330, 540) | ||
L = GraphNode("L", 290, 420) | ||
F = GraphNode("F", 60, 270) | ||
S = GraphNode("S", 80, 120) | ||
M = GraphNode("M", 220, 20) | ||
:return: The expected total cost to reach the goal node. | ||
""" | ||
return self.cost_to_come + self.heuristic_cost_to_go | ||
|
||
self.nodes = [HH, H, B, L, F, S, M] | ||
|
||
HH.add_connected_to(H) | ||
H.add_connected_to(HH) | ||
HH.add_connected_to(B) | ||
B.add_connected_to(HH) | ||
def extract_min(node_set: Set[str], node_dict: Dict[str, Node]) -> str: | ||
""" | ||
Extract the node with minimal total cost from a set. | ||
H.add_connected_to(B) | ||
B.add_connected_to(H) | ||
H.add_connected_to(L) | ||
L.add_connected_to(H) | ||
H.add_connected_to(F) | ||
F.add_connected_to(H) | ||
:param node_set: The set of node names to be considered. | ||
:param node_dict: The node dict, containing the node information. | ||
:return: The name of the node with minimal total cost. | ||
""" | ||
min_node = min(node_set, key=lambda x: node_dict[x].total_cost()) | ||
node_set.remove(min_node) | ||
return min_node | ||
|
||
B.add_connected_to(L) | ||
L.add_connected_to(B) | ||
|
||
L.add_connected_to(S) | ||
S.add_connected_to(L) | ||
L.add_connected_to(M) | ||
M.add_connected_to(L) | ||
class Graph: | ||
def __init__(self, nodes_dict: Dict[str, Node]) -> None: | ||
""" | ||
A graph for A* computation. | ||
F.add_connected_to(S) | ||
S.add_connected_to(F) | ||
:param nodes_dict: The dictionary containing the nodes of the graph. | ||
""" | ||
self.nodes_dict = nodes_dict | ||
|
||
S.add_connected_to(M) | ||
M.add_connected_to(S) | ||
self._end_node = None | ||
|
||
fig, ax = plt.subplots() | ||
self.fig = fig | ||
self.ax = ax | ||
|
||
def draw(self): | ||
def draw_graph(self) -> None: | ||
""" | ||
Draw all nodes and their connections in the graph. | ||
:return: | ||
""" | ||
self.ax.set_xlim([0, 700]) | ||
self.ax.set_ylim([0, 700]) | ||
|
||
for node in self.nodes: | ||
self.ax.plot(node.x, node.y, marker="o", markersize=6, color="k") | ||
self.ax.annotate(node.name, (node.x + 10, node.y + 10)) | ||
for node in self.nodes_dict.values(): | ||
self.ax.plot( | ||
node.position[0], node.position[1], marker="o", markersize=6, color="k" | ||
) | ||
self.ax.annotate(node.name, (node.position[0] + 10, node.position[1] + 10)) | ||
for connected_to in node.connected_to: | ||
connected_node = self.nodes_dict[connected_to] | ||
self.ax.plot( | ||
[node.x, connected_to.x], | ||
[node.y, connected_to.y], | ||
[node.position[0], connected_node.position[0]], | ||
[node.position[1], connected_node.position[1]], | ||
linewidth=1, | ||
color="k", | ||
) | ||
|
||
def a_star(self, start, end): | ||
open_set = [] | ||
closed_set = [] | ||
def draw_result(self) -> None: | ||
""" | ||
Draw the solution to the shortest path problem. | ||
:return: | ||
""" | ||
assert ( | ||
self._end_node | ||
), "End node not defined, run a_star() before drawing the result." | ||
current_node = self._end_node | ||
while self.nodes_dict[current_node].predecessor: | ||
curr_node = self.nodes_dict[current_node] | ||
predecessor = curr_node.predecessor | ||
pred_node = self.nodes_dict[predecessor] | ||
|
||
x_0 = AStarNode(start, end) | ||
self.ax.plot( | ||
[curr_node.position[0], pred_node.position[0]], | ||
[curr_node.position[1], pred_node.position[1]], | ||
linewidth=2, | ||
color="b", | ||
) | ||
|
||
x_0.J = x_0.G | ||
distance = np.linalg.norm(curr_node.position - pred_node.position) | ||
x_mid = (curr_node.position[0] + pred_node.position[0]) / 2.0 | ||
y_mid = (curr_node.position[1] + pred_node.position[1]) / 2.0 | ||
self.ax.annotate(f"{distance:.2f}", (x_mid + 10, y_mid + 10)) | ||
current_node = predecessor | ||
|
||
open_set.append(x_0) | ||
plt.show() | ||
|
||
def extract_min(open_set): | ||
if open_set: | ||
result = open_set[0] | ||
min_J = result.J | ||
for node in open_set: | ||
if node.J < min_J: | ||
result = node | ||
min_J = result.J | ||
def a_star(self, start: str, end: str) -> bool: | ||
""" | ||
Compute the shortest path through the graph with the A* algorithm. | ||
open_set.remove(result) | ||
return result | ||
:param start: Name of the start node. | ||
:param end: Name of the end node. | ||
:return: True if shortest path found, False otherwise. | ||
""" | ||
assert start in self.nodes_dict, f"Start node '{start}' must be in graph" | ||
assert end in self.nodes_dict, f"End node '{end}' must be in graph" | ||
|
||
return None | ||
self._end_node = end | ||
open_set = set() | ||
closed_set = set() | ||
for node in self.nodes_dict.values(): | ||
node.compute_heuristic_cost_to_go(self.nodes_dict[end]) | ||
|
||
def retrace_path(node): | ||
path = [node] | ||
open_set.add(start) | ||
self.nodes_dict[start].cost_to_come = 0 | ||
|
||
while node.predecessor is not None: | ||
node = node.predecessor | ||
path.append(node) | ||
while open_set: | ||
current_node = extract_min(node_set=open_set, node_dict=self.nodes_dict) | ||
|
||
path.reverse() | ||
return path | ||
if current_node == end: | ||
return True | ||
|
||
while open_set: | ||
x = extract_min(open_set) | ||
closed_set.append(x) | ||
|
||
if x.node == end: | ||
return retrace_path(x) | ||
else: | ||
for node in x.node.connected_to: | ||
x_tilde = AStarNode(node, end) | ||
|
||
if x_tilde.node not in closed_set: | ||
cost = x.C + x_tilde.node.distance_to(x.node) | ||
|
||
if ( | ||
not x_tilde.node in [x.node for x in open_set] | ||
or cost < x_tilde.G | ||
): | ||
x_tilde.predecessor = x | ||
x_tilde.C = cost | ||
x_tilde.J = x_tilde.C + x_tilde.G | ||
|
||
if not x_tilde.node in [x.node for x in open_set]: | ||
open_set.append(x_tilde) | ||
|
||
def draw_result(self, result): | ||
for i in range(len(result) - 1): | ||
node_from = result[i].node | ||
node_to = result[i + 1].node | ||
self.ax.plot( | ||
[node_from.x, node_to.x], | ||
[node_from.y, node_to.y], | ||
linewidth=2, | ||
color="b", | ||
) | ||
closed_set.add(current_node) | ||
|
||
distance = node_from.distance_to(node_to) | ||
x_mid = (node_from.x + node_to.x) / 2 | ||
y_mid = (node_from.y + node_to.y) / 2 | ||
self.ax.annotate(f"{distance:.2f}", (x_mid + 10, y_mid + 10)) | ||
for successor_node in self.nodes_dict[current_node].connected_to: | ||
if successor_node in closed_set: | ||
continue | ||
|
||
plt.show() | ||
tentative_cost_to_come = self.nodes_dict[ | ||
current_node | ||
].cost_to_come + np.linalg.norm( | ||
self.nodes_dict[current_node].position | ||
- self.nodes_dict[successor_node].position | ||
) | ||
if ( | ||
successor_node in open_set | ||
and tentative_cost_to_come | ||
>= self.nodes_dict[successor_node].cost_to_come | ||
): | ||
continue | ||
|
||
self.nodes_dict[successor_node].predecessor = current_node | ||
self.nodes_dict[successor_node].cost_to_come = tentative_cost_to_come | ||
open_set.add(successor_node) | ||
|
||
return False |
Oops, something went wrong.