diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 61eb68c..d2d6c52 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -48,6 +48,10 @@ jobs: id: status run: pytest -v . --cov=xvec --cov-append --cov-report term-missing --cov-report xml --color=yes --report-log pytest-log.jsonl + - name: run mypy + if: contains(matrix.environment-file, 'ci/312.yaml') && contains(matrix.os, 'ubuntu') + run: mypy xvec/ --install-types --ignore-missing-imports --non-interactive + - uses: codecov/codecov-action@v3 - name: Generate and publish the report diff --git a/ci/312.yaml b/ci/312.yaml index 2f325a1..08cef9e 100644 --- a/ci/312.yaml +++ b/ci/312.yaml @@ -19,4 +19,5 @@ dependencies: - geopandas-base - geodatasets - pyogrio + - mypy diff --git a/pyproject.toml b/pyproject.toml index b32baae..38684a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ omit = ["xvec/tests/*"] exclude_lines = [ "except ImportError", "except PackageNotFoundError", + "if TYPE_CHECKING:" ] [tool.ruff] diff --git a/xvec/__init__.py b/xvec/__init__.py index 44c022b..b069714 100644 --- a/xvec/__init__.py +++ b/xvec/__init__.py @@ -1,7 +1,7 @@ from importlib.metadata import PackageNotFoundError, version -from .accessor import XvecAccessor # noqa -from .index import GeometryIndex # noqa +from .accessor import XvecAccessor # noqa: F401 +from .index import GeometryIndex # noqa: F401 try: __version__ = version("xvec") diff --git a/xvec/accessor.py b/xvec/accessor.py index 83c45f6..a779bb8 100644 --- a/xvec/accessor.py +++ b/xvec/accessor.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Hashable, Mapping, Sequence -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast import numpy as np import pandas as pd @@ -13,6 +13,9 @@ from .index import GeometryIndex from .zonal import _zonal_stats_iterative, _zonal_stats_rasterize +if TYPE_CHECKING: + from geopandas import GeoDataFrame + @xr.register_dataarray_accessor("xvec") @xr.register_dataset_accessor("xvec") @@ -22,7 +25,7 @@ class XvecAccessor: Currently works on coordinates with :class:`xvec.GeometryIndex`. """ - def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): + def __init__(self, xarray_obj: xr.Dataset | xr.DataArray) -> None: """xvec init, nothing to be done here.""" self._obj = xarray_obj self._geom_coords_all = [ @@ -36,7 +39,9 @@ def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): if self.is_geom_variable(name, has_index=True) ] - def is_geom_variable(self, name: Hashable, has_index: bool = True): + def is_geom_variable( + self, name: Hashable, has_index: bool = True + ) -> bool | np.bool_: """Check if coordinate variable is composed of :class:`shapely.Geometry`. Can return all such variables or only those using :class:`~xvec.GeometryIndex`. @@ -208,7 +213,7 @@ def to_crs( self, variable_crs: Mapping[Any, Any] | None = None, **variable_crs_kwargs: Any, - ): + ) -> xr.DataArray | xr.Dataset: """ Transform :class:`shapely.Geometry` objects of a variable to a new coordinate reference system. @@ -313,20 +318,15 @@ def to_crs( currently wraps :meth:`Dataset.assign_coords ` or :meth:`DataArray.assign_coords `. """ - if variable_crs and variable_crs_kwargs: - raise ValueError( - "Cannot specify both keyword and positional arguments to " - "'.xvec.to_crs'." - ) + variable_crs_solved = _resolve_input( + variable_crs, variable_crs_kwargs, "to_crs" + ) _obj = self._obj.copy(deep=False) - if variable_crs_kwargs: - variable_crs = variable_crs_kwargs - transformed = {} - for key, crs in variable_crs.items(): + for key, crs in variable_crs_solved.items(): if not isinstance(self._obj.xindexes[key], GeometryIndex): raise ValueError( f"The index '{key}' is not an xvec.GeometryIndex. " @@ -335,7 +335,7 @@ def to_crs( ) data = _obj[key] - data_crs = self._obj.xindexes[key].crs + data_crs = self._obj.xindexes[key].crs # type: ignore # transformation code taken from geopandas (BSD 3-clause license) if data_crs is None: @@ -374,21 +374,21 @@ def to_crs( for key, (result, _crs) in transformed.items(): _obj = _obj.assign_coords({key: result}) - _obj = _obj.drop_indexes(variable_crs.keys()) + _obj = _obj.drop_indexes(variable_crs_solved.keys()) - for key, crs in variable_crs.items(): + for key, crs in variable_crs_solved.items(): if crs: _obj[key].attrs["crs"] = CRS.from_user_input(crs) - _obj = _obj.set_xindex(key, GeometryIndex, crs=crs) + _obj = _obj.set_xindex([key], GeometryIndex, crs=crs) return _obj def set_crs( self, variable_crs: Mapping[Any, Any] | None = None, - allow_override=False, + allow_override: bool = False, **variable_crs_kwargs: Any, - ): + ) -> xr.DataArray | xr.Dataset: """Set the Coordinate Reference System (CRS) of coordinates backed by :class:`~xvec.GeometryIndex`. @@ -480,19 +480,13 @@ def set_crs( transform the geometries to a new CRS, use the :meth:`to_crs` method. """ - - if variable_crs and variable_crs_kwargs: - raise ValueError( - "Cannot specify both keyword and positional arguments to " - ".xvec.set_crs." - ) + variable_crs_solved = _resolve_input( + variable_crs, variable_crs_kwargs, "set_crs" + ) _obj = self._obj.copy(deep=False) - if variable_crs_kwargs: - variable_crs = variable_crs_kwargs - - for key, crs in variable_crs.items(): + for key, crs in variable_crs_solved.items(): if not isinstance(self._obj.xindexes[key], GeometryIndex): raise ValueError( f"The index '{key}' is not an xvec.GeometryIndex. " @@ -500,7 +494,7 @@ def set_crs( "handling projection information." ) - data_crs = self._obj.xindexes[key].crs + data_crs = self._obj.xindexes[key].crs # type: ignore if not allow_override and data_crs is not None and not data_crs == crs: raise ValueError( @@ -510,12 +504,12 @@ def set_crs( "want to transform the geometries, use '.xvec.to_crs' instead." ) - _obj = _obj.drop_indexes(variable_crs.keys()) + _obj = _obj.drop_indexes(variable_crs_solved.keys()) - for key, crs in variable_crs.items(): + for key, crs in variable_crs_solved.items(): if crs: _obj[key].attrs["crs"] = CRS.from_user_input(crs) - _obj = _obj.set_xindex(key, GeometryIndex, crs=crs) + _obj = _obj.set_xindex([key], GeometryIndex, crs=crs) return _obj @@ -523,10 +517,10 @@ def query( self, coord_name: str, geometry: shapely.Geometry | Sequence[shapely.Geometry], - predicate: str = None, - distance: float | Sequence[float] = None, - unique=False, - ): + predicate: str | None = None, + distance: float | Sequence[float] | None = None, + unique: bool = False, + ) -> xr.DataArray | xr.Dataset: """Return a subset of a DataArray/Dataset filtered using a spatial query on :class:`~xvec.GeometryIndex`. @@ -619,12 +613,12 @@ def query( """ if isinstance(geometry, shapely.Geometry): - ilocs = self._obj.xindexes[coord_name].sindex.query( + ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore geometry, predicate=predicate, distance=distance ) else: - _, ilocs = self._obj.xindexes[coord_name].sindex.query( + _, ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore geometry, predicate=predicate, distance=distance ) if unique: @@ -634,11 +628,11 @@ def query( def set_geom_indexes( self, - coord_names: str | Sequence[Hashable], + coord_names: str | Sequence[str], crs: Any = None, allow_override: bool = False, - **kwargs, - ): + **kwargs: dict[str, Any], + ) -> xr.DataArray | xr.Dataset: """Set a new :class:`~xvec.GeometryIndex` for one or more existing coordinate(s). One :class:`~xvec.GeometryIndex` is set per coordinate. Only 1-dimensional coordinates are supported. @@ -691,7 +685,7 @@ def set_geom_indexes( for coord in coord_names: if isinstance(self._obj.xindexes[coord], GeometryIndex): - data_crs = self._obj.xindexes[coord].crs + data_crs = self._obj.xindexes[coord].crs # type: ignore if not allow_override and data_crs is not None and not data_crs == crs: raise ValueError( @@ -710,7 +704,7 @@ def set_geom_indexes( return _obj - def to_geopandas(self): + def to_geopandas(self) -> GeoDataFrame | pd.DataFrame: """Convert this array into a GeoPandas :class:`~geopandas.GeoDataFrame` Returns a :class:`~geopandas.GeoDataFrame` with coordinates based on a @@ -762,11 +756,11 @@ def to_geopandas(self): if len(self._geom_indexes): if self._obj.ndim == 1: gdf = self._obj.to_pandas() - elif self._obj.ndim == 2: + else: gdf = self._obj.to_pandas() if gdf.columns.name == self._geom_indexes[0]: gdf = gdf.T - return gdf.reset_index().set_geometry( + return gdf.reset_index().set_geometry( # type: ignore self._geom_indexes[0], crs=self._obj.xindexes[self._geom_indexes[0]].crs, ) @@ -790,7 +784,7 @@ def to_geopandas(self): if index_name in self._geom_coords_all: return gdf.reset_index().set_geometry( index_name, crs=self._obj[index_name].attrs.get("crs", None) - ) + ) # type: ignore warnings.warn( "No active geometry column to be set. The resulting object " @@ -810,7 +804,7 @@ def to_geodataframe( dim_order: Sequence[Hashable] | None = None, geometry: Hashable | None = None, long: bool = True, - ): + ) -> GeoDataFrame | pd.DataFrame: """Convert this array and its coordinates into a tidy geopandas.GeoDataFrame. The GeoDataFrame is indexed by the Cartesian product of index coordinates @@ -884,7 +878,7 @@ def to_geodataframe( level for level in df.index.names if level not in self._geom_coords_all - ] + ] # type: ignore ) if isinstance(df.index, pd.MultiIndex): @@ -907,7 +901,7 @@ def to_geodataframe( if geometry is not None: return df.set_geometry( geometry, crs=self._obj[geometry].attrs.get("crs", None) - ) + ) # type: ignore warnings.warn( "No active geometry column to be set. The resulting object " @@ -926,12 +920,12 @@ def zonal_stats( y_coords: Hashable, stats: str | Callable | Sequence[str | Callable | tuple] = "mean", name: Hashable = "geometry", - index: bool = None, + index: bool | None = None, method: str = "rasterize", all_touched: bool = False, n_jobs: int = -1, - **kwargs, - ): + **kwargs: dict[str, Any], + ) -> xr.DataArray | xr.Dataset: """Extract the values from a dataset indexed by a set of geometries Given an object indexed by x and y coordinates (or latitude and longitude), such @@ -1121,9 +1115,9 @@ def extract_points( y_coords: Hashable, tolerance: float | None = None, name: str = "geometry", - crs: Any = None, - index: bool = None, - ): + crs: Any | None = None, + index: bool | None = None, + ) -> xr.DataArray | xr.Dataset: """Extract points from a DataArray or a Dataset indexed by spatial coordinates Given an object indexed by x and y coordinates (or latitude and longitude), such @@ -1263,3 +1257,22 @@ def extract_points( } ) return result + + +def _resolve_input( + positional: Mapping[Any, Any] | None, + keyword: Mapping[str, Any], + func_name: str, +) -> Mapping[Hashable, Any]: + """Resolve combination of positional and keyword arguments. + + Based on xarray's ``either_dict_or_kwargs``. + """ + if positional and keyword: + raise ValueError( + "Cannot specify both keyword and positional arguments to " + f"'.xvec.{func_name}'." + ) + if positional is None or positional == {}: + return cast(Mapping[Hashable, Any], keyword) + return positional diff --git a/xvec/index.py b/xvec/index.py index 04478a6..b262e71 100644 --- a/xvec/index.py +++ b/xvec/index.py @@ -22,7 +22,7 @@ def _format_crs(crs: CRS | None, max_width: int = 50) -> str: return srs if len(srs) <= max_width else " ".join([srs[:max_width], "..."]) -def _get_common_crs(crs_set: set[CRS | None]): +def _get_common_crs(crs_set: set[CRS | None]) -> CRS | None: # code taken from geopandas (BSD-3 Licence) crs_not_none = [crs for crs in crs_set if crs is not None] @@ -112,7 +112,7 @@ def _check_crs(self, other_crs: CRS | None, allow_none: bool = False) -> bool: def _crs_mismatch_raise( self, other_crs: CRS | None, warn: bool = False, stacklevel: int = 3 - ): + ) -> None: """Raise a CRS mismatch error or warning with the information on the assigned CRS. """ @@ -138,7 +138,7 @@ def from_variables( variables: Mapping[Any, Variable], *, options: Mapping[str, Any], - ): + ) -> GeometryIndex: # TODO: try getting CRS from coordinate attrs or GeometryArray or SRID index = PandasIndex.from_variables(variables, options={}) @@ -166,7 +166,7 @@ def create_variables( def to_pandas_index(self) -> pd.Index: return self._index.index - def isel(self, indexers: Mapping[Any, Any]): + def isel(self, indexers: Mapping[Any, Any]) -> GeometryIndex | None: index = self._index.isel(indexers) if index is not None: @@ -174,7 +174,7 @@ def isel(self, indexers: Mapping[Any, Any]): else: return None - def _sel_sindex(self, labels, method, tolerance): + def _sel_sindex(self, labels, method: str, tolerance) -> IndexSelResult: # only one entry expected assert len(labels) == 1 label = next(iter(labels.values())) @@ -212,7 +212,10 @@ def _sel_sindex(self, labels, method, tolerance): return IndexSelResult({self._index.dim: indices}) def sel( - self, labels: dict[Any, Any], method=None, tolerance=None + self, + labels: dict[Any, Any], + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, ) -> IndexSelResult: if method is None: return self._index.sel(labels) @@ -222,7 +225,7 @@ def sel( # options when `labels` is a single geometry. # Xarray currently doesn't support custom options # (see https://github.com/pydata/xarray/issues/7099) - return self._sel_sindex(labels, method, tolerance) + return self._sel_sindex(labels, method, tolerance) # type: ignore def equals(self, other: Index) -> bool: if not isinstance(other, GeometryIndex): @@ -241,7 +244,10 @@ def join( return type(self)(index, self.crs) def reindex_like( - self, other: GeometryIndex, method=None, tolerance=None + self, + other: GeometryIndex, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, ) -> dict[Hashable, Any]: if not self._check_crs(other.crs, allow_none=True): self._crs_mismatch_raise(other.crs) @@ -254,11 +260,13 @@ def roll(self, shifts: Mapping[Any, int]) -> GeometryIndex: index = self._index.roll(shifts) return type(self)(index, self.crs) - def rename(self, name_dict, dims_dict): + def rename( + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] + ) -> GeometryIndex: index = self._index.rename(name_dict, dims_dict) return type(self)(index, self.crs) - def _repr_inline_(self, max_width: int): + def _repr_inline_(self, max_width: int) -> str: # TODO: remove when fixed in XArray if max_width is None: max_width = get_options()["display_width"] diff --git a/xvec/zonal.py b/xvec/zonal.py index f4fc205..9e29d86 100644 --- a/xvec/zonal.py +++ b/xvec/zonal.py @@ -1,8 +1,8 @@ from __future__ import annotations import gc -from collections.abc import Hashable, Sequence -from typing import Callable +from collections.abc import Hashable, Iterable, Sequence +from typing import Any, Callable import numpy as np import pandas as pd @@ -30,13 +30,13 @@ def _zonal_stats_rasterize( x_coords: Hashable, y_coords: Hashable, stats: str | Callable | Sequence[str | Callable | tuple] = "mean", - name: str = "geometry", + name: Hashable = "geometry", all_touched: bool = False, **kwargs, -): +) -> xr.DataArray | xr.Dataset: try: - import rasterio import rioxarray # noqa: F401 + from rasterio import features except ImportError as err: raise ImportError( "The rioxarray package is required for `zonal_stats()`. " @@ -45,27 +45,27 @@ def _zonal_stats_rasterize( ) from err if hasattr(geometry, "crs"): - crs = geometry.crs + crs = geometry.crs # type: ignore else: crs = None transform = acc._obj.rio.transform() - labels = rasterio.features.rasterize( + labels = features.rasterize( zip(geometry, range(len(geometry))), out_shape=( acc._obj[y_coords].shape[0], acc._obj[x_coords].shape[0], ), transform=transform, - fill=np.nan, + fill=np.nan, # type: ignore all_touched=all_touched, ) groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords))) if pd.api.types.is_list_like(stats): agg = {} - for stat in stats: + for stat in stats: # type: ignore if isinstance(stat, str): agg[stat] = _agg_rasterize(groups, stat, **kwargs) elif callable(stat): @@ -76,19 +76,19 @@ def _zonal_stats_rasterize( else: raise ValueError(f"{stat} is not a valid aggregation.") - agg = xr.concat( + agg_array = xr.concat( agg.values(), dim=xr.DataArray( list(agg.keys()), name="zonal_statistics", dims="zonal_statistics" ), ) elif isinstance(stats, str) or callable(stats): - agg = _agg_rasterize(groups, stats, **kwargs) + agg_array = _agg_rasterize(groups, stats, **kwargs) else: raise ValueError(f"{stats} is not a valid aggregation.") vec_cube = ( - agg.reindex(group=range(len(geometry))) + agg_array.reindex(group=range(len(geometry))) .assign_coords(group=geometry) .rename(group=name) ).xvec.set_geom_indexes(name, crs=crs) @@ -105,11 +105,11 @@ def _zonal_stats_iterative( x_coords: Hashable, y_coords: Hashable, stats: str | Callable | Sequence[str | Callable | tuple] = "mean", - name: str = "geometry", + name: Hashable = "geometry", all_touched: bool = False, n_jobs: int = -1, - **kwargs, -): + **kwargs: dict[str, Any], +) -> xr.DataArray | xr.Dataset: """Extract the values from a dataset indexed by a set of geometries The CRS of the raster and that of geometry need to be equal. @@ -163,7 +163,7 @@ def _zonal_stats_iterative( ) from err try: - from joblib import Parallel, delayed + from joblib import Parallel, delayed # type: ignore except ImportError as err: raise ImportError( "The joblib package is required for `xvec._spatial_agg()`. " @@ -187,11 +187,12 @@ def _zonal_stats_iterative( for geom in geometry ) if hasattr(geometry, "crs"): - crs = geometry.crs + crs = geometry.crs # type: ignore else: crs = None vec_cube = xr.concat( - zonal, dim=xr.DataArray(geometry, name=name, dims=name) + zonal, # type: ignore + dim=xr.DataArray(geometry, name=name, dims=name), ).xvec.set_geom_indexes(name, crs=crs) gc.collect() @@ -202,10 +203,10 @@ def _agg_geom( acc, geom, trans, - x_coords: str = None, - y_coords: str = None, - stats: str | Callable | Sequence[str | Callable | tuple] = "mean", - all_touched=False, + x_coords: str | None = None, + y_coords: str | None = None, + stats: str | Callable | Iterable[str | Callable | tuple] = "mean", + all_touched: bool = False, **kwargs, ): """Aggregate the values from a dataset over a polygon geometry. @@ -239,9 +240,9 @@ def _agg_geom( Aggregated values over the geometry. """ - import rasterio + from rasterio import features - mask = rasterio.features.geometry_mask( + mask = features.geometry_mask( [geom], out_shape=( acc._obj[y_coords].shape[0], @@ -254,7 +255,7 @@ def _agg_geom( masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords))) if pd.api.types.is_list_like(stats): agg = {} - for stat in stats: + for stat in stats: # type: ignore if isinstance(stat, str): agg[stat] = _agg_iterate(masked, stat, x_coords, y_coords, **kwargs) elif callable(stat):