-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
521 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: Security Scan | ||
|
||
on: | ||
pull_request: | ||
push: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
security: | ||
name: OSS Security SAST | ||
uses: Roblox/security-workflows/.github/workflows/oss-security-sast.yaml@main | ||
with: | ||
skip-ossf: true | ||
secrets: | ||
GITLEAKS_LICENSE: ${{ secrets.GITLEAKS_KEY }} | ||
ROBLOX_SEMGREP_GHC_POC_APP_TOKEN: ${{ secrets.ROBLOX_SEMGREP_GHC_POC_APP_TOKEN }} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## Model Description | ||
The model is finetuned on the [WavLM base plus](https://arxiv.org/abs/2110.13900) with 2,374 hours of audio clips from | ||
voice chat for multilabel classification. | ||
The audio clips are automatically labeled using a synthetic data pipeline described in [our blog post](link to blog post here). | ||
A single output can have multiple labels. | ||
The model outputs a n by 6 output tensor where the inferred labels are `Profanity`, `DatingAndSexting`, `Racist`, | ||
`Bullying`, `Other`, `NoViolation`. `Other` consists of policy violation categories with low prevalence such as drugs | ||
and alcohol or self-harm that are combined into a single category. | ||
|
||
We evaluated this model on a dataset with human annotated labels that contained a total of 9795 samples with the class | ||
distribution shown below. Note that we did not include the "other" category in this evaluation dataset. | ||
|
||
|Class|Number of examples| Duration (hours)|% of dataset| | ||
|---|---|---|---| | ||
|Profanity | 4893| 15.38 | 49.95%| | ||
|DatingAndSexting | 688 | 2.52 | 7.02% | | ||
|Racist | 889 | 3.10 | 9.08% | | ||
|Bullying | 1256 | 4.25 | 12.82% | | ||
|NoViolation | 4185 | 9.93 | 42.73% | | ||
|
||
|
||
If we set the same threshold across all classes and treat the model as a binary classifier across all 4 toxicity classes (`Profanity`, `DatingAndSexting`, `Racist`, `Bullying`), we get a binarized average precision of 94.48%. The precision recall curve is as shown below. | ||
<p align="center"> | ||
<img src="images/human_eval_pr_curve.png" alt="PR Curve" width="500"/> | ||
</p> | ||
|
||
## Usage | ||
The dependencies for the inference file can be installed as follows: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
The inference file contains useful helper functions to preprocess the audio file for proper inference. | ||
To run the inference file, please run the following command: | ||
``` | ||
python inference.py --audio_file <your audio file path> --model_path <path to Huggingface model> | ||
``` | ||
You can get the model weights either by downloading from the model releases page [here](https://github.com/Roblox/voice-safety-classifier/releases/tag/vs-classifier-v1), or from HuggingFace under `roblox/voice-safety-classifier`. | ||
If `model_path` isn’t specified, the model will be loaded directly from HuggingFace. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright © 2024 Roblox Corporation | ||
|
||
""" | ||
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model. | ||
""" | ||
|
||
import torch | ||
import librosa | ||
import numpy as np | ||
import argparse | ||
from transformers import WavLMForSequenceClassification | ||
|
||
|
||
def feature_extract_simple( | ||
wav, | ||
sr=16_000, | ||
win_len=15.0, | ||
win_stride=15.0, | ||
do_normalize=False, | ||
): | ||
"""simple feature extraction for wavLM | ||
Parameters | ||
---------- | ||
wav : str or array-like | ||
path to the wav file, or array-like | ||
sr : int, optional | ||
sample rate, by default 16_000 | ||
win_len : float, optional | ||
window length, by default 15.0 | ||
win_stride : float, optional | ||
window stride, by default 15.0 | ||
do_normalize: bool, optional | ||
whether to normalize the input, by default False. | ||
Returns | ||
------- | ||
np.ndarray | ||
batched input to wavLM | ||
""" | ||
if type(wav) == str: | ||
signal, _ = librosa.core.load(wav, sr=sr) | ||
else: | ||
try: | ||
signal = np.array(wav).squeeze() | ||
except Exception as e: | ||
print(e) | ||
raise RuntimeError | ||
batched_input = [] | ||
stride = int(win_stride * sr) | ||
l = int(win_len * sr) | ||
if len(signal) / sr > win_len: | ||
for i in range(0, len(signal), stride): | ||
if i + int(win_len * sr) > len(signal): | ||
# padding the last chunk to make it the same length as others | ||
chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) | ||
else: | ||
chunked = signal[i : i + l] | ||
if do_normalize: | ||
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) | ||
batched_input.append(chunked) | ||
if i + int(win_len * sr) > len(signal): | ||
break | ||
else: | ||
if do_normalize: | ||
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) | ||
batched_input.append(signal) | ||
return np.stack(batched_input) # [N, T] | ||
|
||
|
||
def infer(model, inputs): | ||
output = model(inputs) | ||
probs = torch.sigmoid(torch.Tensor(output.logits)) | ||
return probs | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--audio_file", | ||
type=str, | ||
help="File to run inference", | ||
) | ||
parser.add_argument( | ||
"--model_path", | ||
type=str, | ||
default="roblox/voice-safety-classifier", | ||
help="checkpoint file of model", | ||
) | ||
args = parser.parse_args() | ||
labels_name_list = [ | ||
"Profanity", | ||
"DatingAndSexting", | ||
"Racist", | ||
"Bullying", | ||
"Other", | ||
"NoViolation", | ||
] | ||
# Model is trained on only 16kHz audio | ||
audio, _ = librosa.core.load(args.audio_file, sr=16000) | ||
input_np = feature_extract_simple(audio, sr=16000) | ||
input_pt = torch.Tensor(input_np) | ||
model = WavLMForSequenceClassification.from_pretrained( | ||
args.model_path, num_labels=len(labels_name_list) | ||
) | ||
probs = infer(model, input_pt) | ||
probs = probs.reshape(-1, 6).detach().tolist() | ||
print(f"Probabilities for {args.audio_file} is:") | ||
for chunk_idx in range(len(probs)): | ||
print(f"\nSegment {chunk_idx}:") | ||
for label_idx, label in enumerate(labels_name_list): | ||
print(f"{label} : {probs[chunk_idx][label_idx]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torch | ||
transformers | ||
librosa | ||
numpy |