Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync #98

Merged
merged 15 commits into from
May 31, 2024
41 changes: 41 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime

USER root

ARG DEBIAN_FRONTEND=noninteractive

LABEL github_repo="https://github.com/ddlBoJack/SLAM-LLM"

RUN set -x \
&& apt-get update \
&& apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim ninja-build\
&& apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean

RUN pip install --no-cache-dir packaging editdistance gpustat wandb einops debugpy tqdm soundfile matplotlib scipy sentencepiece pandas \
&& pip install --no-cache-dir torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

WORKDIR /workspace

RUN git clone https://github.com/huggingface/transformers.git \
&& cd transformers \
&& git checkout tags/v4.35.2 \
&& pip install --no-cache-dir -e .

RUN git clone https://github.com/huggingface/peft.git \
&& cd peft \
&& git checkout tags/v0.6.0 \
&& pip install --no-cache-dir -e .

RUN git clone https://github.com/pytorch/fairseq \
&& cd fairseq \
&& pip install --no-cache-dir --editable ./

RUN git clone https://github.com/ddlBoJack/SLAM-LLM.git \
&& cd SLAM-LLM \
&& pip install --no-cache-dir -e .

ENV SHELL=/bin/bash

WORKDIR /workspace/SLAM-LLM
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Ziyang Ma

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,30 @@ pip install -e .
cd ..
git clone https://github.com/huggingface/peft.git
cd peft
git checkout tags/0.6.0
git checkout tags/v0.6.0
pip install -e .
cd ..
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
git clone git@github.com:ddlBoJack/SLAM-LLM.git
git clone https://github.com/ddlBoJack/SLAM-LLM.git
cd SLAM-LLM
pip install -e .
```

For some examples, you may need to use `fairseq`, the command line is as follows:
```
# you need to install fairseq before SLAM-LLM
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./
```
We also provide a docker image for convenience:
```shell
# build docker image
docker build -t slam-llm:latest .

# run docker image with gpu
docker run -it --gpus all --name slam --shm-size=256g slam-llm:latest /bin/bash
```
# Usage
## List of Recipes
We provide reference implementations of various LLM-based speech, audio, and music tasks:
Expand Down Expand Up @@ -89,4 +97,4 @@ command-line (shell file) > Hydra configuration (yaml file) > dataclass configur
# Acknowledge
- We borrow code from [Llama-Recipes](https://github.com/meta-llama/llama-recipes) for the training process.
- We borrow code from [Fairseq](https://github.com/facebookresearch/fairseq) for deepspeed configuration.
- We thank the contributors for providing diverse recipes.
- We thank the contributors for providing diverse recipes.
2 changes: 1 addition & 1 deletion examples/mc_musiccaps/mir_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List
@dataclass
class ModelConfig:
file: str = "examples/music_caption/model/slam_model_mir.py:model_factory"
file: str = "examples/mc_musiccaps/model/slam_model_mir.py:model_factory"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export TOKENIZERS_PARALLELISM=false

run_dir=$PWD
cd $run_dir
code_dir=examples/music_caption
code_dir=examples/mc_musiccaps


music_encoder_path=path/to/pretrained/musicfm/pretrained_msd.pt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export OMP_NUM_THREADS=1

run_dir=$PWD
cd $run_dir
code_dir=examples/music_caption
code_dir=examples/mc_musiccaps

music_encoder_path=path/to/pretrained/musicfm/pretrained_msd.pt
music_encoder_stat_path=path/to/pretrained/musicfm/msd_stats.json
Expand Down
2 changes: 2 additions & 0 deletions src/slam_llm/models/musicfm/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


253 changes: 253 additions & 0 deletions src/slam_llm/models/musicfm/model/musicfm_25hz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# MIT License
#
# Copyright 2023 ByteDance Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

import json
import random
import torch
from torch import nn
from einops import rearrange

from ..modules.random_quantizer import RandomProjectionQuantizer
from ..modules.features import MelSTFT
from ..modules.conv import Conv2dSubsampling


class MusicFM25Hz(nn.Module):
"""
MusicFM

Input: 128-band mel spectrogram
Frontend: 2-layer Residual convolution
Backend: 12-layer Conformer
Quantizer: a codebook for mel spectrogram
"""

def __init__(
self,
num_codebooks=1,
codebook_dim=16,
codebook_size=4096,
features=["melspec_2048"],
hop_length=240,
n_mels=128,
conv_dim=512,
encoder_dim=1024,
encoder_depth=12,
mask_hop=0.4,
mask_prob=0.6,
is_flash=False,
stat_path="./data/fma_stats.json",
model_path="./data/pretrained_fma.pt",
w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft",
):
super(MusicFM25Hz, self).__init__()

# global variables
self.hop_length = hop_length
self.mask_hop = mask_hop
self.mask_prob = mask_prob
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.features = features

# load feature mean / std stats
with open(stat_path, "r") as f:
self.stat = json.load(f)

# feature extractor
self.preprocessor_melspec_2048 = MelSTFT(
n_fft=2048, hop_length=hop_length, is_db=True
)

# random quantizer
seed = 142
for feature in self.features:
for i in range(num_codebooks):
setattr(
self,
f"quantizer_{feature}_{i}",
RandomProjectionQuantizer(
n_mels * 4, codebook_dim, codebook_size, seed=seed + i
),
)

# two residual convolution layers + one projection layer
self.conv = Conv2dSubsampling(
1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
)

# Conformer
if is_flash:
from modules.flash_conformer import (
Wav2Vec2ConformerEncoder,
Wav2Vec2ConformerConfig,
)
else:
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerEncoder,
Wav2Vec2ConformerConfig,
)
config = Wav2Vec2ConformerConfig.from_pretrained(
w2v2_config_path
)
config.num_hidden_layers = encoder_depth
config.hidden_size = encoder_dim

self.conformer = Wav2Vec2ConformerEncoder(config)

# projection
self.linear = nn.Linear(encoder_dim, codebook_size)

# loss function
self.loss = nn.CrossEntropyLoss()

# cls token (used for sequence classification)
random.seed(seed)
self.cls_token = nn.Parameter(torch.randn(encoder_dim))

# load model
if model_path:
S = torch.load(model_path)["state_dict"]
SS = {k[6:]: v for k, v in S.items()}
self.load_state_dict(SS, strict=True)

def masking(self, x):
"""random masking of 400ms with given probability"""
mx = x.clone()
b, t = mx.shape
len_masking_raw = int(24000 * self.mask_hop)
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)

# get random mask indices
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
time_domain_masked_indices = torch.nonzero(
start_indices.repeat_interleave(len_masking_raw, dim=1)
)
token_domain_masked_indices = torch.nonzero(
start_indices.repeat_interleave(len_masking_token, dim=1)
)

# mask with random values
masking_noise = (
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
) # 0 mean 0.1 std
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)

return mx, token_domain_masked_indices

@torch.no_grad()
def preprocessing(self, x, features):
"""extract classic audio features"""
# check precision
if x.dtype == torch.float16:
precision = 16
else:
precision = 32

out = {}
for key in features:
layer = getattr(self, "preprocessor_%s" % key)
out[key] = layer.float()(x.float())[..., :-1]
if precision == 16:
out[key] = out[key].half()
return out

def encoder(self, x):
"""2-layer conv + w2v-conformer"""
x = self.conv(x)
out = self.conformer(x, output_hidden_states=True)
hidden_emb = out["hidden_states"]
last_emb = out["last_hidden_state"]
logits = self.linear(last_emb)
logits = {
key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
for i, key in enumerate(self.features)
}
return logits, hidden_emb

@torch.no_grad()
def normalize(self, x):
"""normalize the input audio to have zero mean unit variance"""
for key in x.keys():
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
return x

@torch.no_grad()
def rearrange(self, x):
"""rearrange the batch to flatten every 4 steps"""
for key in x.keys():
if key == "chromagram":
x[key] = rearrange(x[key], "b f t -> b t f")
else:
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
return x

@torch.no_grad()
def tokenize(self, x):
out = {}
for key in x.keys():
layer = getattr(self, "quantizer_%s" % key)
out[key] = layer(x[key])
return out

def get_targets(self, x):
x = self.preprocessing(x, features=self.features)
x = self.normalize(x)
x = self.rearrange(x)
target_tokens = self.tokenize(x)
return target_tokens

def get_predictions(self, x):
# preprocessing
x = self.preprocessing(x, features=["melspec_2048"])
x = self.normalize(x)

# encoding
logits, hidden_emb = self.encoder(x["melspec_2048"])

return logits, hidden_emb

def get_latent(self, x, layer_ix=12):
_, hidden_states = self.get_predictions(x)
emb = hidden_states[layer_ix]
return emb

def get_loss(self, logits, target_tokens, masked_indices):
losses = {}
accuracies = {}
for key in logits.keys():
masked_logits = logits[key][tuple(masked_indices.t())]
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
losses[key] = self.loss(masked_logits, masked_tokens)
accuracies[key] = (
torch.sum(masked_logits.argmax(-1) == masked_tokens)
/ masked_tokens.numel()
)
return losses, accuracies

def forward(self, x):
# get target feature tokens
target_tokens = self.get_targets(x)

# masking
x, masked_indices = self.masking(x)

# forward
logits, hidden_emb = self.get_predictions(x)

# get loss
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)

return logits, hidden_emb, losses, accuracies
2 changes: 2 additions & 0 deletions src/slam_llm/models/musicfm/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


Loading
Loading