Skip to content

Commit

Permalink
Add tests, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Mar 4, 2022
1 parent c27cef2 commit d399165
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 15 deletions.
44 changes: 29 additions & 15 deletions rtree/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,26 +676,37 @@ def __and__(self, other: Index) -> Index:
i = 0
new_idx = Index(interleaved=self.interleaved, properties=self.properties)

# For each Item in self...
for item1 in self.intersection(self.bounds, objects=True):
# For each Item in other that intersects...
for item2 in other.intersection(item1.bounds, objects=True):
# Compute the intersection bounding box
bounds = []
for j in range(len(item1.bounds)):
if self.interleaved:
if j < len(item1.bounds) // 2:
bounds.append(max(item1.bounds[j], item2.bounds[j]))
if self.interleaved:
# For each Item in self...
for item1 in self.intersection(self.bounds, objects=True):
# For each Item in other that intersects...
for item2 in other.intersection(item1.bbox, objects=True):
# Compute the intersection bounding box
bbox = []
for j in range(len(item1.bbox)):
if j < len(item1.bbox) // 2:
bbox.append(max(item1.bbox[j], item2.bbox[j]))
else:
bounds.append(min(item1.bounds[j], item2.bounds[j]))
else:
bbox.append(min(item1.bbox[j], item2.bbox[j]))

new_idx.insert(i, bbox, (item1.object, item2.object))
i += 1

else:
# For each Item in self...
for item1 in self.intersection(self.bounds, objects=True):
# For each Item in other that intersects...
for item2 in other.intersection(item1.bounds, objects=True):
# Compute the intersection bounding box
bounds = []
for j in range(len(item1.bounds)):
if j % 2 == 0:
bounds.append(max(item1.bounds[j], item2.bounds[j]))
else:
bounds.append(min(item1.bounds[j], item2.bounds[j]))

new_idx.insert(i, bounds, (item1.object, item2.object))
i += 1
new_idx.insert(i, bounds, (item1.object, item2.object))
i += 1

return new_idx

Expand All @@ -715,7 +726,10 @@ def __or__(self, other: Index) -> Index:
for old_idx in [self, other]:
# For each item...
for item in old_idx.intersection(old_idx.bounds, objects=True):
new_idx.insert(item.id, item.bounds, item.object)
if self.interleaved:
new_idx.insert(item.id, item.bbox, item.object)
else:
new_idx.insert(item.id, item.bounds, item.object)

return new_idx

Expand Down
114 changes: 114 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,120 @@ def test_double_insertion(self) -> None:
self.assertEqual([1, 1], list(idx.intersection((0, 0, 5, 5))))


class TestIndexIntersectionUnion:
@pytest.fixture(scope="class")
def index_a_interleaved(self) -> index.Index:
idx = index.Index(interleaved=True)
idx.insert(1, (3, 3, 5, 5), "a_1")
idx.insert(2, (4, 2, 6, 4), "a_2")
return idx

@pytest.fixture(scope="class")
def index_a_uninterleaved(self) -> index.Index:
idx = index.Index(interleaved=False)
idx.insert(1, (3, 5, 3, 5), "a_1")
idx.insert(2, (4, 6, 2, 4), "a_2")
return idx

@pytest.fixture(scope="class")
def index_b_interleaved(self) -> index.Index:
idx = index.Index(interleaved=True)
idx.insert(3, (2, 1, 7, 6), "b_3")
idx.insert(4, (8, 7, 9, 8), "b_4")
return idx

@pytest.fixture(scope="class")
def index_b_uninterleaved(self) -> index.Index:
idx = index.Index(interleaved=False)
idx.insert(3, (2, 7, 1, 6), "b_3")
idx.insert(4, (8, 9, 7, 8), "b_4")
return idx

def test_intersection_interleaved(
self, index_a_interleaved: index.Index, index_b_interleaved: index.Index
) -> None:
index_c_interleaved = index_a_interleaved & index_b_interleaved
assert index_c_interleaved.interleaved
assert len(index_c_interleaved) == 2
for hit in index_c_interleaved.intersection(
index_c_interleaved.bounds, objects=True
):
if hit.bbox == [3.0, 3.0, 5.0, 5.0]:
assert hit.object == ("a_1", "b_3")
elif hit.bbox == [4.0, 2.0, 6.0, 4.0]:
assert hit.object == ("a_2", "b_3")
else:
assert False

def test_intersection_uninterleaved(
self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
index_c_uninterleaved = index_a_uninterleaved & index_b_uninterleaved
assert not index_c_uninterleaved.interleaved
assert len(index_c_uninterleaved) == 2
for hit in index_c_uninterleaved.intersection(
index_c_uninterleaved.bounds, objects=True
):
if hit.bounds == [3.0, 5.0, 3.0, 5.0]:
assert hit.object == ("a_1", "b_3")
elif hit.bounds == [4.0, 6.0, 2.0, 4.0]:
assert hit.object == ("a_2", "b_3")
else:
assert False

def test_intersection_mismatch(
self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
with pytest.raises(AssertionError):
index_a_interleaved & index_b_uninterleaved

def test_union_interleaved(
self, index_a_interleaved: index.Index, index_b_interleaved: index.Index
) -> None:
index_c_interleaved = index_a_interleaved | index_b_interleaved
assert index_c_interleaved.interleaved
assert len(index_c_interleaved) == 4
for hit in index_c_interleaved.intersection(
index_c_interleaved.bounds, objects=True
):
if hit.bbox == [3.0, 3.0, 5.0, 5.0]:
assert hit.object == "a_1"
elif hit.bbox == [4.0, 2.0, 6.0, 4.0]:
assert hit.object == "a_2"
elif hit.bbox == [2.0, 1.0, 7.0, 6.0]:
assert hit.object == "b_3"
elif hit.bbox == [8.0, 7.0, 9.0, 8.0]:
assert hit.object == "b_4"
else:
assert False

def test_union_uninterleaved(
self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
index_c_uninterleaved = index_a_uninterleaved | index_b_uninterleaved
assert not index_c_uninterleaved.interleaved
assert len(index_c_uninterleaved) == 4
for hit in index_c_uninterleaved.intersection(
index_c_uninterleaved.bounds, objects=True
):
if hit.bounds == [3.0, 5.0, 3.0, 5.0]:
assert hit.object == "a_1"
elif hit.bounds == [4.0, 6.0, 2.0, 4.0]:
assert hit.object == "a_2"
elif hit.bounds == [2.0, 7.0, 1.0, 6.0]:
assert hit.object == "b_3"
elif hit.bounds == [8.0, 9.0, 7.0, 8.0]:
assert hit.object == "b_4"
else:
assert False

def test_union_mismatch(
self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
with pytest.raises(AssertionError):
index_a_interleaved | index_b_uninterleaved


class IndexSerialization(unittest.TestCase):
def setUp(self) -> None:
self.boxes15 = np.genfromtxt("boxes_15x15.data")
Expand Down

0 comments on commit d399165

Please sign in to comment.