diff --git a/faninsar/datasets/base.py b/faninsar/datasets/base.py index eb1fb92..eeef8bf 100644 --- a/faninsar/datasets/base.py +++ b/faninsar/datasets/base.py @@ -696,10 +696,23 @@ def __getitem__( return result + def _ensure_bands_idx(self, vrt_fh) -> list[int] | int: + """Return the proper band indexes to use for the dataset. The band indexes + is a list of integers if multiple bands are requested, otherwise it is an + integer. + """ + bands = self.band_indexes or vrt_fh.indexes + # If only one band is requested, return a 2D array + if len(bands) == 1: + bands = bands[0] + return bands + def _points_query(self, points: Points, vrt_fh) -> np.ndarray: - """Return the values of dataset at given points. Points that outside the dataset will be masked.""" + """Return the values of dataset at given points. Points that outside the + dataset will be masked.""" points = self._ensure_query_crs(points) - data = np.ma.hstack(list(vrt_fh.sample(points.values, masked=True))) + bands_idx = self._ensure_bands_idx(vrt_fh) + data = np.ma.hstack(list(vrt_fh.sample(points.values, bands_idx, masked=True))) return data def _bbox_query(self, bbox: BoundingBox, vrt_fh) -> np.ndarray: @@ -707,15 +720,15 @@ def _bbox_query(self, bbox: BoundingBox, vrt_fh) -> np.ndarray: bbox = self._ensure_query_crs(bbox) win = vrt_fh.window(*bbox) - bands = self.band_indexes or vrt_fh.indexes + bands_idx = self._ensure_bands_idx(vrt_fh) data = vrt_fh.read( out_shape=( - len(bands), + len(bands_idx), round((bbox.top - bbox.bottom) / self.res[1]), round((bbox.right - bbox.left) / self.res[0]), ), resampling=self.resampling, - indexes=self.band_indexes, + indexes=bands_idx, window=win, masked=True, boundless=self.same_crs, # boundless=True if self.same_crs else False, @@ -730,13 +743,14 @@ def _bbox_query(self, bbox: BoundingBox, vrt_fh) -> np.ndarray: def _polygons_query(self, polygons: Polygons, vrt_fh) -> np.ndarray: """Return the values of the dataset at the given polygons.""" polygons = self._ensure_query_crs(polygons) - + bands_idx = self._ensure_bands_idx(vrt_fh) mask_params = { "filled": False, "pad": polygons.pad, "all_touched": polygons.all_touched, "invert": False, "crop": True, + "indexes": bands_idx, } rasterize_params = { "all_touched": polygons.all_touched, @@ -1590,7 +1604,7 @@ def _sample_data( if query.points is not None: points_values = self._points_query(query.points, variable) if query.boxes is not None: - + bbox_values = self._bbox_query(query.boxes, variable) if query.polygons is not None: polygons_values = self._polygons_query(query.polygons, variable)