Skip to content

Commit

Permalink
Added bit (sample-by-sample) normalization (#284)
Browse files Browse the repository at this point in the history
* Added bit (sample-by-sample) normalization and fixed normalization bug

* improved testing

---------

Co-authored-by: ariellellouch <[email protected]>
  • Loading branch information
d-chambers and ariellellouch authored Oct 16, 2023
1 parent a2d2782 commit 856364f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
17 changes: 14 additions & 3 deletions dascore/proc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def squeeze(self: PatchType, dim=None) -> PatchType:
def normalize(
self: PatchType,
dim: str,
norm: Literal["l1", "l2", "max"] = "l2",
norm: Literal["l1", "l2", "max", "bit"] = "l2",
) -> PatchType:
"""
Normalize a patch along a specified dimension.
Expand All @@ -311,6 +311,7 @@ def normalize(
l1 - divide each sample by the l1 of the axis.
l2 - divide each sample by the l2 of the axis.
max - divide each sample by the maximum of the absolute value of the axis.
bit - sample-by-sample normalization (-1/+1)
"""
axis = self.dims.index(dim)
data = self.data
Expand All @@ -319,13 +320,23 @@ def normalize(
norm_values = np.linalg.norm(self.data, axis=axis, ord=order)
elif norm == "max":
norm_values = np.max(data, axis=axis)
elif norm == "bit":
pass
else:
msg = (
f"Norm value of {norm} is not supported. "
f"Supported values are {('l1', 'l2', 'max')}"
f"Supported values are {('l1', 'l2', 'max', 'bit')}"
)
raise ValueError(msg)
new_data = data / np.expand_dims(norm_values, axis=axis)
if norm == "bit":
new_data = np.divide(
data, np.abs(data), out=np.zeros_like(data), where=np.abs(data) != 0
)
else:
expanded_norm = np.expand_dims(norm_values, axis=axis)
new_data = np.divide(
data, expanded_norm, out=np.zeros_like(data), where=expanded_norm != 0
)
return self.new(data=new_data)


Expand Down
22 changes: 22 additions & 0 deletions tests/test_proc/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import dascore as dc
from dascore import get_example_patch
from dascore.exceptions import IncompatiblePatchError, UnitError
from dascore.proc.basic import apply_operator
from dascore.units import furlongs, get_quantity, m, s
Expand Down Expand Up @@ -142,6 +143,27 @@ def test_max(self, random_patch):
norm = np.abs(np.sum(time_norm.data, axis=axis))
assert np.allclose(norm, 1)

def test_bit(self):
"""Ensure after operation each sample is -1, 1, or 0."""
patch = get_example_patch("dispersion_event")
bit_norm = patch.normalize("time", norm="bit")
assert np.all(np.unique(bit_norm.data) == np.array([-1.0, 0, 1.0]))

def test_zero_channels(self, random_patch):
"""Ensure after operation each zero row or vector remains so."""
zeroed_data = np.copy(random_patch.data)
zeroed_data[0, :] = 0.0
zeroed_data[:, 0] = 0.0
zeroed_patch = random_patch.new(data=zeroed_data)
for norm_type in ["l1", "l2", "max", "bit"]:
norm = zeroed_patch.normalize("time", norm=norm_type)
assert np.all(norm.data[0, :] == 0.0)
assert np.all(norm.data[:, 0] == 0.0)
for norm_type in ["l1", "l2", "max", "bit"]:
norm = zeroed_patch.normalize("distance", norm=norm_type)
assert np.all(norm.data[0, :] == 0.0)
assert np.all(norm.data[:, 0] == 0.0)


class TestStandarize:
"""Tests for standardization."""
Expand Down

0 comments on commit 856364f

Please sign in to comment.