WaveTF is a TensorFlow library which implements 1D and 2D wavelet transforms, making them available as Keras layers, which can thus be easily plugged into machine learning workflows.
WaveTF can also be used outside of machine learning contexts, as a parallel wavelet computation tool, running on CPUs, GPUs or Google Cloud TPUs, and supporting, transparently at runtime, both 32- and 64-bit floats.
It accepts batched, multichannel inputs, e.g., for the 2D case, inputs of shape [batch_size, dim_x, dim_y, channels]. Currently, Haar and Daubechies-N=2 wavelet kernels are supported, and the input signal is extended via anti-symmetric-reflect padding, which preserves its first order finite difference at the border.
Install with pip:
$ pip3 install .
WaveTF requires TensorFlow 2 to be installed. If you want to run the tests you will also need pytest, numpy and PyWavelets.
API documentation for the latest WaveTF version is available via ReadTheDocs.
Alternatively, it can be generated locally via Sphinx.
To install Sphinx:
$ pip3 install sphinx sphinx_rtd_theme
To generate the html documentation (accessible at location docs/build/html/index.html
):
$ make -C docs/ html
An article describing in detail WaveTF's implementation and performance has been presented at the CADL workshop at ICPR 2020 and is available either via the Springer website or the CRS4 publications repository (direct link to PDF).
@InProceedings{wavetf,
author="Versaci, Francesco",
title="WaveTF: A Fast 2D Wavelet Transform for Machine Learning in Keras",
booktitle="Pattern Recognition. ICPR International Workshops and Challenges",
year="2021",
publisher="Springer International Publishing",
pages="605--618",
isbn="978-3-030-68763-2"
}
WaveTF directly exposes a single class, which is a factory for Keras layers which implement the Haar and Daubechies-N=2 wavelet transforms and anti-transforms. Its use is pretty straightforward.
import tensorflow as tf
from wavetf import WaveTFFactory
# input tensor
t0 = tf.random.uniform([32, 300, 200, 3])
# transform
w = WaveTFFactory().build('db2', dim=2)
t1 = w.call(t0)
# anti-transform
w_i = WaveTFFactory().build('db2', dim=2, inverse=True)
t2 = w_i.call(t1)
# compute difference
delta = abs(t2-t0)
print(f'Precision error: {tf.math.reduce_max(delta)}')
Some basic examples, including a simple wavelet-enriched Convolutional Neural Network (CNN), are available in the examples directory.
WaveTF
is developed by
- Francesco Versaci, CRS4 [email protected]
WaveTF is licensed under the Apache License, Version 2.0. See LICENSE for further details.