Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx Feature Scorer #452

Merged
merged 12 commits into from
Sep 21, 2023
68 changes: 50 additions & 18 deletions rasr/feature_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"ReturnnScorer",
"InvAlignmentPassThroughFeatureScorer",
"PrecomputedHybridFeatureScorer",
"OnnxFeatureScorer",
]

from sisyphus import *
Expand All @@ -15,6 +16,8 @@

import os

from typing import Union, Dict

from .config import *
from i6_core.util import get_returnn_root

Expand Down Expand Up @@ -126,25 +129,54 @@ def __init__(self, prior_mixtures, scale=1.0, priori_scale=0.0, prior_file=None)
self.config.normalize_mixture_weights = False


class TFLabelContextFeatureScorer(FeatureScorer):
class OnnxFeatureScorer(FeatureScorer):
def __init__(
Atticus1806 marked this conversation as resolved.
Show resolved Hide resolved
self,
fs_tf_config,
contextPriorFile,
diphonePriorFile,
prior_mixtures,
prior_scale,
*,
mixtures: tk.Path,
model: tk.Path,
io_map: Dict[str, str],
label_log_posterior_scale: float = 1.0,
label_prior_scale: float = 0.7,
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
label_log_prior_file: tk.Path = None,
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
apply_log_on_output: bool = False,
negate_output: bool = True,
intra_op_threads: int = 1,
inter_op_threads: int = 1,
**kwargs,
):
super().__init__()
"""
:param mixtures: path to a *.mix file e.g. output of either EstimateMixturesJob or CreateDummyMixturesJob
:param model: path of a model e.g. output of ExportPyTorchModelToOnnxJob
:param io_map: mapping between internal rasr identifiers and the model related input/output. Default key values
are "features" and "output", and optionally "features-size", e.g.
io_map = {"features": "data", "output": "classes"}
:param label_log_posterior_scale: scales for the log probability of a label e.g. 1.0 is recommended
:param label_prior_scale: scale for the prior log probability of a label reasonable e.g. values in [0.1, 0.7] interval
:param label_log_prior_file: xml file containing log prior probabilities e.g. estimated from the model via povey method
:param apply_log_on_output: whether to apply the log-function on the output, usefull if the model outputs softmax instead of log-softmax
:param negate_output: whether negate output (because the model outputs log softmax and not negative log softmax
:param intra_op_threads: Onnxruntime session's number of parallel threads within each operator
:param inter_op_threads: Onnxruntime session's number of parallel threads between operators used only for parallel execution mode
"""
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved

self.config = RasrConfig()
self.config.feature_scorer_type = "tf-label-context-scorer"
self.config.file = prior_mixtures
self.config.num_label_contexts = 46
self.config.prior_scale = prior_scale
self.config.context_prior = contextPriorFile
self.config.diphone_prior = diphonePriorFile
self.config.normalize_mixture_weights = False
self.config.loader = fs_tf_config.loader
self.config.input_map = fs_tf_config.input_map
self.config.output_map = fs_tf_config.output_map
self.config.feature_scorer_type = "onnx-feature-scorer"
self.config.file = mixtures
self.config.scale = label_log_posterior_scale
self.config.priori_scale = label_prior_scale
if label_log_prior_file is not None:
self.config.prior_file = label_log_prior_file

self.config.session.file = model

if apply_log_on_output:
self.config.apply_log_on_output = apply_log_on_output
if not negate_output:
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
self.config.negate_output = negate_output

self.post_config.session.intra_op_num_threads = intra_op_threads
self.post_config.session.inter_op_num_threads = inter_op_threads

for k, v in io_map.items():
self.config.io_map[k] = v
Loading