Skip to content

Commit

Permalink
Contrastive learning of general-purpose audio representations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 338221596
  • Loading branch information
lienz authored and copybara-github committed Oct 21, 2020
1 parent 1a96a0c commit 23dc105
Show file tree
Hide file tree
Showing 11 changed files with 860 additions and 0 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ env:
- PROJECT="cell_mixer"
- PROJECT="cfq"
- PROJECT="cnn_quantization"
- PROJECT="cola"
- PROJECT="dataset_analysis"
- PROJECT="dense_representations_for_entity_retrieval"
- PROJECT="depth_and_motion_learning"
Expand Down
56 changes: 56 additions & 0 deletions cola/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Contrastive learning of general purpose audio representations

![Overview](./images/overview_cola.png)

This Python library allows pre-training and fine-tuning contrastive embeddings of audio with the COLA method.
In particular, one can:

* Pre-train COLA embeddings, which use a simple contrastive learning method
* Train a linear classifier on pre-trained embeddings
* Train a supervised neural network from scratch
* Initialize a classifier with pre-trained COLA embeddings and fine-tune on a new dataset

## Dependencies
* [TensorFlow](https://www.tensorflow.org/)
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/overview)

## Quickstart
Training has three modes:

* `SSL`, to pre-train a model with self-supervised contrastive learning
* `DS` to fine-tune a pre-trained model on a downstream task
* `SUP` to train a simple supervised system

Pre-train a COLA embedding on a dataset from tensorflow_datasets (here, LibriSpeech):

```bash
python -m main --experiment_id=cola_pretrain --model_dir=/tmp/cola \
--training_mode=SSL --ssl_dataset=LBS --strategy=gpu
```
Note that so far labels are not necessary.
After pre-training, the model is saved in `/tmp/cola/librispeech/cola_pretrain`.

One can train a linear classifier on these embeddings, on the Speech Commands dataset, in a supervised fashion:

```bash
python -m main --experiment_id=cola_downstream --ssl_checkpoint_id=cola_pretrain \
--model_dir=/tmp/cola --training_mode=DS --ssl_dataset=LBS --ds_dataset=SPCV2 \
--strategy=gpu --freeze_encoder=true
```

The flags `--ssl_checkpoint_id` and `--ssl_dataset` indicate that the pre-trained model is stored in `/tmp/cola/librispeech/cola_pretrain`.

Note the `--freeze_encoder` flag. If set to `False`, the entire network is fine-tuned.

## Advanced usage
Pre-training and fine-tuning only handle tfds datasets, for simplicity. One can easily use arbitrary datasets by overriding the `get_self_supervised_data` and `get_downstream_dataset` methods in `data.py`.
## Reference
If you use this repository, please consider citing:

```
@article{saeed2020,
title={Contrastive learning of general purpose audio representations},
author={Aaqib Saeed and David Grangier and Neil Zeghidour},
year={2020},
}
```
67 changes: 67 additions & 0 deletions cola/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines constants used across module."""

import enum


@enum.unique
class Dataset(enum.Enum):
"""Look up for dataset names."""

LBS = "librispeech"

BSD = "birdsong_detection"

MUSAN = "musan"

AS = "audioset"

TUT = "tut_2018"

SPCV1 = "speech_commands_v1"

SPCV2 = "speech_commands"

NSYNTH_INST = "nsynth_instrument_family"

VOXCELEB = "voxceleb"

VOXFORGE = "voxforge"

CREMA_D = "crema_d"


@enum.unique
class TrainingMode(enum.Enum):
"""Look up for model training modes."""

SSL = "self_supervised"

SUP = "supervised"

RND = "random"

DS = "downstream"


@enum.unique
class SimilarityMeasure(enum.Enum):
"""Look up for similarity measure in contrastive model."""

DOT = "dot_product"

BILINEAR = "bilinear_product"
133 changes: 133 additions & 0 deletions cola/contrastive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Self-supervised model for contrastive learning task."""

import os

import tensorflow as tf

from cola import constants
from cola import data
from cola import network


class ContrastiveModel:
"""Provides functionality for self-supervised constrastive learning model."""

def __init__(self,
strategy,
ssl_dataset_name,
ds_dataset_name,
model_path,
experiment_id,
batch_size,
epochs, learning_rate,
embedding_dim,
temperature,
similarity_type,
pooling_type,
noise,
steps_per_epoch = 1000):
"""Initializes a contrastive model object."""

self._strategy = strategy
self._ssl_dataset_name = ssl_dataset_name
self._ds_dataset_name = ds_dataset_name
self._model_path = model_path
self._experiment_id = experiment_id

self._batch_size = batch_size
self._epochs = epochs
self._learning_rate = learning_rate
self._temperature = temperature
self._embedding_dim = embedding_dim
self._similarity_type = similarity_type
self._pooling_type = pooling_type
self._noise = noise

self._steps_per_epoch = steps_per_epoch
self._shuffle_buffer = 1000
self._n_frames = None
self._n_bands = 64
self._n_channels = 1
self._input_shape = (-1, self._n_frames, self._n_bands, self._n_channels)

def _prepare_example(self, example):
"""Creates an example (anchor-positive) for instance discrimination."""
x = tf.math.l2_normalize(example["audio"], epsilon=1e-9)

waveform_a = data.extract_window(x)
mels_a = data.extract_log_mel_spectrogram(waveform_a)
frames_anchors = mels_a[Ellipsis, tf.newaxis]

waveform_p = data.extract_window(x)
waveform_p = waveform_p + (
self._noise * tf.random.normal(tf.shape(waveform_p)))
mels_p = data.extract_log_mel_spectrogram(waveform_p)
frames_positives = mels_p[Ellipsis, tf.newaxis]

return frames_anchors, frames_positives

def _get_ssl_task_data(self):
"""Prepares a dataset for contrastive self-supervised task."""
ds = data.get_self_supervised_data(self._ssl_dataset_name).repeat()
ds = ds.shuffle(self._shuffle_buffer, reshuffle_each_iteration=True)
ds = ds.map(
self._prepare_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(self._batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds

def train(self):
"""Trains a self-supervised model for contrastive learning."""

train_dataset = self._get_ssl_task_data()
train_dataset = self._strategy.experimental_distribute_dataset(
train_dataset)

with self._strategy.scope():
contrastive_network = network.get_contrastive_network(
embedding_dim=self._embedding_dim,
temperature=self._temperature,
pooling_type=self._pooling_type,
similarity_type=self._similarity_type)
contrastive_network.compile(
optimizer=tf.keras.optimizers.Adam(self._learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

ssl_model_dir = f"{self._ssl_dataset_name.value}/{self._experiment_id}/"
ckpt_path = os.path.join(self._model_path, ssl_model_dir, "ckpt_{epoch}")
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=ckpt_path, save_weights_only=True, monitor="loss")

backup_path = os.path.join(self._model_path, ssl_model_dir, "backup")
backandrestore_callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir=backup_path)

log_dir = os.path.join(self._model_path, "log", self._experiment_id)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

contrastive_network.fit(
train_dataset,
epochs=self._epochs,
steps_per_epoch=self._steps_per_epoch,
verbose=2,
callbacks=[
model_checkpoint_callback,
backandrestore_callback,
tensorboard_callback,
])
113 changes: 113 additions & 0 deletions cola/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides helper data related functions."""

import tensorflow as tf
import tensorflow_datasets as tfds

from cola import constants


def get_self_supervised_data(dataset=constants.Dataset.LBS,
shuffle_buffer=1000):
"""Reads TFDS data for self-supervised task."""

def _parse_example(audio, _):
return {"audio": tf.cast(audio, tf.float32) / float(tf.int16.max)}

if dataset == constants.Dataset.LBS:
split = "train_clean360"
else:
split = "train"

ds_train = tfds.load(
dataset.value, split=split, as_supervised=True)
ds_train = ds_train.shuffle(shuffle_buffer, reshuffle_each_iteration=True)
ds_train = ds_train.map(
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)

return ds_train


def get_downstream_dataset(dataset=constants.Dataset.VOXFORGE,
shuffle_buffer=1000):
"""Reads downstream task data from TFDS."""

def _parse_example(audio, label):
audio = tf.cast(audio, tf.float32) / float(tf.int16.max)
return {"audio": audio, "label": label}

(ds_train, ds_test), ds_info = tfds.load(
dataset.value,
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True)

ds_train = ds_train.shuffle(
shuffle_buffer, reshuffle_each_iteration=True).map(
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)

ds_test = ds_test.shuffle(
shuffle_buffer, reshuffle_each_iteration=True).map(
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)

return (ds_train, ds_test, ds_info.features["label"].num_classes)


def extract_log_mel_spectrogram(waveform,
sample_rate=16000,
frame_length=400,
frame_step=160,
fft_length=1024,
n_mels=64,
fmin=60.0,
fmax=7800.0):
"""Extract frames of log mel spectrogram from a raw waveform."""

stfts = tf.signal.stft(
waveform,
frame_length=frame_length,
frame_step=frame_step,
fft_length=fft_length)
spectrograms = tf.abs(stfts)

num_spectrogram_bins = stfts.shape[-1]
lower_edge_hertz, upper_edge_hertz, num_mel_bins = fmin, fmax, n_mels
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
upper_edge_hertz)
mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1)
mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(
linear_to_mel_weight_matrix.shape[-1:]))

mel_spectrograms = tf.clip_by_value(
mel_spectrograms,
clip_value_min=1e-5,
clip_value_max=1e8)

log_mel_spectrograms = tf.math.log(mel_spectrograms)

return log_mel_spectrograms


def extract_window(waveform, seg_length=16000):
"""Extracts a random segment from a waveform."""
padding = tf.maximum(seg_length - tf.shape(waveform)[0], 0)
left_pad = padding // 2
right_pad = padding - left_pad
padded_waveform = tf.pad(waveform, paddings=[[left_pad, right_pad]])
return tf.image.random_crop(padded_waveform, [seg_length])
Binary file added cola/images/overview_cola.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 23dc105

Please sign in to comment.