diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index e7358ca..01bc1cb 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -785,3 +785,46 @@ def call(self, x, hi_res_feature): Output tensor with the hi_res_feature added to x. """ return tf.concat((x, hi_res_feature), axis=-1) + + +class FunctionalLayer(tf.keras.layers.Layer): + """Custom layer to implement the tensorflow layer functions (e.g., add, + subtract, multiply, maximum, and minimum) with a constant value. These + cannot be implemented in phygnn as normal layers because they need to + operate on two tensors of equal shape.""" + + def __init__(self, name, value): + """ + Parameters + ---------- + name : str + Name of the tensorflow layer function to be implemented, options + are (all lower-case): add, subtract, multiply, maximum, and minimum + value : float + Constant value to use in the function operation + """ + + options = ('add', 'subtract', 'multiply', 'maximum', 'minimum') + msg = (f'FunctionalLayer input `name` must be one of "{options}" ' + f'but received "{name}"') + assert name in options, msg + + super().__init__(name=name) + self.value = value + self.fun = getattr(tf.keras.layers, self.name) + + def call(self, x): + """Operates on x with the specified function + + Parameters + ---------- + x : tf.Tensor + Input tensor + + Returns + ------- + x : tf.Tensor + Output tensor operated on by the specified function + """ + const = tf.constant(value=self.value, shape=x.shape, dtype=x.dtype) + return self.fun((x, const)) diff --git a/phygnn/version.py b/phygnn/version.py index 7c48dcb..9b6286e 100644 --- a/phygnn/version.py +++ b/phygnn/version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Physics Guided Neural Network version.""" -__version__ = '0.0.25' +__version__ = '0.0.26' diff --git a/tests/test_layers.py b/tests/test_layers.py index 54e4ded..97f372a 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -12,6 +12,7 @@ SkipConnection, SpatioTemporalExpansion, TileLayer, + FunctionalLayer, ) from phygnn.layers.handlers import HiddenLayers, Layers @@ -423,3 +424,23 @@ def test_fno_3d(): x = layer(x) with pytest.raises(tf.errors.InvalidArgumentError): tf.assert_equal(x_in, x) + + +def test_functional_layer(): + """Test the generic functional layer""" + + layer = FunctionalLayer('maximum', 1) + x = np.random.normal(0.5, 3, size=(1, 4, 4, 6, 3)) + assert layer(x).numpy().min() == 1.0 + + # make sure layer works with input of arbitrary shape + x = np.random.normal(0.5, 3, size=(2, 8, 8, 4, 1)) + assert layer(x).numpy().min() == 1.0 + + layer = FunctionalLayer('multiply', 1.5) + x = np.random.normal(0.5, 3, size=(1, 4, 4, 6, 3)) + assert np.allclose(layer(x).numpy(), x * 1.5) + + with pytest.raises(AssertionError) as excinfo: + FunctionalLayer('bad_arg', 0) + assert "must be one of" in str(excinfo.value)