diff --git a/examples/seld_spatialsoundqa/README.md b/examples/seld_spatialsoundqa/README.md index 512d7490..2cfcb63e 100644 --- a/examples/seld_spatialsoundqa/README.md +++ b/examples/seld_spatialsoundqa/README.md @@ -1,39 +1,70 @@ # SELD_SpatialSoundQA SELD_SpatialSoundQA -This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)]. +This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/X-LANCE/SLAM-LLM/tree/main/examples/seld_spatialsoundqa#citation)]. Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio. -## Performance and checkpoints -Encoder | Projector | PEFT | LLM -|---|---|---|---| -[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) +## Performance evaluation on **SpatialSoundQA** +We use [Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) as audio encoder, [llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) as LLM backbone. We finetune the model by adding Q-Former and LoRA. To calculate MAP, you can refer to [calculate_map.py](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/scripts/calculate_map.py) +xxx + + +## Checkpoints +Encoder | Projector | LLM | +|---|---|---| +[Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) | [Q-former](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/BAT/model.pt)(~73.56M) | [llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b) | + +## Demo (Spatial Audio Inference) +Try [`inference.ipynb`](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/inference.ipynb). + ## Data preparation You need to prepare the data jsonl in this format. Below is an example. -You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA). -``` -{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"} +You can download the SpatialSoundQA dataset from [SpatialAudio](https://huggingface.co/datasets/zhisheng01/SpatialAudio). +```json +{ + "audio_id": "eval/audio/YI-HlrcP6Qg4", + "reverb_id": "q9vSo1VnCiC/0.npy", + "audio_id2": null, + "reverb_id2": null, + "question_id": 0, + "question_type": "CLASSIFICATION", + "question": "Enumerate the sound occurrences in the audio clip.", + "answer": "accelerating, revving, vroom; car; vehicle" +} + ... -{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"} + +{ + "audio_id": "eval/audio/YZX2fVPmUidA", + "reverb_id": "q9vSo1VnCiC/32.npy", + "audio_id2": "eval/audio/YjNjUU01quLs", + "reverb_id2": "q9vSo1VnCiC/31.npy", + "question_id": 58, + "question_type": "MIXUP_NONBINARY_DISTANCE", + "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", + "answer": "2m" +} ``` ## Train a new model ```bash -bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh +cd examples/seld_spatialsoundqa/ +bash scripts/finetune_spatial-ast_qformer_llama_2_7b.sh ``` ## Decoding with checkpoints ```bash -bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh +cd examples/seld_spatialsoundqa/ +bash scripts/decode_spatial-ast_qformer_llama_2_7b.sh ``` ## TODO - [x] Decode with checkpoints - [x] Upload SpatialSoundQA dataset -- [ ] Upload pretrained checkpoints -- [ ] Update model performance +- [x] Upload pretrained checkpoints +- [x] Update model performance ## Citation ``` diff --git a/examples/seld_spatialsoundqa/assets/74.npy b/examples/seld_spatialsoundqa/assets/74.npy new file mode 100644 index 00000000..cc16faaa Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/74.npy differ diff --git a/examples/seld_spatialsoundqa/assets/75.npy b/examples/seld_spatialsoundqa/assets/75.npy new file mode 100644 index 00000000..f1644629 Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/75.npy differ diff --git a/examples/seld_spatialsoundqa/assets/YCqvbWnTBfTk.wav b/examples/seld_spatialsoundqa/assets/YCqvbWnTBfTk.wav new file mode 100644 index 00000000..e84270b7 Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/YCqvbWnTBfTk.wav differ diff --git a/examples/seld_spatialsoundqa/assets/Yq4Z8j3IalYs.wav b/examples/seld_spatialsoundqa/assets/Yq4Z8j3IalYs.wav new file mode 100644 index 00000000..a5ec3f01 Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/Yq4Z8j3IalYs.wav differ diff --git a/examples/seld_spatialsoundqa/assets/performance.png b/examples/seld_spatialsoundqa/assets/performance.png new file mode 100644 index 00000000..0393d3b9 Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/performance.png differ diff --git a/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py b/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py index f986250f..c3906258 100644 --- a/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py +++ b/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py @@ -37,9 +37,8 @@ def __init__( split, ): super().__init__() - dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl') - with open(dataset_path) as f: - self.data = [json.loads(line) for line in f.readlines()] + dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.json') + self.data = json.load(open(dataset_path))["data"] self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case self.reverb_data_root = dataset_config['reverb_data_root'] diff --git a/examples/seld_spatialsoundqa/finetune_seld.py b/examples/seld_spatialsoundqa/finetune_seld.py index 3b7a959d..605f3582 100644 --- a/examples/seld_spatialsoundqa/finetune_seld.py +++ b/examples/seld_spatialsoundqa/finetune_seld.py @@ -1,5 +1,6 @@ import hydra import logging +from typing import Optional from dataclasses import dataclass, field from omegaconf import DictConfig, ListConfig, OmegaConf @@ -16,32 +17,20 @@ class RunConfig: peft_config: PeftConfig = field(default_factory=PeftConfig) debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) - ckpt_path: str = field( - default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} ) @hydra.main(config_name=None, version_base=None) def main_hydra(cfg: DictConfig): run_config = RunConfig() cfg = OmegaConf.merge(run_config, cfg) - def to_plain_list(cfg_item): - if isinstance(cfg_item, ListConfig): - return OmegaConf.to_container(cfg_item, resolve=True) - elif isinstance(cfg_item, DictConfig): - return {k: to_plain_list(v) for k, v in cfg_item.items()} - else: - return cfg_item - - # kwargs = to_plain_list(cfg) - kwargs = cfg - log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + cfg.train_config.peft_config = cfg.peft_config + + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) logging.basicConfig(level=log_level) - - if kwargs.get("debug", False): - import pdb; - pdb.set_trace() - train(kwargs) + train(cfg) if __name__ == "__main__": diff --git a/examples/seld_spatialsoundqa/inference.ipynb b/examples/seld_spatialsoundqa/inference.ipynb new file mode 100644 index 00000000..03e29bbe --- /dev/null +++ b/examples/seld_spatialsoundqa/inference.ipynb @@ -0,0 +1,786 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Let's dive into a spatial sound world." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import soundfile as sf\n", + "from scipy import signal\n", + "from IPython.display import Audio\n", + "\n", + "import torch\n", + "\n", + "from dataset.spatial_audio_dataset import format_prompt, SpatialAudioDatasetJsonl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Let's load and listen to anechoic audio...\n", + "Audio 1: Drum; Percussion\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Audio 2: Emergency vehicle; Fire engine, fire truck (siren); Siren\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "audio_path1 = \"./assets/YCqvbWnTBfTk.wav\"\n", + "audio_path2 = \"./assets/Yq4Z8j3IalYs.wav\"\n", + "\n", + "print(\"Let's load and listen to anechoic audio...\")\n", + "print(\"Audio 1: Drum; Percussion\")\n", + "display(Audio(audio_path1))\n", + "\n", + "print(\"Audio 2: Emergency vehicle; Fire engine, fire truck (siren); Siren\")\n", + "display(Audio(audio_path2))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Let's load and listen to reverb audio 1 (w/o mixup)...\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Let's load and listen to reverb audio 2 (w/o mixup)...\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def reverb_waveform(audio_path, reverb_path):\n", + " waveform, sr = sf.read(audio_path)\n", + " if len(waveform.shape) > 1:\n", + " waveform = waveform[:, 0]\n", + " if sr != 32000: #! Please make sure the audio is 32000 Hz\n", + " waveform = signal.resample_poly(waveform, 32000, sr)\n", + " sr = 32000\n", + " waveform = SpatialAudioDatasetJsonl.normalize_audio(waveform, -14.0).reshape(1, -1)\n", + " reverb = np.load(reverb_path)\n", + " waveform = signal.fftconvolve(waveform, reverb, mode='full')\n", + " return waveform, sr\n", + "\n", + "reverb_path1 = \"./assets/74.npy\"\n", + "reverb_path2 = \"./assets/75.npy\"\n", + "\"\"\"\n", + "{\"fname\": \"q9vSo1VnCiC/74.npy\", \"agent_position\": \"-12.8775,0.0801,8.415\", \"sensor_position\": \"-12.8775,1.5801,8.415\", \"source_position\": \"-13.4677,1.2183,8.6525\",},\n", + "{\"fname\": \"q9vSo1VnCiC/75.npy\", \"agent_position\": \"-12.8775,0.0801,8.415\", \"sensor_position\": \"-12.8775,1.5801,8.415\", \"source_position\": \"-11.8976,1.0163,8.9789\",}\n", + "\"\"\"\n", + "\n", + "print(\"Let's load and listen to reverb audio 1 (w/o mixup)...\")\n", + "waveform1, _ = reverb_waveform(audio_path1, reverb_path1)\n", + "display(Audio(waveform1, rate=32000))\n", + "\n", + "print(\"Let's load and listen to reverb audio 2 (w/o mixup)...\")\n", + "waveform2, _ = reverb_waveform(audio_path2, reverb_path2)\n", + "display(Audio(waveform2, rate=32000))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Now let's mix them up!\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(\"Now let's mix them up!\")\n", + "if waveform1.shape[1] < waveform2.shape[1]:\n", + " waveform2 = waveform2[:, :waveform1.shape[1]]\n", + "else:\n", + " waveform1 = waveform1[:, :waveform2.shape[1]]\n", + "waveform = (waveform1 + waveform2) / 2\n", + "display(Audio(waveform, rate=32000))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\"Based on the audio you've heard, refer to the instruction and provide a response.\\n\\n### Instruction:\\nWhat is the distance between the sound of the drum and the sound of the siren?\\n\\n### Response:\", \"Based on the audio you've heard, refer to the instruction and provide a response.\\n\\n### Instruction:\\nWhat is the sound on the right side of the sound of the drum?\\n\\n### Response:\", \"Based on the audio you've heard, refer to the instruction and provide a response.\\n\\n### Instruction:\\nAre you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?\\n\\n### Response:\"]\n" + ] + } + ], + "source": [ + "prompts = [\n", + " \"What is the distance between the sound of the drum and the sound of the siren?\",\n", + " \"What is the sound on the right side of the sound of the drum?\",\n", + " \"Are you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?\",\n", + "]\n", + "\n", + "gt_answers = [\n", + " \"1.5m\",\n", + " \"emergency vehicle; fire engine, fire truck (siren); siren\",\n", + " \"Yes\"\n", + "]\n", + "\n", + "prompts = [format_prompt(prompt) for prompt in prompts]\n", + "print(prompts)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zhisheng/miniconda3/envs/speech/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Loading checkpoint shards: 0%| | 0/2 [00:00