From 0b75bcf7fdd38b8fed57fb533c6783a3296a42ed Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Mon, 19 Aug 2024 14:11:26 +0000 Subject: [PATCH] Remove tooling to auto-register operators ghstack-source-id: 457565b18dfd3a7958402320751d1b539c684bb9 Pull Request resolved: https://github.com/fairinternal/xformers/pull/1198 __original_commit__ = fairinternal/xformers@79e7111d0f6d6dbec833578446a08dd5a3b222f6 --- xformers/ops/common.py | 98 +----------------------------------------- 1 file changed, 1 insertion(+), 97 deletions(-) diff --git a/xformers/ops/common.py b/xformers/ops/common.py index 420b0f82b..7dab42caa 100644 --- a/xformers/ops/common.py +++ b/xformers/ops/common.py @@ -3,13 +3,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -import inspect -from dataclasses import dataclass -from functools import wraps -from typing import Any, Callable, Dict, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar import torch -from typing_extensions import Annotated, get_args, get_origin def get_operator(library: str, name: str): @@ -71,95 +67,3 @@ def register_operator(cls: ClsT) -> ClsT: def _get_storage_base(x: torch.Tensor) -> int: return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore - - -@dataclass(frozen=True) -class Alias: - name: str - write: bool - - -def make_pytorch_cuda_operator(fn: ClsT) -> ClsT: - return turn_into_pytorch_op(fn, "CUDA") - - -def make_pytorch_operator_for_dispatch_key(dispatch_key: str) -> Callable[[ClsT], ClsT]: - def decorator(fn: ClsT) -> ClsT: - return turn_into_pytorch_op(fn, dispatch_key) - - return decorator - - -def turn_into_pytorch_op(fn: ClsT, dispatch_key: str) -> ClsT: - from .. import get_python_lib - - def render_arg_type(annotation) -> str: - # Optional[T] is an alias for Union[T, None] - if get_origin(annotation) is Union: - inner_types = [ - t for t in get_args(annotation) if t is not type(None) # noqa: E721 - ] - if len(inner_types) == 1: - return f"{render_arg_type(inner_types[0])}?" - if get_origin(annotation) is list: - (inner_type,) = get_args(annotation) - return f"{render_arg_type(inner_type)}[]" - if get_origin(annotation) is tuple: - return ( - "(" - + ", ".join([render_arg_type(t) for t in get_args(annotation)]) - + ")" - ) - if get_origin(annotation) is Annotated: - inner_type, annotation = get_args(annotation) - if isinstance(annotation, Alias): - alias = annotation.name + ("!" if annotation.write else "") - return f"{render_arg_type(inner_type)}({alias})" - if annotation is torch.Tensor: - return "Tensor" - if annotation is bool: - return "bool" - if annotation is int: - return "int" - if annotation is float: - return "float" - if annotation is torch.dtype: - return "ScalarType" - if annotation is torch.distributed.ProcessGroup: - return "__torch__.torch.classes.c10d.ProcessGroup" - assert False, f"Unable to parse annotation: `{annotation}`" - - def render_default_value(default): - if default is inspect.Parameter.empty: - return "" - return f" = {default!r}" - - sign = inspect.signature(fn) # type: ignore - arguments = [ - f"{render_arg_type(arg.annotation)} {arg.name}{render_default_value(arg.default)}" - for arg in sign.parameters.values() - ] - op_name = fn.__name__ # type: ignore - definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}" - - def callee(*args, **kwargs): - ba = sign.bind(*args, **kwargs) - for name, value in ba.arguments.items(): - if sign.parameters[name].annotation is torch.distributed.ProcessGroup: - ba.arguments[name] = torch.distributed.ProcessGroup.unbox(value) - return fn(*ba.args, **ba.kwargs) - - xformers_lib = get_python_lib() - xformers_lib.define(definition) - xformers_lib.impl(op_name, callee, dispatch_key) - dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name) - - @wraps(fn) # type: ignore[arg-type] - def caller(*args, **kwargs): - ba = sign.bind(*args, **kwargs) - for name, value in ba.arguments.items(): - if sign.parameters[name].annotation is torch.distributed.ProcessGroup: - ba.arguments[name] = value.boxed() - return dispatcher_impl(*ba.args, **ba.kwargs) - - return caller # type: ignore