Skip to content

Commit

Permalink
fno layer
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 18, 2023
1 parent 42c914a commit 6a4f6f4
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,101 @@ def call(self, x):
return x


class FNO(tf.keras.layers.Layer):
"""Custom layer for fourier neural operator block
Note that this is only set up to take a channels-last input
References
----------
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
"""

def __init__(self, ratio=16, sparsity_threshold=0.5):
"""
Parameters
----------
ratio : int
Number of channels/filters divided by the number of
dense connections in the FNO block.
sparsity_threshold : float
Parameter to control sparsity and shrinkage in the softshrink
activation function.
"""

super().__init__()
self._ratio = ratio
self.fft_layer = None
self.ifft_layer = None
self.mlp_layers = None
self._n_channels = None
self._dense_units = None
self.sparsity_threshold = sparsity_threshold

def softshrink(self, x, lambd=0.5):
"""Softshrink activation function
https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
"""
x = tf.convert_to_tensor(x)
values_below_lower = tf.where(x < -lambd, x + lambd, 0)
values_above_upper = tf.where(lambd < x, x - lambd, 0)
return values_below_lower + values_above_upper

def build(self, input_shape):
"""Build the FNO layer based on an input shape
Parameters
----------
input_shape : tuple
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1]
self._dense_units = int(np.ceil(self._n_channels / self._ratio))

if len(input_shape) == 4:
self.fft_layer = tf.signal.fft2d
self.ifft_layer = tf.signal.ifft2d
elif len(input_shape) == 5:
self.fft_layer = tf.signal.fft3d
self.ifft_layer = tf.signal.ifft3d
else:
msg = ('FourierNeuralOperator layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
.format(input_shape))
logger.error(msg)
raise RuntimeError(msg)

self.mlp_layers = [
tf.keras.layers.Dense(self._dense_units, activation='relu'),
tf.keras.layers.Dense(self._n_channels)]

def call(self, x):
"""Call the custom FourierNeuralOperator layer
Parameters
----------
x : tf.Tensor
Input tensor.
Returns
-------
x : tf.Tensor
Output tensor, this is the FNO weights added to the original input
tensor.
"""

t_in = x
x = self.fft_layer(x)
for layer in self.mlp_layers:
x = layer(x)
x = self.softshrink(x, lambd=self.sparsity_threshold)
x = self.ifft_layer(x)

return x + t_in


class Sup3rAdder(tf.keras.layers.Layer):
"""Layer to add high-resolution data to a sup3r model in the middle of a
super resolution forward pass."""
Expand Down

0 comments on commit 6a4f6f4

Please sign in to comment.