Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx Feature Scorer #452

Merged
merged 12 commits into from
Sep 21, 2023
63 changes: 45 additions & 18 deletions rasr/feature_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"ReturnnScorer",
"InvAlignmentPassThroughFeatureScorer",
"PrecomputedHybridFeatureScorer",
"OnnxFeatureScorer",
]

from sisyphus import *
Expand Down Expand Up @@ -126,25 +127,51 @@ def __init__(self, prior_mixtures, scale=1.0, priori_scale=0.0, prior_file=None)
self.config.normalize_mixture_weights = False


class TFLabelContextFeatureScorer(FeatureScorer):
class OnnxFeatureScorer(FeatureScorer):
def __init__(
Atticus1806 marked this conversation as resolved.
Show resolved Hide resolved
self,
fs_tf_config,
contextPriorFile,
diphonePriorFile,
prior_mixtures,
prior_scale,
mixtures,
model,
io_map,
*args,
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
label_log_posterior_scale=1.0,
label_prior_scale=0.7,
label_log_prior_file=None,
apply_log_on_output=False,
negate_output=True,
intra_op_threads=1,
inter_op_threads=1,
**kwargs,
):
super().__init__()
"""
:param str mixtures: path to a *.mix file e.g. output of either EstimateMixturesJob or CreateDummyMixturesJob
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
:param str model: path of a model e.g. output of ExportPyTorchModelToOnnxJob
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
:param dict io_map: mapping between internal rasr identifiers and the model related input/output. Default key values
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
are "features" and "output", and optionally "features-size"
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
:param float label_log_posterior_scale: scales for the log probability of a label e.g. 1.0 is recommended
:param float label_prior_scale: scale for the prior log probability of a label reasonable e.g. values in [0.1, 0.7] interval
:param str label_log_prior_file: xml file containing log prior probabilities e.g. estimated from the model via povey method
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
:param bool apply_log_on_output: whether to apply the log-function on the output, usefull if the model outputs softmax instead of log-softmax
:param bool negate_output: wheter negate output (because the model outputs log softmax and not negative log softmax
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
"""
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved

self.config = RasrConfig()
self.config.feature_scorer_type = "tf-label-context-scorer"
self.config.file = prior_mixtures
self.config.num_label_contexts = 46
self.config.prior_scale = prior_scale
self.config.context_prior = contextPriorFile
self.config.diphone_prior = diphonePriorFile
self.config.normalize_mixture_weights = False
self.config.loader = fs_tf_config.loader
self.config.input_map = fs_tf_config.input_map
self.config.output_map = fs_tf_config.output_map
self.config.feature_scorer_type = "onnx-feature-scorer"
self.config.file = mixtures
self.config.scale = label_log_posterior_scale
self.config.priori_scale = label_prior_scale
if label_log_prior_file is not None:
self.config.prior_file = label_log_prior_file

self.config.session.file = model

if label_log_prior_file:
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
self.config.apply_log_on_output = apply_log_on_output
if not negate_output:
Marvin84 marked this conversation as resolved.
Show resolved Hide resolved
self.config.negate_output = negate_output

self.post_config.session.intra_op_num_threads = intra_op_threads
self.post_config.session.inter_op_num_threads = inter_op_threads

for k, v in io_map.items():
self.config.io_map[k] = v
Loading