From d3991655b56bff3a07185d7a3dada9ffaf8f60a9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 4 Mar 2022 15:57:27 -0600 Subject: [PATCH] Add tests, fix bugs --- rtree/index.py | 44 +++++++++++------ tests/test_index.py | 114 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 15 deletions(-) diff --git a/rtree/index.py b/rtree/index.py index c124a51a..f96fb8e5 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -676,26 +676,37 @@ def __and__(self, other: Index) -> Index: i = 0 new_idx = Index(interleaved=self.interleaved, properties=self.properties) - # For each Item in self... - for item1 in self.intersection(self.bounds, objects=True): - # For each Item in other that intersects... - for item2 in other.intersection(item1.bounds, objects=True): - # Compute the intersection bounding box - bounds = [] - for j in range(len(item1.bounds)): - if self.interleaved: - if j < len(item1.bounds) // 2: - bounds.append(max(item1.bounds[j], item2.bounds[j])) + if self.interleaved: + # For each Item in self... + for item1 in self.intersection(self.bounds, objects=True): + # For each Item in other that intersects... + for item2 in other.intersection(item1.bbox, objects=True): + # Compute the intersection bounding box + bbox = [] + for j in range(len(item1.bbox)): + if j < len(item1.bbox) // 2: + bbox.append(max(item1.bbox[j], item2.bbox[j])) else: - bounds.append(min(item1.bounds[j], item2.bounds[j])) - else: + bbox.append(min(item1.bbox[j], item2.bbox[j])) + + new_idx.insert(i, bbox, (item1.object, item2.object)) + i += 1 + + else: + # For each Item in self... + for item1 in self.intersection(self.bounds, objects=True): + # For each Item in other that intersects... + for item2 in other.intersection(item1.bounds, objects=True): + # Compute the intersection bounding box + bounds = [] + for j in range(len(item1.bounds)): if j % 2 == 0: bounds.append(max(item1.bounds[j], item2.bounds[j])) else: bounds.append(min(item1.bounds[j], item2.bounds[j])) - new_idx.insert(i, bounds, (item1.object, item2.object)) - i += 1 + new_idx.insert(i, bounds, (item1.object, item2.object)) + i += 1 return new_idx @@ -715,7 +726,10 @@ def __or__(self, other: Index) -> Index: for old_idx in [self, other]: # For each item... for item in old_idx.intersection(old_idx.bounds, objects=True): - new_idx.insert(item.id, item.bounds, item.object) + if self.interleaved: + new_idx.insert(item.id, item.bbox, item.object) + else: + new_idx.insert(item.id, item.bounds, item.object) return new_idx diff --git a/tests/test_index.py b/tests/test_index.py index 943b6828..e89d7b08 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -264,6 +264,120 @@ def test_double_insertion(self) -> None: self.assertEqual([1, 1], list(idx.intersection((0, 0, 5, 5)))) +class TestIndexIntersectionUnion: + @pytest.fixture(scope="class") + def index_a_interleaved(self) -> index.Index: + idx = index.Index(interleaved=True) + idx.insert(1, (3, 3, 5, 5), "a_1") + idx.insert(2, (4, 2, 6, 4), "a_2") + return idx + + @pytest.fixture(scope="class") + def index_a_uninterleaved(self) -> index.Index: + idx = index.Index(interleaved=False) + idx.insert(1, (3, 5, 3, 5), "a_1") + idx.insert(2, (4, 6, 2, 4), "a_2") + return idx + + @pytest.fixture(scope="class") + def index_b_interleaved(self) -> index.Index: + idx = index.Index(interleaved=True) + idx.insert(3, (2, 1, 7, 6), "b_3") + idx.insert(4, (8, 7, 9, 8), "b_4") + return idx + + @pytest.fixture(scope="class") + def index_b_uninterleaved(self) -> index.Index: + idx = index.Index(interleaved=False) + idx.insert(3, (2, 7, 1, 6), "b_3") + idx.insert(4, (8, 9, 7, 8), "b_4") + return idx + + def test_intersection_interleaved( + self, index_a_interleaved: index.Index, index_b_interleaved: index.Index + ) -> None: + index_c_interleaved = index_a_interleaved & index_b_interleaved + assert index_c_interleaved.interleaved + assert len(index_c_interleaved) == 2 + for hit in index_c_interleaved.intersection( + index_c_interleaved.bounds, objects=True + ): + if hit.bbox == [3.0, 3.0, 5.0, 5.0]: + assert hit.object == ("a_1", "b_3") + elif hit.bbox == [4.0, 2.0, 6.0, 4.0]: + assert hit.object == ("a_2", "b_3") + else: + assert False + + def test_intersection_uninterleaved( + self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + index_c_uninterleaved = index_a_uninterleaved & index_b_uninterleaved + assert not index_c_uninterleaved.interleaved + assert len(index_c_uninterleaved) == 2 + for hit in index_c_uninterleaved.intersection( + index_c_uninterleaved.bounds, objects=True + ): + if hit.bounds == [3.0, 5.0, 3.0, 5.0]: + assert hit.object == ("a_1", "b_3") + elif hit.bounds == [4.0, 6.0, 2.0, 4.0]: + assert hit.object == ("a_2", "b_3") + else: + assert False + + def test_intersection_mismatch( + self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + with pytest.raises(AssertionError): + index_a_interleaved & index_b_uninterleaved + + def test_union_interleaved( + self, index_a_interleaved: index.Index, index_b_interleaved: index.Index + ) -> None: + index_c_interleaved = index_a_interleaved | index_b_interleaved + assert index_c_interleaved.interleaved + assert len(index_c_interleaved) == 4 + for hit in index_c_interleaved.intersection( + index_c_interleaved.bounds, objects=True + ): + if hit.bbox == [3.0, 3.0, 5.0, 5.0]: + assert hit.object == "a_1" + elif hit.bbox == [4.0, 2.0, 6.0, 4.0]: + assert hit.object == "a_2" + elif hit.bbox == [2.0, 1.0, 7.0, 6.0]: + assert hit.object == "b_3" + elif hit.bbox == [8.0, 7.0, 9.0, 8.0]: + assert hit.object == "b_4" + else: + assert False + + def test_union_uninterleaved( + self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + index_c_uninterleaved = index_a_uninterleaved | index_b_uninterleaved + assert not index_c_uninterleaved.interleaved + assert len(index_c_uninterleaved) == 4 + for hit in index_c_uninterleaved.intersection( + index_c_uninterleaved.bounds, objects=True + ): + if hit.bounds == [3.0, 5.0, 3.0, 5.0]: + assert hit.object == "a_1" + elif hit.bounds == [4.0, 6.0, 2.0, 4.0]: + assert hit.object == "a_2" + elif hit.bounds == [2.0, 7.0, 1.0, 6.0]: + assert hit.object == "b_3" + elif hit.bounds == [8.0, 9.0, 7.0, 8.0]: + assert hit.object == "b_4" + else: + assert False + + def test_union_mismatch( + self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + with pytest.raises(AssertionError): + index_a_interleaved | index_b_uninterleaved + + class IndexSerialization(unittest.TestCase): def setUp(self) -> None: self.boxes15 = np.genfromtxt("boxes_15x15.data")