diff --git a/docs/tutorials/temporal_convolutional_network.ipynb b/docs/tutorials/temporal_convolutional_network.ipynb
new file mode 100644
index 0000000000..52a0c36c41
--- /dev/null
+++ b/docs/tutorials/temporal_convolutional_network.ipynb
@@ -0,0 +1,310 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Copyright 2019 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TensorFlow Addons Layers: Temporal Convolutional Network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview\n",
+ "This notebook will demonstrate how to use Temporal Convolutional Network in TensorFlow Addons."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from __future__ import absolute_import, division, print_function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " # %tensorflow_version only exists in Colab.\n",
+ " %tensorflow_version 2.x\n",
+ "except Exception:\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import tensorflow_addons as tfa\n",
+ "\n",
+ "import tensorflow.keras as keras\n",
+ "from tensorflow.keras.datasets import imdb\n",
+ "from tensorflow.keras.models import Model\n",
+ "from tensorflow.keras import Input\n",
+ "from tensorflow.keras.layers import Dense, Dropout, Embedding\n",
+ "from tensorflow.keras.preprocessing import sequence"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Global Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "max_features = 20000\n",
+ "# cut texts after this number of words\n",
+ "# (among top max_features most common words)\n",
+ "maxlen = 100\n",
+ "batch_size = 32"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Import IMDB data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the data, split between train and test sets\n",
+ "(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)\n",
+ "\n",
+ "# pad the training data\n",
+ "x_train = sequence.pad_sequences(x_train, maxlen=maxlen)\n",
+ "x_test = sequence.pad_sequences(x_test, maxlen=maxlen)\n",
+ "y_train = np.array(y_train)\n",
+ "y_test = np.array(y_test)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Build Simple TCN Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# build the model using functional API\n",
+ "i = Input(shape=(maxlen,))\n",
+ "x = Embedding(max_features, 128)(i)\n",
+ "x = tfa.layers.TCN()(x)\n",
+ "x = Dropout(0.2)(x)\n",
+ "x = Dense(1, activation='sigmoid')(x)\n",
+ "\n",
+ "model = Model(inputs=[i], outputs=[x])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Summary of the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model: \"model\"\n",
+ "_________________________________________________________________\n",
+ "Layer (type) Output Shape Param # \n",
+ "=================================================================\n",
+ "input_1 (InputLayer) [(None, 100)] 0 \n",
+ "_________________________________________________________________\n",
+ "embedding (Embedding) (None, 100, 128) 2560000 \n",
+ "_________________________________________________________________\n",
+ "tcn (TCN) (None, 64) 148800 \n",
+ "_________________________________________________________________\n",
+ "dropout (Dropout) (None, 64) 0 \n",
+ "_________________________________________________________________\n",
+ "dense (Dense) (None, 1) 65 \n",
+ "=================================================================\n",
+ "Total params: 2,708,865\n",
+ "Trainable params: 2,708,865\n",
+ "Non-trainable params: 0\n",
+ "_________________________________________________________________\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Compile the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.compile(optimizer='adam',\n",
+ " loss = 'binary_crossentropy',\n",
+ " metrics=['accuracy'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Fit the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train on 25000 samples, validate on 25000 samples\n",
+ "Epoch 1/3\n",
+ "25000/25000 [==============================] - 100s 4ms/sample - loss: 0.5734 - accuracy: 0.6722 - val_loss: 0.4147 - val_accuracy: 0.8041\n",
+ "Epoch 2/3\n",
+ "25000/25000 [==============================] - 95s 4ms/sample - loss: 0.2852 - accuracy: 0.8811 - val_loss: 0.4136 - val_accuracy: 0.8128\n",
+ "Epoch 3/3\n",
+ "25000/25000 [==============================] - 97s 4ms/sample - loss: 0.1372 - accuracy: 0.9485 - val_loss: 0.4949 - val_accuracy: 0.8190\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.fit(x_train, y_train,\n",
+ " batch_size=batch_size,\n",
+ " epochs=3,\n",
+ " validation_data=(x_test, y_test))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD
index 01f475a2fb..ba31922c5a 100644
--- a/tensorflow_addons/layers/BUILD
+++ b/tensorflow_addons/layers/BUILD
@@ -12,6 +12,7 @@ py_library(
"optical_flow.py",
"poincare.py",
"sparsemax.py",
+ "tcn.py",
"wrappers.py",
],
data = [
@@ -101,3 +102,16 @@ py_test(
":layers",
],
)
+
+py_test(
+ name = "tcn_test",
+ size = "small",
+ srcs = [
+ "tcn_test.py",
+ ],
+ main = "tcn_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":layers",
+ ],
+)
diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md
index 087fcfebbd..04235d8248 100644
--- a/tensorflow_addons/layers/README.md
+++ b/tensorflow_addons/layers/README.md
@@ -9,6 +9,7 @@
| opticalflow | @fsx950223 | fsx950223@gmail.com |
| poincare | @rahulunair | rahulunair@gmail.com |
| sparsemax | @AndreasMadsen | amwwebdk+github@gmail.com |
+| tcn | @shun-lin | shunlin@google.com |
| wrappers | @seanpmorgan | seanmorgan@outlook.com |
## Components
@@ -20,7 +21,8 @@
| normalizations | InstanceNormalization | https://arxiv.org/abs/1607.08022 |
| opticalflow | CorrelationCost | https://arxiv.org/abs/1504.06852 |
| poincare | PoincareNormalize | https://arxiv.org/abs/1705.08039 |
-| sparsemax| Sparsemax | https://arxiv.org/abs/1602.02068 |
+| sparsemax | Sparsemax | https://arxiv.org/abs/1602.02068 |
+| tcn | TCN (Temporal Convolutional Network) | https://arxiv.org/pdf/1803.01271 |
| wrappers | WeightNormalization | https://arxiv.org/abs/1602.07868 |
## Contribution Guidelines
diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py
index d527e16362..7b7b85ef76 100644
--- a/tensorflow_addons/layers/__init__.py
+++ b/tensorflow_addons/layers/__init__.py
@@ -14,15 +14,14 @@
# ==============================================================================
"""Additional layers that conform to Keras API."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
from tensorflow_addons.layers.gelu import GeLU
from tensorflow_addons.layers.maxout import Maxout
-from tensorflow_addons.layers.normalizations import GroupNormalization
-from tensorflow_addons.layers.normalizations import InstanceNormalization
+from tensorflow_addons.layers.normalizations import (GroupNormalization,
+ InstanceNormalization)
from tensorflow_addons.layers.optical_flow import CorrelationCost
from tensorflow_addons.layers.poincare import PoincareNormalize
from tensorflow_addons.layers.sparsemax import Sparsemax
-from tensorflow_addons.layers.wrappers import WeightNormalization
\ No newline at end of file
+from tensorflow_addons.layers.tcn import TCN
+from tensorflow_addons.layers.wrappers import WeightNormalization
diff --git a/tensorflow_addons/layers/tcn.py b/tensorflow_addons/layers/tcn.py
new file mode 100644
index 0000000000..9519b09a56
--- /dev/null
+++ b/tensorflow_addons/layers/tcn.py
@@ -0,0 +1,325 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implements TCN layer."""
+
+from __future__ import absolute_import, division, print_function
+
+import tensorflow as tf
+
+
+@tf.keras.utils.register_keras_serializable(package='Addons')
+class ResidualBlock(tf.keras.layers.Layer):
+ """Defines the residual block for the WaveNet TCN.
+
+ Arguments:
+ dilation_rate (int): The dilation power of 2 we are using
+ for this residual block. Defaults to 1.
+ filters (int): The number of convolutional
+ filters to use in this block. Defaults to 64.
+ kernel_size (int): The size of the convolutional kernel. Defaults
+ to 2.
+ padding (String): The padding used in the convolutional layers,
+ 'same' or 'causal'. Defaults to 'same'
+ activation (String): The final activation used
+ in o = Activation(x + F(x)). Defaults to 'relu'
+ dropout_rate (Float): Float between 0 and 1. Fraction
+ of the input units to drop. Defaults to 0.0.
+ kernel_initializer (String): Initializer for the kernel weights
+ matrix (Conv1D). Defaults to 'he_normal'
+ use_batch_norm (bool): Whether to use batch normalization in the
+ residual layers or not. Defaults to False.
+ last_block (bool): Whether or not this block is the last residual
+ block of the network. Defaults to False.
+ kwargs: Any initializers for Layer class.
+
+ Returns:
+ A Residual Blcok.
+ """
+
+ def __init__(self,
+ dilation_rate=1,
+ filters=64,
+ kernel_size=2,
+ padding='same',
+ activation='relu',
+ dropout_rate=0.0,
+ kernel_initializer='he_normal',
+ use_batch_norm=False,
+ last_block=False,
+ **kwargs):
+
+ self.dilation_rate = dilation_rate
+ self.filters = filters
+ self.kernel_size = kernel_size
+ self.padding = padding
+ self.activation = activation
+ self.dropout_rate = dropout_rate
+ self.use_batch_norm = use_batch_norm
+ self.kernel_initializer = kernel_initializer
+ self.last_block = last_block
+
+ super(ResidualBlock, self).__init__(**kwargs)
+
+ self.residual_layers = list()
+
+ with tf.name_scope(self.name):
+ for k in range(2):
+ name = 'con1D_{}'.format(k)
+
+ with tf.name_scope(name):
+ conv_layer = tf.keras.layers.Conv1D(
+ filters=self.filters,
+ kernel_size=self.kernel_size,
+ dilation_rate=self.dilation_rate,
+ padding=self.padding,
+ name=name,
+ kernel_initializer=self.kernel_initializer)
+ self.residual_layers.append(conv_layer)
+
+ if self.use_batch_norm:
+ batch_norm_layer = tf.keras.layers.BatchNormalization()
+ self.residual_layers.append(batch_norm_layer)
+
+ self.residual_layers.append(tf.keras.layers.Activation('relu'))
+ self.residual_layers.append(
+ tf.keras.layers.SpatialDropout1D(
+ rate=self.dropout_rate))
+
+ if not self.last_block:
+ # 1x1 conv to match the shapes (channel dimension).
+ name = 'conv1D_{}'.format(k + 1)
+ with tf.name_scope(name):
+ self.shape_match_conv = tf.keras.layers.Conv1D(
+ filters=self.filters,
+ kernel_size=1,
+ padding='same',
+ name=name,
+ kernel_initializer=self.kernel_initializer)
+
+ else:
+ self.shape_match_conv = tf.keras.layers.Lambda(
+ lambda x: x, name='identity')
+
+ self.final_activation = tf.keras.layers.Activation(self.activation)
+
+ def build(self, input_shape):
+
+ # build residual layers
+ self.res_output_shape = input_shape
+ for layer in self.residual_layers:
+ layer.build(self.res_output_shape)
+ self.res_output_shape = layer.compute_output_shape(
+ self.res_output_shape)
+
+ # build shape matching convolutional layer
+ self.shape_match_conv.build(input_shape)
+ self.res_output_shape = self.shape_match_conv.compute_output_shape(
+ input_shape)
+
+ # build final activation layer
+ self.final_activation.build(self.res_output_shape)
+
+ super(ResidualBlock, self).build(input_shape)
+
+ def call(self, inputs, training=None):
+ """
+ Returns: A tuple where the first element is the residual model tensor,
+ and the second is the skip connection tensor.
+ """
+ x = inputs
+ for layer in self.residual_layers:
+ x = layer(x, training=training)
+
+ x2 = self.shape_match_conv(inputs)
+ res_x = x2 + x
+ return [self.final_activation(res_x), x]
+
+ def compute_output_shape(self, input_shape):
+ return [self.res_output_shape, self.res_output_shape]
+
+ def get_config(self):
+ config = dict()
+
+ config['dilation_rate'] = self.dilation_rate
+ config['filters'] = self.filters
+ config['kernel_size'] = self.kernel_size
+ config['padding'] = self.padding
+ config['activation'] = self.activation
+ config['dropout_rate'] = self.dropout_rate
+ config['use_batch_norm'] = self.use_batch_norm
+ config['kernel_initializer'] = self.kernel_initializer
+ config['last_block'] = self.last_block
+
+ base_config = super(ResidualBlock, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+@tf.keras.utils.register_keras_serializable(package='Addons')
+class TCN(tf.keras.layers.Layer):
+ """Creates a TCN layer.
+
+ Input shape:
+ A tensor of shape (batch_size, timesteps, input_dim).
+
+ Arguments:
+ filters: The number of filters to use in the convolutional layers.
+ Defaults to 64.
+ kernel_size: The size of the kernel to use in each
+ convolutional layer. Defaults to 2.
+ dilations: The array-like input of the dilations.
+ Defaults to [1, 2, 4, 8, 16, 32, 64]
+ stacks : The number of stacks of residual blocks to use. Defaults
+ to 1.
+ padding: The padding to use in the convolutional layers,
+ 'causal' or 'same'. Defaults to 'causal'.
+ use_skip_connections: Boolean. If we want to add skip
+ connections from input to each residual block.
+ Defaults to True.
+ return_sequences: Boolean. Whether to return the full sequence
+ (when True) or the last output in the output sequence (when False).
+ Defaults to False.
+ activation: The activation used in the residual
+ blocks o = Activation(x + F(x)). Defaults to 'linear'
+ dropout_rate: Float between 0 and 1. Fraction of the input
+ units to drop. Defaults to 0.0.
+ kernel_initializer: Initializer for the kernel weights
+ matrix (Conv1D). Defaulst to 'he_normal'
+ use_batch_norm: Whether to use batch normalization in the
+ residual layers or not. Defaulst to False.
+ kwargs: Any other arguments for configuring parent class Layer.
+ For example "name=str", Name of the model.
+ Use unique names when using multiple TCN.
+ Returns:
+ A TCN layer.
+ """
+
+ def __init__(self,
+ filters=64,
+ kernel_size=2,
+ stacks=1,
+ dilations=[1, 2, 4, 8, 16, 32, 64],
+ padding='causal',
+ use_skip_connections=True,
+ dropout_rate=0.0,
+ return_sequences=False,
+ activation='linear',
+ kernel_initializer='he_normal',
+ use_batch_norm=False,
+ **kwargs):
+
+ super(TCN, self).__init__(**kwargs)
+
+ self.return_sequences = return_sequences
+ self.dropout_rate = dropout_rate
+ self.use_skip_connections = use_skip_connections
+ self.stacks = stacks
+ self.kernel_size = kernel_size
+ self.filters = filters
+ self.dilations = dilations
+ self.activation = activation
+ self.padding = padding
+ self.kernel_initializer = kernel_initializer
+ self.use_batch_norm = use_batch_norm
+
+ # validate paddings
+ validate_paddings = ['causal', 'same']
+ if padding not in validate_paddings:
+ raise ValueError(
+ "Only 'causal' or 'same' padding are compatible for this layer"
+ )
+
+ self.main_conv1D = tf.keras.layers.Conv1D(
+ filters=self.filters,
+ kernel_size=1,
+ padding=self.padding,
+ kernel_initializer=self.kernel_initializer)
+
+ # list to hold all the member ResidualBlocks
+ self.residual_blocks = list()
+ total_num_blocks = self.stacks * len(self.dilations)
+ if not self.use_skip_connections:
+ total_num_blocks += 1 # cheap way to do a false case for below
+
+ for _ in range(self.stacks):
+ for d in self.dilations:
+ self.residual_blocks.append(
+ ResidualBlock(
+ dilation_rate=d,
+ filters=self.filters,
+ kernel_size=self.kernel_size,
+ padding=self.padding,
+ activation=self.activation,
+ dropout_rate=self.dropout_rate,
+ use_batch_norm=self.use_batch_norm,
+ kernel_initializer=self.kernel_initializer,
+ last_block=len(self.residual_blocks) +
+ 1 == total_num_blocks,
+ name='residual_block_{}'.format(
+ len(self.residual_blocks))))
+
+ if not self.return_sequences:
+ self.last_output_layer = tf.keras.layers.Lambda(lambda x:
+ x[:, -1, :])
+
+ def build(self, input_shape):
+ self.main_conv1D.build(input_shape)
+
+ self.build_output_shape = self.main_conv1D.compute_output_shape(
+ input_shape)
+
+ for residual_block in self.residual_blocks:
+ residual_block.build(self.build_output_shape)
+ self.build_output_shape = residual_block.res_output_shape
+
+ def compute_output_shape(self, input_shape):
+ if not self.built:
+ self.build(input_shape)
+ if not self.return_sequences:
+ return self.last_output_layer.compute_output_shape(
+ self.build_output_shape)
+ else:
+ return self.build_output_shape
+
+ def call(self, inputs, training=None):
+ x = inputs
+ x = self.main_conv1D(x)
+ skip_connections = list()
+ for layer in self.residual_blocks:
+ x, skip_out = layer(x, training=training)
+ skip_connections.append(skip_out)
+
+ if self.use_skip_connections:
+ x = tf.keras.layers.add(skip_connections)
+ if not self.return_sequences:
+ x = self.last_output_layer(x)
+ return x
+
+ def get_config(self):
+ config = dict()
+ config['filters'] = self.filters
+ config['kernel_size'] = self.kernel_size
+ config['stacks'] = self.stacks
+ config['dilations'] = self.dilations
+ config['padding'] = self.padding
+ config['use_skip_connections'] = self.use_skip_connections
+ config['dropout_rate'] = self.dropout_rate
+ config['return_sequences'] = self.return_sequences
+ config['activation'] = self.activation
+ config['use_batch_norm'] = self.use_batch_norm
+ config['kernel_initializer'] = self.kernel_initializer
+
+ base_config = super(TCN, self).get_config()
+
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow_addons/layers/tcn_test.py b/tensorflow_addons/layers/tcn_test.py
new file mode 100644
index 0000000000..ee587724e5
--- /dev/null
+++ b/tensorflow_addons/layers/tcn_test.py
@@ -0,0 +1,92 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TCN layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow_addons.utils import test_utils
+from tensorflow_addons.layers import TCN
+from tensorflow_addons.layers.tcn import ResidualBlock
+
+
+@test_utils.run_all_in_graph_and_eager_modes
+class TCNTest(tf.test.TestCase):
+ def test_tcn(self):
+ test_utils.layer_test(TCN, input_shape=(2, 4, 4))
+
+ def test_config_tcn(self):
+
+ # test default config
+ tcn = TCN()
+ self.assertEqual(tcn.filters, 64)
+ self.assertEqual(tcn.kernel_size, 2)
+ self.assertEqual(tcn.stacks, 1)
+ self.assertEqual(tcn.dilations, [1, 2, 4, 8, 16, 32, 64])
+ self.assertEqual(tcn.padding, 'causal')
+ self.assertEqual(tcn.use_skip_connections, True)
+ self.assertEqual(tcn.dropout_rate, 0.0)
+ self.assertEqual(tcn.return_sequences, False)
+ self.assertEqual(tcn.activation, 'linear')
+ self.assertEqual(tcn.kernel_initializer, 'he_normal')
+ self.assertEqual(tcn.use_batch_norm, False)
+
+ # Check save and restore config
+ tcn_2 = TCN.from_config(tcn.get_config())
+ self.assertEqual(tcn_2.filters, 64)
+ self.assertEqual(tcn_2.kernel_size, 2)
+ self.assertEqual(tcn_2.stacks, 1)
+ self.assertEqual(tcn_2.dilations, [1, 2, 4, 8, 16, 32, 64])
+ self.assertEqual(tcn_2.padding, 'causal')
+ self.assertEqual(tcn_2.use_skip_connections, True)
+ self.assertEqual(tcn_2.dropout_rate, 0.0)
+ self.assertEqual(tcn_2.return_sequences, False)
+ self.assertEqual(tcn_2.activation, 'linear')
+ self.assertEqual(tcn_2.kernel_initializer, 'he_normal')
+ self.assertEqual(tcn_2.use_batch_norm, False)
+
+ def test_config_residual_block(self):
+
+ # test default config
+ residual_block = ResidualBlock()
+ self.assertEqual(residual_block.dilation_rate, 1)
+ self.assertEqual(residual_block.filters, 64)
+ self.assertEqual(residual_block.kernel_size, 2)
+ self.assertEqual(residual_block.padding, 'same')
+ self.assertEqual(residual_block.activation, 'relu')
+ self.assertEqual(residual_block.dropout_rate, 0.0)
+ self.assertEqual(residual_block.kernel_initializer, 'he_normal')
+ self.assertEqual(residual_block.last_block, False)
+ self.assertEqual(residual_block.use_batch_norm, False)
+
+ # Check save and restore config
+ residual_block_2 = ResidualBlock.from_config(
+ residual_block.get_config())
+ self.assertEqual(residual_block_2.dilation_rate, 1)
+ self.assertEqual(residual_block_2.filters, 64)
+ self.assertEqual(residual_block_2.kernel_size, 2)
+ self.assertEqual(residual_block_2.padding, 'same')
+ self.assertEqual(residual_block_2.activation, 'relu')
+ self.assertEqual(residual_block_2.dropout_rate, 0.0)
+ self.assertEqual(residual_block_2.kernel_initializer, 'he_normal')
+ self.assertEqual(residual_block_2.last_block, False)
+ self.assertEqual(residual_block_2.use_batch_norm, False)
+
+
+if __name__ == "__main__":
+ tf.test.main()