Skip to content

Commit

Permalink
Remove tooling to auto-register operators
Browse files Browse the repository at this point in the history
ghstack-source-id: 457565b18dfd3a7958402320751d1b539c684bb9
Pull Request resolved: fairinternal/xformers#1198

__original_commit__ = fairinternal/xformers@79e7111
  • Loading branch information
lw authored and xFormers Bot committed Aug 19, 2024
1 parent cb4945e commit 0b75bcf
Showing 1 changed file with 1 addition and 97 deletions.
98 changes: 1 addition & 97 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Type, TypeVar, Union
from typing import Any, Dict, List, Type, TypeVar

import torch
from typing_extensions import Annotated, get_args, get_origin


def get_operator(library: str, name: str):
Expand Down Expand Up @@ -71,95 +67,3 @@ def register_operator(cls: ClsT) -> ClsT:

def _get_storage_base(x: torch.Tensor) -> int:
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore


@dataclass(frozen=True)
class Alias:
name: str
write: bool


def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
return turn_into_pytorch_op(fn, "CUDA")


def make_pytorch_operator_for_dispatch_key(dispatch_key: str) -> Callable[[ClsT], ClsT]:
def decorator(fn: ClsT) -> ClsT:
return turn_into_pytorch_op(fn, dispatch_key)

return decorator


def turn_into_pytorch_op(fn: ClsT, dispatch_key: str) -> ClsT:
from .. import get_python_lib

def render_arg_type(annotation) -> str:
# Optional[T] is an alias for Union[T, None]
if get_origin(annotation) is Union:
inner_types = [
t for t in get_args(annotation) if t is not type(None) # noqa: E721
]
if len(inner_types) == 1:
return f"{render_arg_type(inner_types[0])}?"
if get_origin(annotation) is list:
(inner_type,) = get_args(annotation)
return f"{render_arg_type(inner_type)}[]"
if get_origin(annotation) is tuple:
return (
"("
+ ", ".join([render_arg_type(t) for t in get_args(annotation)])
+ ")"
)
if get_origin(annotation) is Annotated:
inner_type, annotation = get_args(annotation)
if isinstance(annotation, Alias):
alias = annotation.name + ("!" if annotation.write else "")
return f"{render_arg_type(inner_type)}({alias})"
if annotation is torch.Tensor:
return "Tensor"
if annotation is bool:
return "bool"
if annotation is int:
return "int"
if annotation is float:
return "float"
if annotation is torch.dtype:
return "ScalarType"
if annotation is torch.distributed.ProcessGroup:
return "__torch__.torch.classes.c10d.ProcessGroup"
assert False, f"Unable to parse annotation: `{annotation}`"

def render_default_value(default):
if default is inspect.Parameter.empty:
return ""
return f" = {default!r}"

sign = inspect.signature(fn) # type: ignore
arguments = [
f"{render_arg_type(arg.annotation)} {arg.name}{render_default_value(arg.default)}"
for arg in sign.parameters.values()
]
op_name = fn.__name__ # type: ignore
definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}"

def callee(*args, **kwargs):
ba = sign.bind(*args, **kwargs)
for name, value in ba.arguments.items():
if sign.parameters[name].annotation is torch.distributed.ProcessGroup:
ba.arguments[name] = torch.distributed.ProcessGroup.unbox(value)
return fn(*ba.args, **ba.kwargs)

xformers_lib = get_python_lib()
xformers_lib.define(definition)
xformers_lib.impl(op_name, callee, dispatch_key)
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)

@wraps(fn) # type: ignore[arg-type]
def caller(*args, **kwargs):
ba = sign.bind(*args, **kwargs)
for name, value in ba.arguments.items():
if sign.parameters[name].annotation is torch.distributed.ProcessGroup:
ba.arguments[name] = value.boxed()
return dispatcher_impl(*ba.args, **ba.kwargs)

return caller # type: ignore

0 comments on commit 0b75bcf

Please sign in to comment.