From 48cdbc8d908c4a7bbe345e45b0499407e2a335eb Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Fri, 15 Dec 2023 15:35:52 +0100 Subject: [PATCH] store changes --- torchgeo/datasets/rioxr.py | 70 +++++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/torchgeo/datasets/rioxr.py b/torchgeo/datasets/rioxr.py index b419f8f2a15..a9d6cdb871b 100644 --- a/torchgeo/datasets/rioxr.py +++ b/torchgeo/datasets/rioxr.py @@ -50,6 +50,32 @@ def dtype(self) -> torch.dtype: return torch.float32 else: return torch.long + + def harmonize_format(self, ds): + """Convert the dataset to the standard format. + + Args: + ds: dataset or array to harmonize + + Returns: + the harmonized dataset or array + """ + # rioxarray expects spatial dimensions to be named x and y + ds.rio.set_spatial_dims( + self.spatial_x_name, self.spatial_y_name, inplace=True + ) + + # if x coords go from 0 to 360, convert to -180 to 180 + if ds[self.spatial_x_name].min() > 180: + ds = ds.assign_coords({self.spatial_x_name: ds[self.spatial_x_name] % 360 - 180}) + + # if y coords go from 0 to 180, convert to -90 to 90 + if ds[self.spatial_x_name].min() > 90: + ds = ds.assign_coords({self.spatial_y_name: ds[self.spatial_y_name] % 180 - 90}) + # expect asceding coordinate values + ds = ds.sortby(self.spatial_x_name, ascending=True) + ds = ds.sortby(self.spatial_y_name, ascending=True) + return ds def __init__( self, @@ -94,10 +120,7 @@ def __init__( match = re.match(filename_regex, os.path.basename(filepath)) if match is not None: with xr.open_dataset(filepath, decode_times=True) as ds: - # rioxarray expects spatial dimensions to be named x and y - ds.rio.set_spatial_dims( - self.spatial_x_name, self.spatial_y_name, inplace=True - ) + ds = self.harmonize_format(ds) if crs is None: crs = ds.rio.crs @@ -196,27 +219,15 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: data_arrays: list["np.typing.NDArray"] = [] for item in items: with xr.open_dataset(item, decode_cf=True) as ds: - # rioxarray expects spatial dimensions to be named x and y - ds.rio.set_spatial_dims( - self.spatial_x_name, self.spatial_y_name, inplace=True - ) - - if not ds.rio.crs: - ds.rio.write_crs(self._crs, inplace=True) - elif ds.rio.crs != self._crs: - ds = ds.rio.reproject(self._crs) - - # clip box ignores time dimension - clipped = ds.rio.clip_box( - minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy - ) + + ds = self.harmonize_format(ds) # select time dimension if hasattr(ds, "time"): try: - clipped["time"] = clipped.indexes["time"].to_datetimeindex() + ds["time"] = ds.indexes["time"].to_datetimeindex() except AttributeError: - clipped["time"] = clipped.indexes["time"] - clipped = clipped.sel( + ds["time"] = ds.indexes["time"] + ds = ds.sel( time=slice( datetime.fromtimestamp(query.mint), datetime.fromtimestamp(query.maxt), @@ -224,16 +235,27 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: ) for variable in self.data_variables: - if hasattr(clipped, variable): + if hasattr(ds, variable): + da = ds[variable] + if not da.rio.crs: + da.rio.write_crs(self._crs, inplace=True) + elif da.rio.crs != self._crs: + da = da.rio.reproject(self._crs) + # clip box ignores time dimension + clipped = da.rio.clip_box( + minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy + ) # rioxarray expects this order clipped = clipped.transpose( "time", self.spatial_y_name, self.spatial_x_name, ... ) # set proper transform # TODO not working - # clipped = clipped.rio.write_transform(self.transform) - data_arrays.append(clipped[variable].squeeze()) + clipped.rio.write_transform(self.transform) + data_arrays.append(clipped.squeeze()) + import pdb + pdb.set_trace() merged_data = torch.from_numpy( merge_arrays( data_arrays, bounds=(query.minx, query.miny, query.maxx, query.maxy)