Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a function to enrich STM using data from another dataset #66

Merged
merged 33 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a64bb06
draft implementation of querying data
SarahAlidoost Mar 12, 2024
9a5feb4
add scipy and xarray io to dependencies
SarahAlidoost Mar 12, 2024
ba41174
refactor enrich_from_dataset to two approaches of point and raster in…
SarahAlidoost Mar 15, 2024
e133f5f
use KDTree instead of cKDTree
SarahAlidoost Mar 18, 2024
90ef61d
replace KDTree with sel method of xarray, fix a bug in fields
SarahAlidoost Mar 18, 2024
99844c8
fix tests
SarahAlidoost Mar 18, 2024
6926a4b
remove ds copy
SarahAlidoost Mar 20, 2024
f61f355
add test if operations are lazy
SarahAlidoost Mar 20, 2024
b3101c8
add util functions for cropping and unstack operations
SarahAlidoost Mar 21, 2024
605db7e
fix stm enrich function
SarahAlidoost Mar 22, 2024
dc33ecb
fix and refactor util function for cropping
SarahAlidoost Mar 22, 2024
2002854
fix an error msg
SarahAlidoost Mar 22, 2024
9113b02
fix linter errors
SarahAlidoost Mar 22, 2024
86d7b6d
fix linters
SarahAlidoost Mar 25, 2024
da020a4
remove scipy because it is included in xarray io
SarahAlidoost Mar 25, 2024
3dc1756
fix linter errors in _io
SarahAlidoost Mar 25, 2024
0691161
fix minor things
SarahAlidoost Mar 25, 2024
8c473a7
Update stmtools/stm.py
SarahAlidoost Apr 8, 2024
92d5301
add two utils functions for checking coordinates
SarahAlidoost Apr 8, 2024
70089d1
fix test unique coords in test_util
SarahAlidoost Apr 8, 2024
b85415a
add a check if coords are monotonic and unigue is stm
SarahAlidoost Apr 8, 2024
c030cf9
use scipy KDTree instead of xarray unstack and sel functions
SarahAlidoost Apr 10, 2024
f123336
fix linter errors
SarahAlidoost Apr 10, 2024
73acbbd
add test for non monotonic an dduplicates coords
SarahAlidoost Apr 11, 2024
3d01481
add a test for non monotonic time
SarahAlidoost Apr 11, 2024
9b3ddd8
add type to coordinates in tests
SarahAlidoost Apr 12, 2024
45e1900
fix a linter error
SarahAlidoost Apr 12, 2024
0412f10
debug: add debuging to pytest in workflow, and comment the test to ch…
SarahAlidoost Apr 12, 2024
ea580a2
debug comment the test to check on macos
SarahAlidoost Apr 12, 2024
9dc7804
fix tests comparing values instead of data arrays
SarahAlidoost Apr 12, 2024
0666878
remove util function for checking unique values
SarahAlidoost Apr 12, 2024
fc784c8
fix the test
SarahAlidoost Apr 12, 2024
90ef7be
remove -vv from action build
SarahAlidoost Apr 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "0.1.1"
requires-python = ">=3.10"
dependencies = [
"dask[complete]",
"xarray",
"xarray[io]",
"numpy",
"rasterio",
"geopandas",
Expand Down
1 change: 1 addition & 0 deletions stmtools/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def from_csv(
Returns:
-------
xr.Dataset: Output STM instance

"""
# Load csv as Dask DataFrame
ddf = dd.read_csv(file, blocksize=blocksize)
Expand Down
201 changes: 200 additions & 1 deletion stmtools/stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pymorton as pm
import xarray as xr
from scipy.spatial import KDTree
from shapely.geometry import Point
from shapely.strtree import STRtree

Expand All @@ -23,7 +24,7 @@ class SpaceTimeMatrix:
"""Space-Time Matrix."""

def __init__(self, xarray_obj):
"""init."""
"""Init."""
self._obj = xarray_obj

def add_metadata(self, metadata):
Expand All @@ -38,6 +39,7 @@ def add_metadata(self, metadata):
-------
xarray.Dataset
STM with assigned attributes.

"""
self._obj = self._obj.assign_attrs(metadata)
return self._obj
Expand Down Expand Up @@ -69,6 +71,7 @@ def regulate_dims(self, space_label=None, time_label=None):
-------
xarray.Dataset
Regulated STM.

"""
if (
(space_label is None)
Expand Down Expand Up @@ -128,6 +131,7 @@ def subset(self, method: str, **kwargs):
-------
xarray.Dataset
A subset of the original STM.

"""
# Check if both "space" and "time" dimension exists
for dim in ["space", "time"]:
Expand Down Expand Up @@ -203,6 +207,7 @@ def enrich_from_polygon(self, polygon, fields, xlabel="lon", ylabel="lat"):
-------
xarray.Dataset
Enriched STM.

"""
_ = _validate_coords(self._obj, xlabel, ylabel)

Expand Down Expand Up @@ -266,6 +271,7 @@ def _in_polygon(self, polygon, xlabel="lon", ylabel="lat"):
-------
Dask.array
A boolean Dask array. True where a space entry is inside the (multi-)polygon.

"""
# Check if coords exists
_ = _validate_coords(self._obj, xlabel, ylabel)
Expand Down Expand Up @@ -311,6 +317,7 @@ def register_metadata(self, dict_meta: STMMetaData):
-------
xarray.Dataset
STM with registered metadata.

"""
ds_updated = self._obj.assign_attrs(dict_meta)

Expand All @@ -330,6 +337,7 @@ def register_datatype(self, keys: str | Iterable, datatype: DataVarTypes):
-------
xarray.Dataset
STM with registered metadata.

"""
ds_updated = self._obj

Expand Down Expand Up @@ -363,6 +371,7 @@ def get_order(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
Scaling multiplier to the x coordinates before truncating them to integer values.
yscale : float
Scaling multiplier to the y coordinates before truncating them to integer values.

"""
meta_arr = np.array((), dtype=np.int64)
order = da.apply_gufunc(
Expand Down Expand Up @@ -395,11 +404,99 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
Scaling multiplier to the x coordinates before truncating them to integer values.
yscale : float
Scaling multiplier to the y coordinates before truncating them to integer values.

"""
self._obj = self.get_order(xlabel, ylabel, xscale, yscale)
self._obj = self._obj.sortby(self._obj.order)
return self._obj

def enrich_from_dataset(self,
dataset: xr.Dataset | xr.DataArray,
fields: str | Iterable,
method="nearest",
) -> xr.Dataset:
"""Enrich the SpaceTimeMatrix from one or more fields of a dataset.

scipy is required. if dataset is raster, it uses
_enrich_from_raster_block to do interpolation using method. if dataset
is point, it uses _enrich_from_points_block to find the nearest points
in space and time using Euclidean distance.

Parameters
----------
dataset : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
method : str, optional
Method of interpolation, by default "nearest", see
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp.html

Returns
-------
xarray.Dataset
Enriched STM.

"""
# Check if fields is a Iterable or a str
if isinstance(fields, str):
fields = [fields]
elif not isinstance(fields, Iterable):
raise ValueError("fields need to be a Iterable or a string")

# if dataset is a DataArray, convert it to a Dataset
if isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset()

ds = self._obj
# check if both dataset and ds have coords_labels keys
for coord_label in ds.coords.keys():
if coord_label not in dataset.coords.keys():
raise ValueError(
f'Coordinate label "{coord_label}" was not found in the input dataset.'
)

# check if dataset is point or raster if 'space' in dataset.dims:
if "space" in dataset.dims:
approch = "point"
elif "lat" in dataset.dims and "lon" in dataset.dims:
approch = "raster"
elif "y" in dataset.dims and "x" in dataset.dims:
approch = "raster"
else:
raise ValueError(
"The input dataset is not a point or raster dataset."
"The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions."
"Consider renaming using "
"https://docs.xarray.dev/en/latest/generated/xarray.Dataset.rename.html#xarray-dataset-rename"
)

# check if dataset has time dimensions
if "time" not in dataset.dims:
raise ValueError('Missing dimension: "time" in the input dataset.')

# check if dtype of time is the same
if dataset.time.dtype != ds.time.dtype:
raise ValueError("The input dataset and the STM have different time dtype.")

# TODO: check if both ds and dataset has same coordinate system

for field in fields:

# check if dataset has the fields
if field not in dataset.data_vars.keys():
raise ValueError(f'Field "{field}" not found in the the input dataset')

# check STM has the filed already
if field in ds.data_vars.keys():
raise ValueError(f'Field "{field}" already exists in the STM.')
# TODO: overwrite the field in the STM

if approch == "raster":
return _enrich_from_raster_block(ds, dataset, fields, method)
elif approch == "point":
return _enrich_from_points_block(ds, dataset, fields)

@property
def num_points(self):
"""Get number of space entry of the stm.
Expand All @@ -408,6 +505,7 @@ def num_points(self):
-------
int
Number of space entry.

"""
return self._obj.dims["space"]

Expand All @@ -419,6 +517,7 @@ def num_epochs(self):
-------
int
Number of epochs.

"""
return self._obj.dims["time"]

Expand Down Expand Up @@ -472,6 +571,7 @@ def _ml_str_query(xx, yy, polygon, type_polygon):
An array with two columns. The first column is the positional index into the list of
polygons being used to query the tree. The second column is the positional index into
the list of space entries for which the tree was constructed.

"""
# Crop the polygon to the bounding box of the block
xmin, ymin, xmax, ymax = [
Expand Down Expand Up @@ -537,6 +637,7 @@ def _validate_coords(ds, xlabel, ylabel):
------
ValueError
If xlabel or ylabel neither exists in coordinates, raise ValueError

"""
for clabel in [xlabel, ylabel]:
if clabel not in ds.coords.keys():
Expand Down Expand Up @@ -569,6 +670,104 @@ def _compute_morton_code(xx, yy):
-------
array_like
An array with Morton codes per coordinate pair.

"""
code = [pm.interleave(int(xi), int(yi)) for xi, yi in zip(xx, yy, strict=True)]
return code


def _enrich_from_raster_block(ds, dataraster, fields, method):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a raster dataset.

scipy is required. It uses xarray.Dataset.interp_like to interpolate the
raster dataset to the coordinates of ds.
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp.html

Parameters
----------
ds : xarray.Dataset
SpaceTimeMatrix to enrich
dataraster : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
method : str, optional
Method of interpolation, by default "nearest", see

Returns
-------
xarray.Dataset

"""
# interpolate the raster dataset to the coordinates of ds
interpolated = dataraster.interp(ds.coords, method=method)

# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords)
return ds


def _enrich_from_points_block(ds, datapoints, fields):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset.

Assumption is that dimensions of data are space and time.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.KDTree.html#scipy.spatial.KDTree

Parameters
----------
ds : xarray.Dataset
SpaceTimeMatrix to enrich
datapoints : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment

Returns
-------
xarray.Dataset

"""
# The reason that we use KDTRee instead of xarray.unstack is that the latter
# is slow for large datasets

# check the dimensions
indexer = {}
for dim in ["space", "time"]:
if dim not in datapoints.coords:
indexer[dim]= [
coord for coord in datapoints.coords if dim in datapoints[coord].dims
]
else:
indexer[dim] = [dim]

## datapoints
indexes = [datapoints[coord] for coord in indexer["space"]]
dataset_points_coords = np.column_stack(indexes)

# ds
indexes = [ds[coord] for coord in indexer["space"]]
ds_coords = np.column_stack(indexes)

# Create a KDTree object for the spatial coordinates of datapoints
# Find the indices of the nearest points in space in datapoints for each point in ds
# it uses Euclidean distance
tree = KDTree(dataset_points_coords)
_, indices_space = tree.query(ds_coords)

# Create a KDTree object for the temporal coordinates of datapoints
# Find the indices of the nearest points in time in datapoints for each point in ds
datapoints_times = datapoints.time.values.reshape(-1, 1)
ds_times = ds.time.values.reshape(-1, 1)
tree = KDTree(datapoints_times)
_, indices_time = tree.query(ds_times)

selections = datapoints.isel(space=indices_space, time=indices_time)

# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(
selections[field].data, dims=ds.dims, coords=ds.coords
)

return ds
Loading
Loading