Skip to content

Commit

Permalink
Merge pull request #147 from X-LANCE/cwx_slam_aac
Browse files Browse the repository at this point in the history
SLAM-AAC open-source
  • Loading branch information
ddlBoJack authored Oct 12, 2024
2 parents d599ce4 + b8dbc12 commit 2ab898e
Show file tree
Hide file tree
Showing 18 changed files with 1,479 additions and 2 deletions.
88 changes: 88 additions & 0 deletions examples/slam_aac/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SLAM-AAC

SLAM-AAC is a LLM-based model for Automated Audio Captioning (AAC) task. Inspired by techniques in machine translation and ASR, the model enhances audio captioning by incorporating paraphrasing augmentation and a plug-and-play CLAP-Refine strategy.
<!-- For more details, please refer to the [paper](). -->

## Model Architecture
SLAM-AAC uses EAT as the audio encoder and Vicuna-7B as the LLM decoder. During training, only the Linear Projector and LoRA modules are trainable. For inference, multiple candidates are generated using different beam sizes, which are then refined using the CLAP-Refine strategy.

![](./docs/model.png)

## Performance and checkpoints
We have released the pre-trained checkpoint of SLAM-AAC, as well as the fine-tuned checkpoints for the Clotho and AudioCaps datasets. The provided checkpoints include the model's Linear Projector and LoRA modules. Please note that when using each component, be sure to set up the corresponding environments according to the instructions provided in the respective repositories (e.g., for [EAT](https://github.com/cwx-worst-one/EAT)).

### Pre-training
SLAM-AAC was pre-trained on a combination of AudioCaps, Clotho, WavCaps, and MACS datasets. For more information on these datasets, you can refer to [this repository](https://github.com/Labbeti/aac-datasets). Additionally, the Clotho dataset was augmented using a back-translation-based paraphrasing technique.
Audio Encoder | LLM | Checkpoint | Pre-training Dataset|
|:---:|:---:|:---:|:---:|
[EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) |[vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) | AudioCaps, Clotho, WavCaps, MACS |

### Fine-tuning
We fine-tuned the pre-trained model on the Clotho and AudioCaps datasets, respectively. The final evaluation was conducted using audio captions generated with the CLAP-Refine decoding strategy.
Dataset | Audio Encoder | LLM | Checkpoint | METEOR | CIDEr | SPICE | SPIDEr | SPIDEr-FL | FENSE
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| Clotho | [EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/1QX7CM9YAddPi02_NRChI5mzsNmBBtA63?usp=sharing) | 19.7 | 51.5 | 14.8 |33.2 | 33.0 | 54.0 |
| AudioCaps | [EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/1GhFPiSVmBE9BvBhYWCEqkFuH-avKl-4g?usp=sharing) | 26.8 | 84.1 | 19.4 | 51.8 | 51.5 | 66.8 |


## Data preparation
Ensure your `jsonl` data follows the structure outlined below:
```json
{"key": "Y7fmOlUlwoNg_1", "source": "/root/data/AudioCaps/waveforms/test/Y7fmOlUlwoNg.wav", "target": "Constant rattling noise and sharp vibrations"}
{"key": "Y6BJ455B1aAs_1", "source": "/root/data/AudioCaps/waveforms/test/Y6BJ455B1aAs.wav", "target": "A rocket flies by followed by a loud explosion and fire crackling as a truck engine runs idle"}
```
In addition, you can refer to the [manifest](https://drive.google.com/drive/folders/1NJinoWg3yXKSPm-pRrhqKLvCD9dtDuDG?usp=sharing) file we've provided, which includes the Clotho dataset enhanced with **paraphrasing augmentation** as bonus.

## Model Training
To pre-train the SLAM-AAC model with pre-training data, you can run the following command:
```bash
# Pre-train the model
bash scripts/pretrain.sh
```

You can fine-tune the model on the AudioCaps or Clotho datasets using the [provided checkpoint](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) or your own pre-trained model by running the following commands:

```bash
# Fine-tune on AudioCaps
bash scripts/finetune_audiocaps.sh

# Fine-tune on Clotho
bash scripts/finetune_clotho.sh
```

You can also fine-tune the model without loading any pre-trained weights, though this may result in reduced performance.


### Note
- In the current version of SLAM-LLM, the `peft_ckpt` parameter is no longer required. However, if you are using the checkpoint provided by us, which was trained with an earlier version, please keep the `peft_ckpt` parameter in your configuration to ensure compatibility.
- Due to differences in dependency versions, there may be slight variations in the performance of the SLAM-AAC model.

## Inference
To perform inference with the trained models, you can use the following commands to decode using the common beam search method:
```bash
# Inference on AudioCaps (Beam Search)
bash scripts/inference_audiocaps_bs.sh

# Inference on Clotho (Beam Search)
bash scripts/inference_clotho_bs.sh
```

For improved inference results, you can use the CLAP-Refine strategy, which utilizes multiple beam search decoding. To use this method, you need to download and use our pre-trained [CLAP](https://drive.google.com/drive/folders/1X4NYE08N-kbOy6s_Itb0wBR_3X8oZF56?usp=sharing) model. Note that CLAP-Refine may take longer to run, but it can provide better quality outputs. You can execute the following commands:
```bash
# Inference on AudioCaps (CLAP-Refine)
bash scripts/inference_audiocaps_CLAP_Refine.sh

# Inference on Clotho (CLAP-Refine)
bash scripts/inference_clotho_CLAP_Refine.sh
```

If you already have the generated candidates and want to directly refine them using the CLAP-Refine strategy, you can run the following command:
```bash
bash scripts/clap_refine.sh
```

<!-- ## Citation
You can refer to the paper for more results.
```
``` -->
143 changes: 143 additions & 0 deletions examples/slam_aac/aac_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ModelConfig:
file: str = "examples/slam_aac/model/slam_model_aac.py:model_factory"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 4096
encoder_name: Optional[str] = None
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_dim: int = 1280
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
encoder_fairseq_dir: str = "/fairseq/EAT"
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether inpit is normalized, used for models such as wavlm"
})
do_sample: bool = False
top_p: float = 1.0
temperature: float = 1.0
num_beams: int = 4
num_return_sequences: int = 1
length_penalty: float = 1.0
repetition_penalty: float = 1.0
max_new_tokens: int = 200
min_length: int = 1

@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training:int = 4
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
}) #
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.0
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:int = 1

use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = False
specaug:bool = False
noise_aug:bool = False

@dataclass
class DataConfig:
dataset: str = "audio_dataset"
file: str = "src/slam_llm/datasets/audio_dataset.py:get_audio_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split:str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode:bool = False
model_name: str = 'eat'
fbank_mean: float = -4.268
fbank_std: float = 4.569
target_length: int = 1024
fixed_length: bool = False
prompt: str = "Describe the audio you hear."
random_crop: bool = False
encoder_projector_ds_rate: int = 5
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether inpit is normalized, used for models such as wavlm"
})

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/root/test_wandb"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/root/test.log"
log_interval: int = 5
19 changes: 19 additions & 0 deletions examples/slam_aac/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
3 changes: 3 additions & 0 deletions examples/slam_aac/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
prompt: "Describe the audio you hear."
Binary file added examples/slam_aac/docs/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 52 additions & 0 deletions examples/slam_aac/finetune_aac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from slam_llm.pipeline.finetune import main as train

import hydra
import logging
from typing import Optional
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from aac_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
ckpt_path: Optional[str] = field(
default=None, metadata={"help": "The path to projector checkpoint"}
)
peft_ckpt: Optional[str] = field(
default=None, metadata={"help": "The path to peft checkpoint"}
)

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
53 changes: 53 additions & 0 deletions examples/slam_aac/inference_aac_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from slam_llm.pipeline.inference_batch import main as inference

import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Optional
from aac_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig


@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
decode_log: str = field(
default="output/decode_log",
metadata={"help": "The prefix for the decode output"},
)
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
)
peft_ckpt: Optional[str] = field(
default=None,
metadata={
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
},
)


@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


if __name__ == "__main__":
main_hydra()
Loading

0 comments on commit 2ab898e

Please sign in to comment.