Skip to content

Commit

Permalink
Merge pull request #156 from xrysamuel/main
Browse files Browse the repository at this point in the history
Add recipe "sec_emotioncaps"
  • Loading branch information
ddlBoJack authored Nov 5, 2024
2 parents 6fb784b + c16e3a9 commit dbfcfca
Show file tree
Hide file tree
Showing 12 changed files with 608 additions and 0 deletions.
42 changes: 42 additions & 0 deletions examples/sec_emotioncaps/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Speech Emotion Caption

## Model Architecture

This recipe generates high-quality, human-like speech emotion descriptions. The model is based on the **q-former projector** and the **vicuna-7b-v1.5 LLM**. The model is trained on **an unpublished datasets** dataset, which is a large-scale dataset for speech emotion captioning.

![](docs/model.png)

## Performance and checkpoints

We only train the q-former projector in this recipe.

Encoder | Projector | LLM | Similarity Score
---|---|---|---
[emotion2vec_base](https://huggingface.co/emotion2vec/emotion2vec_base) | [Q-Former](to_do)| [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 71.10

> **Note**: The baseline model [SECap](https://github.com/thuhcsi/SECap) was tested in our environment and achieved a similarity score of 71.52. Our model's score is slightly lower.
## Data preparation
You need to prepare the data jsonl in this format.

```
{"key": "key_name", "source": "path_to_wav_file", "target": "corresponding_caption"}
...
```


## Decode with checkpoints

```
bash decode_emotion2vec_qformer_vicuna_7b.sh
```

Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path`, `val_data_path` and `decode_log` in the script when you run the shell script.

## Train a new model

If you do have sufficient relevant data, you can train the model yourself.

```
bash finetune_emotion2vec_qformer_vicuna_7b.sh
```
19 changes: 19 additions & 0 deletions examples/sec_emotioncaps/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
3 changes: 3 additions & 0 deletions examples/sec_emotioncaps/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
prompt: "请用中文用一句话描述上面给出的音频中说话人的情感。"
Binary file added examples/sec_emotioncaps/docs/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 49 additions & 0 deletions examples/sec_emotioncaps/finetune_sec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from slam_llm.pipeline.finetune import main as train

import hydra
import logging
from typing import Optional
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from sec_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
ckpt_path: Optional[str] = field(
default=None, metadata={"help": "The path to projector checkpoint"}
)

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
53 changes: 53 additions & 0 deletions examples/sec_emotioncaps/inference_sec_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from slam_llm.pipeline.inference_batch import main as inference

import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Optional
from sec_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig


@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
decode_log: str = field(
default="output/decode_log",
metadata={"help": "The prefix for the decode output"},
)
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
)
peft_ckpt: Optional[str] = field(
default=None,
metadata={
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
},
)


@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


if __name__ == "__main__":
main_hydra()
155 changes: 155 additions & 0 deletions examples/sec_emotioncaps/model/slam_model_sec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import os
import logging
from slam_llm.models.slam_model import (
slam_model,
setup_tokenizer,
setup_encoder,
setup_encoder_projector,
setup_llm,
)
from slam_llm.utils.train_utils import print_model_size

logger = logging.getLogger(__name__)

def model_factory(train_config, model_config, **kwargs):
# return necessary components for training
tokenizer = setup_tokenizer(train_config, model_config, **kwargs)

encoder = setup_encoder(train_config, model_config, **kwargs)

# llm
llm = setup_llm(train_config, model_config, **kwargs)

# projector
encoder_projector = setup_encoder_projector(
train_config, model_config, **kwargs
)
model = slam_model_sec(
encoder,
llm,
encoder_projector,
tokenizer,
train_config,
model_config,
**kwargs,
)

ckpt_path = kwargs.get(
"ckpt_path", None
) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
if ckpt_path is not None:
logger.info("loading other parts from: {}".format(ckpt_path))
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt_dict, strict=False)

print_model_size(
model,
train_config,
(
int(os.environ["RANK"])
if train_config.enable_fsdp or train_config.enable_ddp
else 0
),
)
return model, tokenizer


class slam_model_sec(slam_model):
def __init__(
self,
encoder,
llm,
encoder_projector,
tokenizer,
train_config,
model_config,
**kwargs,
):
super().__init__(
encoder,
llm,
encoder_projector,
tokenizer,
train_config,
model_config,
**kwargs,
)


@torch.no_grad()
def inference(
self,
wav_path=None,
prompt=None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=None,
assistant_model=None,
streamer=None,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
**kwargs,
):
# inference for asr model

device = kwargs.get("device", "cuda")
if os.path.exists(wav_path): # Audio-Text QA
import whisper

audio_raw = whisper.load_audio(wav_path)
audio_raw = whisper.pad_or_trim(audio_raw)

mel_size = getattr(
self.dataset_config, "mel_size", 80
) # 80 for large v1 and v2, 128 for large v3
audio_mel = (
whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size)
.permute(1, 0)[None, :, :]
.to(device)
)

encoder_outs = self.encoder.extract_variable_length_features(
audio_mel.permute(0, 2, 1)
)

if self.model_config.encoder_projector == "q-former":
audio_mel_post_mask = torch.ones(
encoder_outs.size()[:-1], dtype=torch.long
).to(encoder_outs.device)
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
if self.model_config.encoder_projector == "linear":
encoder_outs = self.encoder_projector(encoder_outs)
else: # Text QA
encoder_outs = torch.empty(
1, 0, self.llm.model.embed_tokens.embedding_dim
).to(device)

prompt = "USER: {}\n ASSISTANT:".format(prompt)
prompt_ids = self.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device)

if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)

inputs_embeds = torch.cat(
(encoder_outs, inputs_embeds[None, :, :]), dim=1
) # [audio,prompt]

attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
inputs_embeds.device
)

# generate
model_outputs = self.generate(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs
)

return model_outputs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
# export PYTHONPATH=/root/fairseq:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=1
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO

run_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM
cd $run_dir
code_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM/examples/sec_emotioncaps

speech_encoder_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/emotion2vec_base.pt
llm_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/vicuna-7b-v1.5
val_data_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/data/valid.jsonl

encoder_fairseq_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/deps/emotion2vec/upstream

output_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-decode-$(date +"%Y%m%d-%s")

ckpt_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-finetune-20241001-1727786623/sec_epoch_1_step_3000/model.pt

decode_log=$output_dir/decode_log

hydra_args="
hydra.run.dir=$output_dir \
++model_config.llm_name=vicuna-7b-v1.5 \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=emotion2vec \
++model_config.encoder_projector_ds_rate=5 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++model_config.encoder_dim=768 \
++model_config.encoder_projector=q-former \
++dataset_config.dataset=speech_dataset \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.data_path=$val_data_path \
++dataset_config.inference_mode=true \
++dataset_config.input_type=raw \
++train_config.model_name=sec \
++train_config.num_epochs=1 \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=2 \
++train_config.output_dir=$output_dir \
++log_config.log_file=$output_dir/train.log \
++ckpt_path=$ckpt_path \
++decode_log=$decode_log
"

# -m debugpy --listen 5678 --wait-for-client
python $code_dir/inference_sec_batch.py \
--config-path "conf" \
--config-name "prompt.yaml" \
$hydra_args
Loading

0 comments on commit dbfcfca

Please sign in to comment.