Skip to content

Commit

Permalink
dft to be idempotent for already transformed coords (#397)
Browse files Browse the repository at this point in the history
* fix_395

* fix #390
  • Loading branch information
d-chambers authored Jun 11, 2024
1 parent 0b3cb2f commit c72907f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
26 changes: 24 additions & 2 deletions dascore/transform/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,24 @@ def _get_dft_attrs(patch, dims, new_coords):
new = dict(patch.attrs)
new["dims"] = new_coords.dims
new["data_units"] = _get_data_units_from_dims(patch, dims, mul)
# As per #390, we also want to remove data_type (eg the patch is no
# longer in strain rate after the dft)
new["_pre_dft_data_type"] = new.pop("data_type", None)
return PatchAttrs(**new)


def _get_untransformed_dims(patch, dims):
"""Return dimensions which have not been transformed."""
dim_set = set(patch.dims)
out = []
for dim in dims:
# This dim has already been transformed.
if (dim not in dim_set) and f"ft_{dim}" in dim_set:
continue
out.append(dim)
return out


@patch_function()
def dft(
patch: PatchType, dim: str | None | Sequence[str], *, real: str | bool | None = None
Expand Down Expand Up @@ -112,7 +127,7 @@ def dft(
- Non-dimensional coordiantes associated with transformed coordinates
will be dropped in the output.
- See the [FFT note](dascore.org/notes/fft_notes.html) in the Notes section
- See the [FFT note](`notes/dft_notes.qmd`) in the Notes section
of DASCore's documentation for more details.
See Also
Expand All @@ -132,11 +147,15 @@ def dft(
"""
dims = list(iterate(dim if dim is not None else patch.dims))
patch.assert_has_coords(dims)
real = dims[-1] if real is True else real # if true grab last dim
dims = _get_untransformed_dims(patch, dims)
real = real if real in dims else None # may need to reset real
if not dims: # no transformation needed.
return patch
# re-arrange list so real dim is last (if provided)
if isinstance(real, str):
assert real in dims, "real must be in provided dimensions."
dims.append(dims.pop(dims.index(real)))
real = dims[-1] if real is True else real # if true grab last dim
# get axes and spacing along desired dimensions.
dxs, axes = _get_dx_or_spacing_and_axes(patch, dims, require_evenly_spaced=True)
func = nft.rfftn if real is not None else nft.fftn
Expand Down Expand Up @@ -209,6 +228,9 @@ def _get_idft_attrs(patch, dims, new_coords):
new = dict(patch.attrs)
new["dims"] = new_coords.dims
new["data_units"] = _get_data_units_from_dims(patch, dims, mul)
# Restore the pre-dft datatype.
if "_pre_dft_data_type" in new:
new["data_type"] = new.pop("_pre_dft_data_type", None)
return PatchAttrs(**new)


Expand Down
44 changes: 41 additions & 3 deletions tests/test_transform/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
@pytest.fixture(scope="session")
def sin_patch():
"""Get the sine wave patch, set units for testing."""
patch = dc.get_example_patch("sin_wav", sample_rate=100, duration=3, frequency=F_0)
out = patch.set_units(get_quantity("1.0 V"), time="s", distance="m")
return out
patch = (
dc.get_example_patch("sin_wav", sample_rate=100, duration=3, frequency=F_0)
.set_units(get_quantity("1.0 V"), time="s", distance="m")
.update_attrs(data_type="strain_rate")
)
return patch


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -124,6 +127,36 @@ def test_parseval(self, sin_patch, fft_sin_patch_time):
vals2 = (pa2.abs() ** 2).integrate("ft_time", definite=True)
assert np.allclose(vals1.data, vals2.data)

def test_idempotent_single_dim(self, fft_sin_patch_time):
"""
Ensure dft is idempotent for a single dimension.
"""
out = fft_sin_patch_time.dft("time")
assert out.equals(fft_sin_patch_time)

def test_idempotent_all_dims(self, fft_sin_patch_all):
"""
Ensure dft is idempotent for transforms applied to all dims.
"""
out = fft_sin_patch_all.dft(dim=("time", "distance"))
assert out.equals(fft_sin_patch_all)

def test_transform_single_dim(
self, sin_patch, fft_sin_patch_time, fft_sin_patch_all
):
"""
Ensure dft is idempotent for time, but untransformed axis still gets
transformed.
"""
out = fft_sin_patch_time.dft(dim=("time", "distance"))
assert not out.equals(fft_sin_patch_time)
assert np.allclose(out.data, fft_sin_patch_all.data)

def test_datatype_removed(self, fft_sin_patch_time, sin_patch):
"""Ensure the data_type attr is removed after transform."""
assert sin_patch.attrs.data_type == "strain_rate"
assert fft_sin_patch_time.attrs.data_type == ""


class TestInverseDiscreteFourierTransform:
"""Inverse DFT suite."""
Expand Down Expand Up @@ -168,3 +201,8 @@ def test_partial_inverse(self, fft_sin_patch_all, sin_patch):
# and then if we reverse distance it should be the same as original
full_inverse = ift.idft("distance")
self._patches_about_equal(full_inverse, sin_patch)

def test_data_type_restored(self, fft_sin_patch_time, sin_patch):
"""Ensure data_type attr is restored."""
out = fft_sin_patch_time.idft("time")
assert out.attrs.data_type == sin_patch.attrs.data_type

0 comments on commit c72907f

Please sign in to comment.