Skip to content

Commit

Permalink
Commit relevant files
Browse files Browse the repository at this point in the history
  • Loading branch information
kelpabc123 committed Jun 24, 2024
2 parents b0cdaa1 + d3a8c4b commit 9cf623d
Show file tree
Hide file tree
Showing 6 changed files with 521 additions and 159 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/security-scan.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Security Scan

on:
pull_request:
push:
branches:
- main

jobs:
security:
name: OSS Security SAST
uses: Roblox/security-workflows/.github/workflows/oss-security-sast.yaml@main
with:
skip-ossf: true
secrets:
GITLEAKS_LICENSE: ${{ secrets.GITLEAKS_KEY }}
ROBLOX_SEMGREP_GHC_POC_APP_TOKEN: ${{ secrets.ROBLOX_SEMGREP_GHC_POC_APP_TOKEN }}
511 changes: 352 additions & 159 deletions LICENSE.md

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## Model Description
The model is finetuned on the [WavLM base plus](https://arxiv.org/abs/2110.13900) with 2,374 hours of audio clips from
voice chat for multilabel classification.
The audio clips are automatically labeled using a synthetic data pipeline described in [our blog post](link to blog post here).
A single output can have multiple labels.
The model outputs a n by 6 output tensor where the inferred labels are `Profanity`, `DatingAndSexting`, `Racist`,
`Bullying`, `Other`, `NoViolation`. `Other` consists of policy violation categories with low prevalence such as drugs
and alcohol or self-harm that are combined into a single category.

We evaluated this model on a dataset with human annotated labels that contained a total of 9795 samples with the class
distribution shown below. Note that we did not include the "other" category in this evaluation dataset.

|Class|Number of examples| Duration (hours)|% of dataset|
|---|---|---|---|
|Profanity | 4893| 15.38 | 49.95%|
|DatingAndSexting | 688 | 2.52 | 7.02% |
|Racist | 889 | 3.10 | 9.08% |
|Bullying | 1256 | 4.25 | 12.82% |
|NoViolation | 4185 | 9.93 | 42.73% |


If we set the same threshold across all classes and treat the model as a binary classifier across all 4 toxicity classes (`Profanity`, `DatingAndSexting`, `Racist`, `Bullying`), we get a binarized average precision of 94.48%. The precision recall curve is as shown below.
<p align="center">
<img src="images/human_eval_pr_curve.png" alt="PR Curve" width="500"/>
</p>

## Usage
The dependencies for the inference file can be installed as follows:
```
pip install -r requirements.txt
```
The inference file contains useful helper functions to preprocess the audio file for proper inference.
To run the inference file, please run the following command:
```
python inference.py --audio_file <your audio file path> --model_path <path to Huggingface model>
```
You can get the model weights either by downloading from the model releases page [here](https://github.com/Roblox/voice-safety-classifier/releases/tag/vs-classifier-v1), or from HuggingFace under `roblox/voice-safety-classifier`.
If `model_path` isn’t specified, the model will be loaded directly from HuggingFace.
Binary file added images/human_eval_pr_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
110 changes: 110 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright © 2024 Roblox Corporation

"""
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model.
"""

import torch
import librosa
import numpy as np
import argparse
from transformers import WavLMForSequenceClassification


def feature_extract_simple(
wav,
sr=16_000,
win_len=15.0,
win_stride=15.0,
do_normalize=False,
):
"""simple feature extraction for wavLM
Parameters
----------
wav : str or array-like
path to the wav file, or array-like
sr : int, optional
sample rate, by default 16_000
win_len : float, optional
window length, by default 15.0
win_stride : float, optional
window stride, by default 15.0
do_normalize: bool, optional
whether to normalize the input, by default False.
Returns
-------
np.ndarray
batched input to wavLM
"""
if type(wav) == str:
signal, _ = librosa.core.load(wav, sr=sr)
else:
try:
signal = np.array(wav).squeeze()
except Exception as e:
print(e)
raise RuntimeError
batched_input = []
stride = int(win_stride * sr)
l = int(win_len * sr)
if len(signal) / sr > win_len:
for i in range(0, len(signal), stride):
if i + int(win_len * sr) > len(signal):
# padding the last chunk to make it the same length as others
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
else:
chunked = signal[i : i + l]
if do_normalize:
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
batched_input.append(chunked)
if i + int(win_len * sr) > len(signal):
break
else:
if do_normalize:
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
batched_input.append(signal)
return np.stack(batched_input) # [N, T]


def infer(model, inputs):
output = model(inputs)
probs = torch.sigmoid(torch.Tensor(output.logits))
return probs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_file",
type=str,
help="File to run inference",
)
parser.add_argument(
"--model_path",
type=str,
default="roblox/voice-safety-classifier",
help="checkpoint file of model",
)
args = parser.parse_args()
labels_name_list = [
"Profanity",
"DatingAndSexting",
"Racist",
"Bullying",
"Other",
"NoViolation",
]
# Model is trained on only 16kHz audio
audio, _ = librosa.core.load(args.audio_file, sr=16000)
input_np = feature_extract_simple(audio, sr=16000)
input_pt = torch.Tensor(input_np)
model = WavLMForSequenceClassification.from_pretrained(
args.model_path, num_labels=len(labels_name_list)
)
probs = infer(model, input_pt)
probs = probs.reshape(-1, 6).detach().tolist()
print(f"Probabilities for {args.audio_file} is:")
for chunk_idx in range(len(probs)):
print(f"\nSegment {chunk_idx}:")
for label_idx, label in enumerate(labels_name_list):
print(f"{label} : {probs[chunk_idx][label_idx]}")
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
transformers
librosa
numpy

0 comments on commit 9cf623d

Please sign in to comment.