diff --git a/tuning/autotune.py b/tuning/autotune.py index a6e4ac8..5139117 100755 --- a/tuning/autotune.py +++ b/tuning/autotune.py @@ -144,6 +144,11 @@ def get_exe_format(self, path: Path) -> str: return f"./{path.as_posix()}" +class CompilationMode(Enum): + DEFAULT = "default" + WINOGRAD = "winograd" + + @dataclass class TaskTuple: args: argparse.Namespace @@ -353,7 +358,7 @@ def parse_arguments() -> argparse.Namespace: # Required arguments parser.add_argument( - "mode", choices=["default", "winograd"], help="Compilation mode" + "mode", choices=[m.value for m in CompilationMode], help="Compilation mode" ) parser.add_argument( "input_file", type=Path, help="Path to the input benchmark file (.mlir)"