Skip to content

Commit

Permalink
added log transform layer
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Jun 18, 2024
1 parent 1e2f462 commit b5d4893
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
41 changes: 41 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,3 +955,44 @@ def call(self, x):
"""

return tf.math.maximum(tf.math.sigmoid(x), x + 0.5)


class LogTransform(tf.keras.layers.Layer):
"""Log transform or inverse transform of data
``y = log(x + adder)`` or ``y = exp(x) - adder`` for inverse
"""

def __init__(self, name=None, adder=0, inverse=False):
"""
Parameters
----------
name : str | None
Name of the tensorflow layer
adder : float
Adder for ``y = log(x + adder)``
inverse : bool
Option to perform the inverse operation e.g. ``y = exp(x) - adder``
"""

super().__init__(name=name)
self.adder = adder
self.inverse = inverse

def call(self, x):
"""Operates on x with (inverse) log transform
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
y : tf.Tensor
Output ``y = log(x + adder)`` or ``y = exp(x) - adder`` if inverse
"""
if not self.inverse:
return tf.math.log(x + self.adder)
else:
return tf.math.exp(x) - self.adder
19 changes: 19 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FunctionalLayer,
GaussianAveragePooling2D,
SigLin,
LogTransform,
)
from phygnn.layers.handlers import HiddenLayers, Layers
from phygnn import TfModel
Expand Down Expand Up @@ -510,3 +511,21 @@ def test_siglin():
assert x.shape == y.shape
assert (y > 0).all()
assert np.allclose(y[mid:], x[mid:] + 0.5)


def test_logtransform():
"""Test the log transform layer"""
n_points = 1000
lt = LogTransform(adder=0)
x = np.linspace(0, 10, n_points + 1)
y = lt(x).numpy()
assert x.shape == y.shape
assert y[0] == -np.inf
lt = LogTransform(adder=1)
ilt = LogTransform(adder=1, inverse=True)
x = np.linspace(0, 10, n_points + 1)
y = lt(x).numpy()
xinv = ilt(y).numpy()
assert not np.isnan(y).any()
assert np.allclose(y, np.log(x + 1))
assert np.allclose(x, xinv)

0 comments on commit b5d4893

Please sign in to comment.