diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 19320c5..89f387a 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -31,7 +31,7 @@ struct ContentView: View { @State private var availableModels: [String] = [] @State private var availableLanguages: [String] = [] @State private var disabledModels: [String] = WhisperKit.recommendedModels().disabled - + @AppStorage("promptText") private var promptText: String? @AppStorage("selectedAudioInput") private var selectedAudioInput: String = "No Audio Input" @AppStorage("selectedModel") private var selectedModel: String = WhisperKit.recommendedModels().default @AppStorage("selectedTab") private var selectedTab: String = "Transcribe" @@ -765,6 +765,11 @@ struct ContentView: View { var settingsForm: some View { List { + + + + + HStack { Text("Show Timestamps") InfoButton("Toggling this will include/exclude timestamps in both the UI and the prefill tokens.\nEither <|notimestamps|> or <|0.00|> will be forced based on this setting unless \"Prompt Prefill\" is de-selected.") @@ -817,6 +822,14 @@ struct ContentView: View { } .padding(.horizontal) .padding(.bottom) + + TextField("Enter prompt text", text: Binding( + get: { self.promptText ?? "" }, + set: { self.promptText = $0.isEmpty ? nil : $0 } + )) + .textFieldStyle(.roundedBorder) + .padding(.horizontal) + .padding(.bottom) VStack { Text("Starting Temperature:") @@ -1303,7 +1316,7 @@ struct ContentView: View { let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate let seekClip: [Float] = [lastConfirmedSegmentEndSeconds] - let options = DecodingOptions( + var options = DecodingOptions( verbose: true, task: task, language: languageCode, @@ -1318,6 +1331,19 @@ struct ContentView: View { clipTimestamps: seekClip, chunkingStrategy: chunkingStrategy ) + + // Prompt + if let promptText = promptText { + guard whisperKit.tokenizer != nil else { + throw WhisperError.tokenizerUnavailable() + } + + if promptText.count > 0, let tokenizer = whisperKit.tokenizer { + options.promptTokens = tokenizer.encode(text: " " + promptText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } + options.usePrefillPrompt = true + } + } + // Early stopping checks let decodingCallback: ((TranscriptionProgress) -> Bool?) = { (progress: TranscriptionProgress) in @@ -1542,7 +1568,7 @@ struct ContentView: View { print(selectedLanguage) print(languageCode) - let options = DecodingOptions( + var options = DecodingOptions( verbose: true, task: task, language: languageCode, @@ -1556,6 +1582,18 @@ struct ContentView: View { wordTimestamps: true, // required for eager mode firstTokenLogProbThreshold: -1.5 // higher threshold to prevent fallbacks from running to often ) + + // Prompt + if let promptText = promptText { + guard whisperKit.tokenizer != nil else { + throw WhisperError.tokenizerUnavailable() + } + + if promptText.count > 0, let tokenizer = whisperKit.tokenizer { + options.promptTokens = tokenizer.encode(text: " " + promptText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } + options.usePrefillPrompt = true + } + } // Early stopping checks let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in