forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Contrastive learning of general-purpose audio representations
PiperOrigin-RevId: 338221596
- Loading branch information
1 parent
1a96a0c
commit 23dc105
Showing
11 changed files
with
860 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Contrastive learning of general purpose audio representations | ||
|
||
 | ||
|
||
This Python library allows pre-training and fine-tuning contrastive embeddings of audio with the COLA method. | ||
In particular, one can: | ||
|
||
* Pre-train COLA embeddings, which use a simple contrastive learning method | ||
* Train a linear classifier on pre-trained embeddings | ||
* Train a supervised neural network from scratch | ||
* Initialize a classifier with pre-trained COLA embeddings and fine-tune on a new dataset | ||
|
||
## Dependencies | ||
* [TensorFlow](https://www.tensorflow.org/) | ||
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/overview) | ||
|
||
## Quickstart | ||
Training has three modes: | ||
|
||
* `SSL`, to pre-train a model with self-supervised contrastive learning | ||
* `DS` to fine-tune a pre-trained model on a downstream task | ||
* `SUP` to train a simple supervised system | ||
|
||
Pre-train a COLA embedding on a dataset from tensorflow_datasets (here, LibriSpeech): | ||
|
||
```bash | ||
python -m main --experiment_id=cola_pretrain --model_dir=/tmp/cola \ | ||
--training_mode=SSL --ssl_dataset=LBS --strategy=gpu | ||
``` | ||
Note that so far labels are not necessary. | ||
After pre-training, the model is saved in `/tmp/cola/librispeech/cola_pretrain`. | ||
|
||
One can train a linear classifier on these embeddings, on the Speech Commands dataset, in a supervised fashion: | ||
|
||
```bash | ||
python -m main --experiment_id=cola_downstream --ssl_checkpoint_id=cola_pretrain \ | ||
--model_dir=/tmp/cola --training_mode=DS --ssl_dataset=LBS --ds_dataset=SPCV2 \ | ||
--strategy=gpu --freeze_encoder=true | ||
``` | ||
|
||
The flags `--ssl_checkpoint_id` and `--ssl_dataset` indicate that the pre-trained model is stored in `/tmp/cola/librispeech/cola_pretrain`. | ||
|
||
Note the `--freeze_encoder` flag. If set to `False`, the entire network is fine-tuned. | ||
|
||
## Advanced usage | ||
Pre-training and fine-tuning only handle tfds datasets, for simplicity. One can easily use arbitrary datasets by overriding the `get_self_supervised_data` and `get_downstream_dataset` methods in `data.py`. | ||
## Reference | ||
If you use this repository, please consider citing: | ||
|
||
``` | ||
@article{saeed2020, | ||
title={Contrastive learning of general purpose audio representations}, | ||
author={Aaqib Saeed and David Grangier and Neil Zeghidour}, | ||
year={2020}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Defines constants used across module.""" | ||
|
||
import enum | ||
|
||
|
||
@enum.unique | ||
class Dataset(enum.Enum): | ||
"""Look up for dataset names.""" | ||
|
||
LBS = "librispeech" | ||
|
||
BSD = "birdsong_detection" | ||
|
||
MUSAN = "musan" | ||
|
||
AS = "audioset" | ||
|
||
TUT = "tut_2018" | ||
|
||
SPCV1 = "speech_commands_v1" | ||
|
||
SPCV2 = "speech_commands" | ||
|
||
NSYNTH_INST = "nsynth_instrument_family" | ||
|
||
VOXCELEB = "voxceleb" | ||
|
||
VOXFORGE = "voxforge" | ||
|
||
CREMA_D = "crema_d" | ||
|
||
|
||
@enum.unique | ||
class TrainingMode(enum.Enum): | ||
"""Look up for model training modes.""" | ||
|
||
SSL = "self_supervised" | ||
|
||
SUP = "supervised" | ||
|
||
RND = "random" | ||
|
||
DS = "downstream" | ||
|
||
|
||
@enum.unique | ||
class SimilarityMeasure(enum.Enum): | ||
"""Look up for similarity measure in contrastive model.""" | ||
|
||
DOT = "dot_product" | ||
|
||
BILINEAR = "bilinear_product" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Self-supervised model for contrastive learning task.""" | ||
|
||
import os | ||
|
||
import tensorflow as tf | ||
|
||
from cola import constants | ||
from cola import data | ||
from cola import network | ||
|
||
|
||
class ContrastiveModel: | ||
"""Provides functionality for self-supervised constrastive learning model.""" | ||
|
||
def __init__(self, | ||
strategy, | ||
ssl_dataset_name, | ||
ds_dataset_name, | ||
model_path, | ||
experiment_id, | ||
batch_size, | ||
epochs, learning_rate, | ||
embedding_dim, | ||
temperature, | ||
similarity_type, | ||
pooling_type, | ||
noise, | ||
steps_per_epoch = 1000): | ||
"""Initializes a contrastive model object.""" | ||
|
||
self._strategy = strategy | ||
self._ssl_dataset_name = ssl_dataset_name | ||
self._ds_dataset_name = ds_dataset_name | ||
self._model_path = model_path | ||
self._experiment_id = experiment_id | ||
|
||
self._batch_size = batch_size | ||
self._epochs = epochs | ||
self._learning_rate = learning_rate | ||
self._temperature = temperature | ||
self._embedding_dim = embedding_dim | ||
self._similarity_type = similarity_type | ||
self._pooling_type = pooling_type | ||
self._noise = noise | ||
|
||
self._steps_per_epoch = steps_per_epoch | ||
self._shuffle_buffer = 1000 | ||
self._n_frames = None | ||
self._n_bands = 64 | ||
self._n_channels = 1 | ||
self._input_shape = (-1, self._n_frames, self._n_bands, self._n_channels) | ||
|
||
def _prepare_example(self, example): | ||
"""Creates an example (anchor-positive) for instance discrimination.""" | ||
x = tf.math.l2_normalize(example["audio"], epsilon=1e-9) | ||
|
||
waveform_a = data.extract_window(x) | ||
mels_a = data.extract_log_mel_spectrogram(waveform_a) | ||
frames_anchors = mels_a[Ellipsis, tf.newaxis] | ||
|
||
waveform_p = data.extract_window(x) | ||
waveform_p = waveform_p + ( | ||
self._noise * tf.random.normal(tf.shape(waveform_p))) | ||
mels_p = data.extract_log_mel_spectrogram(waveform_p) | ||
frames_positives = mels_p[Ellipsis, tf.newaxis] | ||
|
||
return frames_anchors, frames_positives | ||
|
||
def _get_ssl_task_data(self): | ||
"""Prepares a dataset for contrastive self-supervised task.""" | ||
ds = data.get_self_supervised_data(self._ssl_dataset_name).repeat() | ||
ds = ds.shuffle(self._shuffle_buffer, reshuffle_each_iteration=True) | ||
ds = ds.map( | ||
self._prepare_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
ds = ds.batch(self._batch_size, drop_remainder=True) | ||
ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
return ds | ||
|
||
def train(self): | ||
"""Trains a self-supervised model for contrastive learning.""" | ||
|
||
train_dataset = self._get_ssl_task_data() | ||
train_dataset = self._strategy.experimental_distribute_dataset( | ||
train_dataset) | ||
|
||
with self._strategy.scope(): | ||
contrastive_network = network.get_contrastive_network( | ||
embedding_dim=self._embedding_dim, | ||
temperature=self._temperature, | ||
pooling_type=self._pooling_type, | ||
similarity_type=self._similarity_type) | ||
contrastive_network.compile( | ||
optimizer=tf.keras.optimizers.Adam(self._learning_rate), | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) | ||
|
||
ssl_model_dir = f"{self._ssl_dataset_name.value}/{self._experiment_id}/" | ||
ckpt_path = os.path.join(self._model_path, ssl_model_dir, "ckpt_{epoch}") | ||
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( | ||
filepath=ckpt_path, save_weights_only=True, monitor="loss") | ||
|
||
backup_path = os.path.join(self._model_path, ssl_model_dir, "backup") | ||
backandrestore_callback = tf.keras.callbacks.experimental.BackupAndRestore( | ||
backup_dir=backup_path) | ||
|
||
log_dir = os.path.join(self._model_path, "log", self._experiment_id) | ||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir) | ||
|
||
contrastive_network.fit( | ||
train_dataset, | ||
epochs=self._epochs, | ||
steps_per_epoch=self._steps_per_epoch, | ||
verbose=2, | ||
callbacks=[ | ||
model_checkpoint_callback, | ||
backandrestore_callback, | ||
tensorboard_callback, | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Provides helper data related functions.""" | ||
|
||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
from cola import constants | ||
|
||
|
||
def get_self_supervised_data(dataset=constants.Dataset.LBS, | ||
shuffle_buffer=1000): | ||
"""Reads TFDS data for self-supervised task.""" | ||
|
||
def _parse_example(audio, _): | ||
return {"audio": tf.cast(audio, tf.float32) / float(tf.int16.max)} | ||
|
||
if dataset == constants.Dataset.LBS: | ||
split = "train_clean360" | ||
else: | ||
split = "train" | ||
|
||
ds_train = tfds.load( | ||
dataset.value, split=split, as_supervised=True) | ||
ds_train = ds_train.shuffle(shuffle_buffer, reshuffle_each_iteration=True) | ||
ds_train = ds_train.map( | ||
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
|
||
return ds_train | ||
|
||
|
||
def get_downstream_dataset(dataset=constants.Dataset.VOXFORGE, | ||
shuffle_buffer=1000): | ||
"""Reads downstream task data from TFDS.""" | ||
|
||
def _parse_example(audio, label): | ||
audio = tf.cast(audio, tf.float32) / float(tf.int16.max) | ||
return {"audio": audio, "label": label} | ||
|
||
(ds_train, ds_test), ds_info = tfds.load( | ||
dataset.value, | ||
split=["train", "test"], | ||
shuffle_files=True, | ||
as_supervised=True, | ||
with_info=True) | ||
|
||
ds_train = ds_train.shuffle( | ||
shuffle_buffer, reshuffle_each_iteration=True).map( | ||
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
|
||
ds_test = ds_test.shuffle( | ||
shuffle_buffer, reshuffle_each_iteration=True).map( | ||
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
|
||
return (ds_train, ds_test, ds_info.features["label"].num_classes) | ||
|
||
|
||
def extract_log_mel_spectrogram(waveform, | ||
sample_rate=16000, | ||
frame_length=400, | ||
frame_step=160, | ||
fft_length=1024, | ||
n_mels=64, | ||
fmin=60.0, | ||
fmax=7800.0): | ||
"""Extract frames of log mel spectrogram from a raw waveform.""" | ||
|
||
stfts = tf.signal.stft( | ||
waveform, | ||
frame_length=frame_length, | ||
frame_step=frame_step, | ||
fft_length=fft_length) | ||
spectrograms = tf.abs(stfts) | ||
|
||
num_spectrogram_bins = stfts.shape[-1] | ||
lower_edge_hertz, upper_edge_hertz, num_mel_bins = fmin, fmax, n_mels | ||
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( | ||
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, | ||
upper_edge_hertz) | ||
mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1) | ||
mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate( | ||
linear_to_mel_weight_matrix.shape[-1:])) | ||
|
||
mel_spectrograms = tf.clip_by_value( | ||
mel_spectrograms, | ||
clip_value_min=1e-5, | ||
clip_value_max=1e8) | ||
|
||
log_mel_spectrograms = tf.math.log(mel_spectrograms) | ||
|
||
return log_mel_spectrograms | ||
|
||
|
||
def extract_window(waveform, seg_length=16000): | ||
"""Extracts a random segment from a waveform.""" | ||
padding = tf.maximum(seg_length - tf.shape(waveform)[0], 0) | ||
left_pad = padding // 2 | ||
right_pad = padding - left_pad | ||
padded_waveform = tf.pad(waveform, paddings=[[left_pad, right_pad]]) | ||
return tf.image.random_crop(padded_waveform, [seg_length]) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.