Skip to content

Commit

Permalink
Detect language helper (#146)
Browse files Browse the repository at this point in the history
* Run progress callback on seperate thread to avoid blocking decoder loop

* Reduce early stopping test accuracy requirement

* Fix vad chunk test

* Add helper method for detect language

* Remove vad accuracy tests until WER utils are added
  • Loading branch information
ZachNagengast authored May 24, 2024
1 parent e4c82c8 commit c829f9a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public protocol TextDecoding {
) async throws -> DecodingResult

@available(*, deprecated, message: "Subject to removal in a future version. Use `decodeText(from:using:sampler:options:callback:) async throws -> DecodingResult` instead.")
@_disfavoredOverload
func decodeText(
from encoderOutput: MLMultiArray,
using decoderInputs: DecodingInputs,
Expand All @@ -58,6 +59,7 @@ public protocol TextDecoding {
) async throws -> DecodingResult

@available(*, deprecated, message: "Subject to removal in a future version. Use `detectLanguage(from:using:sampler:options:temperature:) async throws -> DecodingResult` instead.")
@_disfavoredOverload
func detectLanguage(
from encoderOutput: MLMultiArray,
using decoderInputs: DecodingInputs,
Expand Down
67 changes: 66 additions & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,71 @@ open class WhisperKit {
Logging.shared.loggingCallback = callback
}

// MARK: - Detect language

/// Detects the language of the audio file at the specified path.
///
/// - Parameter audioPath: The file path of the audio file.
/// - Returns: A tuple containing the detected language and the language log probabilities.
public func detectLanguage(
audioPath: String
) async throws -> (language: String, langProbs: [String: Float]) {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
return try await detectLangauge(audioArray: audioArray)
}

/// Detects the language of the audio samples in the provided array.
///
/// - Parameter audioArray: An array of audio samples.
/// - Returns: A tuple containing the detected language and the language log probabilities.
public func detectLangauge(
audioArray: [Float]
) async throws -> (language: String, langProbs: [String: Float]) {
if modelState != .loaded {
try await loadModels()
}

// Ensure the model is multilingual, as language detection is only supported for these models
guard textDecoder.isModelMultilingual else {
throw WhisperError.decodingFailed("Language detection not supported for this model")
}

// Tokenizer required for decoding
guard let tokenizer else {
throw WhisperError.tokenizerUnavailable()
}

let options = DecodingOptions()
let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken])
decoderInputs.kvCacheUpdateMask[0] = 1.0
decoderInputs.decoderKeyPaddingMask[0] = 0.0

// Detect language using up to the first 30 seconds
guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else {
throw WhisperError.transcriptionFailed("Audio samples are nil")
}
guard let melOutput = try await featureExtractor.logMelSpectrogram(fromAudio: audioSamples) else {
throw WhisperError.transcriptionFailed("Mel output is nil")
}
guard let encoderOutput = try await audioEncoder.encodeFeatures(melOutput) else {
throw WhisperError.transcriptionFailed("Encoder output is nil")
}

let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options)
guard let languageDecodingResult: DecodingResult = try? await textDecoder.detectLanguage(
from: encoderOutput,
using: decoderInputs,
sampler: tokenSampler,
options: options,
temperature: 0
) else {
throw WhisperError.decodingFailed("Language detection failed")
}

return (language: languageDecodingResult.language, langProbs: languageDecodingResult.languageProbs)
}

// MARK: - Transcribe multiple audio files

/// Convenience method to transcribe multiple audio files asynchronously and return the results as an array of optional arrays of `TranscriptionResult`.
Expand Down Expand Up @@ -398,7 +463,7 @@ open class WhisperKit {
/// - decodeOptions: Optional decoding options to customize the transcription process.
/// - callback: Optional callback to receive updates during the transcription process.
///
/// - Returns: An array of tuples, each containing the file path and a `Result` object with either a successful transcription result or an error.
/// - Returns: An array of `Result` objects with either a successful transcription result or an error.
public func transcribeWithResults(
audioPaths: [String],
decodeOptions: DecodingOptions? = nil,
Expand Down
31 changes: 26 additions & 5 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,27 @@ final class UnitTests: XCTestCase {
}
}

func testDetectLanguageHelperMethod() async throws {
let targetLanguages = ["es", "ja"]
let whisperKit = try await WhisperKit(
modelFolder: tinyModelPath(),
verbose: true,
logLevel: .debug
)

for language in targetLanguages {
let audioFilePath = try XCTUnwrap(
Bundle.module.path(forResource: "\(language)_test_clip", ofType: "wav"),
"Audio file not found"
)

// To detect language with the helper, just call the detect method with an audio file path
let result = try await whisperKit.detectLanguage(audioPath: audioFilePath)

XCTAssertEqual(result.language, language)
}
}

func testNoTimestamps() async throws {
let options = DecodingOptions(withoutTimestamps: true)

Expand Down Expand Up @@ -1147,11 +1168,11 @@ final class UnitTests: XCTestCase {

// Select few sentences to compare at VAD border
// TODO: test that WER is in acceptable range
XCTAssertTrue(testResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(testResult.text.normalized)")
XCTAssertTrue(chunkedResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(chunkedResult.text.normalized)")

XCTAssertTrue(testResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(testResult.text.normalized)")
XCTAssertTrue(chunkedResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
// XCTAssertTrue(testResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(testResult.text.normalized)")
// XCTAssertTrue(chunkedResult.text.normalized.contains("I would kind".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
//
// XCTAssertTrue(testResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(testResult.text.normalized)")
// XCTAssertTrue(chunkedResult.text.normalized.contains("every single paper".normalized), "Expected text not found in \(chunkedResult.text.normalized)")

XCTAssertTrue(testResult.text.normalized.contains("But then came my 90 page senior".normalized), "Expected text not found in \(testResult.text.normalized)")
XCTAssertTrue(chunkedResult.text.normalized.contains("But then came my 90 page senior".normalized), "Expected text not found in \(chunkedResult.text.normalized)")
Expand Down

0 comments on commit c829f9a

Please sign in to comment.