Skip to content

Commit

Permalink
BAT update: fix type, upload checkpoint; finish inference code; add d…
Browse files Browse the repository at this point in the history
…emo code for inference
  • Loading branch information
zszheng147 committed Oct 12, 2024
1 parent db55bea commit be67304
Show file tree
Hide file tree
Showing 16 changed files with 952 additions and 106 deletions.
52 changes: 39 additions & 13 deletions examples/seld_spatialsoundqa/README.md
Original file line number Diff line number Diff line change
@@ -1,39 +1,65 @@
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA

This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)].
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/X-LANCE/SLAM-LLM/tree/main/examples/seld_spatialsoundqa#citation)].

Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio.

## Performance and checkpoints
Encoder | Projector | PEFT | LLM
|---|---|---|---|
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)
## Performance evaluation on **SpatialSoundQA**
We use [Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) as audio encoder, [llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) as LLM backbone. We finetune the model by adding Q-Former and LoRA. To calculate MAP, you can refer to [calculate_map.py](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/scripts/calculate_map.py)
<img src="assets/performance.png" alt="xxx">


## Demo (Spatial Audio Inference)
Try [`inference.ipynb`](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/inference.ipynb).


## Data preparation
You need to prepare the data jsonl in this format. Below is an example.
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA).
```
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"}
You can download the SpatialSoundQA dataset from [SpatialAudio](https://huggingface.co/datasets/zhisheng01/SpatialAudio).
```json
{
"audio_id": "eval/audio/YI-HlrcP6Qg4",
"reverb_id": "q9vSo1VnCiC/0.npy",
"audio_id2": null,
"reverb_id2": null,
"question_id": 0,
"question_type": "CLASSIFICATION",
"question": "Enumerate the sound occurrences in the audio clip.",
"answer": "accelerating, revving, vroom; car; vehicle"
}

...
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"}

{
"audio_id": "eval/audio/YZX2fVPmUidA",
"reverb_id": "q9vSo1VnCiC/32.npy",
"audio_id2": "eval/audio/YjNjUU01quLs",
"reverb_id2": "q9vSo1VnCiC/31.npy",
"question_id": 58,
"question_type": "MIXUP_NONBINARY_DISTANCE",
"question": "How far away is the sound of the banjo from the sound of the whack, thwack?",
"answer": "2m"
}
```

## Train a new model
```bash
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
cd examples/seld_spatialsoundqa/
bash scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
```

## Decoding with checkpoints
```bash
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
cd examples/seld_spatialsoundqa/
bash scripts/decode_spatial-ast_qformer_llama_2_7b.sh
```


## TODO
- [x] Decode with checkpoints
- [x] Upload SpatialSoundQA dataset
- [ ] Upload pretrained checkpoints
- [ ] Update model performance
- [x] Upload pretrained checkpoints
- [x] Update model performance

## Citation
```
Expand Down
Binary file added examples/seld_spatialsoundqa/assets/74.npy
Binary file not shown.
Binary file added examples/seld_spatialsoundqa/assets/75.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def __init__(
split,
):
super().__init__()
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl')
with open(dataset_path) as f:
self.data = [json.loads(line) for line in f.readlines()]
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.json')
self.data = json.load(open(dataset_path))["data"]

self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case
self.reverb_data_root = dataset_config['reverb_data_root']
Expand Down
25 changes: 7 additions & 18 deletions examples/seld_spatialsoundqa/finetune_seld.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hydra
import logging
from typing import Optional
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf

Expand All @@ -16,32 +17,20 @@ class RunConfig:
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
ckpt_path: Optional[str] = field(
default=None, metadata={"help": "The path to projector checkpoint"}
)

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
cfg.train_config.peft_config = cfg.peft_config

log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)
train(cfg)


if __name__ == "__main__":
Expand Down
786 changes: 786 additions & 0 deletions examples/seld_spatialsoundqa/inference.ipynb

Large diffs are not rendered by default.

9 changes: 2 additions & 7 deletions examples/seld_spatialsoundqa/inference_seld_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,11 @@ class RunConfig:
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
cfg.train_config.peft_config = cfg.peft_config

log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


Expand Down
25 changes: 2 additions & 23 deletions examples/seld_spatialsoundqa/model/slam_model_seld.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,5 @@ def __init__(
tokenizer,
train_config,
model_config,
**kwargs,
)

@torch.no_grad()
def inference(
self,
wav_path=None,
reverb_path=None,
prompt=None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=None,
assistant_model=None,
streamer=None,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
**kwargs,
):
#!TODO:
# inference for SELD model
pass
**kwargs
)
73 changes: 73 additions & 0 deletions examples/seld_spatialsoundqa/scripts/calculate_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os

import numpy as np
from sklearn import metrics

from tqdm import tqdm
import openai

openai.api_key = "your-openai-api-key"

def cosine_similarity(A, B):
dot_product = np.dot(A, B)
norm_A = np.linalg.norm(A)
norm_B = np.linalg.norm(B)
return dot_product / (norm_A * norm_B)

def get_embedding(text, model="text-embedding-ada-002"):
text = text.replace("\n", " ")
return np.array(openai.Embedding.create(input = [text], model=model)['data'][0]['embedding'])

def calculate_stats(output, target):
classes_num = target.shape[-1]
stats = []

for k in range(classes_num):
avg_precision = metrics.average_precision_score(target[:, k], output[:, k], average=None)
dict = {'AP': avg_precision}
stats.append(dict)

return stats

labels_path = 'https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/AudioSet/metadata/class_labels_indices_subset.csv'
embeds_npy_path = 'https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/AudioSet/metadata/audioset_class_embeds.npy'

label2id = {}
with open(labels_path) as f:
for idx, line in enumerate(f.readlines()[1:]):
label = line.strip().split(',', 2)[-1]
label2id[label.lower()] = idx
# label2emb.append(get_embedding(label))

# label2emb = np.stack(label2emb)
# np.save(embeds_npy_path, label2emb)

total_labels_embeddings = np.load(embeds_npy_path)

one_hot_embeds = np.eye(355)

with open("decode_eval-stage2-classification_beam4_gt") as gt_f:
gt_lines = gt_f.readlines()
targets = []
for line in gt_lines:
target = np.array([one_hot_embeds[label2id[i]] for i in line.strip().split('\t', 1)[1].split("; ")]).sum(axis=0)
targets.append(target)
targets = np.stack(targets)


with open("decode_eval-stage2-classification_beam4_pred") as pred_f:
pred_lines = pred_f.readlines()
preds = []
for line in tqdm(pred_lines):
pred = line.strip().split('\t', 1)[1]
pred = get_embedding(pred)
pred = np.array([cosine_similarity(pred, embed) for embed in total_labels_embeddings])
preds.append(pred)

preds = np.stack(preds)

stats = calculate_stats(preds, targets)

AP = [stat['AP'] for stat in stats]
mAP = np.mean([stat['AP'] for stat in stats])
print("mAP: {:.6f}".format(mAP))
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=2

export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

SLAM_DIR=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM
SLAM_DIR=/path/to/SLAM-LLM
cd $SLAM_DIR
code_dir=examples/seld_spatialsoundqa

stage=stage1-clsdoa
qa_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end
reverb_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d
anechoic_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/AudioSet
audio_encoder_path=/data1/scratch/zhisheng/models/SpatialAST/SpatialAST.pth # https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth
llm_path=/home/zhisheng/models/llama-2-hf # https://huggingface.co/meta-llama/Llama-2-7b-hf

audio_encoder_path=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth
llm_path=/mnt/lustre/hpc_stor03/sjtu_pub/cxgroup/model/Llama-2-7b-hf
stage=stage2-single
qa_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/closed-end # https://huggingface.co/datasets/zhisheng01/SpatialAudio/tree/main/SpatialSoundQA/closed-end
reverb_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/mp3d_reverb # https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/mp3d_reverb.zip
anechoic_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/AudioSet # https://huggingface.co/datasets/zhisheng01/SpatialAudio/tree/main/SpatialSoundQA/AudioSet

split=eval
output_dir=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-stage1-clsdoa-20240519/
ckpt_path=$output_dir/bat_epoch_3_step_3288
split=eval-stage2-classification
output_dir=?? # be same as in finetune script
ckpt_path=$output_dir/bat_epoch_4_step_18223
decode_log=$ckpt_path/decode_${split}_beam4

# -m debugpy --listen 5678 --wait-for-client
Expand All @@ -30,29 +30,25 @@ python -u $code_dir/inference_seld_batch.py \
++model_config.llm_dim=4096 \
++model_config.encoder_name=SpatialAST \
++model_config.encoder_projector=q-former \
++model_config.qformer_layers=8 \
++model_config.encoder_ckpt=$audio_encoder_path \
++dataset_config.test_split=${split} \
++dataset_config.stage=$stage \
++dataset_config.qa_data_root=$qa_data_root \
++dataset_config.anechoic_data_root=$anechoic_data_root \
++dataset_config.reverb_data_root=$reverb_data_root \
++dataset_config.fix_length_audio=64 \
++dataset_config.inference_mode=true \
++train_config.model_name=bat \
++train_config.model_name=BAT \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=1 \
++train_config.val_batch_size=8 \
++train_config.num_workers_dataloader=2 \
++train_config.val_batch_size=1 \
++train_config.num_workers_dataloader=1 \
++train_config.output_dir=$output_dir \
++train_config.use_peft=true \
++peft_config.peft_method=llama_adapter \
++peft_config.peft_method=lora \
++log_config.log_file=$output_dir/test.log \
++decode_log=$decode_log \
++ckpt_path=$ckpt_path/model.pt \
# ++peft_ckpt=$ckpt_path \
# ++train_config.use_peft=true \
# ++train_config.peft_config.r=32 \
# ++dataset_config.normalize=true \
# ++model_config.encoder_projector=q-former \
# ++dataset_config.fix_length_audio=64 \
++ckpt_path=$ckpt_path/model.pt
Loading

0 comments on commit be67304

Please sign in to comment.