Skip to content

Commit

Permalink
refactor enrich_from_dataset to two approaches of point and raster in…
Browse files Browse the repository at this point in the history
…put data
  • Loading branch information
SarahAlidoost committed Mar 15, 2024
1 parent 9a5feb4 commit ba41174
Showing 1 changed file with 126 additions and 47 deletions.
173 changes: 126 additions & 47 deletions stmtools/stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xarray as xr
from shapely.geometry import Point
from shapely.strtree import STRtree
from scipy.spatial import cKDTree

from stmtools.metadata import DataVarTypes, STMMetaData
from stmtools.utils import _has_property
Expand Down Expand Up @@ -400,20 +401,25 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
self._obj = self._obj.sortby(self._obj.order)
return self._obj

def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | Iterable, method="linear") -> xr.Dataset:
def enrich_from_dataset(self,
dataset: xr.Dataset | xr.DataArray,
fields: str | Iterable,
method="nearest") -> xr.Dataset:
"""Enrich the SpaceTimeMatrix from one or more fields of a dataset.
scipy is required. Each field will be assigned as a data variable to the
STM using interpolation in time and space.
scipy is required. if dataset is raster, it uses
_enrich_from_raster_block to do interpolation using method. if dataset
is point, it uses _enrich_from_points_block to find the nearest points
in space and time using Euclidean distance.
Parameters
----------
dataset : xarray.Dataset | xarray.DataArray
dataset : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
method : str, optional
Method of interpolation, by default "linear", see
Method of interpolation, by default "nearest", see
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like
Returns
Expand All @@ -432,32 +438,37 @@ def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str |
dataset = dataset.to_dataset()

ds = self._obj
# check if both dataset and ds have coords_labels keys
for coord_label in ds.coords.keys():
if coord_label not in dataset.coords.keys():
raise ValueError(
f'Coordinate label "{coord_label}" was not found in the input dataset.'
)

# check if dataset is point or raster if 'space' in dataset.dims:
if "space" in dataset.dims:
approch = "point"
elif "lat" in dataset.dims and "lon" in dataset.dims:
approch = "raster"
elif "y" in dataset.dims and "x" in dataset.dims:
approch = "raster"
else:
raise ValueError(
"The input dataset is not a point or raster dataset."
"The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions."
)

# TODO: add utility to preprocess the dataset
# check if dataset has space and time dimensions
if "space" not in dataset.dims:
raise ValueError('Missing dimension: "space" in the input dataset.')
# check if dataset has time dimensions
if "time" not in dataset.dims:
raise ValueError('Missing dimension: "time" in the input dataset.')

# check if dtype of time is the same
if dataset.time.dtype != ds.time.dtype:
raise ValueError("The input dataset and the STM have different time dtype.")

# check if dataset and ds has the same space and time shapes, required
# for interpolation
if dataset.space.shape != ds.space.shape:
raise ValueError("The input dataset and the STM have different space shapes.")
if dataset.time.shape != ds.time.shape:
raise ValueError("The input dataset and the STM have different time shapes.")

# check if the keys of dataset coordinates are the same as the STM
for key in ds.coords.keys():
if key not in dataset.coords.keys():
raise ValueError(f'Coordinate label "{key}" was not found in the input dataset.')
# TODO: check if both ds and dataset has same coordinate system

chunks = (ds.chunksizes["space"][0], ds.chunksizes["time"][0])
for field in fields:
for i, field in enumerate(fields):

# check if dataset has the fields
if field not in dataset.data_vars.keys():
Expand All @@ -469,29 +480,22 @@ def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str |
f'"{field}" was found in the data variables of the STM. '
f'"We will proceed with the data variable from the input dataset as "{field}_other".'
)
field = f"{field}_other"

ds = ds.assign(
{
field: (
["space", "time"],
da.from_array(np.full(ds.space.shape + ds.time.shape, None), chunks=chunks),
)
}
fields[i] = f"{field}_other"

if approch == "raster":
return xr.map_blocks(
_enrich_from_raster_block,
ds,
args=(fields, method),
kwargs={"dataset": dataset}, #TODD: block still not working, refactor
)
elif approch == "point":
return xr.map_blocks(
_enrich_from_points_block,
ds,
args=(fields),
kwargs={"dataset": dataset},
)
# spatial interpolation and map_blocks does not work if coordinates are not same
# ds = xr.map_blocks(
# _enrich_from_dataset_block,
# ds,
# args=(dataset, fields, method),
# template=ds,
# )
_ds = ds.copy(deep=True)
for field in fields:
_ds[field].data = dataset[field].interp_like(ds, method=method)
ds = _ds

return ds

@property
def num_points(self):
Expand Down Expand Up @@ -667,9 +671,84 @@ def _compute_morton_code(xx, yy):
return code


def _enrich_from_dataset_block(ds, dataset, fields, method):
"""Block-wise function for "enrich_from_dataset"."""
def _enrich_from_raster_block(ds, dataraster, fields, method):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a raster dataset.
scipy is required. It uses xarray.Dataset.interp_like to interpolate the
raster dataset to the coordinates of ds.
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like
Parameters
----------
ds : xarray.Dataset
dataset : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
method : str, optional
Method of interpolation, by default "linear", see
Returns
-------
xarray.Dataset
"""
# interpolate the raster dataset to the coordinates of ds
interpolated = dataraster.interp(ds.coords, method=method)

# Assign these values to the corresponding points in ds
_ds = ds.copy(deep=True)
for field in fields:
_ds[field].data = dataset[field].interp_like(ds, method=method)
_ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords)
return _ds


def _enrich_from_points_block(ds, datapoints, fields):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset.
scipy is required. It uses cKDTree to find the nearest points in space and
time using Euclidean distance.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.html#scipy-spatial-ckdtree
Parameters
----------
ds : xarray.Dataset
datapoints : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
Returns
-------
xarray.Dataset
"""
_ds = ds.copy(deep=True)

# create tuple of spatial coordinates
spatial_coords = list(_ds.coords.keys())[:-1] # assuming the last coordinate is time
ds_coords = np.column_stack([_ds[coord].values.flatten() for coord in spatial_coords])

spatial_coords = list(datapoints.coords.keys())[:-1] # assuming the last coordinate is time
dataset_points_coords = np.column_stack([datapoints[coord].values.flatten() for coord in spatial_coords])

# Create a cKDTree object for the spatial coordinates of datapoints
# Find the indices of the nearest points in space in datapoints for each point in _ds
# it uses Euclidean distance
tree = cKDTree(dataset_points_coords)
_, indices_space = tree.query(ds_coords)

# Create a cKDTree object for the temporal coordinates of datapoints
# Find the indices of the nearest points in time in datapoints for each point in _ds
datapoints_times = datapoints.time.values.reshape(-1, 1)
ds_times = _ds.time.values.reshape(-1, 1)
tree = cKDTree(datapoints_times)
_, indices_time = tree.query(ds_times)

selections = datapoints.isel(time=indices_time, space=indices_space)

# Assign these values to the corresponding points in _ds
for field in fields:
_ds[field] = xr.DataArray(selections[field].data, dims=ds.dims, coords=ds.coords)

return _ds

0 comments on commit ba41174

Please sign in to comment.