diff --git a/pathprocessing/paths.py b/pathprocessing/paths.py index 4024f55..3c8729e 100644 --- a/pathprocessing/paths.py +++ b/pathprocessing/paths.py @@ -2,6 +2,7 @@ from rdp import rdp import matplotlib.pyplot as plt from svgpathtools import svg2paths +from sklearn.cluster import KMeans import cairo import qrcode import os @@ -11,6 +12,11 @@ from typing import Union +def _path_length(path): + pairwise_distance = np.sqrt(np.sum(np.square(path[1:] - path[:-1]), axis=1)) + return np.sum(pairwise_distance) + + class LinearPaths2D: """A class for a collection of linear paths in 2D. @@ -62,7 +68,7 @@ def viz(self, color: str = None) -> None: Use standard matplotlib color names. """ for path in self._paths: - plt.plot(path[:, 0], path[:, 1], color = color) + plt.plot(path[:, 0], path[:, 1], color=color) plt.axis("equal") @@ -229,16 +235,85 @@ def minimum_length(self, minimum_length: float = 0.0) -> "LinearPaths2D": as long as the minimum_length. """ - def path_length(path): - pairwise_distance = np.sqrt( - np.sum(np.square(path[1:] - path[:-1]), axis=1) - ) - return np.sum(pairwise_distance) - return LinearPaths2D( - list(filter(lambda path: path_length(path) >= minimum_length, self._paths)) + list(filter(lambda path: _path_length(path) >= minimum_length, self._paths)) + ) + + def sorted( + self, + number_of_groups: int = 1, + reference_point: npt.NDArray = np.array([0, 0]), + path_reversal: bool = False, + ) -> "LinearPaths2D": + """Groups path by length and sorts them by connecting distance. + + Groups the paths by length via k means clustering. + Then sorts the paths by connecting distance measured between the + end of the current path and the start of the next path. + + Args: + number_of_groups: Number of groups to sort into. + reference_point: Sorts first path by distance to this point. + path_reverse: Reverses paths if the end point of the next path + is closer to the current path than it's start point. + + Returns: + A sorted LinearPaths2D object. + """ + ## cluster paths by length. + group_ids = KMeans(n_clusters=number_of_groups).fit_predict( + np.array(list(map(_path_length, self._paths))).reshape(-1, 1) + ) + groups = [[] for _ in range(number_of_groups)] + for i, id in enumerate(group_ids): + groups[id].append(self._paths[i]) + + # sort groups descending by average group length. + groups = sorted( + groups, + key=lambda group: np.mean(list(map(_path_length, group))), + reverse=True, ) + sorted_groups = [] + for i, group in enumerate(groups): + sorted_groups.append([]) + + # sort paths by distance to end point of previous path. + start_points = np.array(list(map(lambda path: path[0], group))) + end_points = np.array(list(map(lambda path: path[-1], group))) + + queue = list(range(len(group))) + while queue: + # sort by distance to reference point. + distances = np.linalg.norm(start_points - reference_point, axis=1) + end_distances = np.linalg.norm(end_points - reference_point, axis=1) + next_path_idx = np.argmin(distances) + next_end_path_idx = np.argmin(end_distances) + + if ( + path_reversal + and end_distances[next_end_path_idx] < distances[next_path_idx] + ): + # reverse path if end point is closer to reference point. + sorted_groups[-1].append(group[next_end_path_idx][::-1]) + # update reference point. + reference_point = group[next_end_path_idx][-1] + # remove path from queue. + start_points[next_end_path_idx] = np.array([np.inf, np.inf]) + end_points[next_end_path_idx] = np.array([np.inf, np.inf]) + queue.remove(next_end_path_idx) + else: + sorted_groups[-1].append(group[next_path_idx]) + # update reference point. + reference_point = group[next_path_idx][0] + # remove path from queue. + start_points[next_path_idx] = np.array([np.inf, np.inf]) + end_points[next_path_idx] = np.array([np.inf, np.inf]) + queue.remove(next_path_idx) + + return LinearPaths2D(sum(sorted_groups, [])) + def unique(self) -> "LinearPaths2D": """Prunes duplicate paths. diff --git a/setup.py b/setup.py index a509921..a1e9baf 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,8 @@ "rdp >= 0.8", "matplotlib >= 3.5.1", "pycairo >= 1.20.1", - "qrcode >= 7.3.1" + "qrcode >= 7.3.1", + "scikit-learn >= 1.2.1" ], classifiers=[ "Programming Language :: Python :: 3.8", diff --git a/tests/paths_test.py b/tests/paths_test.py index 9b42e41..c85db99 100644 --- a/tests/paths_test.py +++ b/tests/paths_test.py @@ -289,3 +289,21 @@ def test_save_load(self): result = LinearPaths2D.load(file_name).tolist() os.remove(file_name) self.assertEqual(result, expected_value) + + def test_sorted(self): + simple_squares = self.SQUARE.shift(1000) + self.SQUARE + self.SQUARE.shift(-0.5) + result = LinearPaths2D.sorted(simple_squares).tolist() + expected_value = ( + self.SQUARE + self.SQUARE.shift(-0.5) + self.SQUARE.shift(1000) + ).tolist() + self.assertEqual(result, expected_value) + + mini_square = self.SQUARE.scale_to(0.1) + min_big_square = ( + mini_square + mini_square + mini_square.shift(1000) + self.SQUARE + ) + result = LinearPaths2D.sorted(min_big_square).tolist() + expected_value = ( + mini_square + mini_square + self.SQUARE + mini_square.shift(1000) + ).tolist() + self.assertEqual(result, expected_value)