Public child-adult speaker diarization or classification model and code with simulated conversations. Can be used both for zero-shot and transfer-learning.
- Clone this repo and cd to whisper-modeling
git clone https://github.com/usc-sail/child-adult-diarization.git
cd child-adult-diarization/whisper-modeling
- Install dependencies (Python 3.10.9 was used originally and thus recommended for dependencies)
pip install -r requirements.txt
- Download whisper-base_rank8_pretrained_50k.pt from https://huggingface.co/AlexXu811/whisper-child-adult/tree/main. This model works well with younger children (around 6 or less years old) zero-shot .
- Example python code is as below. The model outputs one of {0: silence, 1: child, 2: adult, 3: overlap} at the frame-level (for each 20ms). Recommended to use 10s audio segments as inputs, as the pre-trained model is trained with 10s audio inputs.
from models.whisper import WhisperWrapper
import torch
model = WhisperWrapper()
# replace positional embedding for 10s input audio
model.backbone_model.encoder.embed_positions = model.backbone_model.encoder.embed_positions.from_pretrained(model.embed_positions[:500])
model.load_state_dict(torch.load("path/to/whisper-base_rank8_pretrained_50k.pt"))
model.cuda()
test_data = torch.zeros([1, 160000]).cuda()
output = model.forward_eval(test_data)
- An example code to map the frame-level outputs to child, adult, and overlap timestamps:
from scripts.convert_output import get_timestamps, majority_filter
output = majority_filter(output)
output = get_timestamps(output)
- Install dependencies (as shown in quick start).
- Prepare the train data. An example annotation file is shown in example_label.csv. The wav files should be 10s, but feel free to modify the codes to change this. The training data structures are as follows:
project-root/
│
├── audio_dir/
│ ├── train/
│ │ ├── train_file1.wav
│ │ ├── train_file2.wav
│ │ └── ...
│ ├── val/
│ │ ├── val_file1.wav
│ │ ├── val_file2.wav
│ │ └── ...
├── anotation_dir/
│ ├── train/
│ │ ├── train_file1.csv
│ │ ├── train_file2.csv
│ │ └── ...
│ ├── val/
│ │ ├── val_file1.csv
│ │ ├── val_file2.csv
│ │ └── ...
- Edit the config file (especially the paths).
- Run the following to start training
python scripts/main.py --debug f --config path/to/config_file
- Install dependencies
cd path/to/conversation_simulation
pip install -r requirements.txt
- Change the config_audioset.yaml and prepare AudioSet by running the three files (download -> reample to 16k -> extract speech segments). The json files contain extracted timestamps and child/adult speech probabilities using an internal pre-trained model.
python download_audioset.py
python audio_resample.py
python process_audioset.py
- Modify the config_simulated_conversation.yaml and run build_conversations.py
- The extraccted AudioSet child speech skew younger demographics. If you intend to use it for older children, I recommend training with older child speech (e.g., MyST).
@article{xu2024data,
title={Data Efficient Child-Adult Speaker Diarization with Simulated Conversations},
author={Anfeng Xu and Tiantian Feng and Helen Tager-Flusberg and Catherine Lord and Shrikanth Narayanan},
year={2024},
journal={arXiv preprint arXiv:2409.08881},
url={https://arxiv.org/abs/2409.08881},
}
@inproceedings{xu24c_interspeech,
title = {Exploring Speech Foundation Models for Speaker Diarization in Child-Adult Dyadic Interactions},
author = {Anfeng Xu and Kevin Huang and Tiantian Feng and Lue Shen and Helen Tager-Flusberg and Shrikanth Narayanan},
year = {2024},
booktitle = {Interspeech 2024},
pages = {5193--5197},
doi = {10.21437/Interspeech.2024-717},
issn = {2958-1796},
}
Please raise an issue for any questions and feel free to contact us by [email protected].