diff --git a/examples/seld_spatialsoundqa/README.md b/examples/seld_spatialsoundqa/README.md
index 512d7490..2cfcb63e 100644
--- a/examples/seld_spatialsoundqa/README.md
+++ b/examples/seld_spatialsoundqa/README.md
@@ -1,39 +1,70 @@
# SELD_SpatialSoundQA
-This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)].
+This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/X-LANCE/SLAM-LLM/tree/main/examples/seld_spatialsoundqa#citation)].
Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio.
-## Performance and checkpoints
-Encoder | Projector | PEFT | LLM
-|---|---|---|---|
-[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)
+## Performance evaluation on **SpatialSoundQA**
+We use [Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) as audio encoder, [llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) as LLM backbone. We finetune the model by adding Q-Former and LoRA. To calculate MAP, you can refer to [calculate_map.py](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/scripts/calculate_map.py)
+
+
+
+## Checkpoints
+Encoder | Projector | LLM |
+|---|---|---|
+[Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) | [Q-former](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/BAT/model.pt)(~73.56M) | [llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b) |
+
+## Demo (Spatial Audio Inference)
+Try [`inference.ipynb`](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/inference.ipynb).
+
## Data preparation
You need to prepare the data jsonl in this format. Below is an example.
-You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA).
-```
-{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"}
+You can download the SpatialSoundQA dataset from [SpatialAudio](https://huggingface.co/datasets/zhisheng01/SpatialAudio).
+```json
+{
+ "audio_id": "eval/audio/YI-HlrcP6Qg4",
+ "reverb_id": "q9vSo1VnCiC/0.npy",
+ "audio_id2": null,
+ "reverb_id2": null,
+ "question_id": 0,
+ "question_type": "CLASSIFICATION",
+ "question": "Enumerate the sound occurrences in the audio clip.",
+ "answer": "accelerating, revving, vroom; car; vehicle"
+}
+
...
-{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"}
+
+{
+ "audio_id": "eval/audio/YZX2fVPmUidA",
+ "reverb_id": "q9vSo1VnCiC/32.npy",
+ "audio_id2": "eval/audio/YjNjUU01quLs",
+ "reverb_id2": "q9vSo1VnCiC/31.npy",
+ "question_id": 58,
+ "question_type": "MIXUP_NONBINARY_DISTANCE",
+ "question": "How far away is the sound of the banjo from the sound of the whack, thwack?",
+ "answer": "2m"
+}
```
## Train a new model
```bash
-bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
+cd examples/seld_spatialsoundqa/
+bash scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
```
## Decoding with checkpoints
```bash
-bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
+cd examples/seld_spatialsoundqa/
+bash scripts/decode_spatial-ast_qformer_llama_2_7b.sh
```
## TODO
- [x] Decode with checkpoints
- [x] Upload SpatialSoundQA dataset
-- [ ] Upload pretrained checkpoints
-- [ ] Update model performance
+- [x] Upload pretrained checkpoints
+- [x] Update model performance
## Citation
```
diff --git a/examples/seld_spatialsoundqa/assets/74.npy b/examples/seld_spatialsoundqa/assets/74.npy
new file mode 100644
index 00000000..cc16faaa
Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/74.npy differ
diff --git a/examples/seld_spatialsoundqa/assets/75.npy b/examples/seld_spatialsoundqa/assets/75.npy
new file mode 100644
index 00000000..f1644629
Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/75.npy differ
diff --git a/examples/seld_spatialsoundqa/assets/YCqvbWnTBfTk.wav b/examples/seld_spatialsoundqa/assets/YCqvbWnTBfTk.wav
new file mode 100644
index 00000000..e84270b7
Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/YCqvbWnTBfTk.wav differ
diff --git a/examples/seld_spatialsoundqa/assets/Yq4Z8j3IalYs.wav b/examples/seld_spatialsoundqa/assets/Yq4Z8j3IalYs.wav
new file mode 100644
index 00000000..a5ec3f01
Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/Yq4Z8j3IalYs.wav differ
diff --git a/examples/seld_spatialsoundqa/assets/performance.png b/examples/seld_spatialsoundqa/assets/performance.png
new file mode 100644
index 00000000..0393d3b9
Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/performance.png differ
diff --git a/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py b/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
index f986250f..c3906258 100644
--- a/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
+++ b/examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
@@ -37,9 +37,8 @@ def __init__(
split,
):
super().__init__()
- dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl')
- with open(dataset_path) as f:
- self.data = [json.loads(line) for line in f.readlines()]
+ dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.json')
+ self.data = json.load(open(dataset_path))["data"]
self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case
self.reverb_data_root = dataset_config['reverb_data_root']
diff --git a/examples/seld_spatialsoundqa/finetune_seld.py b/examples/seld_spatialsoundqa/finetune_seld.py
index 3b7a959d..605f3582 100644
--- a/examples/seld_spatialsoundqa/finetune_seld.py
+++ b/examples/seld_spatialsoundqa/finetune_seld.py
@@ -1,5 +1,6 @@
import hydra
import logging
+from typing import Optional
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
@@ -16,32 +17,20 @@ class RunConfig:
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
- ckpt_path: str = field(
- default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
+ ckpt_path: Optional[str] = field(
+ default=None, metadata={"help": "The path to projector checkpoint"}
)
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
- def to_plain_list(cfg_item):
- if isinstance(cfg_item, ListConfig):
- return OmegaConf.to_container(cfg_item, resolve=True)
- elif isinstance(cfg_item, DictConfig):
- return {k: to_plain_list(v) for k, v in cfg_item.items()}
- else:
- return cfg_item
-
- # kwargs = to_plain_list(cfg)
- kwargs = cfg
- log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ cfg.train_config.peft_config = cfg.peft_config
+
+ log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)
-
- if kwargs.get("debug", False):
- import pdb;
- pdb.set_trace()
- train(kwargs)
+ train(cfg)
if __name__ == "__main__":
diff --git a/examples/seld_spatialsoundqa/inference.ipynb b/examples/seld_spatialsoundqa/inference.ipynb
new file mode 100644
index 00000000..03e29bbe
--- /dev/null
+++ b/examples/seld_spatialsoundqa/inference.ipynb
@@ -0,0 +1,786 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Let's dive into a spatial sound world."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "import numpy as np\n",
+ "import soundfile as sf\n",
+ "from scipy import signal\n",
+ "from IPython.display import Audio\n",
+ "\n",
+ "import torch\n",
+ "\n",
+ "from dataset.spatial_audio_dataset import format_prompt, SpatialAudioDatasetJsonl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Let's load and listen to anechoic audio...\n",
+ "Audio 1: Drum; Percussion\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Audio 2: Emergency vehicle; Fire engine, fire truck (siren); Siren\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "audio_path1 = \"./assets/YCqvbWnTBfTk.wav\"\n",
+ "audio_path2 = \"./assets/Yq4Z8j3IalYs.wav\"\n",
+ "\n",
+ "print(\"Let's load and listen to anechoic audio...\")\n",
+ "print(\"Audio 1: Drum; Percussion\")\n",
+ "display(Audio(audio_path1))\n",
+ "\n",
+ "print(\"Audio 2: Emergency vehicle; Fire engine, fire truck (siren); Siren\")\n",
+ "display(Audio(audio_path2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Let's load and listen to reverb audio 1 (w/o mixup)...\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Let's load and listen to reverb audio 2 (w/o mixup)...\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def reverb_waveform(audio_path, reverb_path):\n",
+ " waveform, sr = sf.read(audio_path)\n",
+ " if len(waveform.shape) > 1:\n",
+ " waveform = waveform[:, 0]\n",
+ " if sr != 32000: #! Please make sure the audio is 32000 Hz\n",
+ " waveform = signal.resample_poly(waveform, 32000, sr)\n",
+ " sr = 32000\n",
+ " waveform = SpatialAudioDatasetJsonl.normalize_audio(waveform, -14.0).reshape(1, -1)\n",
+ " reverb = np.load(reverb_path)\n",
+ " waveform = signal.fftconvolve(waveform, reverb, mode='full')\n",
+ " return waveform, sr\n",
+ "\n",
+ "reverb_path1 = \"./assets/74.npy\"\n",
+ "reverb_path2 = \"./assets/75.npy\"\n",
+ "\"\"\"\n",
+ "{\"fname\": \"q9vSo1VnCiC/74.npy\", \"agent_position\": \"-12.8775,0.0801,8.415\", \"sensor_position\": \"-12.8775,1.5801,8.415\", \"source_position\": \"-13.4677,1.2183,8.6525\",},\n",
+ "{\"fname\": \"q9vSo1VnCiC/75.npy\", \"agent_position\": \"-12.8775,0.0801,8.415\", \"sensor_position\": \"-12.8775,1.5801,8.415\", \"source_position\": \"-11.8976,1.0163,8.9789\",}\n",
+ "\"\"\"\n",
+ "\n",
+ "print(\"Let's load and listen to reverb audio 1 (w/o mixup)...\")\n",
+ "waveform1, _ = reverb_waveform(audio_path1, reverb_path1)\n",
+ "display(Audio(waveform1, rate=32000))\n",
+ "\n",
+ "print(\"Let's load and listen to reverb audio 2 (w/o mixup)...\")\n",
+ "waveform2, _ = reverb_waveform(audio_path2, reverb_path2)\n",
+ "display(Audio(waveform2, rate=32000))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Now let's mix them up!\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "print(\"Now let's mix them up!\")\n",
+ "if waveform1.shape[1] < waveform2.shape[1]:\n",
+ " waveform2 = waveform2[:, :waveform1.shape[1]]\n",
+ "else:\n",
+ " waveform1 = waveform1[:, :waveform2.shape[1]]\n",
+ "waveform = (waveform1 + waveform2) / 2\n",
+ "display(Audio(waveform, rate=32000))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[\"Based on the audio you've heard, refer to the instruction and provide a response.\\n\\n### Instruction:\\nWhat is the distance between the sound of the drum and the sound of the siren?\\n\\n### Response:\", \"Based on the audio you've heard, refer to the instruction and provide a response.\\n\\n### Instruction:\\nWhat is the sound on the right side of the sound of the drum?\\n\\n### Response:\", \"Based on the audio you've heard, refer to the instruction and provide a response.\\n\\n### Instruction:\\nAre you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?\\n\\n### Response:\"]\n"
+ ]
+ }
+ ],
+ "source": [
+ "prompts = [\n",
+ " \"What is the distance between the sound of the drum and the sound of the siren?\",\n",
+ " \"What is the sound on the right side of the sound of the drum?\",\n",
+ " \"Are you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?\",\n",
+ "]\n",
+ "\n",
+ "gt_answers = [\n",
+ " \"1.5m\",\n",
+ " \"emergency vehicle; fire engine, fire truck (siren); siren\",\n",
+ " \"Yes\"\n",
+ "]\n",
+ "\n",
+ "prompts = [format_prompt(prompt) for prompt in prompts]\n",
+ "print(prompts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/zhisheng/miniconda3/envs/speech/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]/home/zhisheng/miniconda3/envs/speech/lib/python3.8/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
+ " return self.fget.__get__(instance, owner)()\n",
+ "Loading checkpoint shards: 50%|█████ | 1/2 [00:07<00:07, 7.04s/it]"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00, 4.79s/it]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "slam_model_seld(\n",
+ " (encoder): BinauralEncoder(\n",
+ " (patch_embed): PatchEmbed_new(\n",
+ " (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))\n",
+ " )\n",
+ " (pos_drop): Dropout(p=0.0, inplace=False)\n",
+ " (blocks): ModuleList(\n",
+ " (0): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU(approximate='none')\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1-11): 11 x Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): DropPath()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU(approximate='none')\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (head): Linear(in_features=768, out_features=355, bias=True)\n",
+ " (spectrogram_extractor): STFT(\n",
+ " (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)\n",
+ " (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)\n",
+ " )\n",
+ " (logmel_extractor): LogmelFilterBank()\n",
+ " (conv_downsample): Sequential(\n",
+ " (0): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (2): GELU(approximate='none')\n",
+ " )\n",
+ " (bn): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
+ " )\n",
+ " (llm): PeftModelForCausalLM(\n",
+ " (base_model): LoraModel(\n",
+ " (model): LlamaForCausalLM(\n",
+ " (model): LlamaModel(\n",
+ " (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x LlamaDecoderLayer(\n",
+ " (self_attn): LlamaAttention(\n",
+ " (q_proj): lora.Linear(\n",
+ " (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.05, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " (lora_magnitude_vector): ModuleDict()\n",
+ " )\n",
+ " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (v_proj): lora.Linear(\n",
+ " (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.05, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=4096, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=4096, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " (lora_magnitude_vector): ModuleDict()\n",
+ " )\n",
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
+ " )\n",
+ " (mlp): LlamaMLP(\n",
+ " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
+ " (act_fn): SiLU()\n",
+ " )\n",
+ " (input_layernorm): LlamaRMSNorm()\n",
+ " (post_attention_layernorm): LlamaRMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): LlamaRMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (encoder_projector): EncoderProjectorQFormer(\n",
+ " (qformer): Blip2QFormerModel(\n",
+ " (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (encoder): Blip2QFormerEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (crossattention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (crossattention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (crossattention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (crossattention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): Blip2QFormerLayer(\n",
+ " (attention): Blip2QFormerAttention(\n",
+ " (attention): Blip2QFormerMultiHeadAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): Blip2QFormerSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate_query): Blip2QFormerIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output_query): Blip2QFormerOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (linear): Linear(in_features=768, out_features=4096, bias=True)\n",
+ " (norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from omegaconf import OmegaConf\n",
+ "from seld_config import TrainConfig, ModelConfig\n",
+ "from model.slam_model_seld import model_factory\n",
+ "\n",
+ "train_config = TrainConfig(\n",
+ " model_name=\"BAT\",\n",
+ " batching_strategy=\"custom\",\n",
+ " num_epochs=1,\n",
+ " num_workers_dataloader=2,\n",
+ " use_peft=True,\n",
+ " freeze_encoder=True,\n",
+ " freeze_llm=True\n",
+ ")\n",
+ "train_config = OmegaConf.merge(train_config)\n",
+ "\n",
+ "model_config = ModelConfig(\n",
+ " llm_name=\"llama-2-7b\",\n",
+ " llm_path=\"https://huggingface.co/meta-llama/Llama-2-7b-hf\", # \n",
+ " encoder_name=\"SpatialAST\",\n",
+ " encoder_ckpt=\"https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth\", # \n",
+ ")\n",
+ "\n",
+ "kwargs = {\n",
+ " \"decode_log\": None,\n",
+ " \"ckpt_path\": \"https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/BAT/model.pt\", # Download it from huggingface\n",
+ "}\n",
+ "model, tokenizer = model_factory(train_config, model_config, **kwargs)\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # FIX(MZY): put the whole model to device.\n",
+ "model.to(device)\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "audio_length = 64 # We use 64 learnable tokens as audio representation\n",
+ "audio_pseudo = torch.full((audio_length,), -1)\n",
+ "waveform = torch.from_numpy(waveform).float()\n",
+ "waveform = SpatialAudioDatasetJsonl.padding(waveform, padding_length=10*32000-waveform.shape[1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Question: Based on the audio you've heard, refer to the instruction and provide a response.\n",
+ "\n",
+ "### Instruction:\n",
+ "What is the distance between the sound of the drum and the sound of the siren?\n",
+ "\n",
+ "### Response:\n",
+ "Pred: 1.5m\n",
+ "Ground Truth: 1.5m\n",
+ "-------------------------\n",
+ "Question: Based on the audio you've heard, refer to the instruction and provide a response.\n",
+ "\n",
+ "### Instruction:\n",
+ "What is the sound on the right side of the sound of the drum?\n",
+ "\n",
+ "### Response:\n",
+ "Pred: fire engine, fire truck (siren); emergency vehicle; siren; police car (siren)\n",
+ "Ground Truth: emergency vehicle; fire engine, fire truck (siren); siren\n",
+ "-------------------------\n",
+ "Question: Based on the audio you've heard, refer to the instruction and provide a response.\n",
+ "\n",
+ "### Instruction:\n",
+ "Are you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?\n",
+ "\n",
+ "### Response:\n",
+ "Pred: Yes\n",
+ "Ground Truth: Yes\n",
+ "-------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "for prompt, answer in zip(prompts, gt_answers):\n",
+ " input_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.int64)\n",
+ " input_ids = torch.cat((audio_pseudo, input_ids)) # [audio, prompt]\n",
+ " input_ids = input_ids.unsqueeze(0) # [batch, seq]\n",
+ " attention_mask = input_ids.ge(-1)\n",
+ " modality_mask = input_ids.eq(-1)\n",
+ "\n",
+ " model_outputs = model.generate(\n",
+ " input_ids=input_ids.to(device),\n",
+ " attention_mask=attention_mask.to(device),\n",
+ " modality_mask=modality_mask.to(device),\n",
+ " audio=waveform.unsqueeze(0).to(device),\n",
+ " )\n",
+ " output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)\n",
+ " output_text = output_text[0]\n",
+ " print(f\"Question: {prompt}\")\n",
+ " print(f\"Pred: {output_text}\")\n",
+ " print(f\"Ground Truth: {answer}\")\n",
+ " print(\"-------------------------\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "speech",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/seld_spatialsoundqa/inference_seld_batch.py b/examples/seld_spatialsoundqa/inference_seld_batch.py
index 4834d182..1d75ed64 100644
--- a/examples/seld_spatialsoundqa/inference_seld_batch.py
+++ b/examples/seld_spatialsoundqa/inference_seld_batch.py
@@ -36,16 +36,11 @@ class RunConfig:
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
- # kwargs = to_plain_list(cfg)
- log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
+ cfg.train_config.peft_config = cfg.peft_config
+ log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)
- if cfg.get("debug", False):
- import pdb
-
- pdb.set_trace()
-
inference(cfg)
diff --git a/examples/seld_spatialsoundqa/model/slam_model_seld.py b/examples/seld_spatialsoundqa/model/slam_model_seld.py
index 32da226f..9c396c9f 100644
--- a/examples/seld_spatialsoundqa/model/slam_model_seld.py
+++ b/examples/seld_spatialsoundqa/model/slam_model_seld.py
@@ -74,26 +74,5 @@ def __init__(
tokenizer,
train_config,
model_config,
- **kwargs,
- )
-
- @torch.no_grad()
- def inference(
- self,
- wav_path=None,
- reverb_path=None,
- prompt=None,
- generation_config=None,
- logits_processor=None,
- stopping_criteria=None,
- prefix_allowed_tokens_fn=None,
- synced_gpus=None,
- assistant_model=None,
- streamer=None,
- negative_prompt_ids=None,
- negative_prompt_attention_mask=None,
- **kwargs,
- ):
- #!TODO:
- # inference for SELD model
- pass
\ No newline at end of file
+ **kwargs
+ )
\ No newline at end of file
diff --git a/examples/seld_spatialsoundqa/scripts/calculate_map.py b/examples/seld_spatialsoundqa/scripts/calculate_map.py
new file mode 100644
index 00000000..71945d11
--- /dev/null
+++ b/examples/seld_spatialsoundqa/scripts/calculate_map.py
@@ -0,0 +1,73 @@
+import os
+
+import numpy as np
+from sklearn import metrics
+
+from tqdm import tqdm
+import openai
+
+openai.api_key = "your-openai-api-key"
+
+def cosine_similarity(A, B):
+ dot_product = np.dot(A, B)
+ norm_A = np.linalg.norm(A)
+ norm_B = np.linalg.norm(B)
+ return dot_product / (norm_A * norm_B)
+
+def get_embedding(text, model="text-embedding-ada-002"):
+ text = text.replace("\n", " ")
+ return np.array(openai.Embedding.create(input = [text], model=model)['data'][0]['embedding'])
+
+def calculate_stats(output, target):
+ classes_num = target.shape[-1]
+ stats = []
+
+ for k in range(classes_num):
+ avg_precision = metrics.average_precision_score(target[:, k], output[:, k], average=None)
+ dict = {'AP': avg_precision}
+ stats.append(dict)
+
+ return stats
+
+labels_path = 'https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/AudioSet/metadata/class_labels_indices_subset.csv'
+embeds_npy_path = 'https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/AudioSet/metadata/audioset_class_embeds.npy'
+
+label2id = {}
+with open(labels_path) as f:
+ for idx, line in enumerate(f.readlines()[1:]):
+ label = line.strip().split(',', 2)[-1]
+ label2id[label.lower()] = idx
+# label2emb.append(get_embedding(label))
+
+# label2emb = np.stack(label2emb)
+# np.save(embeds_npy_path, label2emb)
+
+total_labels_embeddings = np.load(embeds_npy_path)
+
+one_hot_embeds = np.eye(355)
+
+with open("decode_eval-stage2-classification_beam4_gt") as gt_f:
+ gt_lines = gt_f.readlines()
+ targets = []
+ for line in gt_lines:
+ target = np.array([one_hot_embeds[label2id[i]] for i in line.strip().split('\t', 1)[1].split("; ")]).sum(axis=0)
+ targets.append(target)
+ targets = np.stack(targets)
+
+
+with open("decode_eval-stage2-classification_beam4_pred") as pred_f:
+ pred_lines = pred_f.readlines()
+ preds = []
+ for line in tqdm(pred_lines):
+ pred = line.strip().split('\t', 1)[1]
+ pred = get_embedding(pred)
+ pred = np.array([cosine_similarity(pred, embed) for embed in total_labels_embeddings])
+ preds.append(pred)
+
+ preds = np.stack(preds)
+
+stats = calculate_stats(preds, targets)
+
+AP = [stat['AP'] for stat in stats]
+mAP = np.mean([stat['AP'] for stat in stats])
+print("mAP: {:.6f}".format(mAP))
diff --git a/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh b/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
index 54d71c34..c764499e 100644
--- a/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
+++ b/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
@@ -1,24 +1,24 @@
#!/bin/bash
-#export PYTHONPATH=/root/whisper:$PYTHONPATH
-export CUDA_VISIBLE_DEVICES=2
+
+export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1
-SLAM_DIR=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM
+SLAM_DIR=/path/to/SLAM-LLM
cd $SLAM_DIR
code_dir=examples/seld_spatialsoundqa
-stage=stage1-clsdoa
-qa_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end
-reverb_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d
-anechoic_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/AudioSet
+audio_encoder_path=/data1/scratch/zhisheng/models/SpatialAST/SpatialAST.pth # https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth
+llm_path=/home/zhisheng/models/llama-2-hf # https://huggingface.co/meta-llama/Llama-2-7b-hf
-audio_encoder_path=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth
-llm_path=/mnt/lustre/hpc_stor03/sjtu_pub/cxgroup/model/Llama-2-7b-hf
+stage=stage2-single
+qa_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/closed-end # https://huggingface.co/datasets/zhisheng01/SpatialAudio/tree/main/SpatialSoundQA/closed-end
+reverb_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/mp3d_reverb # https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/mp3d_reverb.zip
+anechoic_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/AudioSet # https://huggingface.co/datasets/zhisheng01/SpatialAudio/tree/main/SpatialSoundQA/AudioSet
-split=eval
-output_dir=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-stage1-clsdoa-20240519/
-ckpt_path=$output_dir/bat_epoch_3_step_3288
+split=eval-stage2-classification
+output_dir=?? # be same as in finetune script
+ckpt_path=$output_dir/bat_epoch_4_step_18223
decode_log=$ckpt_path/decode_${split}_beam4
# -m debugpy --listen 5678 --wait-for-client
@@ -30,29 +30,25 @@ python -u $code_dir/inference_seld_batch.py \
++model_config.llm_dim=4096 \
++model_config.encoder_name=SpatialAST \
++model_config.encoder_projector=q-former \
+ ++model_config.qformer_layers=8 \
++model_config.encoder_ckpt=$audio_encoder_path \
+ ++dataset_config.test_split=${split} \
++dataset_config.stage=$stage \
++dataset_config.qa_data_root=$qa_data_root \
++dataset_config.anechoic_data_root=$anechoic_data_root \
++dataset_config.reverb_data_root=$reverb_data_root \
++dataset_config.fix_length_audio=64 \
++dataset_config.inference_mode=true \
- ++train_config.model_name=bat \
+ ++train_config.model_name=BAT \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=1 \
- ++train_config.val_batch_size=8 \
- ++train_config.num_workers_dataloader=2 \
+ ++train_config.val_batch_size=1 \
+ ++train_config.num_workers_dataloader=1 \
++train_config.output_dir=$output_dir \
++train_config.use_peft=true \
- ++peft_config.peft_method=llama_adapter \
+ ++peft_config.peft_method=lora \
++log_config.log_file=$output_dir/test.log \
++decode_log=$decode_log \
- ++ckpt_path=$ckpt_path/model.pt \
- # ++peft_ckpt=$ckpt_path \
- # ++train_config.use_peft=true \
- # ++train_config.peft_config.r=32 \
- # ++dataset_config.normalize=true \
- # ++model_config.encoder_projector=q-former \
- # ++dataset_config.fix_length_audio=64 \
+ ++ckpt_path=$ckpt_path/model.pt
\ No newline at end of file
diff --git a/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh b/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
index 66e08945..5b1b0e69 100644
--- a/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
+++ b/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
@@ -10,20 +10,20 @@ export OMP_NUM_THREADS=1
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO
-SLAM_DIR=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM
+SLAM_DIR=/path/to/SLAM-LLM
cd $SLAM_DIR
code_dir=examples/seld_spatialsoundqa
-audio_encoder_path=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth
-llm_path=/mnt/lustre/hpc_stor03/sjtu_pub/cxgroup/model/Llama-2-7b-hf
+audio_encoder_path=/data1/scratch/zhisheng/models/SpatialAST/SpatialAST.pth # https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth
+llm_path=/home/zhisheng/models/llama-2-hf # https://huggingface.co/meta-llama/Llama-2-7b-hf
-stage=stage2-single
-qa_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end
-reverb_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d
-anechoic_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/AudioSet
+stage=stage3-mixup
+qa_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/closed-end # https://huggingface.co/datasets/zhisheng01/SpatialAudio/tree/main/SpatialSoundQA/closed-end
+reverb_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/mp3d_reverb # https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialSoundQA/mp3d_reverb.zip
+anechoic_data_root=/data3/scratch/zhisheng/SpatialAudio/SpatialSoundQA/AudioSet # https://huggingface.co/datasets/zhisheng01/SpatialAudio/tree/main/SpatialSoundQA/AudioSet
-ckpt_path=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-stage1-clsdoa-20240519/bat_epoch_3_step_3288
-output_dir=${SLAM_DIR}/outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-${stage}-$(date +"%Y%m%d")
+split=eval-stage3-distance-direction
+output_dir=./outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-${stage}
hydra_args="
hydra.run.dir=$output_dir \
@@ -32,30 +32,32 @@ hydra.run.dir=$output_dir \
++model_config.llm_dim=4096 \
++model_config.encoder_name=SpatialAST \
++model_config.encoder_projector=q-former \
+++model_config.qformer_layers=8 \
++model_config.encoder_ckpt=$audio_encoder_path \
+++dataset_config.test_split=${split} \
++dataset_config.stage=$stage \
++dataset_config.qa_data_root=$qa_data_root \
++dataset_config.anechoic_data_root=$anechoic_data_root \
++dataset_config.reverb_data_root=$reverb_data_root \
+++dataset_config.max_words=96 \
++dataset_config.fix_length_audio=64 \
++train_config.model_name=bat \
-++train_config.num_epochs=3 \
+++train_config.num_epochs=5 \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
-++train_config.warmup_steps=10000 \
-++train_config.total_steps=100000 \
+++train_config.warmup_steps=20000 \
+++train_config.total_steps=200000 \
++train_config.lr=1e-4 \
++train_config.validation_interval=2000 \
-++train_config.batch_size_training=16 \
-++train_config.val_batch_size=16 \
+++train_config.batch_size_training=8 \
+++train_config.val_batch_size=8 \
++train_config.num_workers_dataloader=4 \
++train_config.output_dir=$output_dir \
++train_config.use_peft=true \
-++peft_config.peft_method=llama_adapter \
+++peft_config.peft_method=lora \
++metric=acc \
++log_config.log_file=$output_dir/log.txt \
-++ckpt_path=$ckpt_path/model.pt \
"
# -m debugpy --listen 5678 --wait-for-client
@@ -66,8 +68,8 @@ if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
else
torchrun \
--nnodes 1 \
- --nproc_per_node 4 \
- --master_port=29503 \
+ --nproc_per_node 8 \
+ --master_port=39503 \
$code_dir/finetune_seld.py \
--config-path "conf" \
++train_config.enable_fsdp=false \
diff --git a/examples/seld_spatialsoundqa/seld_config.py b/examples/seld_spatialsoundqa/seld_config.py
index d6e7e739..439e5819 100644
--- a/examples/seld_spatialsoundqa/seld_config.py
+++ b/examples/seld_spatialsoundqa/seld_config.py
@@ -13,6 +13,7 @@ class ModelConfig:
encoder_ckpt: Optional[str] = None
encoder_projector: str = "q-former"
encoder_dim: int = 768
+ qformer_layers: int = 8
@dataclass
class PeftConfig:
diff --git a/src/slam_llm/models/projector.py b/src/slam_llm/models/projector.py
index c9863a55..b6d6037c 100644
--- a/src/slam_llm/models/projector.py
+++ b/src/slam_llm/models/projector.py
@@ -56,7 +56,7 @@ def __init__(self, config):
from transformers import Blip2QFormerConfig, Blip2QFormerModel
configuration = Blip2QFormerConfig()
configuration.encoder_hidden_size = self.encoder_dim
- configuration.num_hidden_layers = 8
+ configuration.num_hidden_layers = config.qformer_layers
self.query_len = int(config.get("query_len", 64))
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))