Skip to content

Commit

Permalink
Merge pull request #692 from DHI/fix-single-timestep-freq
Browse files Browse the repository at this point in the history
Single timestep dataset forgets dt
  • Loading branch information
ecomodeller authored May 10, 2024
2 parents 24537e1 + 28b2672 commit 8516188
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 26 deletions.
40 changes: 33 additions & 7 deletions mikeio/dataset/_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ def __call__(self, tail: bool = True) -> "DataArray":
geometry = GeometryUndefined()

return DataArray(
data=Hm0, time=self.da.time, item=item, dims=dims, geometry=geometry
data=Hm0,
time=self.da.time,
item=item,
dims=dims,
geometry=geometry,
dt=self.da._dt,
)


Expand Down Expand Up @@ -162,10 +167,12 @@ def __init__(
geometry: GeometryType | None = None,
zn: np.ndarray | None = None,
dims: Sequence[str] | None = None,
dt: float = 1.0,
) -> None:
# TODO: add optional validation validate=True
self._values = self._parse_data(data)
self.time: pd.DatetimeIndex = self._parse_time(time)
self._dt = dt

geometry = GeometryUndefined() if geometry is None else geometry
self.dims = self._parse_dims(dims, geometry)
Expand Down Expand Up @@ -421,11 +428,11 @@ def is_equidistant(self) -> bool:
return len(self.time.to_series().diff().dropna().unique()) == 1

@property
def timestep(self) -> float | None:
def timestep(self) -> float:
"""Time step in seconds if equidistant (and at
least two time instances); otherwise None
least two time instances); otherwise original time step is returned.
"""
dt = None
dt = self._dt
if len(self.time) > 1 and self.is_equidistant:
first: pd.Timestamp = self.time[0]
second: pd.Timestamp = self.time[1]
Expand Down Expand Up @@ -539,6 +546,7 @@ def squeeze(self) -> "DataArray":
geometry=self.geometry,
zn=self._zn,
dims=tuple(dims),
dt=self._dt,
)

# ============= Select/interp ===========
Expand Down Expand Up @@ -718,6 +726,7 @@ def isel(
geometry=geometry,
zn=zn,
dims=dims,
dt=self._dt,
)

def sel(
Expand Down Expand Up @@ -969,7 +978,11 @@ def interp(
# )

da = DataArray(
data=dai, time=self.time, geometry=geometry, item=deepcopy(self.item)
data=dai,
time=self.time,
geometry=geometry,
item=deepcopy(self.item),
dt=self._dt,
)
else:
da = self.copy()
Expand Down Expand Up @@ -1097,6 +1110,7 @@ def interp_time(
item=deepcopy(self.item),
geometry=self.geometry,
zn=zn,
dt=self._dt,
)

def interp_na(self, axis: str = "time", **kwargs: Any) -> "DataArray":
Expand Down Expand Up @@ -1197,7 +1211,11 @@ def interp_like(
)
assert isinstance(ari, np.ndarray)
dai = DataArray(
data=ari, time=self.time, geometry=geom, item=deepcopy(self.item)
data=ari,
time=self.time,
geometry=geom,
item=deepcopy(self.item),
dt=self._dt,
)

if hasattr(other, "time"):
Expand Down Expand Up @@ -1506,6 +1524,7 @@ def aggregate(
geometry=geometry,
dims=dims,
zn=zn,
dt=self._dt,
)

@overload
Expand Down Expand Up @@ -1599,7 +1618,13 @@ def _quantile(self, q, *, axis: int | str = 0, func=np.quantile, **kwargs: Any):
dims = tuple([d for i, d in enumerate(self.dims) if i != axis])
item = deepcopy(self.item)
return DataArray(
data=qdat, time=time, item=item, geometry=geometry, dims=dims, zn=zn
data=qdat,
time=time,
item=item,
geometry=geometry,
dims=dims,
zn=zn,
dt=self._dt,
)
else:
res = []
Expand Down Expand Up @@ -1747,6 +1772,7 @@ def _boolmask_to_new_DataArray(self, bmask) -> "DataArray": # type: ignore
item=ItemInfo("Boolean"),
geometry=self.geometry,
zn=self._zn,
dt=self._dt,
)

# ============= output methods: to_xxx() ===========
Expand Down
27 changes: 17 additions & 10 deletions mikeio/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ def __init__(
zn: NDArray[np.floating] | None = None,
dims: Tuple[str, ...] | None = None,
validate: bool = True,
dt: float = 1.0,
):
if not self._is_DataArrays(data):
data = self._create_dataarrays(
data=data, time=time, items=items, geometry=geometry, zn=zn, dims=dims # type: ignore
data=data, time=time, items=items, geometry=geometry, zn=zn, dims=dims, dt=dt # type: ignore
)
self._data_vars: MutableMapping[str, DataArray] = self._init_from_DataArrays(data, validate=validate) # type: ignore
self.plot = _DatasetPlotter(self)
Expand All @@ -123,11 +124,12 @@ def _is_DataArrays(data: Any) -> bool:
@staticmethod
def _create_dataarrays(
data: Sequence[NDArray[np.floating]] | NDArray[np.floating],
time: pd.DatetimeIndex | None = None,
items: Sequence[ItemInfo] | None = None,
geometry: Any = None,
zn: NDArray[np.floating] | None = None,
dims: Tuple[str, ...] | None = None,
time: pd.DatetimeIndex,
items: Sequence[ItemInfo],
geometry: Any,
zn: NDArray[np.floating],
dims: Tuple[str, ...],
dt: float,
) -> Mapping[str, DataArray]:
if not isinstance(data, Iterable):
data = [data]
Expand All @@ -137,7 +139,7 @@ def _create_dataarrays(
data_vars = {}
for dd, it in zip(data, items):
data_vars[it.name] = DataArray(
data=dd, time=time, item=it, geometry=geometry, zn=zn, dims=dims
data=dd, time=time, item=it, geometry=geometry, zn=zn, dims=dims, dt=dt
)
return data_vars

Expand Down Expand Up @@ -303,6 +305,11 @@ def _check_already_present(self, new_da: DataArray) -> None:

# ============= Basic properties/methods ===========

@property
def _dt(self) -> float:
"""Original time step in seconds"""
return self[0]._dt

@property
def time(self) -> pd.DatetimeIndex:
"""Time axis"""
Expand All @@ -326,11 +333,11 @@ def end_time(self) -> datetime:
return self.time[-1].to_pydatetime() # type: ignore

@property
def timestep(self) -> float | None:
def timestep(self) -> float:
"""Time step in seconds if equidistant (and at
least two time instances); otherwise None
least two time instances); otherwise original time step is returned.
"""
dt = None
dt = self._dt
if len(self.time) > 1 and self.is_equidistant:
dt = (self.time[1] - self.time[0]).total_seconds()
return dt
Expand Down
9 changes: 8 additions & 1 deletion mikeio/dfs/_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,14 @@ def read(
items = _get_item_info(self._dfs.ItemInfo, item_numbers)

self._dfs.Close()
return Dataset(data_list, time, items, geometry=self.geometry, validate=False)
return Dataset(
data_list,
time,
items,
geometry=self.geometry,
validate=False,
dt=self._timestep,
)

def _open(self) -> None:
raise NotImplementedError("Should be implemented by subclass")
Expand Down
17 changes: 10 additions & 7 deletions mikeio/dfsu/_dfsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@ def write_dfsu(filename: str | Path, data: Dataset) -> None:
"""
filename = str(filename)

if len(data.time) == 1:
dt = 1 # TODO is there any sensible default?
else:
if not data.is_equidistant:
raise ValueError("Non-equidistant time axis is not supported.")
if not data.is_equidistant:
raise ValueError("Non-equidistant time axis is not supported.")

dt = (data.time[1] - data.time[0]).total_seconds() # type: ignore
dt = data.timestep
n_time_steps = len(data.time)

geometry = data.geometry
Expand Down Expand Up @@ -485,7 +482,13 @@ def read(
data_list = [np.squeeze(d, axis=-1) for d in data_list]

return Dataset(
data_list, time, items, geometry=geometry, dims=dims, validate=False
data_list,
time,
items,
geometry=geometry,
dims=dims,
validate=False,
dt=self.timestep,
)

def _parse_geometry_sel(self, area, x, y):
Expand Down
9 changes: 8 additions & 1 deletion mikeio/dfsu/_layered.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,17 @@ def read(
zn=data_list[0],
dims=dims,
validate=False,
dt=self.timestep,
)
else:
return Dataset(
data_list, time, items, geometry=geometry, dims=dims, validate=False
data_list,
time,
items,
geometry=geometry,
dims=dims,
validate=False,
dt=self.timestep,
)


Expand Down
30 changes: 30 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,3 +1609,33 @@ def test_interp_na():
def test_plot_scatter():
ds = mikeio.read("tests/testdata/oresund_sigma_z.dfsu", time=0)
ds.plot.scatter(x="Salinity", y="Temperature", title="S-vs-T")


def test_select_single_timestep_preserves_dt():
ds = mikeio.read("tests/testdata/tide1.dfs1")
assert ds.timestep == pytest.approx(1800.0)
ds2 = ds.isel(time=-1)
assert ds2.timestep == pytest.approx(1800.0)
assert ds2[0].timestep == pytest.approx(1800.0)


def test_select_multiple_spaced_timesteps_uses_proper_dt(tmp_path):
ds = mikeio.read("tests/testdata/tide1.dfs1")
assert ds.timestep == pytest.approx(1800.0)
ds2 = ds.isel(time=[0, 2, 4])
assert ds2.timestep == pytest.approx(3600.0)


def test_read_write_single_timestep_preserves_dt(tmp_path):
fn = "tests/testdata/oresund_sigma_z.dfsu"
dfs = mikeio.open(fn)
assert dfs.timestep == pytest.approx(10800.0)

ds = dfs.read(time=[0])
assert ds.timestep == pytest.approx(dfs.timestep)

outfn = tmp_path / "single.dfsu"
ds.to_dfs(outfn)

dfs2 = mikeio.open(outfn)
assert dfs2.timestep == pytest.approx(10800.0)

0 comments on commit 8516188

Please sign in to comment.