Skip to content

Commit

Permalink
refine bands strategy of RasterDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Fanchengyan committed Jul 10, 2024
1 parent 9245109 commit 51c0282
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions faninsar/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,26 +696,39 @@ def __getitem__(

return result

def _ensure_bands_idx(self, vrt_fh) -> list[int] | int:
"""Return the proper band indexes to use for the dataset. The band indexes
is a list of integers if multiple bands are requested, otherwise it is an
integer.
"""
bands = self.band_indexes or vrt_fh.indexes
# If only one band is requested, return a 2D array
if len(bands) == 1:
bands = bands[0]
return bands

def _points_query(self, points: Points, vrt_fh) -> np.ndarray:
"""Return the values of dataset at given points. Points that outside the dataset will be masked."""
"""Return the values of dataset at given points. Points that outside the
dataset will be masked."""
points = self._ensure_query_crs(points)
data = np.ma.hstack(list(vrt_fh.sample(points.values, masked=True)))
bands_idx = self._ensure_bands_idx(vrt_fh)
data = np.ma.hstack(list(vrt_fh.sample(points.values, bands_idx, masked=True)))
return data

def _bbox_query(self, bbox: BoundingBox, vrt_fh) -> np.ndarray:
"""Return the values of the dataset at the given bounding box."""
bbox = self._ensure_query_crs(bbox)

win = vrt_fh.window(*bbox)
bands = self.band_indexes or vrt_fh.indexes
bands_idx = self._ensure_bands_idx(vrt_fh)
data = vrt_fh.read(
out_shape=(
len(bands),
len(bands_idx),
round((bbox.top - bbox.bottom) / self.res[1]),
round((bbox.right - bbox.left) / self.res[0]),
),
resampling=self.resampling,
indexes=self.band_indexes,
indexes=bands_idx,
window=win,
masked=True,
boundless=self.same_crs, # boundless=True if self.same_crs else False,
Expand All @@ -730,13 +743,14 @@ def _bbox_query(self, bbox: BoundingBox, vrt_fh) -> np.ndarray:
def _polygons_query(self, polygons: Polygons, vrt_fh) -> np.ndarray:
"""Return the values of the dataset at the given polygons."""
polygons = self._ensure_query_crs(polygons)

bands_idx = self._ensure_bands_idx(vrt_fh)
mask_params = {
"filled": False,
"pad": polygons.pad,
"all_touched": polygons.all_touched,
"invert": False,
"crop": True,
"indexes": bands_idx,
}
rasterize_params = {
"all_touched": polygons.all_touched,
Expand Down Expand Up @@ -1590,7 +1604,7 @@ def _sample_data(
if query.points is not None:
points_values = self._points_query(query.points, variable)
if query.boxes is not None:

bbox_values = self._bbox_query(query.boxes, variable)
if query.polygons is not None:
polygons_values = self._polygons_query(query.polygons, variable)
Expand Down

0 comments on commit 51c0282

Please sign in to comment.