Skip to content

Commit

Permalink
added transposes to fft the correct dims
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 18, 2023
1 parent 1c791fa commit 1df0d28
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,17 +629,31 @@ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
self._activation = activation
self._n_channels = None
self._dense_units = None
self._sparsity_threshold = sparsity_threshold
self._lambd = sparsity_threshold

def _softshrink(self, x, lambd=0.5):
def _softshrink(self, x):
"""Softshrink activation function
https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
"""
values_below_lower = tf.where(x < -lambd, x + lambd, 0)
values_above_upper = tf.where(lambd < x, x - lambd, 0)
values_below_lower = tf.where(x < -self._lambd, x + self._lambd, 0)
values_above_upper = tf.where(self._lambd < x, x - self._lambd, 0)
return values_below_lower + values_above_upper

def _fft(self, x):
"""Apply needed transpositions and fft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._fft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x

def _ifft(self, x):
"""Apply needed transpositions and ifft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._ifft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x

def build(self, input_shape):
"""Build the FNO layer based on an input shape
Expand All @@ -649,6 +663,9 @@ def build(self, input_shape):
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1]
dims = list(range(len(input_shape)))
self._perms_in = [dims[-1], *dims[:-1]]
self._perms_out = [*dims[1:], dims[0]]

if len(input_shape) == 4:
self._fft_layer = tf.signal.fft2d
Expand Down Expand Up @@ -682,11 +699,11 @@ def call(self, x):
tensor.
"""
t_in = x
x = self._fft_layer(tf.cast(x, tf.complex64))
x = self._fft(x)
for layer in self._mlp_layers:
x = layer(x)
x = self._softshrink(x, lambd=self._sparsity_threshold)
x = self._ifft_layer(tf.cast(x, tf.complex64))
x = self._softshrink(x)
x = self._ifft(x)
x = tf.cast(x, dtype=t_in.dtype)

return x + t_in
Expand Down

0 comments on commit 1df0d28

Please sign in to comment.