From 69bb20851d11761760432dec776378c0e5c85ed4 Mon Sep 17 00:00:00 2001 From: Andrew Champion Date: Fri, 8 May 2020 09:42:24 +0100 Subject: [PATCH] Add all-nearest-neighbors iterator for tree-to-tree lookup Add the ability to efficiently find the nearest neighbor in a target tree for each point in a query tree. This works by traversing the query tree depth-first and at each query tree node pruning nodes from the set of candidate subtrees of the target tree that can not potentially hold the nearest neighbors for any point in the query tree node. This results in speedups on the order of 1.3x versus individual lookup of the query points, and this speedup increases with the size and dimensionality of the trees. --- rstar-benches/benches/benchmarks.rs | 24 +++ rstar/src/aabb.rs | 73 +++++++- rstar/src/algorithm/nearest_neighbor.rs | 237 +++++++++++++++++++++++- rstar/src/envelope.rs | 10 + rstar/src/rtree.rs | 34 ++++ 5 files changed, 368 insertions(+), 10 deletions(-) diff --git a/rstar-benches/benches/benchmarks.rs b/rstar-benches/benches/benchmarks.rs index b02fc12..bf4aa20 100644 --- a/rstar-benches/benches/benchmarks.rs +++ b/rstar-benches/benches/benchmarks.rs @@ -16,6 +16,7 @@ use criterion::{Bencher, Criterion, Fun}; const SEED_1: &[u8; 32] = b"Gv0aHMtHkBGsUXNspGU9fLRuCWkZWHZx"; const SEED_2: &[u8; 32] = b"km7DO4GeaFZfTcDXVpnO7ZJlgUY7hZiS"; +#[derive(Clone)] struct Params; impl RTreeParams for Params { @@ -92,6 +93,28 @@ fn tree_creation_quality(c: &mut Criterion) { }); } +fn all_to_all_neighbors(c: &mut Criterion) { + const SIZE: usize = 1_000; + let points: Vec<_> = create_random_points(SIZE, SEED_1); + let tree_target = RTree::<_, Params>::bulk_load_with_params(points.clone()); + let query_points = create_random_points(SIZE, SEED_2); + let tree_query = RTree::<_, Params>::bulk_load_with_params(query_points.clone()); + + let tree_target_cloned = tree_target.clone(); + c.bench_function("all to all tree lookup", move |b| { + b.iter(|| { + tree_target.all_nearest_neighbors(&tree_query).count(); + }); + }) + .bench_function("all to all point lookup", move |b| { + b.iter(|| { + for query_point in &query_points { + tree_target_cloned.nearest_neighbor(&query_point).unwrap(); + } + }); + }); +} + fn locate_successful(c: &mut Criterion) { let points: Vec<_> = create_random_points(100_000, SEED_1); let query_point = points[500]; @@ -115,6 +138,7 @@ criterion_group!( bulk_load_baseline, bulk_load_comparison, tree_creation_quality, + all_to_all_neighbors, locate_successful, locate_unsuccessful ); diff --git a/rstar/src/aabb.rs b/rstar/src/aabb.rs index 3129c19..72fdb76 100755 --- a/rstar/src/aabb.rs +++ b/rstar/src/aabb.rs @@ -1,4 +1,4 @@ -use crate::point::{max_inline, Point, PointExt}; +use crate::point::{max_inline, min_inline, Point, PointExt}; use crate::{Envelope, RTreeObject}; use num_traits::{Bounded, One, Zero}; @@ -94,6 +94,64 @@ where self.min_point(point).sub(point).length_2() } } + + /// Return an iterator over each corner vertex of the AABB. + /// + /// # Example + /// ``` + /// use rstar::AABB; + /// + /// let aabb = AABB::from_corners([1.0, 2.0], [3.0, 4.0]); + /// let mut corners = aabb.iter_corners(); + /// assert_eq!(corners.next(), Some([1.0, 2.0])); + /// assert_eq!(corners.next(), Some([3.0, 2.0])); + /// assert_eq!(corners.next(), Some([1.0, 4.0])); + /// assert_eq!(corners.next(), Some([3.0, 4.0])); + /// assert_eq!(corners.next(), None); + /// ``` + pub fn iter_corners(&self) -> impl Iterator { + CornerIterator::new(self) + } +} + +struct CornerIterator { + lower: P, + upper: P, + idx: usize, +} + +impl

CornerIterator

+where + P: Point, +{ + fn new(aabb: &AABB

) -> Self { + Self { + lower: aabb.lower, + upper: aabb.upper, + idx: 0, + } + } +} + +impl

Iterator for CornerIterator

+where + P: Point, +{ + type Item = P; + + fn next(&mut self) -> Option { + if self.idx & (1 << P::DIMENSIONS) != 0 { + None + } else { + let corner = P::generate(|i| if self.idx & (1 << i) != 0 { + self.upper.nth(i) + } else { + self.lower.nth(i) + }); + self.idx += 1; + Some(corner) + } + } } impl

Envelope for AABB

@@ -144,6 +202,12 @@ where self.distance_2(point) } + fn min_dist_2(&self, other: &Self) -> P::Scalar { + let l = self.min_point(&other.lower); + let u = self.min_point(&other.upper); + min_inline(other.distance_2(&l), other.distance_2(&u)) + } + fn min_max_dist_2(&self, point: &P) ->

::Scalar { let l = self.lower.sub(point); let u = self.upper.sub(point); @@ -169,6 +233,13 @@ where result - max_diff } + fn max_min_max_dist_2(&self, other: &Self) -> P::Scalar { + self.iter_corners() + .map(|corner| other.min_max_dist_2(&corner)) + .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| std::cmp::Ordering::Equal)) + .unwrap_or_else(P::Scalar::zero) + } + fn center(&self) -> Self::Point { let one = ::Scalar::one(); let two = one + one; diff --git a/rstar/src/algorithm/nearest_neighbor.rs b/rstar/src/algorithm/nearest_neighbor.rs index 506ca30..16e40d0 100755 --- a/rstar/src/algorithm/nearest_neighbor.rs +++ b/rstar/src/algorithm/nearest_neighbor.rs @@ -180,19 +180,29 @@ pub fn nearest_neighbor<'a, T>( node: &'a ParentNode, query_point: ::Point, ) -> Option<&'a T> +where + T: PointDistance, +{ + nearest_neighbor_inner(&node.children, query_point).map(|(node, _)| node) +} + +fn nearest_neighbor_inner<'a, T>( + seed_nodes: impl IntoIterator>>, + query_point: ::Point, +) -> Option<(&'a T, <::Point as Point>::Scalar)> where T: PointDistance, { fn extend_heap<'a, T>( nodes: &mut SmallHeap>, - node: &'a ParentNode, + source: impl IntoIterator>>, query_point: ::Point, min_max_distance: &mut <::Point as Point>::Scalar, ) where T: PointDistance + 'a, { - for child in &node.children { - let distance_if_less_or_equal = match child { + for child in source { + let distance_if_less_or_equal = match child.borrow() { RTreeNode::Parent(ref data) => { let distance = data.envelope.distance_2(&query_point); if distance <= *min_max_distance { @@ -208,10 +218,10 @@ where if let Some(distance) = distance_if_less_or_equal { *min_max_distance = min_inline( *min_max_distance, - child.envelope().min_max_dist_2(&query_point), + child.borrow().envelope().min_max_dist_2(&query_point), ); nodes.push(RTreeNodeDistanceWrapper { - node: child, + node: child.borrow(), distance, }); } @@ -222,26 +232,204 @@ where let mut smallest_min_max: <::Point as Point>::Scalar = Bounded::max_value(); let mut nodes = SmallHeap::new(); - extend_heap(&mut nodes, node, query_point, &mut smallest_min_max); + extend_heap(&mut nodes, seed_nodes, query_point, &mut smallest_min_max); while let Some(current) = nodes.pop() { match current { RTreeNodeDistanceWrapper { node: RTreeNode::Parent(ref data), .. } => { - extend_heap(&mut nodes, data, query_point, &mut smallest_min_max); + extend_heap(&mut nodes, &data.children, query_point, &mut smallest_min_max); } RTreeNodeDistanceWrapper { node: RTreeNode::Leaf(ref t), - .. + distance } => { - return Some(t); + return Some((t, distance)); } } } None } +/// The maximum number of subtrees to track when doing tree-to-tree +/// all-nearest-neighbors. +const MAX_AKNN_SUBTREES: usize = 16; + +/// A nearest neighbor pair with the squared euclidean distance between them. +pub struct NearestNeighbors<'a, 'b, T: PointDistance> { + /// The nearest neighbor found to `query`'s location. + pub target: &'a T, + /// The node whose location was used for the query. + pub query: &'b T, + /// Squared euclidean distance between the nodes. + pub distance_2: <::Point as Point>::Scalar, +} + +/// Yield an iterator over nearest neighbors between a pair of trees. +/// +/// Note this is note symmetric. Neighbors are found in `target_node`'s tree for +/// each node in `query_node`'s tree. +pub fn all_nearest_neighbors<'a, T>( + target_node: &'a ParentNode, + query_node: &'a ParentNode, +) -> impl Iterator> + 'a +where + T: PointDistance + 'a, +{ + AllNearestNeighborsIterator::new(target_node, query_node) +} + + +pub struct AllNearestNeighborsIterator<'a, T> +where + T: PointDistance + 'a, +{ + /// Stack of subtrees of the target tree that are candidate nearest nodes + /// for each depth of the current location in the query tree. + stack: Vec>, + /// LIFO queue of query nodes whose nearest neighbors or neighest neighbor + /// covering subtrees are to be found in depth-first order. + queue: Vec<(&'a RTreeNode, usize)>, +} + +impl<'a, T> AllNearestNeighborsIterator<'a, T> +where + T: PointDistance + 'a, +{ + fn new( + target_node: &'a ParentNode, + query_node: &'a ParentNode, + ) -> Self { + Self { + stack: vec![NeighborSubtrees::root(target_node, query_node)], + queue: query_node.children.iter().map(|child| (child, 0)).collect(), + } + } +} + +impl<'a, T> Iterator for AllNearestNeighborsIterator<'a, T> +where + T: PointDistance, +{ + type Item = NearestNeighbors<'a, 'a, T>; + + fn next(&mut self) -> Option { + // Fetch the next query node from the queue. + while let Some((node, depth)) = self.queue.pop() { + match node { + RTreeNode::Parent(ref node) => { + // If a cached subtrees struct to hold the child subtrees + // doesn't already exist, create it. + if self.stack.len() < depth + 2 { + self.stack.push(NeighborSubtrees::empty()); + } + let pair = &mut self.stack[depth..depth+2]; + let (parent, child) = pair.split_at_mut(1); + child[0].child_subtrees(&parent[0], node); + + // Add children of the query node to the end of the LIFO queue + // for depth-first traversal. + self.queue.extend(node.children().iter().map(|child| (child, depth + 1))); + }, + RTreeNode::Leaf(ref leaf) => { + // Find the nearest neighbor for `leaf`. The stack at the + // node's depth contains subtrees of the target tree that + // cover any potential nearest neighbor matches. + let subtrees = &self.stack[depth]; + + return nearest_neighbor_inner( + &subtrees.target_nodes, + // FIXME: inelegant solution to recover leaf's point. + leaf.envelope().center() + ).map(|(target, distance_2)| NearestNeighbors { + query: leaf, + target, + distance_2 + }) + } + } + } + + None + } +} + + +struct NeighborSubtrees<'a, T> +where + T: PointDistance + 'a +{ + /// Nodes comprising a subtree of the target tree that cover all potential + /// neighest neighbor matches. + target_nodes: Vec<&'a RTreeNode>, +} + +impl<'a, T> NeighborSubtrees<'a, T> +where + T: PointDistance + 'a, +{ + fn empty() -> Self { + Self { + target_nodes: vec![] + } + } + + fn root( + target_node: &'a ParentNode, + query_node: &'a ParentNode, + ) -> Self { + let target_nodes = target_node.children().iter().collect(); + let preroot = Self { + target_nodes, + }; + let mut root = Self::empty(); + root.child_subtrees(&preroot, query_node); + root + } + + /// Replace the contents of this subtree with subtrees of `parent`'s target + /// subtrees that are guaranteed to cover any nearest neighbor queries from + /// `query_node`. + fn child_subtrees( + &mut self, + parent: &Self, + query_node: &'a ParentNode, + ) { + self.target_nodes.clear(); + if parent.target_nodes.len() < MAX_AKNN_SUBTREES { + // If the set of target subtrees is not too large, subdivide each + // subtree into its children so they can be individually pruned + // by distance to the query. + parent.target_nodes.iter().for_each(|node| { + match *node { + RTreeNode::Parent(ref parent) => self.target_nodes.extend(&parent.children), + leaf @ RTreeNode::Leaf(..) => self.target_nodes.push(leaf), + } + }); + } else { + // If the set of target subtrees is already large, retain it rather + // than further subdividing. + self.target_nodes.extend(&parent.target_nodes); + }; + + // For each target subtree, find the distance which guarantees any + // potential elements in the query node's envelope have a match with + // the target subtree's envelope. Find the minimal such distance. + let min_max_dist = self.target_nodes.iter().fold(Bounded::max_value(), |min_max_dist, node| { + let dist = query_node.envelope.max_min_max_dist_2(&node.envelope()); + min_inline(min_max_dist, dist) + }); + + // Only retain subtrees that potentially have a match with the query + // node nearer than the min max distance computed above. + self.target_nodes.retain(|node| { + let distance = node.envelope().min_dist_2(&query_node.envelope); + distance <= min_max_dist + }); + } +} + #[cfg(test)] mod test { use crate::object::PointDistance; @@ -275,6 +463,37 @@ mod test { } } + #[test] + fn test_all_nearest_neighbors() { + let points = create_random_points(1_000, SEED_1); + let tree = RTree::bulk_load(points.clone()); + + let mut tree_sequential = RTree::new(); + for point in &points { + tree_sequential.insert(*point); + } + + // Test that in identical trees, all-nearest-neighbors match the + // identical nodes with themselves. + for neighbors in super::all_nearest_neighbors(tree.root(), tree_sequential.root()) { + assert_eq!(neighbors.query, neighbors.target); + assert_eq!(neighbors.distance_2, 0.0); + } + + assert_eq!(super::all_nearest_neighbors(tree.root(), tree_sequential.root()).count(), points.len()); + + // For different trees, test that the all-nearest-neighbor results match + // individual nearest neighbors. + // From random testing, the large number of points is necessary to catch + // errors in the pruning algorithm. + let sample_points = create_random_points(10_000, SEED_2); + let sample_tree = RTree::bulk_load(sample_points.clone()); + for neighbors in super::all_nearest_neighbors(tree.root(), sample_tree.root()) { + let single_neighbor = tree.nearest_neighbor(neighbors.query); + assert_eq!(Some(neighbors.target), single_neighbor); + } + } + #[test] fn test_nearest_neighbor_iterator() { let mut points = create_random_points(1000, SEED_1); diff --git a/rstar/src/envelope.rs b/rstar/src/envelope.rs index 56e2a2e..6a23899 100755 --- a/rstar/src/envelope.rs +++ b/rstar/src/envelope.rs @@ -35,6 +35,9 @@ pub trait Envelope: Clone + Copy + PartialEq + ::std::fmt::Debug { /// Returns the euclidean distance to the envelope's border. fn distance_2(&self, point: &Self::Point) -> ::Scalar; + /// Returns the minimum euclidean distance between this envelope and another. + fn min_dist_2(&self, other: &Self) -> ::Scalar; + /// Returns the squared min-max distance, a concept that helps to find nearest neighbors efficiently. /// /// Visually, if an AABB and a point are given, the min-max distance returns the distance at which we @@ -44,6 +47,13 @@ pub trait Envelope: Clone + Copy + PartialEq + ::std::fmt::Debug { /// Roussopoulos, Nick, Stephen Kelley, and Frédéric Vincent. "Nearest neighbor queries." ACM sigmod record. Vol. 24. No. 2. ACM, 1995. fn min_max_dist_2(&self, point: &Self::Point) -> ::Scalar; + /// Returns the squared euclidean distance such that for *any* point in this envelope, + /// we surely know that *an* element must be present in `other` envelope within + /// that distance. + /// + /// Note that this is not necessarily symmetric. + fn max_min_max_dist_2(&self, other: &Self) -> ::Scalar; + /// Returns the envelope's center point. fn center(&self) -> Self::Point; diff --git a/rstar/src/rtree.rs b/rstar/src/rtree.rs index 7fdcc56..75c2648 100755 --- a/rstar/src/rtree.rs +++ b/rstar/src/rtree.rs @@ -611,6 +611,40 @@ where } } + /// Returns an iterator of the nearest neighbor in this tree for each node in a query tree. + /// + /// This method is more efficient than querying each node individually. + /// + /// # Example + /// ``` + /// use rstar::RTree; + /// let target_tree = RTree::bulk_load(vec![ + /// [0.0, 0.0], + /// [0.0, 1.0], + /// ]); + /// let query_tree = RTree::bulk_load(vec![ + /// [0.2, 0.7], + /// [0.5, 0.0], + /// [0.5, 1.0], + /// ]); + /// let mut ann = target_tree.all_nearest_neighbors(&query_tree) + /// .map(|nn| (nn.query, nn.target)) + /// .collect::>(); + /// ann.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()); + /// assert_eq!(ann, &[ + /// // Query point, nearest neighbor + /// (&[0.2, 0.7], &[0.0, 1.0]), + /// (&[0.5, 0.0], &[0.0, 0.0]), + /// (&[0.5, 1.0], &[0.0, 1.0]), + /// ]); + /// ``` + pub fn all_nearest_neighbors<'a, P2: RTreeParams>( + &'a self, + query_tree: &'a RTree, + ) -> impl Iterator> { + nearest_neighbor::all_nearest_neighbors(&self.root, query_tree.root()) + } + /// Returns all elements of the tree within a certain distance. /// /// The elements may be returned in any order. Each returned element