Skip to content

Commit

Permalink
Merge pull request #12 from bobqianic/base
Browse files Browse the repository at this point in the history
Fix compatibility issue
  • Loading branch information
bobqianic authored Jun 25, 2024
2 parents 2b61aec + a53175a commit 7ea8a64
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 53 deletions.
6 changes: 5 additions & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,11 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) {

bool is_file_exist(const char *fileName)
{
std::ifstream infile(fileName);
#ifdef _WIN32
std::wifstream infile(console::UTF8toUTF16(fileName).c_str());
#else
std::ifstream infile(fileName);
#endif
return infile.good();
}

Expand Down
102 changes: 51 additions & 51 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,11 @@ struct whisper_params {

void whisper_print_usage(int argc, const char ** argv, const whisper_params & params);

char* whisper_param_turn_lowercase(char* in){
int string_len = strlen(in);
for(int i = 0; i < string_len; i++){
*(in+i) = tolower((unsigned char)*(in+i));
}
return in;
std::string toLowerCase(const std::string& input) {
std::string result = input; // Create a copy of the input string
std::transform(result.begin(), result.end(), result.begin(),
[](unsigned char c){ return std::tolower(c); });
return result;
}

bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) {
Expand Down Expand Up @@ -163,7 +162,7 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params)
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); }
else if (arg == "-l" || arg == "--language") { params.language = toLowerCase(argv[++i]); }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
Expand Down Expand Up @@ -954,35 +953,6 @@ void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
int run(int argc, const char ** argv) {
whisper_params params;

// If the only argument starts with "@", read arguments line-by-line
// from the given file.
std::vector<std::string> vec_args;
if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@') {
// Save the name of the executable.
vec_args.push_back(argv[0]);

// Open the response file.
char const * rspfile = argv[1] + sizeof(char);
std::ifstream fin(rspfile);
if (fin.is_open() == false) {
fprintf(stderr, "error: response file '%s' not found\n", rspfile);
return 1;
}

// Read the entire response file.
std::string line;
while (std::getline(fin, line)) {
vec_args.push_back(line);
}

// Use the contents of the response file as the command-line arguments.
argc = static_cast<int>(vec_args.size());
argv = static_cast<char **>(alloca(argc * sizeof (char *)));
for (int i = 0; i < argc; ++i) {
argv[i] = const_cast<char *>(vec_args[i].c_str());
}
}

if (whisper_params_parse(argc, argv, params) == false) {
whisper_print_usage(argc, argv, params);
return 1;
Expand Down Expand Up @@ -1151,7 +1121,6 @@ int run(int argc, const char ** argv) {
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;

wparams.heuristic = params.heuristic;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;

wparams.debug_mode = params.debug_mode;
Expand Down Expand Up @@ -1295,22 +1264,53 @@ int run(int argc, const char ** argv) {
return 0;
}

#if _WIN32
int wmain(int argc, const wchar_t ** argv_UTF16LE) {
console::init(true, true);
atexit([]() { console::cleanup(); });
std::vector<std::string> buffer(argc);
std::vector<const char*> argv_UTF8(argc);
for (int i = 0; i < argc; ++i) {
buffer[i] = console::UTF16toUTF8(argv_UTF16LE[i]);
argv_UTF8[i] = buffer[i].c_str();
}
return run(argc, argv_UTF8.data());
// Platform-specific function to convert UTF-16 to UTF-8
#ifdef _WIN32
std::string UTF16toUTF8(const wchar_t* utf16str) {
return console::UTF16toUTF8(utf16str);
}
#define MAIN wmain
#define CHAR_TYPE const wchar_t
#else
int main(int argc, const char ** argv_UTF8) {
#define MAIN main
#define CHAR_TYPE const char
#endif

int MAIN(int argc, CHAR_TYPE** argv) {
console::init(true, true);
atexit([]() { console::cleanup(); });
return run(argc, argv_UTF8);
}

#ifdef _WIN32
auto convert_to_utf8 = UTF16toUTF8;
#else
auto convert_to_utf8 = [](const char* str) { return std::string(str); };
#endif

std::vector<std::string> args;
if (argc == 2 && argv != nullptr && argv[1] != nullptr && convert_to_utf8(argv[1])[0] == '@') {
args.push_back(convert_to_utf8(argv[0]));
const char* rspfile = convert_to_utf8(argv[1]).c_str() + 1; // skip '@'
std::ifstream fin(rspfile);

if (!fin.is_open()) {
fprintf(stderr, "error: response file '%s' not found\n", rspfile);
return 1;
}

std::string line;
while (std::getline(fin, line)) {
args.push_back(line);
}
} else {
for (int i = 0; i < argc; ++i) {
args.push_back(convert_to_utf8(argv[i]));
}
}

std::vector<const char*> argv_converted(args.size());
for (size_t i = 0; i < args.size(); ++i) {
argv_converted[i] = args[i].c_str();
}

return run(static_cast<int>(args.size()), argv_converted.data());
}
1 change: 0 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ int main(int argc, char ** argv) {

wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;

wparams.debug_mode = params.debug_mode;
Expand Down

0 comments on commit 7ea8a64

Please sign in to comment.