-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #72 from ddlBoJack/main
sync
- Loading branch information
Showing
13 changed files
with
982 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Music Caption | ||
|
||
## Performance and checkpoints | ||
Here is a recipe for music captioning, using MusicFM as encoder. We only train the linear projector. For more about MusicFM and its checkpoints, please refer to [this repository](https://github.com/minzwon/musicfm). | ||
|
||
The following results are obtained by training on the LP-MusicCaps-MC training set and evaluating on the LP-MusicCaps-MC test set. | ||
Encoder | Projector | LLM | BLEU-1 | METEOR | SPICE | SPIDER | ||
|---|---|---|---|---|---|--- | ||
[MusicFM(pretrained with MSD)](https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt) | [Linear](https://drive.google.com/file/d/1-9pob6QvJRoq5Dy-LZbiDfF6Q7QRO8Au/view?usp=sharing)(~18.88M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 25.6 | 10.0 | 8.7 | 6.9 | ||
|
||
|
||
## Data preparation | ||
You need to prepare the data jsonl in this format. Note that you may need to pre-extract the sample rate and duration of audio files for better loading efficiency. | ||
``` | ||
{"key": "[-0Gj8-vB1q4]-[30-40]", "source": "path/to/MusicCaps/wav/[-0Gj8-vB1q4]-[30-40].wav", "target": "The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.", "duration": 10.0, "sample_rate": 48000} | ||
... | ||
{"key": "[-0vPFx-wRRI]-[30-40]", "source": "path/to/MusicCaps/wav/[-0vPFx-wRRI]-[30-40].wav", "target": "a male voice is singing a melody with changing tempos while snipping his fingers rhythmically. The recording sounds like it has been recorded in an empty room. This song may be playing, practicing snipping and singing along.", "duration": 10.0, "sample_rate": 48000} | ||
``` | ||
|
||
## Decode with checkpoints | ||
``` | ||
bash decode_musicfm_linear_vicuna_7b_10s.sh | ||
``` | ||
Modify the path including `music_encoder_path`, `music_encoder_stat_path`, `music_encoder_config_path`(if specified), `ckpt_path`, `val_data_path` and `decode_log` in the script when you run the shell script. | ||
|
||
## Train a new model | ||
|
||
### Use MusicFM as encoder for music modality. | ||
``` | ||
finetune_musicfm_linear_vicuna_7b_10s.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
prompt: "Describe this music." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from slam_llm.pipeline.finetune_deepspeed import main as train | ||
from slam_llm.utils.deepspeed_utils import deepspeed_main_wrapper | ||
|
||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig | ||
|
||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
deepspeed_config: str = field(default="examples/asr_librispeech/conf/ds_config.json", metadata={"help": "The metric for evaluation"}) | ||
|
||
|
||
@deepspeed_main_wrapper(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from slam_llm.pipeline.finetune import main as train | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from mir_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from slam_llm.pipeline.inference_batch import main as inference | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from typing import Optional | ||
from mir_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
decode_log: str = field( | ||
default="output/decode_log", | ||
metadata={"help": "The prefix for the decode output"}, | ||
) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
peft_ckpt: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The path to peft checkpoint, should be a directory including adapter_config.json" | ||
}, | ||
) | ||
|
||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
# kwargs = to_plain_list(cfg) | ||
log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if cfg.get("debug", False): | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
inference(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, List | ||
@dataclass | ||
class ModelConfig: | ||
file: str = "examples/music_caption/model/slam_model_mir.py:model_factory" | ||
llm_name: str = "vicuna-13b-v1.5" | ||
llm_path: str = "PATH/to/LLAMA/7B" | ||
llm_type: str = "decoder_only" | ||
llm_dim: int = 4096 | ||
encoder_name: Optional[str] = "mulan" | ||
encoder_ds_rate: int = 2 | ||
encoder_path: Optional[str] = None | ||
encoder_config_path: Optional[str] = None | ||
encoder_stat_path: Optional[str] = None | ||
encoder_layer_idx: Optional[int] = None | ||
encoder_dim: int = 768 | ||
encoder_projector: str = "linear" | ||
encoder_projector_ds_rate: int = 5 | ||
modal: str = "audio" | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether inpit is normalized, used for models such as wavlm" | ||
}) | ||
|
||
@dataclass | ||
class PeftConfig: | ||
peft_method: str = "lora" # None , llama_adapter, prefix | ||
r: int = 8 | ||
lora_alpha: int = 32 | ||
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) | ||
bias: str = "none" | ||
task_type: str = "CAUSAL_LM" | ||
lora_dropout: float = 0.05 | ||
inference_mode: bool = False | ||
|
||
@dataclass | ||
class TrainConfig: | ||
model_name:str = "PATH/to/LLAMA/7B" | ||
enable_ddp:bool = False | ||
enable_deepspeed:bool = False | ||
enable_fsdp:bool = False | ||
low_cpu_fsdp:bool = False | ||
run_validation:bool = True | ||
batch_size_training:int = 4 | ||
batching_strategy:str = field(default="packing", metadata={ | ||
"help":"alternative: padding" | ||
}) # | ||
context_length:int = 4096 | ||
gradient_accumulation_steps:int = 1 | ||
num_epochs:int = 3 | ||
num_workers_dataloader:int = 1 | ||
warmup_steps:int = 1000 | ||
total_steps:int = 100000 | ||
validation_interval:int = 1000 | ||
lr:float = 1e-4 | ||
weight_decay:float = 0.0 | ||
gamma:float = 0.85 | ||
seed:int = 42 | ||
use_fp16:bool = False | ||
mixed_precision:bool = True | ||
val_batch_size:int = 1 | ||
|
||
use_peft:bool = False | ||
peft_config:PeftConfig = field(default_factory=PeftConfig) | ||
output_dir:str = "PATH/to/save/PEFT/model" | ||
freeze_layers:bool = False | ||
num_freeze_layers:int = 1 | ||
quantization:bool = False | ||
one_gpu:bool = False | ||
save_model:bool = True | ||
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP | ||
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP | ||
save_optimizer:bool = False # will be used if using FSDP | ||
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
run_test_during_validation:bool = False | ||
run_test_during_validation_file:str = "test.wav" | ||
run_test_during_validation_prompt:str = "<|ASR|>" | ||
freeze_llm:bool = field(default=False, metadata={ | ||
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning" | ||
}) | ||
freeze_encoder:bool = True # False | ||
|
||
@dataclass | ||
class DataConfig: | ||
dataset: str = "mir_dataset" | ||
file: str = "src/slam_llm/datasets/mir_dataset.py:get_mir_dataset" | ||
train_data_path: Optional[str] = None | ||
val_data_path: Optional[str] = None | ||
fixed_duration: float = 10.0 | ||
audio_label_freq: int = 75 | ||
fixed_audio_token_num: Optional[int] = None | ||
sample_rate: int = 24000 | ||
train_split: str = "train" | ||
test_split:str = "validation" | ||
prompt: Optional[str] = None | ||
data_path: Optional[str] = None | ||
max_words: Optional[int] = None | ||
max_mel: Optional[float] = None | ||
fix_length_audio: int = -1 | ||
inference_mode:bool = False | ||
input_type: str = field(default="raw", metadata={ | ||
"help":"Use raw when input is wav, mel when for whisper" | ||
}) | ||
mel_size: int = field(default=80, metadata={ | ||
"help": "80 for whisper large v1 and v2, 128 for v3" | ||
}) | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether inpit is normalized, used for models such as wavlm" | ||
}) | ||
|
||
@dataclass | ||
class FSDPConfig: | ||
mixed_precision: bool = True | ||
use_fp16: bool = False | ||
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD | ||
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP | ||
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | ||
fsdp_activation_checkpointing: bool = True | ||
fsdp_cpu_offload: bool = False | ||
pure_bf16: bool = False | ||
optimizer: str = "AdamW" | ||
|
||
@dataclass | ||
class LogConfig: | ||
use_wandb: bool = False | ||
wandb_dir: str = "/root/test_wandb" | ||
wandb_entity_name: str = "project_name" | ||
wandb_project_name: str = "project_name" | ||
wandb_exp_name: str = "exp_name" | ||
log_file: str = "/root/test.log" | ||
log_interval: int = 5 |
Oops, something went wrong.