From d0f38def0878d13ee9fbb50a6c31c0232c4d07d9 Mon Sep 17 00:00:00 2001 From: Tamotsu Takahashi Date: Wed, 1 Jan 2025 23:44:29 +0900 Subject: [PATCH] Expose more ctx->vocab interfaces. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I need these functions to implement a kind of weighting coefficient logits_filter_callback like: ``` void filter_callback( struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data * tokens, int n_tokens, float * logits, void * user_data ) { const static std::vector good_words = { "音声", "認識" }; std::wstring_convert, char32_t> conv; auto prev = n_tokens > 0 ? std::string(whisper_token_to_str(ctx, tokens[n_tokens - 1].id)) : ""; for (const std::string & token : good_words) { auto s32 = conv.from_bytes(token); auto s0 = conv.to_bytes(s32[0]); auto s1 = conv.to_bytes(s32[1]); if (whisper_token_exists(ctx, token.c_str())) { logits[whisper_str_to_token(ctx, token.c_str())] *= 2; } else if ( prev.size() >= s0.size() && prev.compare(prev.size() - s0.size(), s0.size(), s0) == 0 && whisper_token_exists(ctx, s1.c_str()) ) { logits[whisper_str_to_token(ctx, s1.c_str())] *= 1.6; } else if (whisper_token_exists(ctx, s0.c_str())) { logits[whisper_str_to_token(ctx, s0.c_str())] *= 1.2; } } } ``` --- include/whisper.h | 3 +++ src/whisper.cpp | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/include/whisper.h b/include/whisper.h index 71949bdd397..8999c1c3636 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -408,6 +408,9 @@ extern "C" { WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx); + // String -> Token Id. Uses the vocabulary in the provided context + WHISPER_API bool whisper_token_exists(struct whisper_context * ctx, const char * str); + WHISPER_API whisper_token whisper_str_to_token(struct whisper_context * ctx, const char * str); // Special tokens WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); diff --git a/src/whisper.cpp b/src/whisper.cpp index bcc530ae891..aea3f20bc71 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -4068,6 +4068,14 @@ const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token to return ctx->vocab.id_to_token.at(token).c_str(); } +whisper_token whisper_str_to_token(struct whisper_context * ctx, const char * str) { + return ctx->vocab.token_to_id.at(str); +} + +bool whisper_token_exists(struct whisper_context * ctx, const char * str) { + return ctx->vocab.token_to_id.find(str) != ctx->vocab.token_to_id.end(); +} + whisper_token whisper_token_eot(struct whisper_context * ctx) { return ctx->vocab.token_eot; }