diff --git a/rasr/feature_scorer.py b/rasr/feature_scorer.py index 50a0fe52..6a390068 100644 --- a/rasr/feature_scorer.py +++ b/rasr/feature_scorer.py @@ -7,6 +7,7 @@ "ReturnnScorer", "InvAlignmentPassThroughFeatureScorer", "PrecomputedHybridFeatureScorer", + "OnnxFeatureScorer", ] from sisyphus import * @@ -15,6 +16,8 @@ import os +from typing import Union, Dict, Optional + from .config import * from i6_core.util import get_returnn_root @@ -126,25 +129,51 @@ def __init__(self, prior_mixtures, scale=1.0, priori_scale=0.0, prior_file=None) self.config.normalize_mixture_weights = False -class TFLabelContextFeatureScorer(FeatureScorer): +class OnnxFeatureScorer(FeatureScorer): def __init__( self, - fs_tf_config, - contextPriorFile, - diphonePriorFile, - prior_mixtures, - prior_scale, + *, + mixtures: tk.Path, + model: tk.Path, + io_map: Dict[str, str], + label_log_posterior_scale: float = 1.0, + label_prior_scale: float, + label_log_prior_file: Optional[tk.Path] = None, + apply_log_on_output: bool = False, + negate_output: bool = True, + intra_op_threads: int = 1, + inter_op_threads: int = 1, + **kwargs, ): - super().__init__() + """ + :param mixtures: path to a *.mix file e.g. output of either EstimateMixturesJob or CreateDummyMixturesJob + :param model: path of a model e.g. output of ExportPyTorchModelToOnnxJob + :param io_map: mapping between internal rasr identifiers and the model related input/output. Default key values + are "features" and "output", and optionally "features-size", e.g. + io_map = {"features": "data", "output": "classes"} + :param label_log_posterior_scale: scales for the log probability of a label e.g. 1.0 is recommended + :param label_prior_scale: scale for the prior log probability of a label reasonable e.g. values in [0.1, 0.7] interval + :param label_log_prior_file: xml file containing log prior probabilities e.g. estimated from the model via povey method + :param apply_log_on_output: whether to apply the log-function on the output, usefull if the model outputs softmax instead of log-softmax + :param negate_output: whether negate output (because the model outputs log softmax and not negative log softmax + :param intra_op_threads: Onnxruntime session's number of parallel threads within each operator + :param inter_op_threads: Onnxruntime session's number of parallel threads between operators used only for parallel execution mode + """ + super().__init__(*args, **kwargs) - self.config = RasrConfig() - self.config.feature_scorer_type = "tf-label-context-scorer" - self.config.file = prior_mixtures - self.config.num_label_contexts = 46 - self.config.prior_scale = prior_scale - self.config.context_prior = contextPriorFile - self.config.diphone_prior = diphonePriorFile - self.config.normalize_mixture_weights = False - self.config.loader = fs_tf_config.loader - self.config.input_map = fs_tf_config.input_map - self.config.output_map = fs_tf_config.output_map + self.config.feature_scorer_type = "onnx-feature-scorer" + self.config.file = mixtures + self.config.scale = label_log_posterior_scale + self.config.priori_scale = label_prior_scale + if label_log_prior_file is not None: + self.config.prior_file = label_log_prior_file + + self.config.session.file = model + self.config.apply_log_on_output = apply_log_on_output + self.config.negate_output = negate_output + + self.post_config.session.intra_op_num_threads = intra_op_threads + self.post_config.session.inter_op_num_threads = inter_op_threads + + for k, v in io_map.items(): + self.config.io_map[k] = v