From 1a06f2c0219e96b843414b0de4085d7e02e73573 Mon Sep 17 00:00:00 2001 From: hainazhu Date: Sun, 19 May 2024 21:48:49 +0800 Subject: [PATCH] music caption --- examples/music_caption/README.md | 31 ++ examples/music_caption/conf/ds_config.json | 19 ++ examples/music_caption/conf/prompt.yaml | 3 + .../music_caption/deepspeed_finetune_mir.py | 47 +++ examples/music_caption/finetune_mir.py | 45 +++ examples/music_caption/inference_mir_batch.py | 53 +++ examples/music_caption/mir_config.py | 130 +++++++ .../music_caption/model/slam_model_mir.py | 156 +++++++++ .../decode_musicfm_linear_vicuna_7b_10s.sh | 67 ++++ .../finetune_musicfm_linear_vicuna_7b_10s.sh | 84 +++++ src/slam_llm/datasets/mir_dataset.py | 320 ++++++++++++++++++ src/slam_llm/models/encoder.py | 23 +- src/slam_llm/models/slam_model.py | 5 + 13 files changed, 982 insertions(+), 1 deletion(-) create mode 100644 examples/music_caption/README.md create mode 100644 examples/music_caption/conf/ds_config.json create mode 100644 examples/music_caption/conf/prompt.yaml create mode 100644 examples/music_caption/deepspeed_finetune_mir.py create mode 100644 examples/music_caption/finetune_mir.py create mode 100644 examples/music_caption/inference_mir_batch.py create mode 100644 examples/music_caption/mir_config.py create mode 100644 examples/music_caption/model/slam_model_mir.py create mode 100644 examples/music_caption/scripts/decode_musicfm_linear_vicuna_7b_10s.sh create mode 100644 examples/music_caption/scripts/finetune_musicfm_linear_vicuna_7b_10s.sh create mode 100644 src/slam_llm/datasets/mir_dataset.py diff --git a/examples/music_caption/README.md b/examples/music_caption/README.md new file mode 100644 index 00000000..f18d26ae --- /dev/null +++ b/examples/music_caption/README.md @@ -0,0 +1,31 @@ +# Music Caption + +## Performance and checkpoints +Here is a recipe for music captioning, using MusicFM as encoder. We only train the linear projector. For more about MusicFM and its checkpoints, please refer to [this repository](https://github.com/minzwon/musicfm). + +The following results are obtained by training on the LP-MusicCaps-MC training set and evaluating on the LP-MusicCaps-MC test set. +Encoder | Projector | LLM | BLEU-1 | METEOR | SPICE | SPIDER +|---|---|---|---|---|---|--- +[MusicFM(pretrained with MSD)](https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt) | [Linear](https://drive.google.com/file/d/1-9pob6QvJRoq5Dy-LZbiDfF6Q7QRO8Au/view?usp=sharing)(~18.88M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 25.6 | 10.0 | 8.7 | 6.9 + + +## Data preparation +You need to prepare the data jsonl in this format. Note that you may need to pre-extract the sample rate and duration of audio files for better loading efficiency. +``` +{"key": "[-0Gj8-vB1q4]-[30-40]", "source": "path/to/MusicCaps/wav/[-0Gj8-vB1q4]-[30-40].wav", "target": "The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.", "duration": 10.0, "sample_rate": 48000} +... +{"key": "[-0vPFx-wRRI]-[30-40]", "source": "path/to/MusicCaps/wav/[-0vPFx-wRRI]-[30-40].wav", "target": "a male voice is singing a melody with changing tempos while snipping his fingers rhythmically. The recording sounds like it has been recorded in an empty room. This song may be playing, practicing snipping and singing along.", "duration": 10.0, "sample_rate": 48000} +``` + +## Decode with checkpoints +``` +bash decode_musicfm_linear_vicuna_7b_10s.sh +``` +Modify the path including `music_encoder_path`, `music_encoder_stat_path`, `music_encoder_config_path`(if specified), `ckpt_path`, `val_data_path` and `decode_log` in the script when you run the shell script. + +## Train a new model + +### Use MusicFM as encoder for music modality. +``` +finetune_musicfm_linear_vicuna_7b_10s.sh +``` diff --git a/examples/music_caption/conf/ds_config.json b/examples/music_caption/conf/ds_config.json new file mode 100644 index 00000000..7ea70e4a --- /dev/null +++ b/examples/music_caption/conf/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + } + } +} \ No newline at end of file diff --git a/examples/music_caption/conf/prompt.yaml b/examples/music_caption/conf/prompt.yaml new file mode 100644 index 00000000..8d134a53 --- /dev/null +++ b/examples/music_caption/conf/prompt.yaml @@ -0,0 +1,3 @@ +dataset_config: + # we put prompt here, because the hydra override in shell script only support a small subset of chars + prompt: "Describe this music." diff --git a/examples/music_caption/deepspeed_finetune_mir.py b/examples/music_caption/deepspeed_finetune_mir.py new file mode 100644 index 00000000..0edd3637 --- /dev/null +++ b/examples/music_caption/deepspeed_finetune_mir.py @@ -0,0 +1,47 @@ +from slam_llm.pipeline.finetune_deepspeed import main as train +from slam_llm.utils.deepspeed_utils import deepspeed_main_wrapper + +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + deepspeed_config: str = field(default="examples/asr_librispeech/conf/ds_config.json", metadata={"help": "The metric for evaluation"}) + + +@deepspeed_main_wrapper(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/music_caption/finetune_mir.py b/examples/music_caption/finetune_mir.py new file mode 100644 index 00000000..8077302a --- /dev/null +++ b/examples/music_caption/finetune_mir.py @@ -0,0 +1,45 @@ +from slam_llm.pipeline.finetune import main as train + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from mir_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/music_caption/inference_mir_batch.py b/examples/music_caption/inference_mir_batch.py new file mode 100644 index 00000000..2df0ef26 --- /dev/null +++ b/examples/music_caption/inference_mir_batch.py @@ -0,0 +1,53 @@ +from slam_llm.pipeline.inference_batch import main as inference + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional +from mir_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/music_caption/mir_config.py b/examples/music_caption/mir_config.py new file mode 100644 index 00000000..5d2e242e --- /dev/null +++ b/examples/music_caption/mir_config.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass, field +from typing import Optional, List +@dataclass +class ModelConfig: + file: str = "examples/music_caption/model/slam_model_mir.py:model_factory" + llm_name: str = "vicuna-13b-v1.5" + llm_path: str = "PATH/to/LLAMA/7B" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + encoder_name: Optional[str] = "mulan" + encoder_ds_rate: int = 2 + encoder_path: Optional[str] = None + encoder_config_path: Optional[str] = None + encoder_stat_path: Optional[str] = None + encoder_layer_idx: Optional[int] = None + encoder_dim: int = 768 + encoder_projector: str = "linear" + encoder_projector_ds_rate: int = 5 + modal: str = "audio" + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether inpit is normalized, used for models such as wavlm" + }) + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name:str = "PATH/to/LLAMA/7B" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training:int = 4 + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) # + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:int = 1 + + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "<|ASR|>" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }) + freeze_encoder:bool = True # False + +@dataclass +class DataConfig: + dataset: str = "mir_dataset" + file: str = "src/slam_llm/datasets/mir_dataset.py:get_mir_dataset" + train_data_path: Optional[str] = None + val_data_path: Optional[str] = None + fixed_duration: float = 10.0 + audio_label_freq: int = 75 + fixed_audio_token_num: Optional[int] = None + sample_rate: int = 24000 + train_split: str = "train" + test_split:str = "validation" + prompt: Optional[str] = None + data_path: Optional[str] = None + max_words: Optional[int] = None + max_mel: Optional[float] = None + fix_length_audio: int = -1 + inference_mode:bool = False + input_type: str = field(default="raw", metadata={ + "help":"Use raw when input is wav, mel when for whisper" + }) + mel_size: int = field(default=80, metadata={ + "help": "80 for whisper large v1 and v2, 128 for v3" + }) + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether inpit is normalized, used for models such as wavlm" + }) + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/root/test.log" + log_interval: int = 5 diff --git a/examples/music_caption/model/slam_model_mir.py b/examples/music_caption/model/slam_model_mir.py new file mode 100644 index 00000000..699a4b19 --- /dev/null +++ b/examples/music_caption/model/slam_model_mir.py @@ -0,0 +1,156 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_mir( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_mir(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for mir model + + device = kwargs.get("device", "cuda") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/music_caption/scripts/decode_musicfm_linear_vicuna_7b_10s.sh b/examples/music_caption/scripts/decode_musicfm_linear_vicuna_7b_10s.sh new file mode 100644 index 00000000..a6253bfb --- /dev/null +++ b/examples/music_caption/scripts/decode_musicfm_linear_vicuna_7b_10s.sh @@ -0,0 +1,67 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +run_dir=$PWD +cd $run_dir +code_dir=examples/music_caption + + +music_encoder_path=path/to/pretrained/musicfm/pretrained_msd.pt +music_encoder_stat_path=path/to/pretrained/musicfm/msd_stats.json +music_encoder_config_path=facebook/wav2vec2-conformer-rope-large-960h-ft + +llm_path=lmsys/vicuna-7b-v1.5 + + +output_dir=/root/cq7_haina/save/music-caption/musicfm_vicuna7b_mc_10s_20240513_15:07:28 +ckpt_path=$output_dir/mir_epoch_3_step_900 + + +split=LP-MusicCaps-MC.test.exist +val_data_path=/root/cq7_haina/data/LP-MusicCaps-MC/${split}.jsonl +decode_log=$ckpt_path/decode_${split}_avg + + +python $code_dir/inference_mir_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_name=vicuna-7b-v1.5 \ + ++model_config.llm_dim=4096 \ + ++model_config.llm_path=$llm_path \ + ++model_config.encoder_name=musicfm \ + ++dataset_config.normalize=false \ + ++model_config.encoder_layer_idx=9 \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$music_encoder_path \ + ++model_config.encoder_stat_path=$music_encoder_stat_path \ + ++model_config.encoder_config_path=$music_encoder_config_path \ + ++model_config.encoder_dim=1024 \ + ++model_config.encoder_projector=linear \ + ++dataset_config.dataset=mir_dataset \ + ++dataset_config.val_data_path=$val_data_path \ + ++dataset_config.input_type=raw \ + ++dataset_config.fixed_duration=10.0 \ + ++dataset_config.audio_label_freq=25 \ + ++dataset_config.inference_mode=true \ + ++train_config.model_name=mir \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.batching_strategy=custom \ + ++train_config.num_epochs=1 \ + ++train_config.val_batch_size=1 \ + ++train_config.num_workers_dataloader=0 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt \ + # ++peft_ckpt=$ckpt_path \ + # ++train_config.use_peft=true \ + # ++train_config.peft_config.r=32 \ + # ++dataset_config.normalize=true \ + # ++model_config.encoder_projector=q-former \ + # ++dataset_config.fix_length_audio=64 \ + diff --git a/examples/music_caption/scripts/finetune_musicfm_linear_vicuna_7b_10s.sh b/examples/music_caption/scripts/finetune_musicfm_linear_vicuna_7b_10s.sh new file mode 100644 index 00000000..4567450e --- /dev/null +++ b/examples/music_caption/scripts/finetune_musicfm_linear_vicuna_7b_10s.sh @@ -0,0 +1,84 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0,1 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +run_dir=$PWD +cd $run_dir +code_dir=examples/music_caption + +music_encoder_path=path/to/pretrained/musicfm/pretrained_msd.pt +music_encoder_stat_path=path/to/pretrained/musicfm/msd_stats.json +music_encoder_config_path=facebook/wav2vec2-conformer-rope-large-960h-ft + +llm_path=lmsys/vicuna-7b-v1.5 + + +train_data_path=/root/cq7_haina/data/LP-MusicCaps-MC/LP-MusicCaps-MC.train.exist.jsonl +val_data_path=/root/cq7_haina/data/LP-MusicCaps-MC/LP-MusicCaps-MC.valid.exist.jsonl + +output_dir=/root/cq7_haina/save/music-caption/musicfm_vicuna7b_mc_10s_$(date +"%Y%m%d_%H:%M:%S") + + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_path=$llm_path \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=musicfm \ +++model_config.normalize=false \ +++model_config.encoder_layer_idx=9 \ +++dataset_config.normalize=false \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$music_encoder_path \ +++model_config.encoder_stat_path=$music_encoder_stat_path \ +++model_config.encoder_config_path=$music_encoder_config_path \ +++model_config.encoder_dim=1024 \ +++model_config.encoder_projector=linear \ +++dataset_config.dataset=mir_dataset \ +++dataset_config.train_data_path=$train_data_path \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.input_type=raw \ +++dataset_config.fixed_duration=10.0 \ +++dataset_config.audio_label_freq=25 \ +++train_config.model_name=mir \ +++train_config.num_epochs=10000 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=3000 \ +++train_config.batch_size_training=1 \ +++train_config.val_batch_size=1 \ +++train_config.num_workers_dataloader=0 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_mir.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 2 \ + --master_port=29503 \ + $code_dir/finetune_mir.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=false \ + $hydra_args +fi diff --git a/src/slam_llm/datasets/mir_dataset.py b/src/slam_llm/datasets/mir_dataset.py new file mode 100644 index 00000000..5d12127c --- /dev/null +++ b/src/slam_llm/datasets/mir_dataset.py @@ -0,0 +1,320 @@ +import os.path as osp +import random +import json, yaml +import copy + +import numpy as np +from scipy import signal +import soundfile as sf + +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper +from slam_llm.utils.compute_utils import calculate_output_length_1d +from typing import Tuple +import math + + +class RandCropReader(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if self.n_samples is None: + chunk, cur_sample_rate = torchaudio.load(filename) + t_start = 0. + t_end = 1.0 + offset = 0 + elif(duration<(float(self.n_samples)/self.sample_rate+1)): + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + else: + offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + if self.n_samples is None: + pass + elif chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + +class MirDatasetJsonl(torch.utils.data.Dataset): + + def __init__(self, + dataset_config, + tokenizer=None, + split='train', + ): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + # data_parallel_size = dist.get_world_size() + data_parallel_size = 1 + + # self.data_list = contents + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = dataset_config.get("prompt", None) + self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3 + # self.prompt_library = [ + # "Begin by converting the spoken words into written text. ", + # "Can you transcribe the speech into a written format? ", + # "Focus on translating the audible content into text. ", + # "Transcribe the speech by carefully listening to it. ", + # "Would you kindly write down the content of the speech? ", + # "Analyze the speech and create a written transcription. ", + # "Engage with the speech to produce a text-based version. ", + # "Can you document the speech in written form? ", + # "Transform the spoken words into text accurately. ", + # "How about putting the speech's content into writing? " + # ] + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "{}" + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + + self.sample_rate = dataset_config.get('sample_rate', 24000) + self.fixed_duration = dataset_config.get('fixed_duration', 10.0) + self.inference_mode = dataset_config.get("inference_mode", False) + self.audio_label_freq = dataset_config.get("audio_label_freq", self.sample_rate//320) + + self.reader = RandCropReader( + int(self.fixed_duration * self.sample_rate), self.sample_rate + ) # int(self.fixed_duration * self.sample_rate) if not self.inference_mode else None + + + self.normalize = dataset_config.get("normalize", False) + self.input_type = dataset_config.get("input_type", None) + self.fixed_audio_token_num = dataset_config.get("fixed_audio_token_num", None) + assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" + + self.data_list = [] + if split == "train": + with open(dataset_config.train_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + else: + with open(dataset_config.val_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + + # # debug + # with open(dataset_config.train_data_path, encoding='utf-8') as fin: + # for line in fin: + # data_dict = json.loads(line.strip()) + # self.data_list.append(data_dict) + # if split == "train": + # self.data_list = self.data_list[:80] + # else: + # self.data_list = self.data_list[80:100] + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + data_dict = self.data_list[index] + audio_path = data_dict.get("source") + target = data_dict.get("target", None) + task = data_dict.get("prompt", "ASR") + key = data_dict.get("key", None) + + # audio_raw = whisper.load_audio(audio_path) + if data_dict.get('duration', None) is None: + wav, sr = torchaudio.load(audio_path) + dur = wav.shape[-1] / sr + else: + dur = data_dict.get('duration') + sr = data_dict.get('sample_rate') + audio_raw, *ignored = self.reader(audio_path, dur, sr) + if self.input_type == "raw": + # audio_raw = torch.from_numpy(audio_raw) + if len(audio_raw.shape) > 1: + audio_raw = audio_raw.squeeze_(0) + if self.normalize: + audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape) + # audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample + audio_length = len(audio_raw) // (self.sample_rate // self.audio_label_freq) + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + # if self.inference_mode: + # audio_length = 150 + elif self.input_type == "mel": + audio_raw = whisper.pad_or_trim(audio_raw) + # audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30] + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + # audio_length = calculate_output_length_1d(audio_length, 5, 5, 0) # ad-hoc for 5x cov1d downsample + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + + if self.fixed_audio_token_num: + audio_length = self.fixed_audio_token_num + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + prompt = self.prompt + if prompt is None: + # prompt = random.choice(self.prompt_library) + prompt = 'Describe this music.' + prompt = self.prompt_template.format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_length": audio_length, + "key": key, + "target": target, + } + + answer = self.answer_template.format(target) + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_length": audio_length, + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + if self.input_type == "raw": + audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) + audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) + for s in samples]) + audio_mask = torch.zeros(len(samples), audio_raw_max_length) + for line, sample in enumerate(samples): + audio_mask[line, :sample['audio'].shape[0]] = 1 + elif self.input_type == "mel": + audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) + audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) + for s in samples]) + audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats + for line, sample in enumerate(samples): + audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask, + "keys": keys, + "targets": targets + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask + } + + + +def get_mir_dataset(dataset_config, tokenizer, split): + dataset = MirDatasetJsonl(dataset_config, tokenizer, split) + + return dataset diff --git a/src/slam_llm/models/encoder.py b/src/slam_llm/models/encoder.py index 0b2bbff3..d9fc76aa 100644 --- a/src/slam_llm/models/encoder.py +++ b/src/slam_llm/models/encoder.py @@ -120,4 +120,25 @@ class HfTextEncoder: def load(cls, model_config): from transformers import AutoModel model = AutoModel.from_pretrained(model_config.encoder_path) - return model \ No newline at end of file + return model + +class MusicFMEncoder(nn.Module): + def __init__(self, config, model): + super().__init__() + self.config = config + self.model = model + + @classmethod + def load(cls, model_config): + from .musicfm.model.musicfm_25hz import MusicFM25Hz + model = MusicFM25Hz( + stat_path = model_config.encoder_stat_path, + model_path = model_config.encoder_path, + w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft") + ) + return cls(model_config, model) + + def extract_features(self, source, padding_mask=None): + _, hidden_states = self.model.get_predictions(source) + out = hidden_states[self.config.encoder_layer_idx] + return out diff --git a/src/slam_llm/models/slam_model.py b/src/slam_llm/models/slam_model.py index 38d4b471..33a3106a 100644 --- a/src/slam_llm/models/slam_model.py +++ b/src/slam_llm/models/slam_model.py @@ -89,6 +89,9 @@ def setup_encoder(train_config, model_config, **kwargs): if encoder_name == "hubert": from slam_llm.models.encoder import HubertEncoder encoder = HubertEncoder.load(model_config) + if encoder_name == "musicfm": + from slam_llm.models.encoder import MusicFMEncoder + encoder = MusicFMEncoder.load(model_config) if "llama" in encoder_name.lower(): from slam_llm.models.encoder import HfTextEncoder @@ -322,6 +325,8 @@ def forward(self, encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] encoder_outs = encoder_outs.transpose(0, 1) audio_mel_post_mask = (~audio_mel_post_mask).float() + if self.model_config.encoder_name == 'musicfm': + encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask if self.encoder is None: encoder_outs = audio_mel if audio_mel is not None else audio