diff --git a/CMakeLists.txt b/CMakeLists.txt index f027de5..27a841c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0135 NEW) endif() -set(KALDI_DECODER_VERSION "0.2.3") +set(KALDI_DECODER_VERSION "0.2.4") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/cmake/kaldifst.cmake b/cmake/kaldifst.cmake index 81b35e8..3360b64 100644 --- a/cmake/kaldifst.cmake +++ b/cmake/kaldifst.cmake @@ -1,18 +1,18 @@ function(download_kaldifst) include(FetchContent) - set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz") - set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz") - set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2") + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.10.tar.gz") + set(kaldifst_URL2 "https://hub.nuaa.cf/k2-fsa/kaldifst/archive/refs/tags/v1.7.10.tar.gz") + set(kaldifst_HASH "SHA256=7f7b3173a6584a6b1987f65ae7af2ac453d66b845f875a9d31074b8d2cd0de54") # If you don't have access to the Internet, # please pre-download kaldifst set(possible_file_locations - $ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz - ${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz - ${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz - /tmp/kaldifst-1.7.6.tar.gz - /star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz + $ENV{HOME}/Downloads/kaldifst-1.7.10.tar.gz + ${CMAKE_SOURCE_DIR}/kaldifst-1.7.10.tar.gz + ${CMAKE_BINARY_DIR}/kaldifst-1.7.10.tar.gz + /tmp/kaldifst-1.7.10.tar.gz + /star-fj/fangjun/download/github/kaldifst-1.7.10.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/kaldi-decoder/csrc/CMakeLists.txt b/kaldi-decoder/csrc/CMakeLists.txt index 03cbbf5..9e74d23 100644 --- a/kaldi-decoder/csrc/CMakeLists.txt +++ b/kaldi-decoder/csrc/CMakeLists.txt @@ -5,9 +5,11 @@ set(srcs decodable-ctc.cc eigen.cc faster-decoder.cc + lattice-faster-decoder.cc + lattice-simple-decoder.cc + simple-decoder.cc ) - add_library(kaldi-decoder-core ${srcs}) target_link_libraries(kaldi-decoder-core PUBLIC kaldifst_core) diff --git a/kaldi-decoder/csrc/decodable-ctc.cc b/kaldi-decoder/csrc/decodable-ctc.cc index d09009d..cc97fee 100644 --- a/kaldi-decoder/csrc/decodable-ctc.cc +++ b/kaldi-decoder/csrc/decodable-ctc.cc @@ -8,15 +8,16 @@ namespace kaldi_decoder { -DecodableCtc::DecodableCtc(const FloatMatrix &log_probs) - : log_probs_(log_probs) { +DecodableCtc::DecodableCtc(const FloatMatrix &log_probs, int32_t offset /*= 0*/) + : log_probs_(log_probs), offset_(offset) { p_ = &log_probs_(0, 0); num_rows_ = log_probs_.rows(); num_cols_ = log_probs_.cols(); } -DecodableCtc::DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols) - : p_(p), num_rows_(num_rows), num_cols_(num_cols) {} +DecodableCtc::DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols, + int32_t offset /*= 0*/) + : p_(p), num_rows_(num_rows), num_cols_(num_cols), offset_(offset) {} float DecodableCtc::LogLikelihood(int32_t frame, int32_t index) { // Note: We need to use index - 1 here since @@ -24,10 +25,10 @@ float DecodableCtc::LogLikelihood(int32_t frame, int32_t index) { // construction assert(index >= 1); - return *(p_ + frame * num_cols_ + index - 1); + return *(p_ + (frame - offset_) * num_cols_ + index - 1); } -int32_t DecodableCtc::NumFramesReady() const { return num_rows_; } +int32_t DecodableCtc::NumFramesReady() const { return offset_ + num_rows_; } int32_t DecodableCtc::NumIndices() const { return num_cols_; } diff --git a/kaldi-decoder/csrc/decodable-ctc.h b/kaldi-decoder/csrc/decodable-ctc.h index 7b9a419..71a482c 100644 --- a/kaldi-decoder/csrc/decodable-ctc.h +++ b/kaldi-decoder/csrc/decodable-ctc.h @@ -13,14 +13,15 @@ namespace kaldi_decoder { class DecodableCtc : public DecodableInterface { public: // It copies the input log_probs - explicit DecodableCtc(const FloatMatrix &log_probs); + explicit DecodableCtc(const FloatMatrix &log_probs, int32_t offset = 0); // It shares the memory with the input array. // // @param p Pointer to a 2-d array of shape (num_rows, num_cols). // The array should be kept alive as long as this object is still // alive. - DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols); + DecodableCtc(const float *p, int32_t num_rows, int32_t num_cols, + int32_t offset = 0); float LogLikelihood(int32_t frame, int32_t index) override; @@ -38,6 +39,7 @@ class DecodableCtc : public DecodableInterface { const float *p_ = nullptr; // pointer to a 2-d array int32_t num_rows_; // number of rows in the 2-d array int32_t num_cols_; // number of cols in the 2-d array + int32_t offset_ = 0; }; } // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/kaldi-math.h b/kaldi-decoder/csrc/kaldi-math.h new file mode 100644 index 0000000..02fd4c0 --- /dev/null +++ b/kaldi-decoder/csrc/kaldi-math.h @@ -0,0 +1,48 @@ +// kaldi-decoder/csrc/kaldi-math.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian; +// Jan Silovsky; Saarland University +// 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/base/kaldi-math.h +#ifndef KALDI_DECODER_CSRC_KALDI_MATH_H_ +#define KALDI_DECODER_CSRC_KALDI_MATH_H_ + +#include +#include + +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif + +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +// M_LOG_2PI = log(2*pi) +#ifndef M_LOG_2PI +#define M_LOG_2PI 1.8378770664093454835606594728112 +#endif + +#define KALDI_ISINF std::isinf +#define KALDI_ISNAN std::isnan + +#include "kaldi-decoder/csrc/log.h" + +namespace kaldi_decoder { + +/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)). +static inline bool ApproxEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + if (a == b) return true; + float diff = std::abs(a - b); + if (diff == std::numeric_limits::infinity() || diff != diff) + return false; // diff is +inf or nan. + return (diff <= relative_tolerance * (std::abs(a) + std::abs(b))); +} + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_KALDI_MATH_H_ diff --git a/kaldi-decoder/csrc/lattice-faster-decoder.cc b/kaldi-decoder/csrc/lattice-faster-decoder.cc new file mode 100644 index 0000000..b073576 --- /dev/null +++ b/kaldi-decoder/csrc/lattice-faster-decoder.cc @@ -0,0 +1,13 @@ +// kaldi-decoder/csrc/lattice-faster-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2018 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/decoder/lattice-faster-decoder.cc + +#include "kaldi-decoder/csrc/lattice-faster-decoder.h" +namespace kaldi_decoder {} diff --git a/kaldi-decoder/csrc/lattice-faster-decoder.h b/kaldi-decoder/csrc/lattice-faster-decoder.h new file mode 100644 index 0000000..0717660 --- /dev/null +++ b/kaldi-decoder/csrc/lattice-faster-decoder.h @@ -0,0 +1,274 @@ +// kaldi-decoder/csrc/lattice-faster-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/decoder/lattice-faster-decoder.h +#ifndef KALDI_DECODER_CSRC_LATTICE_FASTER_DECODER_H_ +#define KALDI_DECODER_CSRC_LATTICE_FASTER_DECODER_H_ + +#include +#include + +#include "fst/fst.h" +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/log.h" + +namespace kaldi_decoder { + +struct LatticeFasterDecoderConfig { + float beam; + int32_t max_active; + int32_t min_active; + float lattice_beam; + int32_t prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + float beam_delta; + float hash_ratio; + // Note: we don't make prune_scale configurable on the command line, it's not + // a very important parameter. It affects the algorithm that prunes the + // tokens as we go. + float prune_scale; + + // Number of elements in the block for Token and ForwardLink memory + // pool allocation. + int32_t memory_pool_tokens_block_size; + int32_t memory_pool_links_block_size; + + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + // fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderConfig( + float beam = 16.0, + int32_t max_active = std::numeric_limits::max(), + int32_t min_active = 200, float lattice_beam = 10.0, + int32_t prune_interval = 25, bool determinize_lattice = true, + float beam_delta = 0.5, float hash_ratio = 2.0, float prune_scale = 0.1, + int32_t memory_pool_tokens_block_size = 1 << 8, + int32_t memory_pool_links_block_size = 1 << 8) + : beam(beam), + max_active(max_active), + min_active(min_active), + lattice_beam(lattice_beam), + prune_interval(prune_interval), + determinize_lattice(determinize_lattice), + beam_delta(beam_delta), + hash_ratio(hash_ratio), + prune_scale(prune_scale), + memory_pool_tokens_block_size(memory_pool_tokens_block_size), + memory_pool_links_block_size(memory_pool_links_block_size) {} +#if 0 + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, + "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, + "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, + "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, + "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, + "Interval (in frames) at " + "which to prune tokens"); + opts->Register( + "determinize-lattice", &determinize_lattice, + "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register( + "beam-delta", &beam_delta, + "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, + "Setting used in decoder to " + "control hash behavior"); + opts->Register( + "memory-pool-tokens-block-size", &memory_pool_tokens_block_size, + "Memory pool block size suggestion for storing tokens (in elements). " + "Smaller uses less memory but increases cache misses."); + opts->Register( + "memory-pool-links-block-size", &memory_pool_links_block_size, + "Memory pool block size suggestion for storing links (in elements). " + "Smaller uses less memory but increases cache misses."); + } +#endif + std::string ToString() const { + std::ostringstream os; + + os << "LatticeFasterDecoderConfig("; + os << "beam=" << beam << ", "; + os << "max_active=" << max_active << ", "; + os << "min_active=" << min_active << ", "; + os << "lattice_beam=" << lattice_beam << ", "; + os << "prune_interval=" << prune_interval << ", "; + os << "determinize_lattice=" << (determinize_lattice ? "True" : "False") + << ", "; + + os << "beam_delta=" << beam_delta << ", "; + os << "hash_ratio=" << hash_ratio << ", "; + os << "prune_scale=" << prune_scale << ", "; + os << "memory_pool_tokens_block_size=" << memory_pool_tokens_block_size + << ", "; + os << "memory_pool_links_block_size=" << memory_pool_links_block_size + << ")"; + + return os.str(); + } + void Check() const { + KALDI_DECODER_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && + min_active <= max_active && prune_interval > 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && + prune_scale > 0.0 && prune_scale < 1.0); + } +}; + +namespace decoder { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + float graph_cost; // graph cost of traversing arc (contains LM, etc.) + float acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + ForwardLink(Token *next_tok, Label ilabel, Label olabel, float graph_cost, + float acoustic_cost, ForwardLink *next) + : next_tok(next_tok), + ilabel(ilabel), + olabel(olabel), + graph_cost(graph_cost), + acoustic_cost(acoustic_cost), + next(next) {} +}; + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + float tot_cost; + + // extra_cost is >= 0. After calling PruneForwardLinks, this equals the + // minimum difference between the cost of the best path that this link is a + // part of, and the cost of the absolute best path, under the assumption that + // any of the currently active states at the decoding front may eventually + // succeed (e.g. if you were to take the currently active states one by one + // and compute this difference, and then take the minimum). + float extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer(Token *backpointer) {} + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + StdToken(float tot_cost, float extra_cost, ForwardLinkT *links, Token *next, + Token *backpointer) + : tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) {} +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + float tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + float extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + void SetBackpointer(Token *backpointer) { this->backpointer = backpointer; } + + BackpointerToken(float tot_cost, float extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer) + : tot_cost(tot_cost), + extra_cost(extra_cost), + links(links), + next(next), + backpointer(backpointer) {} +}; + +} // namespace decoder + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to + support quick lookup of the current best path (see + lattice-faster-online-decoder.h) + + The FST you invoke this decoder which is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_LATTICE_FASTER_DECODER_H_ diff --git a/kaldi-decoder/csrc/lattice-simple-decoder.cc b/kaldi-decoder/csrc/lattice-simple-decoder.cc new file mode 100644 index 0000000..484ffb8 --- /dev/null +++ b/kaldi-decoder/csrc/lattice-simple-decoder.cc @@ -0,0 +1,659 @@ +// kaldi-decoder/csrc/lattice-simple-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/decoder/lattice-simple-decoder.cc + +#include "kaldi-decoder/csrc/lattice-simple-decoder.h" + +#include "kaldi-decoder/csrc/kaldi-math.h" + +namespace kaldi_decoder { + +void LatticeSimpleDecoder::InitDecoding() { + // clean up from last time: + cur_toks_.clear(); + prev_toks_.clear(); + ClearActiveTokens(); + warned_ = false; + decoding_finalized_ = false; + final_costs_.clear(); + num_toks_ = 0; + StateId start_state = fst_.Start(); + KALDI_DECODER_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, nullptr, nullptr); + active_toks_[0].toks = start_tok; + cur_toks_[start_state] = start_tok; + num_toks_++; + ProcessNonemitting(); +} + +void LatticeSimpleDecoder::ClearActiveTokens() { // a cleanup routine, at utt + // end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != nullptr;) { + tok->DeleteForwardLinks(); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_DECODER_ASSERT(num_toks_ == 0); +} + +bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) { + InitDecoding(); + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + + ProcessEmitting(decodable); + // Important to call PruneCurrentTokens before ProcessNonemitting, or we + // would get dangling forward pointers. Anyway, ProcessNonemitting uses the + // beam. + PruneCurrentTokens(config_.beam, &cur_toks_); + ProcessNonemitting(); + } + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !final_costs_.empty(); +} + +// FindOrAddToken either locates a token in cur_toks_, or if necessary inserts a +// new, empty token (i.e. with no forward links) for the current frame. [note: +// it's inserted if necessary into cur_toks_ and also into the singly linked +// list of tokens active on this frame (whose head is at active_toks_[frame]). +// +// Returns the Token pointer. Sets "changed" (if non-NULL) to true +// if the token was newly created or the cost changed. +LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken( + StateId state, int32_t frame, float tot_cost, bool emitting, + bool *changed) { + KALDI_DECODER_ASSERT(frame < active_toks_.size()); + Token *&toks = active_toks_[frame].toks; + + auto find_iter = cur_toks_.find(state); + if (find_iter == cur_toks_.end()) { // no such token presently. + // Create one. + const float extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token(tot_cost, extra_cost, nullptr, toks); + toks = new_tok; + num_toks_++; + cur_toks_[state] = new_tok; + + if (changed) { + *changed = true; + } + + return new_tok; + } else { + Token *tok = + find_iter->second; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { + tok->tot_cost = tot_cost; + if (changed) { + *changed = true; + } + } else { + if (changed) { + *changed = false; + } + } + return tok; + } +} + +void LatticeSimpleDecoder::ProcessNonemitting() { + KALDI_DECODER_ASSERT(!active_toks_.empty()); + int32_t frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within + // cur_toks_. Note-- this queue structure is is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + std::vector queue; + float best_cost = std::numeric_limits::infinity(); + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { + StateId state = iter->first; + + if (fst_.NumInputEpsilons(state) != 0) { + queue.push_back(state); + } + + best_cost = std::min(best_cost, iter->second->tot_cost); + } + + if (queue.empty()) { + if (!warned_) { + KALDI_DECODER_LOG + << "Error in ProcessNonEmitting: no surviving tokens: frame is " + << frame; + warned_ = true; + } + } + float cutoff = best_cost + config_.beam; + + while (!queue.empty()) { + StateId state = queue.back(); + queue.pop_back(); + Token *tok = cur_toks_[state]; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + tok->DeleteForwardLinks(); + tok->links = nullptr; + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + float graph_cost = arc.weight.Value(); + float cur_cost = tok->tot_cost; + float tot_cost = cur_cost + graph_cost; + + if (tot_cost < cutoff) { + bool changed; + Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + false, &changed); + + tok->links = new ForwardLink(new_tok, 0, arc.olabel, graph_cost, 0, + tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_.NumInputEpsilons(arc.nextstate) != 0) { + queue.push_back(arc.nextstate); + } + } + } + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame +// before that. We go backwards through the frames and stop when we reach a +// point where the delta-costs are not changing (and the delta controls when we +// consider a cost to have "not changed"). +void LatticeSimpleDecoder::PruneActiveTokens(float delta) { + int32_t cur_frame_plus_one = NumFramesDecoded(); + int32_t num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to + // subtract one to get the corresponding index for the decodable object. + for (int32_t f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them + // (2) we never pruned the forward links on the next frame, which + // + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) + active_toks_[f - 1].must_prune_forward_links = true; + if (links_pruned) active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; + } + if (f + 1 < cur_frame_plus_one && active_toks_[f + 1].must_prune_tokens) { + PruneTokensForFrame(f + 1); + active_toks_[f + 1].must_prune_tokens = false; + } + } + KALDI_DECODER_LOG << "PruneActiveTokens: pruned tokens from " + << num_toks_begin << " to " << num_toks_; +} + +// delta is the amount by which the extra_costs must +// change before it sets "extra_costs_changed" to true. If delta is larger, +// we'll tend to go back less far toward the beginning of the file. +void LatticeSimpleDecoder::PruneForwardLinks(int32_t frame, + bool *extra_costs_changed, + bool *links_pruned, float delta) { + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_DECODER_ASSERT(frame >= 0 && frame < active_toks_.size()); + if (active_toks_[frame].toks == nullptr) { // empty list; this should + // not happen. + if (!warned_) { + KALDI_DECODER_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + bool changed = true; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame].toks; tok != nullptr; + tok = tok->next) { + ForwardLink *link, *prev_link = nullptr; + // will recompute tok_extra_cost. + float tok_extra_cost = std::numeric_limits::infinity(); + for (link = tok->links; link != nullptr;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + float link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); + KALDI_DECODER_ASSERT(link_extra_cost == + link_extra_cost); // check for NaN + + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLink *next_link = link->next; + if (prev_link != nullptr) { + prev_link->next = next_link; + } else { + tok->links = next_link; + } + + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) { + KALDI_DECODER_WARN << "Negative extra_cost: " << link_extra_cost; + } + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) { + tok_extra_cost = link_extra_cost; + } + + prev_link = link; + link = link->next; + } + } + if (fabs(tok_extra_cost - tok->extra_cost) > delta) { + changed = true; + } + + tok->extra_cost = + tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + if (changed) { + *extra_costs_changed = true; + } + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } +} + +// Prune away any tokens on this frame that have no forward links. [we don't do +// this in PruneForwardLinks because it would give us a problem with dangling +// pointers]. +void LatticeSimpleDecoder::PruneTokensForFrame(int32_t frame) { + KALDI_DECODER_ASSERT(frame >= 0 && frame < active_toks_.size()); + Token *&toks = active_toks_[frame].toks; + if (toks == nullptr) { + KALDI_DECODER_WARN << "No tokens alive [doing pruning]"; + } + + Token *tok, *next_tok, *prev_tok = nullptr; + for (tok = toks; tok != nullptr; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // Next token is unreachable from end of graph; excise tok from list + // and delete tok. + if (prev_tok != nullptr) { + prev_tok->next = tok->next; + } else { + toks = tok->next; + } + delete tok; + num_toks_--; + } else { + prev_tok = tok; + } + } +} + +// PruneCurrentTokens deletes the tokens from the "toks" map, but not +// from the active_toks_ list, which could cause dangling forward pointers +// (will delete it during regular pruning operation). +void LatticeSimpleDecoder::PruneCurrentTokens( + float beam, std::unordered_map *toks) { + if (toks->empty()) { + KALDI_DECODER_LOG << "No tokens to prune.\n"; + return; + } + float best_cost = 1.0e+10; // positive == high cost == bad. + for (auto iter = toks->begin(); iter != toks->end(); ++iter) { + best_cost = std::min(best_cost, static_cast(iter->second->tot_cost)); + } + std::vector retained; + float cutoff = best_cost + beam; + for (auto iter = toks->begin(); iter != toks->end(); ++iter) { + if (iter->second->tot_cost < cutoff) { + retained.push_back(iter->first); + } + } + std::unordered_map tmp; + for (size_t i = 0; i < retained.size(); i++) { + tmp[retained[i]] = (*toks)[retained[i]]; + } + KALDI_DECODER_LOG << "Pruned to " << (retained.size()) << " toks.\n"; + tmp.swap(*toks); +} + +void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) { + int32_t frame = static_cast(active_toks_.size()) - + 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + prev_toks_.clear(); + cur_toks_.swap(prev_toks_); + + // Processes emitting arcs for one frame. Propagates from + // prev_toks_ to cur_toks_. + float cutoff = std::numeric_limits::infinity(); + for (auto iter = prev_toks_.begin(); iter != prev_toks_.end(); ++iter) { + StateId state = iter->first; + Token *tok = iter->second; + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + float ac_cost = -decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost >= cutoff) { + continue; + } else if (tot_cost + config_.beam < cutoff) { + cutoff = tot_cost + config_.beam; + } + + // AddToken adds the next_tok to cur_toks_ (if not already present). + Token *next_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, true, NULL); + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +void LatticeSimpleDecoder::FinalizeDecoding() { + int32_t final_frame_plus_one = NumFramesDecoded(); + int32_t num_toks_begin = num_toks_; + PruneForwardLinksFinal(); + for (int32_t f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + float dontcare = 0.0; + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_DECODER_LOG << "pruned tokens from " << num_toks_begin << " to " + << num_toks_; +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses the +// final-probs for pruning, otherwise it treats all tokens as final. +void LatticeSimpleDecoder::PruneForwardLinksFinal() { + KALDI_DECODER_ASSERT(!active_toks_.empty()); + int32_t frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == + nullptr) { // empty list; should not happen. + KALDI_DECODER_WARN << "No tokens alive at end of file\n"; + } + + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We're about to delete some of the tokens active on the final frame, so we + // clear cur_toks_ because otherwise it would then contain dangling pointers. + cur_toks_.clear(); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + float delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != nullptr; + tok = tok->next) { + ForwardLink *link, *prev_link = nullptr; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to + // infinity below we set it to the difference between the + // (score+final_prob) of this token, and the best such (score+final_prob). + + float final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + auto iter = final_costs_.find(tok); + if (iter != final_costs_.end()) { + final_cost = iter->second; + } else { + final_cost = std::numeric_limits::infinity(); + } + } + float tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != nullptr;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + float link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLink *next_link = link->next; + if (prev_link != nullptr) { + prev_link->next = next_link; + } else { + tok->links = next_link; + } + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) { + KALDI_DECODER_WARN << "Negative extra_cost: " << link_extra_cost; + } + + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) { + tok_extra_cost = link_extra_cost; + } + + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) { + tok_extra_cost = std::numeric_limits::infinity(); + } + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) { + changed = true; + } + + // will be +infinity or <= lattice_beam_. + tok->extra_cost = tok_extra_cost; + } + } // while changed +} + +void LatticeSimpleDecoder::ComputeFinalCosts( + std::unordered_map *final_costs, float *final_relative_cost, + float *final_best_cost) const { + KALDI_DECODER_ASSERT(!decoding_finalized_); + if (final_costs != nullptr) { + final_costs->clear(); + } + + float infinity = std::numeric_limits::infinity(); + float best_cost = infinity, best_cost_with_final = infinity; + + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { + StateId state = iter->first; + Token *tok = iter->second; + float final_cost = fst_.Final(state).Value(); + float cost = tok->tot_cost, cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != nullptr && final_cost != infinity) { + (*final_costs)[tok] = final_cost; + } + } + if (final_relative_cost != nullptr) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != nullptr) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +float LatticeSimpleDecoder::FinalRelativeCost() const { + if (!decoding_finalized_) { + float relative_cost; + ComputeFinalCosts(nullptr, &relative_cost, nullptr); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + +bool LatticeSimpleDecoder::GetBestPath(fst::Lattice *ofst, + bool use_final_probs) const { + fst::VectorFst fst; + GetRawLattice(&fst, use_final_probs); + ShortestPath(fst, ofst); + return (ofst->NumStates() > 0); +} + +// Outputs an FST corresponding to the raw, state-level +// tracebacks. +bool LatticeSimpleDecoder::GetRawLattice(fst::Lattice *ofst, + bool use_final_probs) const { + using Arc = fst::LatticeArc; + using StateId = Arc::StateId; + using Weight = Arc::Weight; + using Label = Arc::Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) { + KALDI_DECODER_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + } + + std::unordered_map final_costs_local; + + const std::unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + + if (!decoding_finalized_ && use_final_probs) { + ComputeFinalCosts(&final_costs_local, nullptr, nullptr); + } + + ofst->DeleteStates(); + int32_t num_frames = NumFramesDecoded(); + KALDI_DECODER_ASSERT(num_frames > 0); + const int32_t bucket_count = num_toks_ / 2 + 3; + std::unordered_map tok_map(bucket_count); + // First create all states. + for (int32_t f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == nullptr) { + KALDI_DECODER_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + for (Token *tok = active_toks_[f].toks; tok != nullptr; tok = tok->next) { + tok_map[tok] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. + // Because we always add new states to the head of the list + // active_toks_[f].toks, and the start state was the first one + // added, it will be the last one added to ofst. + if (f == 0 && ofst->NumStates() > 0) { + ofst->SetStart(ofst->NumStates() - 1); + } + } + StateId cur_state = 0; // we rely on the fact that we numbered these + // consecutively (AddState() returns the numbers in order..) + for (int32_t f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != nullptr; + tok = tok->next, cur_state++) { + for (ForwardLink *l = tok->links; l != nullptr; l = l->next) { + auto iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_DECODER_ASSERT(iter != tok_map.end()); + Arc arc(l->ilabel, l->olabel, Weight(l->graph_cost, l->acoustic_cost), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + auto iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, fst::LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, fst::LatticeWeight::One()); + } + } + } + } + KALDI_DECODER_ASSERT(cur_state == ofst->NumStates()); + return (cur_state != 0); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/lattice-simple-decoder.h b/kaldi-decoder/csrc/lattice-simple-decoder.h new file mode 100644 index 0000000..b1269d5 --- /dev/null +++ b/kaldi-decoder/csrc/lattice-simple-decoder.h @@ -0,0 +1,324 @@ +// kaldi-decoder/csrc/lattice-simple-decoder.h + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/decoder/lattice-simple-decoder.h +#ifndef KALDI_DECODER_CSRC_LATTICE_SIMPLE_DECODER_H_ +#define KALDI_DECODER_CSRC_LATTICE_SIMPLE_DECODER_H_ +#include +#include +#include + +#include "fst/fst.h" +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-itf.h" +#include "kaldi-decoder/csrc/log.h" +#include "kaldifst/csrc/lattice-weight.h" + +namespace kaldi_decoder { + +struct LatticeSimpleDecoderConfig { + float beam; + float lattice_beam; + int32_t prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + bool prune_lattice; + float beam_ratio; + float prune_scale; // Note: we don't make this configurable on the command + // line, it's not a very important parameter. It affects + // the algorithm that prunes the tokens as we go. + // fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeSimpleDecoderConfig(float beam = 16.0, float lattice_beam = 10.0, + int32_t prune_interval = 25, + bool determinize_lattice = true, + bool prune_lattice = true, float beam_ratio = 0.9, + float prune_scale = 0.1) + : beam(beam), + lattice_beam(lattice_beam), + prune_interval(prune_interval), + determinize_lattice(determinize_lattice), + prune_lattice(prune_lattice), + beam_ratio(beam_ratio), + prune_scale(prune_scale) {} +#if 0 + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam"); + opts->Register("prune-interval", &prune_interval, + "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, + "If true, " + "determinize the lattice (in a special sense, keeping only " + "best pdf-sequence for each word-sequence)."); + } +#endif + void Check() const { + KALDI_DECODER_ASSERT(beam > 0.0 && lattice_beam > 0.0 && + prune_interval > 0); + } + + std::string ToString() const { + std::ostringstream os; + + os << "LatticeSimpleDecoderConfig("; + os << "beam=" << beam << ", "; + os << "lattice_beam=" << lattice_beam << ", "; + os << "prune_interval=" << prune_interval << ", "; + os << "determinize_lattice=" << (determinize_lattice ? "True" : "False") + << ", "; + + os << "prune_lattice=" << prune_lattice << ", "; + os << "beam_ratio=" << beam_ratio << ", "; + os << "prune_scale=" << prune_scale << ")"; + + return os.str(); + } +}; + +/** Simplest possible decoder, included largely for didactic purposes and as a + means to debug more highly optimized decoders. See \ref decoders_simple + for more information. + */ +class LatticeSimpleDecoder { + public: + using Arc = fst::StdArc; + using Label = Arc::Label; + using StateId = Arc::StateId; + using Weight = Arc::Weight; + + // instantiate this class once for each thing you have to decode. + LatticeSimpleDecoder(const fst::Fst &fst, + const LatticeSimpleDecoderConfig &config) + : fst_(fst), config_(config), num_toks_(0) { + config.Check(); + } + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need + /// to call this. You can call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + // Outputs an FST corresponding to the single best path + // through the lattice. Returns true if result is nonempty + // (using the return status is deprecated, it will become void). + // If "use_final_probs" is true AND we reached the final-state + // of the graph then it will include those as final-probs, else + // it will treat all final-probs as one. + bool GetBestPath(fst::Lattice *lat, bool use_final_probs = true) const; + + // Outputs an FST corresponding to the raw, state-level + // tracebacks. Returns true if result is nonempty + // (using the return status is deprecated, it will become void). + // If "use_final_probs" is true AND we reached the final-state + // of the graph then it will include those as final-probs, else + // it will treat all final-probs as one. + bool GetRawLattice(fst::Lattice *lat, bool use_final_probs = true) const; + + ~LatticeSimpleDecoder() { ClearActiveTokens(); } + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + float FinalRelativeCost() const; + + const LatticeSimpleDecoderConfig &GetOptions() const { return config_; } + + int32_t NumFramesDecoded() const { return active_toks_.size() - 1; } + + // Returns true if any kind of traceback is available (not necessarily from + // a final state). + bool Decode(DecodableInterface *decodable); + + private: + struct Token; + // ForwardLinks are the links from a token to a token on the next frame. + // or sometimes on the current frame (for input-epsilon links). + struct ForwardLink { + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on link. + Label olabel; // olabel on link. + float graph_cost; // graph cost of traversing link (contains LM, etc.) + float acoustic_cost; // acoustic cost (pre-scaled) of traversing link + ForwardLink *next; // next in singly-linked list of forward links from a + // token. + ForwardLink(Token *next_tok, Label ilabel, Label olabel, float graph_cost, + float acoustic_cost, ForwardLink *next) + : next_tok(next_tok), + ilabel(ilabel), + olabel(olabel), + graph_cost(graph_cost), + acoustic_cost(acoustic_cost), + next(next) {} + }; + + // Token is what's resident in a particular state at a particular time. + // In this decoder a Token actually contains *forward* links. + // When first created, a Token just has the (total) cost. We add forward + // links from it when we process the next frame. + struct Token { + float tot_cost; // would equal weight.Value()... cost up to this point. + float extra_cost; // >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path this is on, + // and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + + ForwardLink *links; // Head of singly linked list of ForwardLinks + + Token *next; // Next in list of tokens for this frame. + + Token(float tot_cost, float extra_cost, ForwardLink *links, Token *next) + : tot_cost(tot_cost), + extra_cost(extra_cost), + links(links), + next(next) {} + + Token() = default; + + void DeleteForwardLinks() { + ForwardLink *l = links; + ForwardLink *m; + + while (l != nullptr) { + m = l->next; + delete l; + l = m; + } + links = nullptr; + } + }; + + // head and tail of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList() + : toks(nullptr), + must_prune_forward_links(true), + must_prune_tokens(true) {} + }; + + // FindOrAddToken either locates a token in cur_toks_, or if necessary inserts + // a new, empty token (i.e. with no forward links) for the current frame. + // [note: it's inserted if necessary into cur_toks_ and also into the singly + // linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). + // + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + inline Token *FindOrAddToken(StateId state, int32_t frame_plus_one, + float tot_cost, bool emitting, bool *changed); + + // delta is the amount by which the extra_costs must + // change before it sets "extra_costs_changed" to true. If delta is larger, + // we'll tend to go back less far toward the beginning of the file. + void PruneForwardLinks(int32_t frame, bool *extra_costs_changed, + bool *links_pruned, float delta); + + // Prune away any tokens on this frame that have no forward links. [we don't + // do this in PruneForwardLinks because it would give us a problem with + // dangling pointers]. + void PruneTokensForFrame(int32_t frame); + + void ClearActiveTokens(); // a cleanup routine, at utt end/begin + + void ProcessNonemitting(); + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses the + // final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, or zero for all + // states if none were final. It outputs to final_relative_cost, if non-NULL, + // the difference between the best forward-cost including the final-prob cost, + // and the best forward-cost without including the final-prob cost (this will + // usually be positive), or infinity if there were no final-probs. It outputs + // to final_best_cost, if non-NULL, the lowest for any token t active on the + // final frame, of t + final-cost[t], where final-cost[t] is the final-cost + // in the graph of the state corresponding to token t, or zero if there + // were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(std::unordered_map *final_costs, + float *final_relative_cost, + float *final_best_cost) const; + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(float delta); + + // PruneCurrentTokens deletes the tokens from the "toks" map, but not + // from the active_toks_ list, which could cause dangling forward pointers + // (will delete it during regular pruning operation). + void PruneCurrentTokens(float beam, + std::unordered_map *toks); + + void ProcessEmitting(DecodableInterface *decodable); + + private: + const fst::Fst &fst_; + LatticeSimpleDecoderConfig config_; + int32_t num_toks_; // current total #toks allocated... + bool warned_; + + std::unordered_map cur_toks_; + std::unordered_map prev_toks_; + std::vector active_toks_; // Lists of tokens, indexed by frame + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from + /// cur_toks_ to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + std::unordered_map final_costs_; + float final_relative_cost_; + float final_best_cost_; +}; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_LATTICE_SIMPLE_DECODER_H_ diff --git a/kaldi-decoder/csrc/simple-decoder.cc b/kaldi-decoder/csrc/simple-decoder.cc new file mode 100644 index 0000000..77721b1 --- /dev/null +++ b/kaldi-decoder/csrc/simple-decoder.cc @@ -0,0 +1,283 @@ +// kaldi-decoder/csrc/simple-decoder.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2012-2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/csrc/simple-decoder.h" + +#include +#include +#include +#include + +#include "kaldi-decoder/csrc/log.h" +#include "kaldifst/csrc/remove-eps-local.h" + +namespace kaldi_decoder { + +SimpleDecoder::~SimpleDecoder() { + ClearToks(cur_toks_); + ClearToks(prev_toks_); +} + +bool SimpleDecoder::Decode(DecodableInterface *decodable) { + InitDecoding(); + AdvanceDecoding(decodable); + return (!cur_toks_.empty()); +} + +void SimpleDecoder::InitDecoding() { + // clean up from last time: + ClearToks(cur_toks_); + ClearToks(prev_toks_); + // initialize decoding: + StateId start_state = fst_.Start(); + KALDI_DECODER_ASSERT(start_state != fst::kNoStateId); + StdArc dummy_arc(0, 0, StdWeight::One(), start_state); + cur_toks_[start_state] = new Token(dummy_arc, 0.0, nullptr); + num_frames_decoded_ = 0; + ProcessNonemitting(); +} + +void SimpleDecoder::AdvanceDecoding(DecodableInterface *decodable, + int32_t max_num_frames /*= -1*/) { + KALDI_DECODER_ASSERT(num_frames_decoded_ >= 0 && + "You must call InitDecoding() before AdvanceDecoding()"); + int32_t num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_DECODER_ASSERT(num_frames_ready >= num_frames_decoded_); + int32_t target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) { + target_frames_decoded = + std::min(target_frames_decoded, num_frames_decoded_ + max_num_frames); + } + + while (num_frames_decoded_ < target_frames_decoded) { + // note: ProcessEmitting() increments num_frames_decoded_ + ClearToks(prev_toks_); + cur_toks_.swap(prev_toks_); + ProcessEmitting(decodable); + ProcessNonemitting(); + PruneToks(beam_, &cur_toks_); + } +} + +bool SimpleDecoder::ReachedFinal() const { + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { + if (iter->second->cost_ != std::numeric_limits::infinity() && + fst_.Final(iter->first) != StdWeight::Zero()) + return true; + } + return false; +} + +float SimpleDecoder::FinalRelativeCost() const { + // as a special case, if there are no active tokens at all (e.g. some kind of + // pruning failure), return infinity. + double infinity = std::numeric_limits::infinity(); + if (cur_toks_.empty()) return infinity; + double best_cost = infinity, best_cost_with_final = infinity; + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { + // Note: Plus is taking the minimum cost, since we're in the tropical + // semiring. + best_cost = std::min(best_cost, iter->second->cost_); + best_cost_with_final = + std::min(best_cost_with_final, + iter->second->cost_ + fst_.Final(iter->first).Value()); + } + float extra_cost = best_cost_with_final - best_cost; + if (extra_cost != extra_cost) { // NaN. This shouldn't happen; it indicates + // some kind of error, most likely. + KALDI_DECODER_WARN << "Found NaN (likely search failure in decoding)"; + return infinity; + } + // Note: extra_cost will be infinity if no states were final. + return extra_cost; +} + +// Outputs an FST corresponding to the single best path +// through the lattice. +bool SimpleDecoder::GetBestPath(fst::Lattice *fst_out, + bool use_final_probs) const { + fst_out->DeleteStates(); + Token *best_tok = nullptr; + bool is_final = ReachedFinal(); + if (!is_final) { + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) + if (best_tok == nullptr || *best_tok < *(iter->second)) + best_tok = iter->second; + } else { + double infinity = std::numeric_limits::infinity(), + best_cost = infinity; + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { + double this_cost = iter->second->cost_ + fst_.Final(iter->first).Value(); + if (this_cost != infinity && this_cost < best_cost) { + best_cost = this_cost; + best_tok = iter->second; + } + } + } + if (best_tok == nullptr) return false; // No output. + + std::vector arcs_reverse; // arcs in reverse order. + for (Token *tok = best_tok; tok != nullptr; tok = tok->prev_) + arcs_reverse.push_back(tok->arc_); + KALDI_DECODER_ASSERT(arcs_reverse.back().nextstate == fst_.Start()); + arcs_reverse.pop_back(); // that was a "fake" token... gives no info. + + StateId cur_state = fst_out->AddState(); + fst_out->SetStart(cur_state); + for (ssize_t i = static_cast(arcs_reverse.size()) - 1; i >= 0; i--) { + fst::LatticeArc arc = arcs_reverse[i]; + arc.nextstate = fst_out->AddState(); + fst_out->AddArc(cur_state, arc); + cur_state = arc.nextstate; + } + if (is_final && use_final_probs) + fst_out->SetFinal( + cur_state, + fst::LatticeWeight(fst_.Final(best_tok->arc_.nextstate).Value(), 0.0)); + else + fst_out->SetFinal(cur_state, fst::LatticeWeight::One()); + fst::RemoveEpsLocal(fst_out); + return true; +} + +void SimpleDecoder::ProcessEmitting(DecodableInterface *decodable) { + int32_t frame = num_frames_decoded_; + // Processes emitting arcs for one frame. Propagates from + // prev_toks_ to cur_toks_. + double cutoff = std::numeric_limits::infinity(); + for (auto iter = prev_toks_.begin(); iter != prev_toks_.end(); ++iter) { + StateId state = iter->first; + Token *tok = iter->second; + KALDI_DECODER_ASSERT(state == tok->arc_.nextstate); + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const StdArc &arc = aiter.Value(); + if (arc.ilabel == 0) { + continue; + } + + // propagate.. + float acoustic_cost = -decodable->LogLikelihood(frame, arc.ilabel); + double total_cost = tok->cost_ + arc.weight.Value() + acoustic_cost; + + if (total_cost >= cutoff) { + continue; + } + + if (total_cost + beam_ < cutoff) { + cutoff = total_cost + beam_; + } + + Token *new_tok = new Token(arc, acoustic_cost, tok); + auto find_iter = cur_toks_.find(arc.nextstate); + if (find_iter == cur_toks_.end()) { + cur_toks_[arc.nextstate] = new_tok; + } else { + if (*(find_iter->second) < *new_tok) { + Token::TokenDelete(find_iter->second); + find_iter->second = new_tok; + } else { + Token::TokenDelete(new_tok); + } + } + } + } + num_frames_decoded_++; +} + +void SimpleDecoder::ProcessNonemitting() { + // Processes nonemitting arcs for one frame. Propagates within + // cur_toks_. + std::vector queue; + double infinity = std::numeric_limits::infinity(); + double best_cost = infinity; + for (auto iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { + queue.push_back(iter->first); + best_cost = std::min(best_cost, iter->second->cost_); + } + double cutoff = best_cost + beam_; + + while (!queue.empty()) { + StateId state = queue.back(); + queue.pop_back(); + Token *tok = cur_toks_[state]; + KALDI_DECODER_ASSERT(tok != nullptr && state == tok->arc_.nextstate); + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const StdArc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate nonemitting only... + continue; + } + + const float acoustic_cost = 0.0; + Token *new_tok = new Token(arc, acoustic_cost, tok); + if (new_tok->cost_ > cutoff) { + Token::TokenDelete(new_tok); + } else { + auto find_iter = cur_toks_.find(arc.nextstate); + if (find_iter == cur_toks_.end()) { + cur_toks_[arc.nextstate] = new_tok; + queue.push_back(arc.nextstate); + } else { + if (*(find_iter->second) < *new_tok) { + // find_iter has a higher cost + Token::TokenDelete(find_iter->second); + find_iter->second = new_tok; + queue.push_back(arc.nextstate); + } else { + Token::TokenDelete(new_tok); + } + } + } + } + } +} + +// static +void SimpleDecoder::ClearToks(std::unordered_map &toks) { + for (auto iter = toks.begin(); iter != toks.end(); ++iter) { + Token::TokenDelete(iter->second); + } + toks.clear(); +} + +// static +void SimpleDecoder::PruneToks(float beam, + std::unordered_map *toks) { + if (toks->empty()) { + KALDI_DECODER_LOG << "No tokens to prune.\n"; + return; + } + double best_cost = std::numeric_limits::infinity(); + for (auto iter = toks->begin(); iter != toks->end(); ++iter) { + best_cost = std::min(best_cost, iter->second->cost_); + } + + std::vector retained; + double cutoff = best_cost + beam; + for (auto iter = toks->begin(); iter != toks->end(); ++iter) { + if (iter->second->cost_ < cutoff) { + retained.push_back(iter->first); + } else { + Token::TokenDelete(iter->second); + } + } + + std::unordered_map tmp; + for (size_t i = 0; i < retained.size(); i++) { + tmp[retained[i]] = (*toks)[retained[i]]; + } + + KALDI_DECODER_LOG << "Pruned from " << toks->size() << " to " + << (retained.size()) << " toks.\n"; + tmp.swap(*toks); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/simple-decoder.h b/kaldi-decoder/csrc/simple-decoder.h new file mode 100644 index 0000000..d45df76 --- /dev/null +++ b/kaldi-decoder/csrc/simple-decoder.h @@ -0,0 +1,138 @@ +// kaldi-decoder/csrc/simple-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Lukas Burget; +// Saarland University (author: Arnab Ghoshal); +// Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation +#ifndef KALDI_DECODER_CSRC_SIMPLE_DECODER_H_ +#define KALDI_DECODER_CSRC_SIMPLE_DECODER_H_ + +#include + +#include "fst/fst.h" +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-itf.h" +#include "kaldi-decoder/csrc/log.h" +#include "kaldifst/csrc/lattice-weight.h" + +namespace kaldi_decoder { + +/** Simplest possible decoder, included largely for didactic purposes and as a + means to debug more highly optimized decoders. See \ref decoders_simple + for more information. + */ +class SimpleDecoder { + public: + using StdArc = fst::StdArc; + using StdWeight = StdArc::Weight; + using Label = StdArc::Label; + using StateId = StdArc::StateId; + + SimpleDecoder(const fst::Fst &fst, float beam) + : fst_(fst), beam_(beam) {} + + ~SimpleDecoder(); + SimpleDecoder(const SimpleDecoder &) = delete; + SimpleDecoder &operator=(const SimpleDecoder &) = delete; + + /// Decode this utterance. + /// Returns true if any tokens reached the end of the file (regardless of + /// whether they are in a final state); query ReachedFinal() after Decode() + /// to see whether we reached a final state. + bool Decode(DecodableInterface *decodable); + + bool ReachedFinal() const; + + // GetBestPath gets the decoding traceback. If "use_final_probs" is true + // AND we reached a final state, it limits itself to final states; + // otherwise it gets the most likely token not taking into account + // final-probs. fst_out will be empty (Start() == kNoStateId) if nothing was + // available due to search error. If Decode() returned true, it is safe to + // assume GetBestPath will return true. It returns true if the output lattice + // was nonempty (i.e. had states in it); using the return value is deprecated. + bool GetBestPath(fst::Lattice *fst_out, bool use_final_probs = true) const; + + /// *** The next functions are from the "new interface". *** + + /// FinalRelativeCost() serves the same function as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. + float FinalRelativeCost() const; + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need + /// to call this. You can call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object, but if max_num_frames is >= 0 it will decode no more than + /// that many frames. If it returns false, then no tokens are alive, + /// which is a kind of error state. + void AdvanceDecoding(DecodableInterface *decodable, + int32_t max_num_frames = -1); + + /// Returns the number of frames already decoded. + int32_t NumFramesDecoded() const { return num_frames_decoded_; } + + private: + class Token { + public: + fst::LatticeArc arc_; // We use LatticeArc so that we can separately + // store the acoustic and graph cost, in case + // we need to produce lattice-formatted output. + Token *prev_; + int32_t ref_count_; + double cost_; // accumulated total cost up to this point. + Token(const StdArc &arc, float acoustic_cost, Token *prev) + : prev_(prev), ref_count_(1) { + arc_.ilabel = arc.ilabel; + arc_.olabel = arc.olabel; + arc_.weight = fst::LatticeWeight(arc.weight.Value(), acoustic_cost); + arc_.nextstate = arc.nextstate; + if (prev) { + prev->ref_count_++; + cost_ = prev->cost_ + (arc.weight.Value() + acoustic_cost); + } else { + cost_ = arc.weight.Value() + acoustic_cost; + } + } + bool operator<(const Token &other) { return cost_ > other.cost_; } + + static void TokenDelete(Token *tok) { + while (--tok->ref_count_ == 0) { + Token *prev = tok->prev_; + delete tok; + if (prev == nullptr) { + return; + } else { + tok = prev; + } + } + KALDI_DECODER_ASSERT(tok->ref_count_ > 0); + } + }; + + // ProcessEmitting decodes the frame num_frames_decoded_ of the + // decodable object, then increments num_frames_decoded_. + void ProcessEmitting(DecodableInterface *decodable); + + void ProcessNonemitting(); + + std::unordered_map cur_toks_; + std::unordered_map prev_toks_; + const fst::Fst &fst_; + float beam_; + // Keep track of the number of frames decoded in the current file. + int32_t num_frames_decoded_ = -1; + + static void ClearToks(std::unordered_map &toks); // NOLINT + + static void PruneToks(float beam, std::unordered_map *toks); +}; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_SIMPLE_DECODER_H_ diff --git a/kaldi-decoder/python/csrc/CMakeLists.txt b/kaldi-decoder/python/csrc/CMakeLists.txt index 9470aa0..a8a8637 100644 --- a/kaldi-decoder/python/csrc/CMakeLists.txt +++ b/kaldi-decoder/python/csrc/CMakeLists.txt @@ -5,6 +5,8 @@ set(srcs decodable-itf.cc faster-decoder.cc kaldi-decoder.cc + lattice-simple-decoder.cc + simple-decoder.cc ) pybind11_add_module(_kaldi_decoder ${srcs}) diff --git a/kaldi-decoder/python/csrc/decodable-ctc.cc b/kaldi-decoder/python/csrc/decodable-ctc.cc index 7920421..27381f6 100644 --- a/kaldi-decoder/python/csrc/decodable-ctc.cc +++ b/kaldi-decoder/python/csrc/decodable-ctc.cc @@ -11,7 +11,8 @@ namespace kaldi_decoder { void PybindDecodableCtc(py::module *m) { using PyClass = DecodableCtc; py::class_(*m, "DecodableCtc") - .def(py::init(), py::arg("feats")); + .def(py::init(), py::arg("feats"), + py::arg("offset") = 0); } } // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/faster-decoder.cc b/kaldi-decoder/python/csrc/faster-decoder.cc index 95938d1..75a3581 100644 --- a/kaldi-decoder/python/csrc/faster-decoder.cc +++ b/kaldi-decoder/python/csrc/faster-decoder.cc @@ -53,7 +53,7 @@ void PybindFasterDecoder(py::module *m) { }, py::arg("use_final_probs") = true) .def("init_decoding", &PyClass::InitDecoding) - .def("advanced_decoding", &PyClass::AdvanceDecoding, py::arg("decodable"), + .def("advance_decoding", &PyClass::AdvanceDecoding, py::arg("decodable"), py::arg("max_num_frames") = -1) .def("num_frames_decoded", &PyClass::NumFramesDecoded); } diff --git a/kaldi-decoder/python/csrc/kaldi-decoder.cc b/kaldi-decoder/python/csrc/kaldi-decoder.cc index 1b00418..7294fc5 100644 --- a/kaldi-decoder/python/csrc/kaldi-decoder.cc +++ b/kaldi-decoder/python/csrc/kaldi-decoder.cc @@ -7,6 +7,8 @@ #include "kaldi-decoder/python/csrc/decodable-ctc.h" #include "kaldi-decoder/python/csrc/decodable-itf.h" #include "kaldi-decoder/python/csrc/faster-decoder.h" +#include "kaldi-decoder/python/csrc/lattice-simple-decoder.h" +#include "kaldi-decoder/python/csrc/simple-decoder.h" namespace kaldi_decoder { @@ -14,6 +16,8 @@ PYBIND11_MODULE(_kaldi_decoder, m) { m.doc() = "pybind11 binding of kaldi-decoder"; PybindDecodableItf(&m); PybindFasterDecoder(&m); + PybindLatticeSimpleDecoder(&m); + PybindSimpleDecoder(&m); PybindDecodableCtc(&m); } diff --git a/kaldi-decoder/python/csrc/lattice-simple-decoder.cc b/kaldi-decoder/python/csrc/lattice-simple-decoder.cc new file mode 100644 index 0000000..6044c0f --- /dev/null +++ b/kaldi-decoder/python/csrc/lattice-simple-decoder.cc @@ -0,0 +1,71 @@ +// kaldi-decoder/python/csrc/lattice-simple-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/python/csrc/lattice-simple-decoder.h" + +#include "kaldi-decoder/csrc/lattice-simple-decoder.h" + +namespace kaldi_decoder { + +static void PybindLatticeSimpleDecoderConfig(py::module *m) { + using PyClass = LatticeSimpleDecoderConfig; + + py::class_(*m, "LatticeSimpleDecoderConfig") + .def(py::init(), + py::arg("beam") = 16.0, py::arg("lattice_beam") = 10.0, + py::arg("prune_interval") = 25, + py::arg("determinize_lattice") = true, + py::arg("prune_lattice") = true, py::arg("beam_ratio") = 0.9, + py::arg("prune_scale") = 0.1) + .def_readwrite("beam", &PyClass::beam) + .def_readwrite("lattice_beam", &PyClass::lattice_beam) + .def_readwrite("prune_interval", &PyClass::prune_interval) + .def_readwrite("determinize_lattice", &PyClass::determinize_lattice) + .def_readwrite("prune_lattice", &PyClass::prune_lattice) + .def_readwrite("beam_ratio", &PyClass::beam_ratio) + .def_readwrite("prune_scale", &PyClass::prune_scale) + .def("__str__", &PyClass::ToString); +} + +void PybindLatticeSimpleDecoder(py::module *m) { + PybindLatticeSimpleDecoderConfig(m); + + using PyClass = LatticeSimpleDecoder; + py::class_(*m, "LatticeSimpleDecoder") + .def(py::init &, + const LatticeSimpleDecoderConfig &>(), + py::arg("fst"), py::arg("config")) + .def(py::init &, + const LatticeSimpleDecoderConfig &>(), + py::arg("fst"), py::arg("config")) + .def(py::init &, + const LatticeSimpleDecoderConfig &>(), + py::arg("fst"), py::arg("config")) + .def("get_config", &PyClass::GetOptions) + .def("num_frames_decoded", &PyClass::NumFramesDecoded) + .def("final_relative_cost", &PyClass::FinalRelativeCost) + .def("decode", &PyClass::Decode, py::arg("decodable")) + .def("init_decoding", &PyClass::InitDecoding) + .def("finalize_decoding", &PyClass::FinalizeDecoding) + .def( + "get_best_path", + [](PyClass &self, bool use_final_probs) + -> std::pair> { + fst::VectorFst fst; + bool ok = self.GetBestPath(&fst, use_final_probs); + return std::make_pair(ok, fst); + }, + py::arg("use_final_probs") = true) + .def( + "get_raw_lattice", + [](PyClass &self, bool use_final_probs) + -> std::pair> { + fst::VectorFst fst; + bool ok = self.GetRawLattice(&fst, use_final_probs); + return std::make_pair(ok, fst); + }, + py::arg("use_final_probs") = true); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/lattice-simple-decoder.h b/kaldi-decoder/python/csrc/lattice-simple-decoder.h new file mode 100644 index 0000000..3853740 --- /dev/null +++ b/kaldi-decoder/python/csrc/lattice-simple-decoder.h @@ -0,0 +1,16 @@ +// kaldi-decoder/python/csrc/lattice-simple-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_PYTHON_CSRC_LATTICE_SIMPLE_DECODER_H_ +#define KALDI_DECODER_PYTHON_CSRC_LATTICE_SIMPLE_DECODER_H_ + +#include "kaldi-decoder/python/csrc/kaldi-decoder.h" + +namespace kaldi_decoder { + +void PybindLatticeSimpleDecoder(py::module *m); + +} + +#endif // KALDI_DECODER_PYTHON_CSRC_LATTICE_SIMPLE_DECODER_H_ diff --git a/kaldi-decoder/python/csrc/simple-decoder.cc b/kaldi-decoder/python/csrc/simple-decoder.cc new file mode 100644 index 0000000..2633df6 --- /dev/null +++ b/kaldi-decoder/python/csrc/simple-decoder.cc @@ -0,0 +1,40 @@ +// kaldi-decoder/python/csrc/simple-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/python/csrc/simple-decoder.h" + +#include + +#include "kaldi-decoder/csrc/simple-decoder.h" + +namespace kaldi_decoder { + +void PybindSimpleDecoder(py::module *m) { + using PyClass = SimpleDecoder; + py::class_(*m, "SimpleDecoder") + .def(py::init &, float>(), py::arg("fst"), + py::arg("beam")) + .def(py::init &, float>(), + py::arg("fst"), py::arg("beam")) + .def(py::init &, float>(), + py::arg("fst"), py::arg("beam")) + .def("decode", &PyClass::Decode, py::arg("decodable")) + .def("reached_final", &PyClass::ReachedFinal) + .def( + "get_best_path", + [](PyClass &self, bool use_final_probs) + -> std::pair> { + fst::VectorFst fst; + bool ok = self.GetBestPath(&fst, use_final_probs); + return std::make_pair(ok, fst); + }, + py::arg("use_final_probs") = true) + .def("final_relative_cost", &PyClass::FinalRelativeCost) + .def("init_decoding", &PyClass::InitDecoding) + .def("advance_decoding", &PyClass::AdvanceDecoding, py::arg("decodable"), + py::arg("max_num_frames") = -1) + .def("num_frames_decoded", &PyClass::NumFramesDecoded); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/simple-decoder.h b/kaldi-decoder/python/csrc/simple-decoder.h new file mode 100644 index 0000000..f160e2a --- /dev/null +++ b/kaldi-decoder/python/csrc/simple-decoder.h @@ -0,0 +1,16 @@ +// kaldi-decoder/python/csrc/simple-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_PYTHON_CSRC_SIMPLE_DECODER_H_ +#define KALDI_DECODER_PYTHON_CSRC_SIMPLE_DECODER_H_ + +#include "kaldi-decoder/python/csrc/kaldi-decoder.h" + +namespace kaldi_decoder { + +void PybindSimpleDecoder(py::module *m); + +} + +#endif // KALDI_DECODER_PYTHON_CSRC_SIMPLE_DECODER_H_ diff --git a/kaldi-decoder/python/kaldi_decoder/__init__.py b/kaldi-decoder/python/kaldi_decoder/__init__.py index b70339a..326f935 100644 --- a/kaldi-decoder/python/kaldi_decoder/__init__.py +++ b/kaldi-decoder/python/kaldi_decoder/__init__.py @@ -1,6 +1,9 @@ from _kaldi_decoder import ( - FasterDecoderOptions, - FasterDecoder, - DecodableInterface, DecodableCtc, + DecodableInterface, + FasterDecoder, + FasterDecoderOptions, + LatticeSimpleDecoder, + LatticeSimpleDecoderConfig, + SimpleDecoder, )