Skip to content

Commit

Permalink
ready for checkout
Browse files Browse the repository at this point in the history
  • Loading branch information
konstntokas committed Dec 18, 2024
1 parent ba484bf commit fa79e87
Showing 1 changed file with 379 additions and 0 deletions.
379 changes: 379 additions & 0 deletions xcube/core/resampling/rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,382 @@ def _iclamp(x: int, x_min: int, x_max: int) -> int:

def _millis(seconds: float) -> int:
return round(1000 * seconds)


# ------------------------------
# ------------------------------
# new implementation


def rectify_dataset_new(
source_ds: xr.Dataset,
/,
source_gm: Optional[GridMapping] = None,
target_gm: Optional[GridMapping] = None,
ref_ds: Optional[xr.Dataset] = None,
var_names: Optional[Union[str, Sequence[str]]] = None,
encode_cf: bool = True,
gm_name: Optional[str] = None,
tile_size: Optional[Union[int, tuple[int, int]]] = None,
is_j_axis_up: Optional[bool] = None,
output_ij_names: Optional[tuple[str, str]] = None,
compute_subset: bool = True,
uv_delta: float = 1e-3,
interpolation: Optional[str] = None,
xy_var_names: Optional[tuple[str, str]] = None,
) -> Optional[xr.Dataset]:
"""Reproject dataset *source_ds* using its per-pixel
x,y coordinates or the given *source_gm*.
The function expects *source_ds* or the given
*source_gm* to have either one- or two-dimensional
coordinate variables that provide spatial x,y coordinates
for every data variable with the same spatial dimensions.
For example, a dataset may comprise variables with
spatial dimensions ``var(..., y_dim, x_dim)``,
then one the function expects coordinates to be provided
in two forms:
1. One-dimensional ``x_var(x_dim)``
and ``y_var(y_dim)`` (coordinate) variables.
2. Two-dimensional ``x_var(y_dim, x_dim)``
and ``y_var(y_dim, x_dim)`` (coordinate) variables.
If *target_gm* is given and it defines a tile size,
or *tile_size* is given and the number of tiles is
greater than one in the output's x- or y-direction, then the
returned dataset will be composed of lazy, chunked dask
arrays. Otherwise, the returned dataset will be composed
of ordinary numpy arrays.
New in 1.6: If *target_ds* is given, its coordinate
variables are copied by reference into the returned
dataset.
Args:
source_ds: Source dataset.
source_gm: Source dataset grid mapping.
target_gm: Optional target geometry. If not given, output
geometry will be computed to spatially fit *dataset* and to
retain its spatial resolution.
ref_ds: An optional dataset that provides the
target grid mapping if *target_gm* is not provided.
If *ref_ds* is given, its coordinate variables are copied
by reference into the returned dataset.
var_names: Optional variable name or sequence of variable names.
encode_cf: Whether to encode the target grid mapping into the
resampled dataset in a CF-compliant way. Defaults to
``True``.
gm_name: Name for the grid mapping variable. Defaults to "crs".
Used only if *encode_cf* is ``True``.
tile_size: Optional tile size for the output.
is_j_axis_up: Whether y coordinates are increasing with positive
image j axis.
output_ij_names: If given, a tuple of variable names in which to
store the computed source pixel coordinates in the returned
output.
compute_subset: Whether to compute a spatial subset from
*source_ds* using the boundary of the target grid mapping.
If set, the function may return ``None`` in case there is no
overlap.
uv_delta: A normalized value that is used to determine whether
x,y coordinates in the output are contained in the triangles
defined by the input x,y coordinates. The higher this value,
the more inaccurate the rectification will be.
interpolation: Interpolation method for computing output pixels.
If given, must be "nearest", "triangular", or "bilinear".
The default is "nearest". The "triangular" interpolation is
performed between 3 and "bilinear" between 4 adjacent source
pixels. Both are applied only to variables of
floating point type. If you need to interpolate between
integer data you should cast it to float first.
xy_var_names: Deprecated. No longer used since 1.0.0,
no replacement.
Returns:
A reprojected dataset, or None if the requested output does not
intersect with *dataset*.
"""
if xy_var_names:
warnings.warn(
"argument 'xy_var_names' has been deprecated in 1.4.2"
" and may be removed anytime.",
category=DeprecationWarning,
)

if source_gm is None:
source_gm = GridMapping.from_dataset(source_ds)

src_attrs = dict(source_ds.attrs)

if target_gm is None and ref_ds is not None:
target_gm = GridMapping.from_dataset(ref_ds)

if target_gm is None:
target_gm = source_gm.to_regular(tile_size=tile_size)
elif compute_subset:
source_ds_subset = select_spatial_subset(
source_ds,
xy_bbox=target_gm.xy_bbox,
ij_border=1,
xy_border=0.5 * (target_gm.x_res + target_gm.y_res),
grid_mapping=source_gm,
)
if source_ds_subset is None:
return None
if source_ds_subset is not source_ds:
source_gm = GridMapping.from_dataset(source_ds_subset)
source_ds = source_ds_subset

if tile_size is not None or is_j_axis_up is not None:
target_gm = target_gm.derive(tile_size=tile_size, is_j_axis_up=is_j_axis_up)

src_vars = _select_variables(source_ds, source_gm, var_names)

interpolation_mode = _INTERPOLATIONS.get(interpolation or "nearest")
if interpolation_mode is None:
raise ValueError(f"invalid interpolation: {interpolation!r}")

if target_gm.is_tiled:
compute_dst_src_ij_images = _compute_ij_images_xarray_dask
compute_dst_var_image = _compute_var_image_xarray_dask
else:
compute_dst_src_ij_images = _compute_ij_images_xarray_numpy
compute_dst_var_image = _compute_var_image_xarray_numpy

dst_src_ij_array = compute_dst_src_ij_images(source_gm, target_gm, uv_delta)

dst_x_dim, dst_y_dim = target_gm.xy_dim_names
dst_dims = dst_y_dim, dst_x_dim
dst_ds_coords = target_gm.to_coords()
dst_vars = dict()
for src_var_name, src_var in src_vars.items():
dst_var_dims = src_var.dims[0:-2] + dst_dims
dst_var_coords = {
d: src_var.coords[d] for d in dst_var_dims if d in src_var.coords
}
# noinspection PyTypeChecker
dst_var_coords.update(
{d: dst_ds_coords[d] for d in dst_var_dims if d in dst_ds_coords}
)
dst_var_array = compute_dst_var_image(
src_var,
dst_src_ij_array,
fill_value=np.nan,
interpolation=interpolation_mode,
)
dst_var = xr.DataArray(
dst_var_array,
dims=dst_var_dims,
coords=dst_var_coords,
attrs=src_var.attrs,
)
dst_vars[src_var_name] = dst_var

if output_ij_names:
output_i_name, output_j_name = output_ij_names
dst_ij_coords = {d: dst_ds_coords[d] for d in dst_dims if d in dst_ds_coords}
dst_vars[output_i_name] = xr.DataArray(
dst_src_ij_array[0], dims=dst_dims, coords=dst_ij_coords
)
dst_vars[output_j_name] = xr.DataArray(
dst_src_ij_array[1], dims=dst_dims, coords=dst_ij_coords
)

return complete_resampled_dataset(
encode_cf,
xr.Dataset(dst_vars, coords=dst_ds_coords, attrs=src_attrs),
target_gm,
gm_name,
ref_ds.coords if ref_ds else None,
)


def _generate_index_map(source_gm: GridMapping, target_gm: GridMapping) -> da.Array:
"""Generate dask.array.Array index map with size defined by *target_gm* filled
with pixel indices of source image.
"""
target_shape = 2, target_gm.height, target_gm.width
target_chunks = 2, target_gm.tile_width, target_gm.tile_height

dst_x_min, dst_y_min, dst_x_max, dst_y_max = target_gm.xy_bbox

# Compute an empirical xy_border as a function of the
# number of tiles, because the more tiles we have
# the smaller the destination xy-bboxes and the higher
# the risk to not find any source ij-bbox for a given xy-bbox.
# xy_border will not be larger than half of the
# coverage of a tile.
num_tiles_x = target_gm.width / target_gm.tile_width
num_tiles_y = target_gm.height / target_gm.tile_height
xy_border = min(
min(2 * num_tiles_x * target_gm.x_res, 2 * num_tiles_y * target_gm.y_res),
min(0.5 * (dst_x_max - dst_x_min), 0.5 * (dst_y_max - dst_y_min)),
)
src_ij_bboxes = source_gm.ij_bboxes_from_xy_bboxes(
target_gm.xy_bboxes, xy_border=xy_border, ij_border=1
)

index_map = da.full(target_shape, np.nan, chunk=target_chunks, dtype=np.float64)
src_x_values, src_y_values = source_gm.xy_coords

u_min = v_min = -uv_delta
uv_max = 1.0 + 2 * uv_delta

src_j0 = da.arange(source_gm.height - 1)
src_i0 = da.arange(source_gm.width - 1)
src_j0, src_i0 = da.meshgrid(src_j0, src_i0, indexing="ij")
src_j1 = src_j0 + 1
src_i1 = src_i0 + 1

dst_p0x = source_gm.xy_coords[0, src_j0, src_i0]
dst_p1x = source_gm.xy_coords[0, src_j0, src_i1]
dst_p2x = source_gm.xy_coords[0, src_j1, src_i0]
dst_p3x = source_gm.xy_coords[0, src_j1, src_i1]

dst_p0y = source_gm.xy_coords[1, src_j0, src_i0]
dst_p1y = source_gm.xy_coords[1, src_j0, src_i1]
dst_p2y = source_gm.xy_coords[1, src_j1, src_i0]
dst_p3y = source_gm.xy_coords[1, src_j1, src_i1]

if (
dst_i_max < 0
or dst_j_max < 0
or dst_i_min >= dst_width
or dst_j_min >= dst_height
):
continue

if dst_i_min < 0:
dst_i_min = 0

if dst_i_max >= dst_width:
dst_i_max = dst_width - 1

if dst_j_min < 0:
dst_j_min = 0

if dst_j_max >= dst_height:
dst_j_max = dst_height - 1

# u from p0 right to p1, v from p0 down to p2
# noinspection PyTypeChecker
det_a = _fdet(dst_p0x, dst_p0y, dst_p1x, dst_p1y, dst_p2x, dst_p2y)
if np.isnan(det_a):
det_a = 0.0

# u from p3 left to p2, v from p3 up to p1
# noinspection PyTypeChecker
det_b = _fdet(dst_p3x, dst_p3y, dst_p2x, dst_p2y, dst_p1x, dst_p1y)
if np.isnan(det_b):
det_b = 0.0

if det_a == 0.0 and det_b == 0.0:
# Both the triangles do not exist.
continue

for dst_j in range(dst_j_min, dst_j_max + 1):
dst_y = dst_y_offset + (dst_j + 0.5) * dst_y_scale
for dst_i in range(dst_i_min, dst_i_max + 1):
sentinel = dst_src_ij_images[0, dst_j, dst_i]
if not np.isnan(sentinel):
# If we have a source pixel in dst_i, dst_j already,
# there is no need to compute another one.
# One is as good as the other.
continue

dst_x = dst_x_offset + (dst_i + 0.5) * dst_x_scale

src_i = src_j = -1

if det_a != 0.0:
# noinspection PyTypeChecker
u = _fu(dst_x, dst_y, dst_p0x, dst_p0y, dst_p2x, dst_p2y) / det_a
# noinspection PyTypeChecker
v = _fv(dst_x, dst_y, dst_p0x, dst_p0y, dst_p1x, dst_p1y) / det_a
if u >= u_min and v >= v_min and u + v <= uv_max:
src_i = src_i0 + _fclamp(u, 0.0, 1.0)
src_j = src_j0 + _fclamp(v, 0.0, 1.0)
if src_i == -1 and det_b != 0.0:
# noinspection PyTypeChecker
u = _fu(dst_x, dst_y, dst_p3x, dst_p3y, dst_p1x, dst_p1y) / det_b
# noinspection PyTypeChecker
v = _fv(dst_x, dst_y, dst_p3x, dst_p3y, dst_p2x, dst_p2y) / det_b
if u >= u_min and v >= v_min and u + v <= uv_max:
src_i = src_i1 - _fclamp(u, 0.0, 1.0)
src_j = src_j1 - _fclamp(v, 0.0, 1.0)
if src_i != -1:
dst_src_ij_images[0, dst_j, dst_i] = src_i_min + src_i
dst_src_ij_images[1, dst_j, dst_i] = src_j_min + src_j

return compute_array_from_func(
_compute_ij_images_xarray_dask_block,
dst_var_shape,
dst_var_chunks,
np.float64,
ctx_arg_names=[
"dtype",
"block_id",
"block_shape",
"block_slices",
],
args=(
src_geo_coding.xy_coords,
src_ij_bboxes,
dst_x_min,
dst_y_min,
dst_y_max,
dst_x_res,
dst_y_res,
dst_is_j_axis_up,
uv_delta,
),
name="ij_pixels",
)


def _interpolate(
source: da.array,
index_map: da.array,
interpolation: int,
):
"""Extract pixels from *source* da.ndarray using *index_map*."""
bounds = [(0, source.shape[-2] - 1), (0, source.shape[-1] - 1)]

index_map0 = da.round(index_map).astype(int)

if interpolation == 0:
# interpolation == "nearest"
index_map0 = _check_bounds(index_map0, bounds)
destination = source[..., index_map0[0], index_map0[1]]
else:
u, v = index_map - index_map0
index_map1 = index_map0 + 1
index_map1 = _check_bounds(index_map1, bounds)
value_00 = source[..., index_map0[0], index_map0[1]]
value_01 = source[..., index_map0[0], index_map1[1]]
value_10 = source[..., index_map1[0], index_map0[1]]
value_11 = source[..., index_map1[0], index_map1[1]]
if interpolation == 1:
# interpolation == "triangular"
destination = da.where(
u + v < 1.0,
value_00 + u * (value_01 - value_00) + v * (value_10 - value_00),
value_11
+ (1.0 - u) * (value_10 - value_11)
+ (1.0 - v) * (value_01 - value_11),
)
else:
# interpolation == "bilinear
value_u0 = value_00 + u * (value_01 - value_00)
value_u1 = value_10 + u * (value_11 - value_10)
destination = value_u0 + v * (value_u1 - value_u0)
return destination


def _check_bounds(arr: da.array, bounds: list[tuple[int]]) -> da.array:
for idx, bound in enumerate(bounds):
arr[idx] = da.where(arr[idx] < bound[1], arr[idx], bound[1])
arr[idx] = da.where(arr[idx] > bound[0], arr[idx], bound[0])
return arr

0 comments on commit fa79e87

Please sign in to comment.