Skip to content

Commit

Permalink
Add make_pytorch_cuda_operator to make creating PyTorch operators s…
Browse files Browse the repository at this point in the history
…eamless

__original_commit__ = fairinternal/xformers@dab6e3e2137448247757c3e806c3861c666c15f2
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jul 10, 2023
1 parent f7aeb35 commit b31f4a1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
5 changes: 5 additions & 0 deletions xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def _is_triton_available():
return False


@compute_once
def get_python_lib():
return torch.library.Library("xformers_python", "DEF")


if _is_functorch_available:
try:
from xformers.components.nvfuser import NVFusedBiasActivationDropout # noqa
Expand Down
36 changes: 36 additions & 0 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from typing import Any, Dict, List, Type, TypeVar

import torch
Expand Down Expand Up @@ -63,3 +64,38 @@ def register_operator(cls: ClsT) -> ClsT:

def _get_storage_base(x: torch.Tensor) -> int:
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore


def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
from .. import get_python_lib

def render_arg_type(annotation) -> str:
if annotation is torch.Tensor:
return "Tensor"
if annotation is bool:
return "bool"
if annotation is int:
return "int"
if annotation is List[int]:
return "int[]"
if annotation is List[torch.Tensor]:
return "Tensor[]"
assert False, f"Unable to parse annotation: `{annotation}`"

sign = inspect.signature(fn) # type: ignore
arguments = [
f"{render_arg_type(arg.annotation)} {arg.name}"
for arg in sign.parameters.values()
]
op_name = fn.__name__ # type: ignore
definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}"

xformers_lib = get_python_lib()
xformers_lib.define(definition)
xformers_lib.impl(op_name, fn, "CUDA")
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)

def wrapper(*args, **kwargs):
return dispatcher_impl(*args, **kwargs)

return wrapper # type: ignore

0 comments on commit b31f4a1

Please sign in to comment.