Skip to content

Commit

Permalink
Spiral Spanning Tree Coverage Path Planning (AtsushiSakai#355)
Browse files Browse the repository at this point in the history
* First commit of Spiral Spanning Tree Coverage

* Modify followed by first code review

* fix pycodestyle error

* modifies following 2nd code review
  • Loading branch information
reso1 authored Jul 12, 2020
1 parent 9fe14e6 commit 0c23ebe
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 0 deletions.
Binary file added PathPlanning/SpiralSpanningTreeCPP/map/test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PathPlanning/SpiralSpanningTreeCPP/map/test_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PathPlanning/SpiralSpanningTreeCPP/map/test_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
"""
Spiral Spanning Tree Coverage Path Planner
author: Todd Tang
paper: Spiral-STC: An On-Line Coverage Algorithm of Grid Environments
by a Mobile Robot - Gabriely et.al.
link: https://ieeexplore.ieee.org/abstract/document/1013479
"""

import os
import sys
import math

import numpy as np
import matplotlib.pyplot as plt

do_animation = True


class SpiralSpanningTreeCoveragePlanner:
def __init__(self, occ_map):
self.origin_map_height = occ_map.shape[0]
self.origin_map_width = occ_map.shape[1]

# original map resolution must be even
if self.origin_map_height % 2 == 1 or self.origin_map_width % 2 == 1:
sys.exit('original map width/height must be even \
in grayscale .png format')

self.occ_map = occ_map
self.merged_map_height = self.origin_map_height // 2
self.merged_map_width = self.origin_map_width // 2

self.edge = []

def plan(self, start):
"""plan
performing Spiral Spanning Tree Coverage path planning
:param start: the start node of Spiral Spanning Tree Coverage
"""

visit_times = np.zeros(
(self.merged_map_height, self.merged_map_width), dtype=np.int)
visit_times[start[0]][start[1]] = 1

# generate route by
# recusively call perform_spanning_tree_coverage() from start node
route = []
self.perform_spanning_tree_coverage(start, visit_times, route)

path = []
# generate path from route
for idx in range(len(route)-1):
dp = abs(route[idx][0] - route[idx+1][0]) + \
abs(route[idx][1] - route[idx+1][1])
if dp == 0:
# special handle for round-trip path
path.append(self.get_round_trip_path(route[idx-1], route[idx]))
elif dp == 1:
path.append(self.move(route[idx], route[idx+1]))
elif dp == 2:
# special handle for non-adjacent route nodes
mid_node = self.get_intermediate_node(route[idx], route[idx+1])
path.append(self.move(route[idx], mid_node))
path.append(self.move(mid_node, route[idx+1]))
else:
sys.exit('adjacent path node distance larger than 2')

return self.edge, route, path

def perform_spanning_tree_coverage(self, current_node, visit_times, route):
"""perform_spanning_tree_coverage
recursive function for function <plan>
:param current_node: current node
"""

def is_valid_node(i, j):
is_i_valid_bounded = 0 <= i < self.merged_map_height
is_j_valid_bounded = 0 <= j < self.merged_map_width
if is_i_valid_bounded and is_j_valid_bounded:
# free only when the 4 sub-cells are all free
return bool(
self.occ_map[2*i][2*j]
and self.occ_map[2*i+1][2*j]
and self.occ_map[2*i][2*j+1]
and self.occ_map[2*i+1][2*j+1])

return False

# counter-clockwise neighbor finding order
order = [[1, 0], [0, 1], [-1, 0], [0, -1]]

found = False
route.append(current_node)
for inc in order:
ni, nj = current_node[0] + inc[0], current_node[1] + inc[1]
if is_valid_node(ni, nj) and visit_times[ni][nj] == 0:
neighbor_node = (ni, nj)
self.edge.append((current_node, neighbor_node))
found = True
visit_times[ni][nj] += 1
self.perform_spanning_tree_coverage(
neighbor_node, visit_times, route)

# backtrace route from node with neighbors all visited
# to first node with unvisited neighbor
if not found:
has_node_with_unvisited_ngb = False
for node in reversed(route):
# drop nodes that have been visited twice
if visit_times[node[0]][node[1]] == 2:
continue

visit_times[node[0]][node[1]] += 1
route.append(node)

for inc in order:
ni, nj = node[0] + inc[0], node[1] + inc[1]
if is_valid_node(ni, nj) and visit_times[ni][nj] == 0:
has_node_with_unvisited_ngb = True
break

if has_node_with_unvisited_ngb:
break

return route

def move(self, p, q):
direction = self.get_vector_direction(p, q)
# move east
if direction == 'E':
p = self.get_sub_node(p, 'SE')
q = self.get_sub_node(q, 'SW')
# move west
elif direction == 'W':
p = self.get_sub_node(p, 'NW')
q = self.get_sub_node(q, 'NE')
# move south
elif direction == 'S':
p = self.get_sub_node(p, 'SW')
q = self.get_sub_node(q, 'NW')
# move north
elif direction == 'N':
p = self.get_sub_node(p, 'NE')
q = self.get_sub_node(q, 'SE')
else:
sys.exit('move direction error...')
return [p, q]

def get_round_trip_path(self, last, pivot):
direction = self.get_vector_direction(last, pivot)
if direction == 'E':
return [self.get_sub_node(pivot, 'SE'),
self.get_sub_node(pivot, 'NE')]
elif direction == 'S':
return [self.get_sub_node(pivot, 'SW'),
self.get_sub_node(pivot, 'SE')]
elif direction == 'W':
return [self.get_sub_node(pivot, 'NW'),
self.get_sub_node(pivot, 'SW')]
elif direction == 'N':
return [self.get_sub_node(pivot, 'NE'),
self.get_sub_node(pivot, 'NW')]
else:
sys.exit('get_round_trip_path: last->pivot direction error.')

def get_vector_direction(self, p, q):
# east
if p[0] == q[0] and p[1] < q[1]:
return 'E'
# west
elif p[0] == q[0] and p[1] > q[1]:
return 'W'
# south
elif p[0] < q[0] and p[1] == q[1]:
return 'S'
# north
elif p[0] > q[0] and p[1] == q[1]:
return 'N'
else:
sys.exit('get_vector_direction: Only E/W/S/N direction supported.')

def get_sub_node(self, node, direction):
if direction == 'SE':
return [2*node[0]+1, 2*node[1]+1]
elif direction == 'SW':
return [2*node[0]+1, 2*node[1]]
elif direction == 'NE':
return [2*node[0], 2*node[1]+1]
elif direction == 'NW':
return [2*node[0], 2*node[1]]
else:
sys.exit('get_sub_node: sub-node direction error.')

def get_interpolated_path(self, p, q):
# direction p->q: southwest / northeast
if (p[0] < q[0]) ^ (p[1] < q[1]):
ipx = [p[0], p[0], q[0]]
ipy = [p[1], q[1], q[1]]
# direction p->q: southeast / northwest
else:
ipx = [p[0], q[0], q[0]]
ipy = [p[1], p[1], q[1]]
return ipx, ipy

def get_intermediate_node(self, p, q):
p_ngb, q_ngb = set(), set()

for m, n in self.edge:
if m == p:
p_ngb.add(n)
if n == p:
p_ngb.add(m)
if m == q:
q_ngb.add(n)
if n == q:
q_ngb.add(m)

itsc = p_ngb.intersection(q_ngb)
if len(itsc) == 0:
sys.exit('get_intermediate_node: \
no intermediate node between', p, q)
elif len(itsc) == 1:
return list(itsc)[0]
else:
sys.exit('get_intermediate_node: \
more than 1 intermediate node between', p, q)

def visualize_path(self, edge, path, start):
def coord_transform(p):
return [2*p[1] + 0.5, 2*p[0] + 0.5]

if do_animation:
last = path[0][0]
trajectory = [[last[1]], [last[0]]]
for p, q in path:
distance = math.hypot(p[0]-last[0], p[1]-last[1])
if distance <= 1.0:
trajectory[0].append(p[1])
trajectory[1].append(p[0])
else:
ipx, ipy = self.get_interpolated_path(last, p)
trajectory[0].extend(ipy)
trajectory[1].extend(ipx)

last = q

trajectory[0].append(last[1])
trajectory[1].append(last[0])

for idx, state in enumerate(np.transpose(trajectory)):
plt.cla()
# for stopping simulation with the esc key.
plt.gcf().canvas.mpl_connect(
'key_release_event',
lambda event: [exit(0) if event.key == 'escape' else None])

# draw spanning tree
plt.imshow(self.occ_map, 'gray')
for p, q in edge:
p = coord_transform(p)
q = coord_transform(q)
plt.plot([p[0], q[0]], [p[1], q[1]], '-oc')
sx, sy = coord_transform(start)
plt.plot([sx], [sy], 'pr', markersize=10)

# draw move path
plt.plot(trajectory[0][:idx+1], trajectory[1][:idx+1], '-k')
plt.plot(state[0], state[1], 'or')
plt.axis('equal')
plt.grid(True)
plt.pause(0.01)

else:
# draw spanning tree
plt.imshow(self.occ_map, 'gray')
for p, q in edge:
p = coord_transform(p)
q = coord_transform(q)
plt.plot([p[0], q[0]], [p[1], q[1]], '-oc')
sx, sy = coord_transform(start)
plt.plot([sx], [sy], 'pr', markersize=10)

# draw move path
last = path[0][0]
for p, q in path:
distance = math.hypot(p[0]-last[0], p[1]-last[1])
if distance == 1.0:
plt.plot([last[1], p[1]], [last[0], p[0]], '-k')
else:
ipx, ipy = self.get_interpolated_path(last, p)
plt.plot(ipy, ipx, '-k')
plt.arrow(p[1], p[0], q[1]-p[1], q[0]-p[0], head_width=0.2)
last = q

plt.show()


def main():
dir_path = os.path.dirname(os.path.realpath(__file__))
img = plt.imread(os.path.join(dir_path, 'map', 'test_2.png'))
STC_planner = SpiralSpanningTreeCoveragePlanner(img)
start = (10, 0)
edge, route, path = STC_planner.plan(start)
STC_planner.visualize_path(edge, path, start)


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions tests/test_spiral_spanning_tree_coverage_path_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import sys
import matplotlib.pyplot as plt
from unittest import TestCase

sys.path.append(os.path.dirname(
os.path.abspath(__file__)) + "/../PathPlanning/SpiralSpanningTreeCPP")
try:
import spiral_spanning_tree_coverage_path_planner
except ImportError:
raise

spiral_spanning_tree_coverage_path_planner.do_animation = True


class TestPlanning(TestCase):
def spiral_stc_cpp(self, img, start):
num_free = 0
for i in range(img.shape[0]):
for j in range(img.shape[1]):
num_free += img[i][j]

STC_planner = spiral_spanning_tree_coverage_path_planner.\
SpiralSpanningTreeCoveragePlanner(img)

edge, route, path = STC_planner.plan(start)

covered_nodes = set()
for p, q in edge:
covered_nodes.add(p)
covered_nodes.add(q)

# assert complete coverage
self.assertEqual(len(covered_nodes), num_free / 4)

def test_spiral_stc_cpp_1(self):
img_dir = os.path.dirname(
os.path.abspath(__file__)) + \
"/../PathPlanning/SpiralSpanningTreeCPP"
img = plt.imread(os.path.join(img_dir, 'map', 'test.png'))
start = (0, 0)
self.spiral_stc_cpp(img, start)

def test_spiral_stc_cpp_2(self):
img_dir = os.path.dirname(
os.path.abspath(__file__)) + \
"/../PathPlanning/SpiralSpanningTreeCPP"
img = plt.imread(os.path.join(img_dir, 'map', 'test_2.png'))
start = (10, 0)
self.spiral_stc_cpp(img, start)

def test_spiral_stc_cpp_3(self):
img_dir = os.path.dirname(
os.path.abspath(__file__)) + \
"/../PathPlanning/SpiralSpanningTreeCPP"
img = plt.imread(os.path.join(img_dir, 'map', 'test_3.png'))
start = (0, 0)
self.spiral_stc_cpp(img, start)

0 comments on commit 0c23ebe

Please sign in to comment.