diff --git a/examples/common.cpp b/examples/common.cpp index 603c655a184..997803b3dca 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -264,6 +264,82 @@ std::wstring convert_to_wstring(const std::string & input) { return converter.from_bytes(input); } +// split UTF-8 string into valid and invalid parts +// eg. (a = "�123456", result = {"�", "123456", ""}) +// eg. (a = "123456�", result = {"", "123456", "�"}) +// eg. (a = "�123456�", result = {"�", "123456", "�"}) +// eg. (a = "�123�456�", result = {"�", "123�456", "�"}) +// result = {invalid, valid?, invalid} +std::vector utf8_split(const std::string & a) { + if (a.empty()) {return {"", "", ""};} + std::string str1; + std::string str2; + std::string str3; + + // forward pass + for (int64_t i = 0; i < static_cast(a.length()); i++) { + auto value = static_cast(a[i]); + if (value >= 0 && value <= 127 || value >= 192 && value <= 247) { + // 1, 2, 3, 4 byte head + break; + } else if (value >= 128 && value <= 191) { + // body byte + str1 += a[i]; + } + } + + // backward pass + int length = 0; + int expect = 0; + for (int64_t i = static_cast(a.length()) - 1; i >= 0; i--) { + auto value = static_cast(a[i]); + if (value >= 0 && value <= 127) { + // 1 byte head + expect = 1; + length++; + break; + } else if (value >= 128 && value <= 191){ + // body byte + length++; + } else if (value >= 192 && value <= 223){ + // 2 bytes head + expect = 2; + length++; + break; + } else if (value >= 224 && value <= 239){ + // 3 bytes head + expect = 3; + length++; + break; + } else if (value >= 240 && value <= 247){ + // 4 bytes head + expect = 4; + length++; + break; + } + } + if (expect != length) { + str3 = a.substr(a.length() - length, length); + } + + str2 = a.substr(str1.length(), a.length() - str3.length()); + + if (str1 == str3 && str1.length() + str2.length() + str3.length() > a.length()) { + return {str1, str2, ""}; + } + return {str1, str2, str3}; +} + +// check if the start and end of the std::string are UTF-8 encoded +bool utf8_is_valid(const std::string & a) { + if (a.empty()) {return true;} + auto result = utf8_split(a); + if (result[0].empty() && result[2].empty()) { + return true; + } + return false; +} + void gpt_split_words(std::string str, std::vector& words) { const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; const std::regex re(pattern); @@ -639,10 +715,17 @@ bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); } +#if _WIN32 + else if (drwav_init_file_w(&wav, ConvertUTF8toUTF16(fname).c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); + return false; + } +#else else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); return false; } +#endif if (wav.channels != 1 && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); diff --git a/examples/common.h b/examples/common.h index 54f0b00d0ef..36df1c7f94b 100644 --- a/examples/common.h +++ b/examples/common.h @@ -2,6 +2,8 @@ #pragma once +#include "console.h" + #include #include #include @@ -77,6 +79,47 @@ std::string convert_to_utf8(const std::wstring & input); std::wstring convert_to_wstring(const std::string & input); +std::vector utf8_split(const std::string & a); + +bool utf8_is_valid(const std::string & a); + +// used to store merged tokens +struct utf8_token { + std::string text; // text of tokens + float p_sum; // token probability sum + int token_c; // total number of tokens in buffer + int64_t t0; // start time + int64_t t1; // end time + bool start_of_seg; // start of segment + + void clear() { + text = ""; + p_sum = 0.0; + token_c = 0; + t0 = 0; + t1 = 0; + start_of_seg = false; + } + + utf8_token() + : text(""), + p_sum(0.0), + token_c(0), + t0(0), + t1(0), + start_of_seg(false) + {} + + utf8_token(const std::string& text, float p_sum, int token_c, int64_t t0, int64_t t1, bool start_of_seg) + : text(text), + p_sum(p_sum), + token_c(token_c), + t0(t0), + t1(t1), + start_of_seg(start_of_seg) + {} +}; + void gpt_split_words(std::string str, std::vector& words); // split text into tokens diff --git a/examples/console.h b/examples/console.h new file mode 100644 index 00000000000..b70551ce1e8 --- /dev/null +++ b/examples/console.h @@ -0,0 +1,99 @@ +// +// Created by bobqianic on 9/19/2023. +// + +#ifndef CONSOLE_H +#define CONSOLE_H + +#include +#if _WIN32 +#define NOMINMAX +#define _WINSOCKAPI_ +#include +#include +#include +#endif + +#if _WIN32 +// use std::wstring on Windows +typedef std::wstring ustring; +#else +// use std::string on other platforms +typedef std::string ustring; +#endif + +#if _WIN32 +// Convert UTF-8 to UTF-16 +// Windows only +inline std::wstring ConvertUTF8toUTF16(const std::string& utf8Str) { + if (utf8Str.empty()) return {std::wstring()}; + + int requiredSize = MultiByteToWideChar(CP_UTF8, 0, utf8Str.c_str(), -1, NULL, 0); + if (requiredSize == 0) { + // Handle error here + return {std::wstring()}; + } + + std::wstring utf16Str(requiredSize, 0); + if (MultiByteToWideChar(CP_UTF8, 0, utf8Str.c_str(), -1, &utf16Str[0], requiredSize) == 0) { + // Handle error here + return {std::wstring()}; + } + + // Remove the additional null byte from the end + utf16Str.resize(requiredSize - 1); + + return utf16Str; +} +#endif + +#if _WIN32 +// Convert UTF-16 to UTF-8 +// Windows only +inline std::string ConvertUTF16toUTF8(const std::wstring & utf16Str) { + if (utf16Str.empty()) return {std::string()}; + + int requiredSize = WideCharToMultiByte(CP_UTF8, 0, utf16Str.c_str(), -1, NULL, 0, NULL, NULL); + if (requiredSize == 0) { + // Handle error here + return {std::string()}; + } + + std::string utf8Str(requiredSize, 0); + if (WideCharToMultiByte(CP_UTF8, 0, utf16Str.c_str(), -1, &utf8Str[0], requiredSize, NULL, NULL) == 0) { + // Handle error here + return {std::string()}; + } + + // Remove the additional null byte from the end + utf8Str.resize(requiredSize - 1); + + return utf8Str; +} +#endif + +// initialize the console +// set output encoding +inline bool init_console() { +#if _WIN32 + // set output encoding to UTF-8 + SetConsoleOutputCP(CP_UTF8); + HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); + if (hOut == INVALID_HANDLE_VALUE) { + return GetLastError(); + } + + DWORD dwMode = 0; + if (!GetConsoleMode(hOut, &dwMode)) { + return GetLastError(); + } + + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if (!SetConsoleMode(hOut, dwMode)) { + return GetLastError(); + } +#endif + return true; +} + +#endif //CONSOLE_H diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9699802e023..6af3a3ced88 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,5 +1,4 @@ #include "common.h" - #include "whisper.h" #include @@ -14,11 +13,13 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +typedef std::vector whisper_merged_tokens; + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] // Lowest is red, middle is yellow, highest is green. const std::vector k_colors = { "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", - "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m" }; // 500 -> 00:05.000 @@ -106,9 +107,9 @@ struct whisper_params { std::vector fname_out = {}; }; -void whisper_print_usage(int argc, char ** argv, const whisper_params & params); +void whisper_print_usage(int argc, const char ** argv, const whisper_params & params); -bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { +bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -177,7 +178,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { return true; } -void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { +void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); @@ -275,6 +276,51 @@ void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct wh } } +whisper_merged_tokens whisper_merge_tokens(struct whisper_context * ctx, const whisper_params & params, int s0, int n_segments) { + whisper_merged_tokens result; + utf8_token buf; + + // Loop through each token within the segments, merging any neighboring tokens that are incomplete + for (int i = s0; i < n_segments; i++) { + int64_t t0 = whisper_full_get_segment_t0(ctx, i); + int64_t t1 = whisper_full_get_segment_t1(ctx, i); + bool start_of_seg = true; + + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (!params.print_special) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * token_text = whisper_full_get_token_text(ctx, i, j); + const float token_p = whisper_full_get_token_p (ctx, i, j); + + if (utf8_is_valid(token_text)) { + result.emplace_back(std::string(token_text), token_p, 1, t0, t1, start_of_seg); + } else { + buf.text += std::string(token_text); + buf.p_sum += token_p; + buf.token_c++; + if (buf.token_c == 1) {buf.t0 = t0;} + buf.t1 = t1; + if (buf.token_c == 1 && start_of_seg) { + buf.start_of_seg = start_of_seg; + } + } + + if (buf.token_c > 1 && utf8_is_valid(buf.text)) { + result.push_back(buf); + buf.clear(); + } + + start_of_seg = false; + } + } + return result; +} + void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -283,69 +329,64 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper std::string speaker = ""; - int64_t t0 = 0; - int64_t t1 = 0; - // print the last n_new segments const int s0 = n_segments - n_new; - if (s0 == 0) { - printf("\n"); - } - - for (int i = s0; i < n_segments; i++) { - if (!params.no_timestamps || params.diarize) { - t0 = whisper_full_get_segment_t0(ctx, i); - t1 = whisper_full_get_segment_t1(ctx, i); - } + // merge tokens, ensuring each one is encoded in UTF-8 without any truncation + auto merged_tokens = whisper_merge_tokens(ctx, params, s0, n_segments); - if (!params.no_timestamps) { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + // print tokens to terminal + for (size_t i = 0; i < merged_tokens.size(); i++) { + // print headers at the beginning of each segment + if (merged_tokens[i].start_of_seg) { + if (!params.no_timestamps) { + printf("[%s --> %s] ", to_timestamp(merged_tokens[i].t0).c_str(), to_timestamp(merged_tokens[i].t1).c_str()); + } + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, merged_tokens[i].t0, merged_tokens[i].t1); + } + printf("%s", speaker.c_str()); } - if (params.diarize && pcmf32s.size() == 2) { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + // print a single token + if (params.print_colors) { + const int color_idx = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(merged_tokens[i].p_sum / static_cast(merged_tokens[i].token_c), 3)*float(k_colors.size())))); + printf("%s%s%s", k_colors[color_idx].c_str(), merged_tokens[i].text.c_str(), "\033[0m"); + } else { + printf("%s", merged_tokens[i].text.c_str()); } - if (params.print_colors) { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; - } + // print suffix at the end of each segment + if (i == merged_tokens.size() - 1 || (i < merged_tokens.size() - 1 && merged_tokens[i + 1].start_of_seg)) { + if (params.tinydiarize) { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { + printf("%s", params.tdrz_speaker_turn.c_str()); } - - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - - printf("%s%s", speaker.c_str(), text); - } - if (params.tinydiarize) { - if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { - printf("%s", params.tdrz_speaker_turn.c_str()); + // with timestamps or speakers: each segment on new line + if (!params.no_timestamps || params.diarize) { + printf("\n"); } - } - // with timestamps or speakers: each segment on new line - if (!params.no_timestamps || params.diarize) { - printf("\n"); + fflush(stdout); } - - fflush(stdout); } } +// convert UTF-8 path to UTF-16LE and open file with std::ofstream on Windows +// use UTF-8 path open file with std::ofstream on other systems +std::ofstream open(const std::string & path) { +#if _WIN32 + std::ofstream file_out(ConvertUTF8toUTF16(path)); +#else + std::ofstream file_out(path); +#endif + return file_out; +} + bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { - std::ofstream fout(fname); + auto fout = open(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); return false; @@ -372,7 +413,7 @@ bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_ } bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { - std::ofstream fout(fname); + auto fout = open(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); return false; @@ -404,7 +445,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_ } bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { - std::ofstream fout(fname); + auto fout = open(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); return false; @@ -464,7 +505,7 @@ char *escape_double_quotes_and_backslashes(const char *str) { } bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { - std::ofstream fout(fname); + auto fout = open(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); return false; @@ -499,7 +540,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_ } bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector> /*pcmf32s*/) { - std::ofstream fout(fname); + auto fout = open(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); const int n_segments = whisper_full_n_segments(ctx); @@ -523,7 +564,7 @@ bool output_json( const whisper_params & params, std::vector> pcmf32s, bool full) { - std::ofstream fout(fname); + auto fout = open(fname); int indent = 0; auto doindent = [&]() { @@ -688,7 +729,7 @@ bool output_json( // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector> pcmf32s) { - std::ofstream fout(fname); + auto fout = open(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); @@ -813,7 +854,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f } bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { - std::ofstream fout(fname); + auto fout = open(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); return false; @@ -852,7 +893,7 @@ bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_ return true; } -int main(int argc, char ** argv) { +int run(int argc, const char ** argv) { whisper_params params; if (whisper_params_parse(argc, argv, params) == false) { @@ -1076,3 +1117,21 @@ int main(int argc, char ** argv) { return 0; } + +#if _WIN32 +int wmain(int argc, const wchar_t ** argv_UTF16LE) { + init_console(); + std::vector buffer(argc); + std::vector argv_UTF8(argc); + for (int i = 0; i < argc; ++i) { + buffer[i] = ConvertUTF16toUTF8(argv_UTF16LE[i]); + argv_UTF8[i] = buffer[i].c_str(); + } + return run(argc, argv_UTF8.data()); +} +#else +int main(int argc, const char ** argv_UTF8) { + init_console(); + return run(argc, argv_UTF8); +} +#endif diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 47f1780b4ea..b43c15e20b7 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -56,9 +56,9 @@ struct whisper_params { std::string fname_out; }; -void whisper_print_usage(int argc, char ** argv, const whisper_params & params); +void whisper_print_usage(int argc, const char ** argv, const whisper_params & params); -bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { +bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -97,7 +97,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { return true; } -void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { +void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); @@ -126,7 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "\n"); } -int main(int argc, char ** argv) { +int run(int argc, const char ** argv) { whisper_params params; if (whisper_params_parse(argc, argv, params) == false) { @@ -428,6 +428,23 @@ int main(int argc, char ** argv) { whisper_print_timings(ctx); whisper_free(ctx); - return 0; } + +#if _WIN32 +int wmain(int argc, const wchar_t ** argv_UTF16LE) { + init_console(); + std::vector buffer(argc); + std::vector argv_UTF8(argc); + for (int i = 0; i < argc; ++i) { + buffer[i] = ConvertUTF16toUTF8(argv_UTF16LE[i]); + argv_UTF8[i] = buffer[i].c_str(); + } + return run(argc, argv_UTF8.data()); +} +#else +int main(int argc, const char ** argv_UTF8) { + init_console(); + return run(argc, argv_UTF8); +} +#endif \ No newline at end of file