Skip to content

Commit

Permalink
[custom ops] Add register_vmap for custom ops (pytorch#130589)
Browse files Browse the repository at this point in the history
Fixes pytorch#130284
Fixes pytorch#130653

- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.

Pull Request resolved: pytorch#130589
Approved by: https://github.com/zou3519
  • Loading branch information
yushangdi authored and pytorchmergebot committed Jul 23, 2024
1 parent 1e5ecc4 commit 074b420
Show file tree
Hide file tree
Showing 9 changed files with 683 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/source/library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ via PyTorch's C++ operator registration APIs).
.. autofunction:: register_kernel
.. autofunction:: register_autograd
.. autofunction:: register_fake
.. autofunction:: register_vmap
.. autofunction:: impl_abstract
.. autofunction:: get_ctx
.. autofunction:: register_torch_dispatch
Expand Down
30 changes: 25 additions & 5 deletions test/functorch/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from torch.testing._internal.common_device_type import toleranceOverride
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
from torch.testing._internal.common_modules import module_db
from torch.testing._internal.custom_op_db import custom_op_db



IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1"
Expand All @@ -38,8 +40,26 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
flat_out, out_spec = pytree.tree_flatten(out)
outs.append(flat_out)

# use the same out_dim for all outputs
if isinstance(out_dim, int):
flat_out_dim = [out_dim for _ in flat_out]
else:
flat_out_dim, _ = pytree.tree_flatten(out_dim)

outs = zip(*outs)
result = [torch.stack(out_lst) for out_lst in outs]

result = []
for i, out_lst in enumerate(outs):
if flat_out_dim[i] is not None:
if not all(isinstance(x, torch.Tensor) for x in out_lst):
raise ValueError(
f"vmap `{op}` must only return "
"Tensors. Did you mean to set out_dims= to None for output?"
)
result.append(torch.stack(out_lst))
else:
# not batched over, result should be the same for all batches
result.append(out_lst[0])
return pytree.tree_unflatten(result, out_spec)


Expand Down Expand Up @@ -317,9 +337,9 @@ def f(dummy, *args, **kwargs):
inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
outer_in_dims = (0,) + in_dims
batched_args, kwarg_values = maybe_clone_inputs()
vmapvmap_output = vmap(vmap(f, inner_in_dims), outer_in_dims)(
dummy, *batched_args, **kwarg_values
)
vmapvmap_output = vmap(
vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim
)(dummy, *batched_args, **kwarg_values)

yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected)

Expand Down Expand Up @@ -440,7 +460,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):


def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db + additional_op_db + autograd_function_db
all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db
for decorate_meta in to_skip:
matching_opinfos = [
o
Expand Down
22 changes: 22 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,28 @@ def vmap(info, in_dims, input):
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
result = vmap(Zeros.apply)(x)

def test_kwarg_only_tensors(self, device):
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):

class MyClass(torch.autograd.Function):
@staticmethod
def forward(x, *, y):
return x + y

@staticmethod
def setup_context(ctx, inputs, output):
pass

@staticmethod
def vmap(info, in_dims, x, *, y):
assert in_dims == (0,)
return x + y, 0

x = torch.randn(3)
y = torch.randn(3)

vmap(MyClass.apply)(x, y=y)


@markDynamoStrictTest
class TestVmapOfGrad(TestCase):
Expand Down
47 changes: 41 additions & 6 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
unMarkDynamoStrictTest,
xfailIfTorchDynamo,
)

from torch.testing._internal.custom_op_db import custom_op_db

from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -3937,10 +3940,17 @@ def discover_variants(opinfo):
@unMarkDynamoStrictTest
class TestVmapOperatorsOpInfo(TestCase):
def vmap_outplace_test(
self, func, args, kwargs, in_dims, check_shape_only=False, postprocess_fn=None
self,
func,
args,
kwargs,
in_dims,
check_shape_only=False,
postprocess_fn=None,
out_dim=0,
):
for vmap_out, loop_out in compute_quantities_for_vmap_test(
func, args, kwargs, in_dims
func, args, kwargs, in_dims, out_dim=out_dim
):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
Expand All @@ -3950,7 +3960,9 @@ def vmap_outplace_test(
continue
self.assertEqual(vmap_out, loop_out)

def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None):
def vmap_inplace_test(
self, func, args, kwargs, in_dims, postprocess_fn=None, out_dim=0
):
# NB: This test assumes that the first argument is being modified.
# This is OK because it's what every other OpInfo-based test assumes,
# but it is going to need a more robust solution eventually.
Expand All @@ -3963,13 +3975,19 @@ def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None):
args,
kwargs,
in_dims,
out_dim=out_dim,
compute_loop_out=False,
clone_inputs=True,
):
pass
return
for vmap_out, loop_out in compute_quantities_for_vmap_test(
func, args, kwargs, in_dims, clone_inputs=True
func,
args,
kwargs,
in_dims,
clone_inputs=True,
out_dim=out_dim,
):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
Expand Down Expand Up @@ -4027,6 +4045,13 @@ def test():
continue
kwargs = sample_input.kwargs
is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
out_dim = 0
if op.name == "NumpySplitCopyWithIntCustomOp":
# special case for this custom op
def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim):
return [0 for _ in range(len(splits) + 1)], None

out_dim = sample_vmap_out_dim_numpy_split_copy_with_int(*args)
for batched_args, in_dims, _ in generate_vmap_inputs(
args, {}, is_batch_norm_and_training=is_batch_norm_and_training
):
Expand All @@ -4038,6 +4063,7 @@ def test():
in_dims,
check_shape_only,
postprocess_fn,
out_dim=out_dim,
)
if op.name in skip_inplace:
continue
Expand Down Expand Up @@ -4109,6 +4135,9 @@ def test():
"linalg.eigh", ""
), # not always return the same result for the same input, see test_linalg_eigh for manual test
skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format
# UnimplementedError: data-dependent operators cannot be vmapped
xfail("NumpyNonzeroCustomOp"),
xfail("NumpyNMSCustomOp"),
# ----------------------------------------------------------------------
# ---------------------------- BUGS ------------------------------------
# entries in here don't work and need to be fixed.
Expand Down Expand Up @@ -4187,7 +4216,10 @@ def test():
}

@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
@ops(
op_db + additional_op_db + autograd_function_db + custom_op_db,
dtypes=OpDTypes.any_one,
)
@opsToleranceOverride(
"TestVmapOperatorsOpInfo",
"test_vmap_exhaustive",
Expand Down Expand Up @@ -4248,7 +4280,10 @@ def test_vmap_exhaustive(self, device, dtype, op):
)

@with_tf32_off
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
@ops(
op_db + additional_op_db + autograd_function_db + custom_op_db,
dtypes=OpDTypes.any_one,
)
@opsToleranceOverride(
"TestVmapOperatorsOpInfo",
"test_op_has_batch_rule",
Expand Down
Loading

0 comments on commit 074b420

Please sign in to comment.