diff --git a/i6_models/parts/factored_hybrid/util.py b/i6_models/parts/factored_hybrid/util.py index 8b7f33b3..48d88ae9 100644 --- a/i6_models/parts/factored_hybrid/util.py +++ b/i6_models/parts/factored_hybrid/util.py @@ -14,6 +14,17 @@ class PhonemeStateClassV1(Enum): def factor(self): return self.value + @staticmethod + def from_flags(cls, use_word_end_classes: bool, use_boundary_classes: bool) -> "PhonemeStateClassV1": + assert not (use_word_end_classes and use_boundary_classes), "cannot use both classes" + + if use_boundary_classes: + return cls.boundary + elif use_word_end_classes: + return cls.word_end + else: + return cls.none + def get_center_dim( n_contexts: int,