Skip to content

Commit

Permalink
store changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh committed Dec 15, 2023
1 parent 5c99a2f commit 48cdbc8
Showing 1 changed file with 46 additions and 24 deletions.
70 changes: 46 additions & 24 deletions torchgeo/datasets/rioxr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@ def dtype(self) -> torch.dtype:
return torch.float32
else:
return torch.long

def harmonize_format(self, ds):
"""Convert the dataset to the standard format.
Args:
ds: dataset or array to harmonize
Returns:
the harmonized dataset or array
"""
# rioxarray expects spatial dimensions to be named x and y
ds.rio.set_spatial_dims(
self.spatial_x_name, self.spatial_y_name, inplace=True
)

# if x coords go from 0 to 360, convert to -180 to 180
if ds[self.spatial_x_name].min() > 180:
ds = ds.assign_coords({self.spatial_x_name: ds[self.spatial_x_name] % 360 - 180})

# if y coords go from 0 to 180, convert to -90 to 90
if ds[self.spatial_x_name].min() > 90:
ds = ds.assign_coords({self.spatial_y_name: ds[self.spatial_y_name] % 180 - 90})
# expect asceding coordinate values
ds = ds.sortby(self.spatial_x_name, ascending=True)
ds = ds.sortby(self.spatial_y_name, ascending=True)
return ds

def __init__(
self,
Expand Down Expand Up @@ -94,10 +120,7 @@ def __init__(
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
with xr.open_dataset(filepath, decode_times=True) as ds:
# rioxarray expects spatial dimensions to be named x and y
ds.rio.set_spatial_dims(
self.spatial_x_name, self.spatial_y_name, inplace=True
)
ds = self.harmonize_format(ds)

if crs is None:
crs = ds.rio.crs
Expand Down Expand Up @@ -196,44 +219,43 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
data_arrays: list["np.typing.NDArray"] = []
for item in items:
with xr.open_dataset(item, decode_cf=True) as ds:
# rioxarray expects spatial dimensions to be named x and y
ds.rio.set_spatial_dims(
self.spatial_x_name, self.spatial_y_name, inplace=True
)

if not ds.rio.crs:
ds.rio.write_crs(self._crs, inplace=True)
elif ds.rio.crs != self._crs:
ds = ds.rio.reproject(self._crs)

# clip box ignores time dimension
clipped = ds.rio.clip_box(
minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy
)

ds = self.harmonize_format(ds)
# select time dimension
if hasattr(ds, "time"):
try:
clipped["time"] = clipped.indexes["time"].to_datetimeindex()
ds["time"] = ds.indexes["time"].to_datetimeindex()
except AttributeError:
clipped["time"] = clipped.indexes["time"]
clipped = clipped.sel(
ds["time"] = ds.indexes["time"]
ds = ds.sel(
time=slice(
datetime.fromtimestamp(query.mint),
datetime.fromtimestamp(query.maxt),
)
)

for variable in self.data_variables:
if hasattr(clipped, variable):
if hasattr(ds, variable):
da = ds[variable]
if not da.rio.crs:
da.rio.write_crs(self._crs, inplace=True)
elif da.rio.crs != self._crs:
da = da.rio.reproject(self._crs)
# clip box ignores time dimension
clipped = da.rio.clip_box(
minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy
)
# rioxarray expects this order
clipped = clipped.transpose(
"time", self.spatial_y_name, self.spatial_x_name, ...
)

# set proper transform # TODO not working
# clipped = clipped.rio.write_transform(self.transform)
data_arrays.append(clipped[variable].squeeze())
clipped.rio.write_transform(self.transform)
data_arrays.append(clipped.squeeze())

import pdb
pdb.set_trace()
merged_data = torch.from_numpy(
merge_arrays(
data_arrays, bounds=(query.minx, query.miny, query.maxx, query.maxy)
Expand Down

0 comments on commit 48cdbc8

Please sign in to comment.