Skip to content

Commit

Permalink
upload musicfm
Browse files Browse the repository at this point in the history
  • Loading branch information
hainazhu committed May 21, 2024
1 parent 66eb87c commit bad097a
Show file tree
Hide file tree
Showing 7 changed files with 2,581 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/slam_llm/models/musicfm/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


253 changes: 253 additions & 0 deletions src/slam_llm/models/musicfm/model/musicfm_25hz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# MIT License
#
# Copyright 2023 ByteDance Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

import json
import random
import torch
from torch import nn
from einops import rearrange

from ..modules.random_quantizer import RandomProjectionQuantizer
from ..modules.features import MelSTFT
from ..modules.conv import Conv2dSubsampling


class MusicFM25Hz(nn.Module):
"""
MusicFM
Input: 128-band mel spectrogram
Frontend: 2-layer Residual convolution
Backend: 12-layer Conformer
Quantizer: a codebook for mel spectrogram
"""

def __init__(
self,
num_codebooks=1,
codebook_dim=16,
codebook_size=4096,
features=["melspec_2048"],
hop_length=240,
n_mels=128,
conv_dim=512,
encoder_dim=1024,
encoder_depth=12,
mask_hop=0.4,
mask_prob=0.6,
is_flash=False,
stat_path="./data/fma_stats.json",
model_path="./data/pretrained_fma.pt",
w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft",
):
super(MusicFM25Hz, self).__init__()

# global variables
self.hop_length = hop_length
self.mask_hop = mask_hop
self.mask_prob = mask_prob
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.features = features

# load feature mean / std stats
with open(stat_path, "r") as f:
self.stat = json.load(f)

# feature extractor
self.preprocessor_melspec_2048 = MelSTFT(
n_fft=2048, hop_length=hop_length, is_db=True
)

# random quantizer
seed = 142
for feature in self.features:
for i in range(num_codebooks):
setattr(
self,
f"quantizer_{feature}_{i}",
RandomProjectionQuantizer(
n_mels * 4, codebook_dim, codebook_size, seed=seed + i
),
)

# two residual convolution layers + one projection layer
self.conv = Conv2dSubsampling(
1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
)

# Conformer
if is_flash:
from modules.flash_conformer import (
Wav2Vec2ConformerEncoder,
Wav2Vec2ConformerConfig,
)
else:
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerEncoder,
Wav2Vec2ConformerConfig,
)
config = Wav2Vec2ConformerConfig.from_pretrained(
w2v2_config_path
)
config.num_hidden_layers = encoder_depth
config.hidden_size = encoder_dim

self.conformer = Wav2Vec2ConformerEncoder(config)

# projection
self.linear = nn.Linear(encoder_dim, codebook_size)

# loss function
self.loss = nn.CrossEntropyLoss()

# cls token (used for sequence classification)
random.seed(seed)
self.cls_token = nn.Parameter(torch.randn(encoder_dim))

# load model
if model_path:
S = torch.load(model_path)["state_dict"]
SS = {k[6:]: v for k, v in S.items()}
self.load_state_dict(SS, strict=True)

def masking(self, x):
"""random masking of 400ms with given probability"""
mx = x.clone()
b, t = mx.shape
len_masking_raw = int(24000 * self.mask_hop)
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)

# get random mask indices
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
time_domain_masked_indices = torch.nonzero(
start_indices.repeat_interleave(len_masking_raw, dim=1)
)
token_domain_masked_indices = torch.nonzero(
start_indices.repeat_interleave(len_masking_token, dim=1)
)

# mask with random values
masking_noise = (
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
) # 0 mean 0.1 std
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)

return mx, token_domain_masked_indices

@torch.no_grad()
def preprocessing(self, x, features):
"""extract classic audio features"""
# check precision
if x.dtype == torch.float16:
precision = 16
else:
precision = 32

out = {}
for key in features:
layer = getattr(self, "preprocessor_%s" % key)
out[key] = layer.float()(x.float())[..., :-1]
if precision == 16:
out[key] = out[key].half()
return out

def encoder(self, x):
"""2-layer conv + w2v-conformer"""
x = self.conv(x)
out = self.conformer(x, output_hidden_states=True)
hidden_emb = out["hidden_states"]
last_emb = out["last_hidden_state"]
logits = self.linear(last_emb)
logits = {
key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
for i, key in enumerate(self.features)
}
return logits, hidden_emb

@torch.no_grad()
def normalize(self, x):
"""normalize the input audio to have zero mean unit variance"""
for key in x.keys():
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
return x

@torch.no_grad()
def rearrange(self, x):
"""rearrange the batch to flatten every 4 steps"""
for key in x.keys():
if key == "chromagram":
x[key] = rearrange(x[key], "b f t -> b t f")
else:
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
return x

@torch.no_grad()
def tokenize(self, x):
out = {}
for key in x.keys():
layer = getattr(self, "quantizer_%s" % key)
out[key] = layer(x[key])
return out

def get_targets(self, x):
x = self.preprocessing(x, features=self.features)
x = self.normalize(x)
x = self.rearrange(x)
target_tokens = self.tokenize(x)
return target_tokens

def get_predictions(self, x):
# preprocessing
x = self.preprocessing(x, features=["melspec_2048"])
x = self.normalize(x)

# encoding
logits, hidden_emb = self.encoder(x["melspec_2048"])

return logits, hidden_emb

def get_latent(self, x, layer_ix=12):
_, hidden_states = self.get_predictions(x)
emb = hidden_states[layer_ix]
return emb

def get_loss(self, logits, target_tokens, masked_indices):
losses = {}
accuracies = {}
for key in logits.keys():
masked_logits = logits[key][tuple(masked_indices.t())]
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
losses[key] = self.loss(masked_logits, masked_tokens)
accuracies[key] = (
torch.sum(masked_logits.argmax(-1) == masked_tokens)
/ masked_tokens.numel()
)
return losses, accuracies

def forward(self, x):
# get target feature tokens
target_tokens = self.get_targets(x)

# masking
x, masked_indices = self.masking(x)

# forward
logits, hidden_emb = self.get_predictions(x)

# get loss
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)

return logits, hidden_emb, losses, accuracies
2 changes: 2 additions & 0 deletions src/slam_llm/models/musicfm/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


82 changes: 82 additions & 0 deletions src/slam_llm/models/musicfm/modules/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# MIT License
#
# Copyright 2023 ByteDance Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

from torch import nn
from einops import rearrange


class Res2dModule(nn.Module):
def __init__(self, idim, odim, stride=(2, 2)):
super(Res2dModule, self).__init__()
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
self.bn1 = nn.BatchNorm2d(odim)
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
self.bn2 = nn.BatchNorm2d(odim)
self.relu = nn.ReLU()

# residual
self.diff = False
if (idim != odim) or (stride[0] > 1):
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
self.bn3 = nn.BatchNorm2d(odim)
self.diff = True

def forward(self, x):
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
if self.diff:
x = self.bn3(self.conv3(x))
out = x + out
out = self.relu(out)
return out


class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
hdim (int): Hidden dimension.
odim (int): Output dimension.
strides (list): Sizes of strides.
n_bands (int): Number of frequency bands.
"""

def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()

self.conv = nn.Sequential(
Res2dModule(idim, hdim, (2, strides[0])),
Res2dModule(hdim, hdim, (2, strides[1])),
)
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)

def forward(self, x):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, idim, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
"""

if x.dim() == 3:
x = x.unsqueeze(1) # (b, c, f, t)
x = self.conv(x)
x = rearrange(x, "b c f t -> b t (c f)")
x = self.linear(x)
return x
45 changes: 45 additions & 0 deletions src/slam_llm/models/musicfm/modules/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# MIT License
#
# Copyright 2023 ByteDance Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

import torchaudio
from torch import nn


class MelSTFT(nn.Module):
def __init__(
self,
sample_rate=24000,
n_fft=2048,
hop_length=240,
n_mels=128,
is_db=False,
):
super(MelSTFT, self).__init__()

# spectrogram
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
)

# amplitude to decibel
self.is_db = is_db
if is_db:
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()

def forward(self, waveform):
if self.is_db:
return self.amplitude_to_db(self.mel_stft(waveform))
else:
return self.mel_stft(waveform)
Loading

0 comments on commit bad097a

Please sign in to comment.