From 234a3d8ac99faae7b85ded278566da5f987c1e23 Mon Sep 17 00:00:00 2001 From: konstntokas Date: Fri, 20 Dec 2024 10:47:14 +0100 Subject: [PATCH] allow rectifying dataset for nd --- xcube/core/gridmapping/regular.py | 4 +- xcube/core/resampling/rectify.py | 429 +++--------------------------- 2 files changed, 42 insertions(+), 391 deletions(-) diff --git a/xcube/core/gridmapping/regular.py b/xcube/core/gridmapping/regular.py index 2c8679e69..9643bf602 100644 --- a/xcube/core/gridmapping/regular.py +++ b/xcube/core/gridmapping/regular.py @@ -55,7 +55,9 @@ def _new_xy_coords(self) -> xr.DataArray: xy_coords = da.concatenate( [da.expand_dims(x_coords_2d, 0), da.expand_dims(y_coords_2d, 0)] ) - xy_coords = da.rechunk(xy_coords, chunks=(2, 512, 512)) + xy_coords = da.rechunk( + xy_coords, chunks=(2, xy_coords.chunksize[1], xy_coords.chunksize[2]) + ) xy_coords = xr.DataArray( xy_coords, dims=("coord", self.y_coords.dims[0], self.x_coords.dims[0]), diff --git a/xcube/core/resampling/rectify.py b/xcube/core/resampling/rectify.py index 2db64e96b..f94ec585e 100644 --- a/xcube/core/resampling/rectify.py +++ b/xcube/core/resampling/rectify.py @@ -602,22 +602,30 @@ def _compute_var_image_xarray_numpy( def _compute_var_image_xarray_dask( src_var: xr.DataArray, - dst_src_ij_images: np.ndarray, + dst_src_ij_images: da.Array, fill_value: Union[int, float, complex] = np.nan, interpolation: int = 0, ) -> da.Array: """Extract source pixels from xarray.DataArray source with dask.array.Array data. """ - return da.map_blocks( + if src_var.ndim == 2: + src_var = src_var.expand_dims(dim={"dummy": 1}) + chunksize = src_var.shape[:-2] + dst_src_ij_images.chunksize[-2:] + arr = da.map_blocks( _compute_var_image_xarray_dask_block, - src_var.values, dst_src_ij_images, + src_var, fill_value, interpolation, + chunksize, dtype=src_var.dtype, - drop_axis=0, + chunks=chunksize, ) + arr = arr[..., : dst_src_ij_images.shape[-2], : dst_src_ij_images.shape[-1]] + if arr.shape[0] == 1: + arr = arr[0, :, :] + return arr @nb.njit(nogil=True, cache=True) @@ -640,12 +648,12 @@ def _compute_var_image_numpy( return dst_values -@nb.njit(nogil=True, cache=True) def _compute_var_image_xarray_dask_block( - src_var_image: np.ndarray, dst_src_ij_images: np.ndarray, + src_var_image: xr.DataArray, fill_value: Union[int, float, complex], interpolation: int, + chunksize: tuple[int], ) -> np.ndarray: """Extract source pixels from np.ndarray source and return a block of a dask array. @@ -653,11 +661,24 @@ def _compute_var_image_xarray_dask_block( dst_width = dst_src_ij_images.shape[-1] dst_height = dst_src_ij_images.shape[-2] dst_shape = src_var_image.shape[:-2] + (dst_height, dst_width) + dst_out = np.full(chunksize, fill_value, dtype=src_var_image.dtype) + if np.all(np.isnan(dst_src_ij_images[0])): + return dst_out dst_values = np.full(dst_shape, fill_value, dtype=src_var_image.dtype) + src_bbox = [ + int(np.nanmin(dst_src_ij_images[0])), + int(np.nanmin(dst_src_ij_images[1])), + min(int(np.nanmax(dst_src_ij_images[0])) + 1, src_var_image.shape[-1]), + min(int(np.nanmax(dst_src_ij_images[1])) + 1, src_var_image.shape[-2]), + ] + src_var_image = src_var_image[ + ..., src_bbox[1] : src_bbox[3], src_bbox[0] : src_bbox[2] + ].values.astype(np.float64) _compute_var_image_numpy_sequential( - src_var_image, dst_src_ij_images, dst_values, interpolation + src_var_image, dst_src_ij_images, dst_values, src_bbox, interpolation ) - return dst_values + dst_out[..., :dst_height, :dst_width] = dst_values + return dst_out @nb.njit(nogil=True, parallel=True, cache=True) @@ -688,6 +709,7 @@ def _compute_var_image_numpy_sequential( src_var_image: np.ndarray, dst_src_ij_images: np.ndarray, dst_var_image: np.ndarray, + src_bbox: list[int], interpolation: int, ): """Extract source pixels from np.ndarray source @@ -696,7 +718,12 @@ def _compute_var_image_numpy_sequential( dst_height = dst_var_image.shape[-2] for dst_j in range(dst_height): _compute_var_image_for_dest_line( - dst_j, src_var_image, dst_src_ij_images, dst_var_image, interpolation + dst_j, + src_var_image, + dst_src_ij_images, + dst_var_image, + src_bbox, + interpolation, ) @@ -706,6 +733,7 @@ def _compute_var_image_for_dest_line( src_var_image: np.ndarray, dst_src_ij_images: np.ndarray, dst_var_image: np.ndarray, + src_bbox: list[int], interpolation: int, ): """Extract source pixels from *src_values* np.ndarray @@ -719,8 +747,8 @@ def _compute_var_image_for_dest_line( src_i_max = src_width - 1 src_j_max = src_height - 1 for dst_i in range(dst_width): - src_i_f = dst_src_ij_images[0, dst_j, dst_i] - src_j_f = dst_src_ij_images[1, dst_j, dst_i] + src_i_f = dst_src_ij_images[0, dst_j, dst_i] - src_bbox[0] + src_j_f = dst_src_ij_images[1, dst_j, dst_i] - src_bbox[1] if np.isnan(src_i_f) or np.isnan(src_j_f): continue # Note int() is 2x faster than math.floor() and @@ -811,382 +839,3 @@ def _iclamp(x: int, x_min: int, x_max: int) -> int: def _millis(seconds: float) -> int: return round(1000 * seconds) - - -# ------------------------------ -# ------------------------------ -# new implementation - - -def rectify_dataset_new( - source_ds: xr.Dataset, - /, - source_gm: Optional[GridMapping] = None, - target_gm: Optional[GridMapping] = None, - ref_ds: Optional[xr.Dataset] = None, - var_names: Optional[Union[str, Sequence[str]]] = None, - encode_cf: bool = True, - gm_name: Optional[str] = None, - tile_size: Optional[Union[int, tuple[int, int]]] = None, - is_j_axis_up: Optional[bool] = None, - output_ij_names: Optional[tuple[str, str]] = None, - compute_subset: bool = True, - uv_delta: float = 1e-3, - interpolation: Optional[str] = None, - xy_var_names: Optional[tuple[str, str]] = None, -) -> Optional[xr.Dataset]: - """Reproject dataset *source_ds* using its per-pixel - x,y coordinates or the given *source_gm*. - - The function expects *source_ds* or the given - *source_gm* to have either one- or two-dimensional - coordinate variables that provide spatial x,y coordinates - for every data variable with the same spatial dimensions. - - For example, a dataset may comprise variables with - spatial dimensions ``var(..., y_dim, x_dim)``, - then one the function expects coordinates to be provided - in two forms: - - 1. One-dimensional ``x_var(x_dim)`` - and ``y_var(y_dim)`` (coordinate) variables. - 2. Two-dimensional ``x_var(y_dim, x_dim)`` - and ``y_var(y_dim, x_dim)`` (coordinate) variables. - - If *target_gm* is given and it defines a tile size, - or *tile_size* is given and the number of tiles is - greater than one in the output's x- or y-direction, then the - returned dataset will be composed of lazy, chunked dask - arrays. Otherwise, the returned dataset will be composed - of ordinary numpy arrays. - - New in 1.6: If *target_ds* is given, its coordinate - variables are copied by reference into the returned - dataset. - - Args: - source_ds: Source dataset. - source_gm: Source dataset grid mapping. - target_gm: Optional target geometry. If not given, output - geometry will be computed to spatially fit *dataset* and to - retain its spatial resolution. - ref_ds: An optional dataset that provides the - target grid mapping if *target_gm* is not provided. - If *ref_ds* is given, its coordinate variables are copied - by reference into the returned dataset. - var_names: Optional variable name or sequence of variable names. - encode_cf: Whether to encode the target grid mapping into the - resampled dataset in a CF-compliant way. Defaults to - ``True``. - gm_name: Name for the grid mapping variable. Defaults to "crs". - Used only if *encode_cf* is ``True``. - tile_size: Optional tile size for the output. - is_j_axis_up: Whether y coordinates are increasing with positive - image j axis. - output_ij_names: If given, a tuple of variable names in which to - store the computed source pixel coordinates in the returned - output. - compute_subset: Whether to compute a spatial subset from - *source_ds* using the boundary of the target grid mapping. - If set, the function may return ``None`` in case there is no - overlap. - uv_delta: A normalized value that is used to determine whether - x,y coordinates in the output are contained in the triangles - defined by the input x,y coordinates. The higher this value, - the more inaccurate the rectification will be. - interpolation: Interpolation method for computing output pixels. - If given, must be "nearest", "triangular", or "bilinear". - The default is "nearest". The "triangular" interpolation is - performed between 3 and "bilinear" between 4 adjacent source - pixels. Both are applied only to variables of - floating point type. If you need to interpolate between - integer data you should cast it to float first. - xy_var_names: Deprecated. No longer used since 1.0.0, - no replacement. - - Returns: - A reprojected dataset, or None if the requested output does not - intersect with *dataset*. - """ - if xy_var_names: - warnings.warn( - "argument 'xy_var_names' has been deprecated in 1.4.2" - " and may be removed anytime.", - category=DeprecationWarning, - ) - - if source_gm is None: - source_gm = GridMapping.from_dataset(source_ds) - - src_attrs = dict(source_ds.attrs) - - if target_gm is None and ref_ds is not None: - target_gm = GridMapping.from_dataset(ref_ds) - - if target_gm is None: - target_gm = source_gm.to_regular(tile_size=tile_size) - elif compute_subset: - source_ds_subset = select_spatial_subset( - source_ds, - xy_bbox=target_gm.xy_bbox, - ij_border=1, - xy_border=0.5 * (target_gm.x_res + target_gm.y_res), - grid_mapping=source_gm, - ) - if source_ds_subset is None: - return None - if source_ds_subset is not source_ds: - source_gm = GridMapping.from_dataset(source_ds_subset) - source_ds = source_ds_subset - - if tile_size is not None or is_j_axis_up is not None: - target_gm = target_gm.derive(tile_size=tile_size, is_j_axis_up=is_j_axis_up) - - src_vars = _select_variables(source_ds, source_gm, var_names) - - interpolation_mode = _INTERPOLATIONS.get(interpolation or "nearest") - if interpolation_mode is None: - raise ValueError(f"invalid interpolation: {interpolation!r}") - - if target_gm.is_tiled: - compute_dst_src_ij_images = _compute_ij_images_xarray_dask - compute_dst_var_image = _compute_var_image_xarray_dask - else: - compute_dst_src_ij_images = _compute_ij_images_xarray_numpy - compute_dst_var_image = _compute_var_image_xarray_numpy - - dst_src_ij_array = compute_dst_src_ij_images(source_gm, target_gm, uv_delta) - - dst_x_dim, dst_y_dim = target_gm.xy_dim_names - dst_dims = dst_y_dim, dst_x_dim - dst_ds_coords = target_gm.to_coords() - dst_vars = dict() - for src_var_name, src_var in src_vars.items(): - dst_var_dims = src_var.dims[0:-2] + dst_dims - dst_var_coords = { - d: src_var.coords[d] for d in dst_var_dims if d in src_var.coords - } - # noinspection PyTypeChecker - dst_var_coords.update( - {d: dst_ds_coords[d] for d in dst_var_dims if d in dst_ds_coords} - ) - dst_var_array = compute_dst_var_image( - src_var, - dst_src_ij_array, - fill_value=np.nan, - interpolation=interpolation_mode, - ) - dst_var = xr.DataArray( - dst_var_array, - dims=dst_var_dims, - coords=dst_var_coords, - attrs=src_var.attrs, - ) - dst_vars[src_var_name] = dst_var - - if output_ij_names: - output_i_name, output_j_name = output_ij_names - dst_ij_coords = {d: dst_ds_coords[d] for d in dst_dims if d in dst_ds_coords} - dst_vars[output_i_name] = xr.DataArray( - dst_src_ij_array[0], dims=dst_dims, coords=dst_ij_coords - ) - dst_vars[output_j_name] = xr.DataArray( - dst_src_ij_array[1], dims=dst_dims, coords=dst_ij_coords - ) - - return complete_resampled_dataset( - encode_cf, - xr.Dataset(dst_vars, coords=dst_ds_coords, attrs=src_attrs), - target_gm, - gm_name, - ref_ds.coords if ref_ds else None, - ) - - -def _generate_index_map(source_gm: GridMapping, target_gm: GridMapping) -> da.Array: - """Generate dask.array.Array index map with size defined by *target_gm* filled - with pixel indices of source image. - """ - target_shape = 2, target_gm.height, target_gm.width - target_chunks = 2, target_gm.tile_width, target_gm.tile_height - - dst_x_min, dst_y_min, dst_x_max, dst_y_max = target_gm.xy_bbox - - # Compute an empirical xy_border as a function of the - # number of tiles, because the more tiles we have - # the smaller the destination xy-bboxes and the higher - # the risk to not find any source ij-bbox for a given xy-bbox. - # xy_border will not be larger than half of the - # coverage of a tile. - num_tiles_x = target_gm.width / target_gm.tile_width - num_tiles_y = target_gm.height / target_gm.tile_height - xy_border = min( - min(2 * num_tiles_x * target_gm.x_res, 2 * num_tiles_y * target_gm.y_res), - min(0.5 * (dst_x_max - dst_x_min), 0.5 * (dst_y_max - dst_y_min)), - ) - src_ij_bboxes = source_gm.ij_bboxes_from_xy_bboxes( - target_gm.xy_bboxes, xy_border=xy_border, ij_border=1 - ) - - index_map = da.full(target_shape, np.nan, chunk=target_chunks, dtype=np.float64) - src_x_values, src_y_values = source_gm.xy_coords - - u_min = v_min = -uv_delta - uv_max = 1.0 + 2 * uv_delta - - src_j0 = da.arange(source_gm.height - 1) - src_i0 = da.arange(source_gm.width - 1) - src_j0, src_i0 = da.meshgrid(src_j0, src_i0, indexing="ij") - src_j1 = src_j0 + 1 - src_i1 = src_i0 + 1 - - dst_p0x = source_gm.xy_coords[0, src_j0, src_i0] - dst_p1x = source_gm.xy_coords[0, src_j0, src_i1] - dst_p2x = source_gm.xy_coords[0, src_j1, src_i0] - dst_p3x = source_gm.xy_coords[0, src_j1, src_i1] - - dst_p0y = source_gm.xy_coords[1, src_j0, src_i0] - dst_p1y = source_gm.xy_coords[1, src_j0, src_i1] - dst_p2y = source_gm.xy_coords[1, src_j1, src_i0] - dst_p3y = source_gm.xy_coords[1, src_j1, src_i1] - - if ( - dst_i_max < 0 - or dst_j_max < 0 - or dst_i_min >= dst_width - or dst_j_min >= dst_height - ): - continue - - if dst_i_min < 0: - dst_i_min = 0 - - if dst_i_max >= dst_width: - dst_i_max = dst_width - 1 - - if dst_j_min < 0: - dst_j_min = 0 - - if dst_j_max >= dst_height: - dst_j_max = dst_height - 1 - - # u from p0 right to p1, v from p0 down to p2 - # noinspection PyTypeChecker - det_a = _fdet(dst_p0x, dst_p0y, dst_p1x, dst_p1y, dst_p2x, dst_p2y) - if np.isnan(det_a): - det_a = 0.0 - - # u from p3 left to p2, v from p3 up to p1 - # noinspection PyTypeChecker - det_b = _fdet(dst_p3x, dst_p3y, dst_p2x, dst_p2y, dst_p1x, dst_p1y) - if np.isnan(det_b): - det_b = 0.0 - - if det_a == 0.0 and det_b == 0.0: - # Both the triangles do not exist. - continue - - for dst_j in range(dst_j_min, dst_j_max + 1): - dst_y = dst_y_offset + (dst_j + 0.5) * dst_y_scale - for dst_i in range(dst_i_min, dst_i_max + 1): - sentinel = dst_src_ij_images[0, dst_j, dst_i] - if not np.isnan(sentinel): - # If we have a source pixel in dst_i, dst_j already, - # there is no need to compute another one. - # One is as good as the other. - continue - - dst_x = dst_x_offset + (dst_i + 0.5) * dst_x_scale - - src_i = src_j = -1 - - if det_a != 0.0: - # noinspection PyTypeChecker - u = _fu(dst_x, dst_y, dst_p0x, dst_p0y, dst_p2x, dst_p2y) / det_a - # noinspection PyTypeChecker - v = _fv(dst_x, dst_y, dst_p0x, dst_p0y, dst_p1x, dst_p1y) / det_a - if u >= u_min and v >= v_min and u + v <= uv_max: - src_i = src_i0 + _fclamp(u, 0.0, 1.0) - src_j = src_j0 + _fclamp(v, 0.0, 1.0) - if src_i == -1 and det_b != 0.0: - # noinspection PyTypeChecker - u = _fu(dst_x, dst_y, dst_p3x, dst_p3y, dst_p1x, dst_p1y) / det_b - # noinspection PyTypeChecker - v = _fv(dst_x, dst_y, dst_p3x, dst_p3y, dst_p2x, dst_p2y) / det_b - if u >= u_min and v >= v_min and u + v <= uv_max: - src_i = src_i1 - _fclamp(u, 0.0, 1.0) - src_j = src_j1 - _fclamp(v, 0.0, 1.0) - if src_i != -1: - dst_src_ij_images[0, dst_j, dst_i] = src_i_min + src_i - dst_src_ij_images[1, dst_j, dst_i] = src_j_min + src_j - - return compute_array_from_func( - _compute_ij_images_xarray_dask_block, - dst_var_shape, - dst_var_chunks, - np.float64, - ctx_arg_names=[ - "dtype", - "block_id", - "block_shape", - "block_slices", - ], - args=( - src_geo_coding.xy_coords, - src_ij_bboxes, - dst_x_min, - dst_y_min, - dst_y_max, - dst_x_res, - dst_y_res, - dst_is_j_axis_up, - uv_delta, - ), - name="ij_pixels", - ) - - -def _interpolate( - source: da.array, - index_map: da.array, - interpolation: int, -): - """Extract pixels from *source* da.ndarray using *index_map*.""" - bounds = [(0, source.shape[-2] - 1), (0, source.shape[-1] - 1)] - - index_map0 = da.round(index_map).astype(int) - - if interpolation == 0: - # interpolation == "nearest" - index_map0 = _check_bounds(index_map0, bounds) - destination = source[..., index_map0[0], index_map0[1]] - else: - u, v = index_map - index_map0 - index_map1 = index_map0 + 1 - index_map1 = _check_bounds(index_map1, bounds) - value_00 = source[..., index_map0[0], index_map0[1]] - value_01 = source[..., index_map0[0], index_map1[1]] - value_10 = source[..., index_map1[0], index_map0[1]] - value_11 = source[..., index_map1[0], index_map1[1]] - if interpolation == 1: - # interpolation == "triangular" - destination = da.where( - u + v < 1.0, - value_00 + u * (value_01 - value_00) + v * (value_10 - value_00), - value_11 - + (1.0 - u) * (value_10 - value_11) - + (1.0 - v) * (value_01 - value_11), - ) - else: - # interpolation == "bilinear - value_u0 = value_00 + u * (value_01 - value_00) - value_u1 = value_10 + u * (value_11 - value_10) - destination = value_u0 + v * (value_u1 - value_u0) - return destination - - -def _check_bounds(arr: da.array, bounds: list[tuple[int]]) -> da.array: - for idx, bound in enumerate(bounds): - arr[idx] = da.where(arr[idx] < bound[1], arr[idx], bound[1]) - arr[idx] = da.where(arr[idx] > bound[0], arr[idx], bound[0]) - return arr