Skip to content

Commit

Permalink
added custom unit conversion layer with test
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Jun 24, 2024
1 parent 594549e commit cd3adcf
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
87 changes: 87 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,90 @@ def call(self, x):
return tf.math.log(x + self.adder) * self.scalar
else:
return tf.math.exp(x / self.scalar) - self.adder


class UnitConversion(tf.keras.layers.Layer):
"""Layer to convert units per feature channel using the linear transform:
``y = x * scalar + adder``
Be sure to check how this will interact with normalization factors.
"""

def __init__(self, name=None, adder=0, scalar=1):
"""
Parameters
----------
name : str | None
Name of the tensorflow layer
adder : float | list
Adder term for ``y = x * scalar + adder``. If this is a float, the
same value will be used for all feature channels. If this is a
list, each value will be used for the corresponding feature channel
and the length must match the number of feature channels
scalar : float | list
Scalar term for ``y = x * scalar + adder``. If this is a float, the
same value will be used for all feature channels. If this is a
list, each value will be used for the corresponding feature channel
and the length must match the number of feature channels
"""

super().__init__(name=name)
self.adder = adder
self.scalar = scalar
self.rank = None

def build(self, input_shape):
"""Custom implementation of the tf layer build method.
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
self.rank = len(input_shape)
nfeat = input_shape[-1]

dtypes = (int, np.int64, np.int32, float, np.float32, np.float64)

if isinstance(self.adder, dtypes):
self.adder = np.ones(nfeat) * self.adder
self.adder = tf.convert_to_tensor(self.adder, dtype=tf.float32)
else:
msg = (f'UnitConversion layer `adder` array has length '
f'{len(self.adder)} but input shape has last dimension '
f'as {input_shape[-1]}')
assert len(self.adder) == input_shape[-1], msg

if isinstance(self.scalar, dtypes):
self.scalar = np.ones(nfeat) * self.scalar
self.scalar = tf.convert_to_tensor(self.scalar, dtype=tf.float32)
else:
msg = (f'UnitConversion layer `scalar` array has length '
f'{len(self.scalar)} but input shape has last dimension '
f'as {input_shape[-1]}')
assert len(self.scalar) == input_shape[-1], msg

def call(self, x):
"""Convert units
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
y : tf.Tensor
Unit-converted x tensor
"""

if self.rank is None:
self.build(x.shape)

out = []
for idf, (adder, scalar) in enumerate(zip(self.adder, self.scalar)):
out.append(x[..., idf:idf+1] * scalar + adder)

out = tf.concat(out, -1, name='concat')

return out
28 changes: 28 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
GaussianAveragePooling2D,
SigLin,
LogTransform,
UnitConversion,
)
from phygnn.layers.handlers import HiddenLayers, Layers
from phygnn import TfModel
Expand Down Expand Up @@ -529,3 +530,30 @@ def test_logtransform():
assert not np.isnan(y).any()
assert np.allclose(y, np.log(x + 1))
assert np.allclose(x, xinv)


def test_unit_conversion():
"""Test the custom unit conversion layer"""
x = np.random.uniform(0, 1, (1, 10, 10, 4)) # 4 features

layer = UnitConversion(adder=0, scalar=1)
y = layer(x).numpy()
assert np.allclose(x, y)

layer = UnitConversion(adder=1, scalar=1)
y = layer(x).numpy()
assert (y >= 1).all() and (y <= 2).all()

layer = UnitConversion(adder=1, scalar=100)
y = layer(x).numpy()
assert (y >= 1).all() and (y > 90).any() and (y <= 101).all()

layer = UnitConversion(adder=0, scalar=[100, 1, 1, 1])
y = layer(x).numpy()
assert (y[..., 0] > 90).any() and (y[..., 0] <= 100).all()
assert (y[..., 1:] >= 0).all() and (y[..., 1:] <= 1).all()

with pytest.raises(AssertionError):
# bad number of scalar values
layer = UnitConversion(adder=0, scalar=[100, 1, 1])
y = layer(x)

0 comments on commit cd3adcf

Please sign in to comment.