diff --git a/dascore/proc/basic.py b/dascore/proc/basic.py index 1306291f..dc55236a 100644 --- a/dascore/proc/basic.py +++ b/dascore/proc/basic.py @@ -296,7 +296,7 @@ def squeeze(self: PatchType, dim=None) -> PatchType: def normalize( self: PatchType, dim: str, - norm: Literal["l1", "l2", "max"] = "l2", + norm: Literal["l1", "l2", "max", "bit"] = "l2", ) -> PatchType: """ Normalize a patch along a specified dimension. @@ -311,6 +311,7 @@ def normalize( l1 - divide each sample by the l1 of the axis. l2 - divide each sample by the l2 of the axis. max - divide each sample by the maximum of the absolute value of the axis. + bit - sample-by-sample normalization (-1/+1) """ axis = self.dims.index(dim) data = self.data @@ -319,13 +320,23 @@ def normalize( norm_values = np.linalg.norm(self.data, axis=axis, ord=order) elif norm == "max": norm_values = np.max(data, axis=axis) + elif norm == "bit": + pass else: msg = ( f"Norm value of {norm} is not supported. " - f"Supported values are {('l1', 'l2', 'max')}" + f"Supported values are {('l1', 'l2', 'max', 'bit')}" ) raise ValueError(msg) - new_data = data / np.expand_dims(norm_values, axis=axis) + if norm == "bit": + new_data = np.divide( + data, np.abs(data), out=np.zeros_like(data), where=np.abs(data) != 0 + ) + else: + expanded_norm = np.expand_dims(norm_values, axis=axis) + new_data = np.divide( + data, expanded_norm, out=np.zeros_like(data), where=expanded_norm != 0 + ) return self.new(data=new_data) diff --git a/tests/test_proc/test_basic.py b/tests/test_proc/test_basic.py index 26e84114..4e893b15 100644 --- a/tests/test_proc/test_basic.py +++ b/tests/test_proc/test_basic.py @@ -6,6 +6,7 @@ import pytest import dascore as dc +from dascore import get_example_patch from dascore.exceptions import IncompatiblePatchError, UnitError from dascore.proc.basic import apply_operator from dascore.units import furlongs, get_quantity, m, s @@ -142,6 +143,27 @@ def test_max(self, random_patch): norm = np.abs(np.sum(time_norm.data, axis=axis)) assert np.allclose(norm, 1) + def test_bit(self): + """Ensure after operation each sample is -1, 1, or 0.""" + patch = get_example_patch("dispersion_event") + bit_norm = patch.normalize("time", norm="bit") + assert np.all(np.unique(bit_norm.data) == np.array([-1.0, 0, 1.0])) + + def test_zero_channels(self, random_patch): + """Ensure after operation each zero row or vector remains so.""" + zeroed_data = np.copy(random_patch.data) + zeroed_data[0, :] = 0.0 + zeroed_data[:, 0] = 0.0 + zeroed_patch = random_patch.new(data=zeroed_data) + for norm_type in ["l1", "l2", "max", "bit"]: + norm = zeroed_patch.normalize("time", norm=norm_type) + assert np.all(norm.data[0, :] == 0.0) + assert np.all(norm.data[:, 0] == 0.0) + for norm_type in ["l1", "l2", "max", "bit"]: + norm = zeroed_patch.normalize("distance", norm=norm_type) + assert np.all(norm.data[0, :] == 0.0) + assert np.all(norm.data[:, 0] == 0.0) + class TestStandarize: """Tests for standardization."""