Skip to content

Commit

Permalink
support whisper v3 (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 9, 2023
1 parent 2037944 commit 2624da8
Show file tree
Hide file tree
Showing 14 changed files with 3,834 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ project(kaldifeat)
# remember to change the version in
# scripts/conda/kaldifeat/meta.yaml
# scripts/conda-cpu/kaldifeat/meta.yaml
set(kaldifeat_VERSION "1.25.2")
set(kaldifeat_VERSION "1.25.3")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ See <a href="https://github.com/csukuangfj/kaldifeat/pull/82">#82</a>
</td>
</tr>

<tr>
<td>Fbank for <a href="https://github.com/openai/whisper">Whisper v3</a></td>
<td><code>kaldifeat.WhisperFbankOptions</code></td>
<td><code>kaldifeat.WhisperFbank</code></td>
<td>
<pre lang="python">
opts = kaldifeat.WhisperFbankOptions()
opts.num_mels = 128
opts.device = torch.device('cuda', 0)
fbank = kaldifeat.WhisperFbank(opts)
features = fbank(wave)
</pre>
</td>
</tr>

<tr>
<td>FBANK</td>
<td><code>kaldifeat.FbankOptions</code></td>
Expand Down
2 changes: 1 addition & 1 deletion kaldifeat/csrc/CPPLINT.cfg
Original file line number Diff line number Diff line change
@@ -1 +1 @@
exclude_files=whisper-mel-bank.h
exclude_files=whisper-mel-bank.h,whisper-v3-mel-bank.h
39 changes: 39 additions & 0 deletions kaldifeat/csrc/generate-whisper-melbank-v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3

# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)

import librosa
import numpy as np


def main():
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)
assert m.shape == (128, 201)
s = "// Auto-generated. Do NOT edit!\n\n"
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
s += "\n"
s += "#ifndef KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
s += "#define KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"
s += "namespace kaldifeat {\n\n"
s += f"constexpr int32_t kWhisperV3MelRows = {m.shape[0]};\n"
s += f"constexpr int32_t kWhisperV3MelCols = {m.shape[1]};\n"
s += "\n"
s += "constexpr float kWhisperV3MelArray[] = {\n"
sep = ""
for i, f in enumerate(m.reshape(-1).tolist()):
s += f"{sep}{f:.8f}"
sep = ", "
if i and i % 7 == 0:
s += ",\n"
sep = ""

s += "};\n\n"
s += "} // namespace kaldifeat\n\n"
s += "#endif // KALDIFEAT_CSRC_WHISPER_V3_MEL_BANK_H_\n"

with open("whisper-v3-mel-bank.h", "w") as f:
f.write(s)


if __name__ == "__main__":
main()
18 changes: 14 additions & 4 deletions kaldifeat/csrc/whisper-fbank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/csrc/whisper-mel-bank.h"
#include "kaldifeat/csrc/whisper-v3-mel-bank.h"

#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
Expand All @@ -31,9 +32,18 @@
namespace kaldifeat {

WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
: opts_(opts),
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
opts.device) {
: opts_(opts) {
if (opts.num_mels == 80) {
mel_banks_ = std::make_unique<MelBanks>(kWhisperMelArray, kWhisperMelRows,
kWhisperMelCols, opts.device);
} else if (opts.num_mels == 128) {
mel_banks_ = std::make_unique<MelBanks>(
kWhisperV3MelArray, kWhisperV3MelRows, kWhisperV3MelCols, opts.device);
} else {
KALDIFEAT_ERR << "Unsupported num_mels: " << opts.num_mels
<< ". Support only 80 and 128";
}

opts_.frame_opts.samp_freq = 16000;
opts_.frame_opts.frame_shift_ms = 10;
opts_.frame_opts.frame_length_ms = 25;
Expand Down Expand Up @@ -67,7 +77,7 @@ torch::Tensor WhisperFbankComputer::Compute(
torch::Tensor power = (real.square() + imag.square());
#endif

torch::Tensor mel_energies = mel_banks_.Compute(power);
torch::Tensor mel_energies = mel_banks_->Compute(power);
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
torch::Tensor mel = (log_spec + 4.0) / 4.0;
Expand Down
6 changes: 5 additions & 1 deletion kaldifeat/csrc/whisper-fbank.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_

#include <memory>
#include <string>
#include <vector>

Expand All @@ -30,12 +31,15 @@ namespace kaldifeat {

struct WhisperFbankOptions {
FrameExtractionOptions frame_opts;
// for large v3, please use 128
int32_t num_mels = 80;

torch::Device device{"cpu"};
std::string ToString() const {
std::ostringstream os;
os << "WhisperFbankOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "num_mels=" << num_mels << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
Expand Down Expand Up @@ -64,7 +68,7 @@ class WhisperFbankComputer {

private:
WhisperFbankOptions opts_;
MelBanks mel_banks_;
std::unique_ptr<MelBanks> mel_banks_;
};

using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;
Expand Down
Loading

0 comments on commit 2624da8

Please sign in to comment.