diff --git a/sherpa/cpp_api/bin/online-recognizer.cc b/sherpa/cpp_api/bin/online-recognizer.cc index ffbd26920..b55a369ae 100644 --- a/sherpa/cpp_api/bin/online-recognizer.cc +++ b/sherpa/cpp_api/bin/online-recognizer.cc @@ -132,6 +132,12 @@ int32_t main(int32_t argc, char *argv[]) { } config.Validate(); + + if (config.use_gpu) { + config.feat_config.fbank_opts.device = torch::Device("cuda:0"); + } else { + config.feat_config.fbank_opts.device = torch::Device("cpu"); + } SHERPA_CHECK_EQ(config.feat_config.fbank_opts.frame_opts.samp_freq, expected_sample_rate) @@ -147,6 +153,8 @@ int32_t main(int32_t argc, char *argv[]) { torch::Tensor tail_padding = torch::zeros( {static_cast(padding_seconds * expected_sample_rate)}, torch::kFloat); + + tail_padding = tail_padding.to(config.feat_config.fbank_opts.device); sherpa::OnlineRecognizer recognizer(config); if (use_wav_scp) { @@ -193,6 +201,7 @@ int32_t main(int32_t argc, char *argv[]) { {d.NumCols()}, torch::kFloat) / 32768; auto s = recognizer.CreateStream(); + tensor = tensor.to(config.feat_config.fbank_opts.device); s->AcceptWaveform(expected_sample_rate, tensor); s->AcceptWaveform(expected_sample_rate, tail_padding); s->InputFinished(); @@ -227,6 +236,7 @@ int32_t main(int32_t argc, char *argv[]) { wave.index({torch::indexing::Slice(start, end)}); start = end; + samples = samples.to(config.feat_config.fbank_opts.device); s->AcceptWaveform(expected_sample_rate, samples); while (recognizer.IsReady(s.get())) { @@ -265,7 +275,7 @@ int32_t main(int32_t argc, char *argv[]) { torch::Tensor wave = sherpa::ReadWave(po.GetArg(i), expected_sample_rate).first; - + wave = wave.to(config.feat_config.fbank_opts.device); s->AcceptWaveform(expected_sample_rate, wave); s->AcceptWaveform(expected_sample_rate, tail_padding);