Skip to content

Commit

Permalink
bandindexer improvements
Browse files Browse the repository at this point in the history
now more support for multiband arrays
  • Loading branch information
njwilson23 committed Jul 3, 2016
1 parent 66bfc47 commit cc14420
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
24 changes: 21 additions & 3 deletions karta/raster/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ def __getitem__(self, key):
if isinstance(key, np.ndarray):
return np.dstack([b[:,:][key] for b in self.bands])
else:
return np.dstack([b[key] for b in self.bands])
if len(key) not in (2, 3):
raise IndexError("indexing tuple must have length 2 or 3")
sr, sc = key[:2]
sb = slice(None, None, None) if len(key) == 2 else key[2]
if isinstance(sb, slice):
return np.dstack([b[sr,sc] for b in self.bands[sb]])
else:
return self.bands[sb][sr,sc]

def __setitem__(self, key, value):
if len(self.bands) == 1:
Expand All @@ -47,8 +54,19 @@ def __setitem__(self, key, value):
tmp[key] = v
b[:,:] = tmp
else:
for b, v in zip(self.bands, value):
b[key] = v
if len(key) not in (2, 3):
raise IndexError("indexing tuple must have length 2 or 3")
sr, sc = key[:2]
sb = slice(None, None, None) if len(key) == 2 else key[2]
if isinstance(sb, slice):
if len(value.shape) == 3:
for i in range(len(self.bands)):
self.bands[i][sr,sc] = value[:,:,i]
else:
for b in self.bands[sb]:
b[sr,sc] = value
else:
self.bands[sb][sr,sc] = value
return

def __iter__(self):
Expand Down
35 changes: 35 additions & 0 deletions tests/band_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,40 @@ def test_set_masked(self):
indexer[mask] = -1
self.assertEqual(np.sum(indexer[:,:]), 32)

def test_get_multibanded(self):
values = np.ones([16, 16])
bands = [CompressedBand((16, 16), np.float32),
CompressedBand((16, 16), np.float32),
CompressedBand((16, 16), np.float32)]
bands[0][:,:] = values
bands[1][:,:] = 2*values
bands[2][:,:] = 3*values

indexer = BandIndexer(bands)
result = indexer[4:7,2:8,:]
self.assertEqual(result.shape, (3, 6, 3))
self.assertTrue(np.all(result[0,0,:] == np.array([1.0, 2.0, 3.0])))

# make sure it works with a scalar band index
result = indexer[4:7,2:8,1]
self.assertEqual(result.shape, (3, 6))
self.assertTrue(np.all(result == 2.0))
return

def test_set_multibanded(self):
values = np.ones([16, 16])
bands = [CompressedBand((16, 16), np.float32),
CompressedBand((16, 16), np.float32),
CompressedBand((16, 16), np.float32)]

indexer = BandIndexer(bands)
indexer[:,:,0] = values
indexer[:,:,1:] = 2*values

self.assertTrue(np.all(bands[0][:,:] == 1.0))
self.assertTrue(np.all(bands[1][:,:] == 2.0))
self.assertTrue(np.all(bands[2][:,:] == 2.0))
return

if __name__ == "__main__":
unittest.main()

0 comments on commit cc14420

Please sign in to comment.