From 9113b0237f7bf6119699abd36d624f5a429ad851 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 22 Mar 2024 17:56:43 +0100 Subject: [PATCH] fix linter errors --- stmtools/stm.py | 37 ++++++++++++++++++++++++++++++------- stmtools/utils.py | 10 ++++++++-- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 70462f4..8c42cae 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -12,7 +12,6 @@ from shapely.geometry import Point from shapely.strtree import STRtree -from stmtools import utils from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -39,6 +38,7 @@ def add_metadata(self, metadata): ------- xarray.Dataset STM with assigned attributes. + """ self._obj = self._obj.assign_attrs(metadata) return self._obj @@ -70,6 +70,7 @@ def regulate_dims(self, space_label=None, time_label=None): ------- xarray.Dataset Regulated STM. + """ if ( (space_label is None) @@ -129,6 +130,7 @@ def subset(self, method: str, **kwargs): ------- xarray.Dataset A subset of the original STM. + """ # Check if both "space" and "time" dimension exists for dim in ["space", "time"]: @@ -204,6 +206,7 @@ def enrich_from_polygon(self, polygon, fields, xlabel="lon", ylabel="lat"): ------- xarray.Dataset Enriched STM. + """ _ = _validate_coords(self._obj, xlabel, ylabel) @@ -267,6 +270,7 @@ def _in_polygon(self, polygon, xlabel="lon", ylabel="lat"): ------- Dask.array A boolean Dask array. True where a space entry is inside the (multi-)polygon. + """ # Check if coords exists _ = _validate_coords(self._obj, xlabel, ylabel) @@ -312,6 +316,7 @@ def register_metadata(self, dict_meta: STMMetaData): ------- xarray.Dataset STM with registered metadata. + """ ds_updated = self._obj.assign_attrs(dict_meta) @@ -331,6 +336,7 @@ def register_datatype(self, keys: str | Iterable, datatype: DataVarTypes): ------- xarray.Dataset STM with registered metadata. + """ ds_updated = self._obj @@ -364,6 +370,7 @@ def get_order(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): Scaling multiplier to the x coordinates before truncating them to integer values. yscale : float Scaling multiplier to the y coordinates before truncating them to integer values. + """ meta_arr = np.array((), dtype=np.int64) order = da.apply_gufunc( @@ -396,6 +403,7 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): Scaling multiplier to the x coordinates before truncating them to integer values. yscale : float Scaling multiplier to the y coordinates before truncating them to integer values. + """ self._obj = self.get_order(xlabel, ylabel, xscale, yscale) self._obj = self._obj.sortby(self._obj.order) @@ -422,10 +430,12 @@ def enrich_from_dataset(self, method : str, optional Method of interpolation, by default "nearest", see https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like + Returns ------- xarray.Dataset Enriched STM. + """ # Check if fields is a Iterable or a str if isinstance(fields, str): @@ -455,7 +465,7 @@ def enrich_from_dataset(self, else: raise ValueError( "The input dataset is not a point or raster dataset." - "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." # give help on renaming + "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." "Consider renaming using " "https://docs.xarray.dev/en/latest/generated/xarray.Dataset.rename.html#xarray-dataset-rename" ) @@ -494,6 +504,7 @@ def num_points(self): ------- int Number of space entry. + """ return self._obj.dims["space"] @@ -505,6 +516,7 @@ def num_epochs(self): ------- int Number of epochs. + """ return self._obj.dims["time"] @@ -558,6 +570,7 @@ def _ml_str_query(xx, yy, polygon, type_polygon): An array with two columns. The first column is the positional index into the list of polygons being used to query the tree. The second column is the positional index into the list of space entries for which the tree was constructed. + """ # Crop the polygon to the bounding box of the block xmin, ymin, xmax, ymax = [ @@ -623,6 +636,7 @@ def _validate_coords(ds, xlabel, ylabel): ------ ValueError If xlabel or ylabel neither exists in coordinates, raise ValueError + """ for clabel in [xlabel, ylabel]: if clabel not in ds.coords.keys(): @@ -655,6 +669,7 @@ def _compute_morton_code(xx, yy): ------- array_like An array with Morton codes per coordinate pair. + """ code = [pm.interleave(int(xi), int(yi)) for xi, yi in zip(xx, yy, strict=True)] return code @@ -670,8 +685,8 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): Parameters ---------- ds : xarray.Dataset - - dataset : xarray.Dataset | xarray.DataArray + SpaceTimeMatrix to enrich + dataraster : xarray.Dataset | xarray.DataArray Input data for enrichment fields : str or list of str Field name(s) in the dataset for enrichment @@ -681,6 +696,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): Returns ------- xarray.Dataset + """ # interpolate the raster dataset to the coordinates of ds interpolated = dataraster.interp(ds.coords, method=method) @@ -699,7 +715,7 @@ def _enrich_from_points_block(ds, datapoints, fields): Parameters ---------- ds : xarray.Dataset - + SpaceTimeMatrix to enrich datapoints : xarray.Dataset | xarray.DataArray Input data for enrichment fields : str or list of str @@ -708,11 +724,16 @@ def _enrich_from_points_block(ds, datapoints, fields): Returns ------- xarray.Dataset + """ # unstak the dimensions for dim in datapoints.dims: if dim not in datapoints.coords: - indexer = {dim: [coord for coord in datapoints.coords if dim in datapoints[coord].dims]} + indexer = { + dim: [ + coord for coord in datapoints.coords if dim in datapoints[coord].dims + ] + } datapoints = datapoints.set_index(indexer) datapoints = datapoints.unstack(dim) @@ -722,6 +743,8 @@ def _enrich_from_points_block(ds, datapoints, fields): # Assign these values to the corresponding points in ds for field in fields: - ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords) + ds[field] = xr.DataArray( + selections[field].data.transpose(), dims=ds.dims, coords=ds.coords + ) return ds diff --git a/stmtools/utils.py b/stmtools/utils.py index 95bdf3f..c41e261 100644 --- a/stmtools/utils.py +++ b/stmtools/utils.py @@ -1,6 +1,7 @@ -import xarray as xr from collections.abc import Iterable +import xarray as xr + def _has_property(ds, keys: str | Iterable): if isinstance(keys, str): @@ -27,6 +28,7 @@ def crop(ds, other, buffer): ------- xarray.Dataset Cropped dataset. + """ if isinstance(ds, xr.DataArray): ds = ds.to_dataset() @@ -48,7 +50,11 @@ def crop(ds, other, buffer): indexer = {} for dim in other.dims: if dim not in other.coords.keys(): - indexer = {dim: [coord for coord in other.coords.keys() if dim in other.coords[coord].dims]} + indexer = { + dim: [ + coord for coord in other.coords.keys() if dim in other.coords[coord].dims + ] + } other = other.set_index(indexer) other = other.unstack(indexer)