Skip to content

Commit

Permalink
fix_379 (#422)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers authored Aug 17, 2024
1 parent 8942801 commit ea8bf01
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 6 deletions.
7 changes: 4 additions & 3 deletions dascore/proc/taper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def _get_taper_slices(patch, kwargs):
start, stop = value, value
dur = coord.max() - coord.min()
# either let units pass through or multiply by d_len
start = start if isinstance(start, Quantity) or start is None else start * dur
stop = stop if isinstance(stop, Quantity) or stop is None else stop * dur
clses = (Quantity, np.timedelta64)
start = start if isinstance(start, clses) or start is None else start * dur
stop = stop if isinstance(stop, clses) or stop is None else stop * dur
stop = -stop if stop is not None else stop
_, inds_1 = coord.select((None, start), relative=True)
_, inds_2 = coord.select((stop, None), relative=True)
Expand All @@ -61,7 +62,7 @@ def _get_window_function(window_type):


def _validate_windows(samps, start_slice, end_slice, shape, axis):
"""Validate the the windows don't overlap or exceed dim len."""
"""Validate the windows don't overlap or exceed dim len."""
max_len = shape[axis]
start_ind = start_slice.stop
end_ind = end_slice.start
Expand Down
7 changes: 6 additions & 1 deletion dascore/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import pint
from pint import DimensionalityError, Quantity, UndefinedUnitError, Unit

import dascore as dc
from dascore.exceptions import UnitError
from dascore.utils.misc import unbyte
from dascore.utils.time import dtype_time_like
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float

str_or_none = TypeVar("str_or_none", None, str)
numeric = TypeVar("numeric", np.ndarray, int, float)
Expand Down Expand Up @@ -76,12 +77,16 @@ def get_quantity(value: str_or_none) -> Quantity | None:
>>> import dascore as dc
>>> meters = dc.get_quantity("m")
>>> accel = dc.get_quantity("m/s^2")
>>> # This can also convert date times.
>>> many_seconds = dc.get_quantity(dc.to_timedelta64(200))
"""
value = unbyte(value)
if value is None or value is ... or value == "":
return None
if isinstance(value, Quantity):
return value
elif is_datetime64(value) | is_timedelta64(value):
return to_float(value) * dc.get_unit("s")
return _str_to_quant(value)


Expand Down
6 changes: 5 additions & 1 deletion dascore/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,12 @@ def all_close(ar1, ar2):
ar1, ar2 = np.asarray(ar1), np.asarray(ar2)
if not ar1.shape == ar2.shape:
return False
ar1_null = pd.isnull(ar1)
ar2_null = pd.isnull(ar2)
try:
return np.allclose(ar1, ar2)
close = np.isclose(ar1, ar2)
bools = close | ar1_null | ar2_null
return np.all(bools)
except TypeError:
return np.all(ar1 == ar2)

Expand Down
9 changes: 9 additions & 0 deletions tests/test_proc/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dascore as dc
from dascore.exceptions import ParameterError
from dascore.units import m
from dascore.utils.misc import all_close
from dascore.utils.pd import rolling_df


Expand Down Expand Up @@ -57,6 +58,14 @@ def test_along_time(self, range_patch):
out = trans.rolling(time=time_step * self.window).sum()
assert np.allclose(out.dropna("time").data, expected.transpose())

def test_rolling_timdelta(self, random_patch):
"""Ensure rolling works with timedeltas."""
time_step = random_patch.get_coord("time").step
time = time_step * self.window
out1 = random_patch.rolling(time=dc.to_timedelta64(time)).sum()
out2 = random_patch.rolling(time=time * dc.get_quantity("s")).sum()
assert all_close(out1, out2)

def test_apply_with_step(self, range_patch):
"""Ensure apply works with various step sizes."""
# first calculate rolling max on time axis.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_proc/test_taper.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,11 @@ def test_taper_with_units(self, patch_ones):
assert np.allclose(
data_new[mid_dim - 10 : mid_dim + 10], data_old[mid_dim - 10 : mid_dim + 10]
)

def test_timedelta_taper(self, random_patch):
"""Test that a timedelta works for the taper argument. See #379."""
time1 = dc.to_timedelta64(2)
time2 = 2 * dc.get_quantity("seconds")
patch1 = random_patch.taper(time=time1)
patch2 = random_patch.taper(time=time2)
assert patch1 == patch2
38 changes: 37 additions & 1 deletion tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,23 @@ def test_quantity(self):
quant = get_quantity("m/s")
out = get_quantity_str(quant)
assert out == "m / s"
# with magnitude it should be included.
# with magnitude, it should be included.
quant = get_quantity("10 m /s")
out = get_quantity_str(quant)
assert "10.0" in out

def test_timedelta_to_quantity(self):
"""Ensure a timedelta can be converted to a quantity."""
dt = dc.to_timedelta64(20)
quant = dc.get_quantity(dt)
assert quant == (20 * dc.get_unit("s"))

def test_datetime_to_quantity(self):
"""Ensure a datetime can be converted to a quantity."""
td = dc.to_datetime64("1970-01-01T00:00:20")
quant = dc.get_quantity(td)
assert quant == (20 * dc.get_unit("s"))


class TestUnitAndFactor:
"""tests for returning units and scaling factor."""
Expand All @@ -93,6 +105,20 @@ def test_none(self):
assert factor == 1
assert unit is None

def test_timedelta64(self):
"""Ensure timedeltas can be separated."""
td = dc.to_timedelta64(20)
(factor, unit) = get_factor_and_unit(td)
assert factor == 20.00
assert unit == "s"

def test_datetime64(self):
"""Ensure datetime64 can be separated."""
td = dc.to_datetime64(20)
(factor, unit) = get_factor_and_unit(td)
assert factor == 20.00
assert unit == "s"


class TestGetQuantity:
"""Tests for getting a quantity."""
Expand All @@ -111,6 +137,16 @@ def test_get_temp(self):
quant1 = get_quantity("degC")
assert "°C" in str(quant1)

def test_timedelta64(self):
"""Ensure time deltas can be converted to quantity"""
quant = get_quantity(dc.to_timedelta64(20))
assert quant == (20 * dc.get_unit("s"))

def test_datetime64(self):
"""Ensure time deltas can be converted to quantity"""
quant = get_quantity(dc.to_datetime64(20))
assert quant == (20 * dc.get_unit("s"))


class TestConvenientImport:
"""Tests for conveniently importing units for dascore.units."""
Expand Down

0 comments on commit ea8bf01

Please sign in to comment.