diff --git a/tests/test_layers.py b/tests/test_layers.py index 64a7a31..54e4ded 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -5,13 +5,15 @@ import pytest import tensorflow as tf -from phygnn.layers.custom_layers import (SkipConnection, - SpatioTemporalExpansion, - FlattenAxis, - ExpandDims, - TileLayer, - GaussianNoiseAxis) -from phygnn.layers.handlers import Layers, HiddenLayers +from phygnn.layers.custom_layers import ( + ExpandDims, + FlattenAxis, + GaussianNoiseAxis, + SkipConnection, + SpatioTemporalExpansion, + TileLayer, +) +from phygnn.layers.handlers import HiddenLayers, Layers @pytest.mark.parametrize( @@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll): n_filters = 2 * s_mult**2 * t_mult shape = (1, 4, 4, 3, n_filters) n = np.product(shape) - x = np.arange(n).reshape((shape)) + x = np.arange(n).reshape(shape) y = layer(x) assert y.shape[0] == x.shape[0] assert y.shape[1] == s_mult * x.shape[1] @@ -387,3 +389,37 @@ def test_squeeze_excite_3d(): x = layer(x) with pytest.raises(tf.errors.InvalidArgumentError): tf.assert_equal(x_in, x) + + +def test_fno_2d(): + """Test the FNO layer with 2D data (4D tensor input)""" + hidden_layers = [ + {'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01, + 'activation': 'relu'}] + layers = HiddenLayers(hidden_layers) + assert len(layers.layers) == 1 + + x = np.random.normal(0, 1, size=(1, 4, 4, 3)) + + for layer in layers: + x_in = x + x = layer(x) + with pytest.raises(tf.errors.InvalidArgumentError): + tf.assert_equal(x_in, x) + + +def test_fno_3d(): + """Test the FNO layer with 3D data (5D tensor input)""" + hidden_layers = [ + {'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01, + 'activation': 'relu'}] + layers = HiddenLayers(hidden_layers) + assert len(layers.layers) == 1 + + x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3)) + + for layer in layers: + x_in = x + x = layer(x) + with pytest.raises(tf.errors.InvalidArgumentError): + tf.assert_equal(x_in, x)