diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 13753a6d5949..2b4c07cf76b3 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -30,6 +30,11 @@ TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent +# Safely load fast C Yaml loader if it is are available +try: + from yaml import CSafeLoader as Loader +except ImportError: + from yaml import SafeLoader as Loader #type:ignore[assignment, misc] def reindent(text, prefix=""): return indent(dedent(text), prefix) @@ -175,7 +180,7 @@ def generate_native_functions(self): ) ts_native_yaml = None if ts_native_yaml_path.exists(): - ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader) + ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), Loader) else: logging.warning( f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}" @@ -208,7 +213,7 @@ def get_opnames(ops): ) with self.config_path.open() as f: - config = yaml.load(f, yaml.CLoader) + config = yaml.load(f, Loader) # List of unsupported ops in LTC autogen because of some error blacklist = set(config.get("blacklist", []))