diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 1ecd55c02c..fca8581f31 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ 3.7 ] + python-version: [ 3.8 ] services: plantuml: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 198119f41d..fa5b65194b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,6 +36,8 @@ jobs: numpy-version: '1.22.0' - python-version: '3.7' numpy-version: '1.24.0' + - python-version: '3.9' + numpy-version: '1.18.0' - python-version: '3.10' numpy-version: '1.18.0' - python-version: '3.11' diff --git a/requirements-doc.txt b/requirements-doc.txt index c836abd335..6e7c6200de 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,7 +1,7 @@ Jinja2~=3.0.0 sphinx~=3.2.0 sphinx_rtd_theme~=0.4.3 -enum_tools +enum_tools~=0.9.0 sphinx-toolbox plantumlcli>=0.0.2 packaging diff --git a/test/torch/funcs/test_construct.py b/test/torch/funcs/test_construct.py index 6ea120b858..ec037ece6d 100644 --- a/test/torch/funcs/test_construct.py +++ b/test/torch/funcs/test_construct.py @@ -190,6 +190,50 @@ def test_randn_like(self): } }) + @choose_mark() + def test_rand(self): + _target = ttorch.rand(200, 300) + assert 0.45 <= _target.view(60000).mean().tolist() <= 0.55 + assert _target.shape == torch.Size([200, 300]) + + _target = ttorch.rand({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }) + assert _target.shape == ttorch.Size({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + @choose_mark() + def test_rand_like(self): + _target = ttorch.rand_like(torch.ones(200, 300)) + assert 0.45 <= _target.view(60000).mean().tolist() <= 0.55 + assert _target.shape == torch.Size([200, 300]) + + _target = ttorch.rand_like({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float), + 'x': { + 'c': torch.tensor([5, 6, 7], dtype=torch.float64), + 'd': torch.tensor([[[8, 9]]], dtype=torch.float32), + } + }) + assert _target.shape == ttorch.Size({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + @choose_mark() def test_randint(self): _target = ttorch.randint(-10, 10, { diff --git a/test/torch/funcs/test_wrapper.py b/test/torch/funcs/test_wrapper.py new file mode 100644 index 0000000000..dc147a4539 --- /dev/null +++ b/test/torch/funcs/test_wrapper.py @@ -0,0 +1,106 @@ +from unittest import skipUnless + +import pytest +import torch +from hbutils.testing import vpip + +import treetensor.torch as ttorch +from treetensor.torch import Size + + +@pytest.fixture() +def treetensor_x(): + return ttorch.randn({ + 'a': (2, 5, 7), + 'b': { + 'x': (3, 4, 6), + } + }) + + +@pytest.fixture() +def treetensor_y(): + return ttorch.randn({ + 'a': (2, 5, 7), + 'b': { + 'x': (3, 4, 6), + } + }) + + +@pytest.mark.unittest +class TestTorchTensorWrapper: + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + native_vf = torch.vmap(f) + tv_vf = ttorch.vmap(f) + r = tv_vf(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (2,), + 'b': { + 'x': (3,) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': native_vf(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': native_vf(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap_in_dims(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + native_vf = torch.vmap(f, in_dims=1) + tv_vf = ttorch.vmap(f, in_dims=1) + r = tv_vf(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (5,), + 'b': { + 'x': (4,) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': native_vf(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': native_vf(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap_nested(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + native_vf = torch.vmap(torch.vmap(f)) + tv_vf = ttorch.vmap(ttorch.vmap(f)) + r = tv_vf(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (2, 5), + 'b': { + 'x': (3, 4) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': native_vf(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': native_vf(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') < '2', 'Torch 1.x required.') + def test_vmap_torch_1x(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + with pytest.raises(NotImplementedError): + _ = ttorch.vmap(f) diff --git a/treetensor/torch/funcs/__init__.py b/treetensor/torch/funcs/__init__.py index 98b029bf89..51a4c89230 100644 --- a/treetensor/torch/funcs/__init__.py +++ b/treetensor/torch/funcs/__init__.py @@ -14,6 +14,8 @@ from .operation import __all__ as _operation_all from .reduction import * from .reduction import __all__ as _reduction_all +from .wrapper import * +from .wrapper import __all__ as _wrapper_all from ...utils import module_autoremove __all__ = [ @@ -24,6 +26,7 @@ *_matrix_all, *_operation_all, *_reduction_all, + *_wrapper_all, ] _current_module = sys.modules[__name__] diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index d5ee52a939..1bf21a72dc 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,4 +1,7 @@ +from functools import wraps + import torch +from hbutils.testing import vpip from treevalue import func_treelize as original_func_treelize from ..tensor import Tensor @@ -11,3 +14,17 @@ auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)]) get_func_from_torch = module_func_loader(torch, Tensor, [(torch.is_tensor, Tensor)]) + +_is_torch_2 = vpip('torch') >= '2' + + +def wrap_for_treelize(*args, **kwargs): + def _decorator(func): + @wraps(func) + def _new_func(*args_, **kwargs_): + retval = func(*args_, **kwargs_) + return func_treelize(*args, **kwargs)(retval) + + return _new_func + + return _decorator diff --git a/treetensor/torch/funcs/construct.py b/treetensor/torch/funcs/construct.py index ef166e9787..a0062ccfe3 100644 --- a/treetensor/torch/funcs/construct.py +++ b/treetensor/torch/funcs/construct.py @@ -10,6 +10,7 @@ 'tensor', 'as_tensor', 'clone', 'zeros', 'zeros_like', 'randn', 'randn_like', + 'rand', 'rand_like', 'randint', 'randint_like', 'ones', 'ones_like', 'full', 'full_like', @@ -216,6 +217,62 @@ def randn_like(input, *args, **kwargs): return stream_call(torch.randn_like, input, *args, **kwargs) +@doc_from_base() +@args_treelize +@func_treelize() +def rand(*args, **kwargs): + """ + In ``treetensor``, you can use ``rand`` to create a tree of tensors with numbers + obey standard normal distribution. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.rand(2, 3) # the same as torch.rand(2, 3) + tensor([[-0.8534, -0.5754, -0.2507], + [ 0.0826, -1.4110, 0.9748]]) + + >>> ttorch.rand({'a': (2, 3), 'b': {'x': (4, )}}) + + ├── a --> tensor([[ 0.5398, 0.7529, -2.0339], + │ [-0.5722, -1.1900, 0.7945]]) + └── b --> + └── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129]) + """ + return stream_call(torch.rand, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from_base() +@args_treelize +@func_treelize() +def rand_like(input, *args, **kwargs): + """ + In ``treetensor``, you can use ``rand_like`` to create a tree of tensors with numbers + obey standard normal distribution like another tree. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.rand_like(torch.ones(2, 3)) # the same as torch.rand_like(torch.ones(2, 3)) + tensor([[ 1.8436, 0.2601, 0.9687], + [ 1.6430, -0.1765, -1.1732]]) + + >>> ttorch.rand_like({ + ... 'a': torch.ones(2, 3), + ... 'b': {'x': torch.ones(4, )}, + ... }) + + ├── a --> tensor([[-0.1532, 1.3965, -1.2956], + │ [-0.0750, 0.6475, 1.1421]]) + └── b --> + └── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022]) + """ + return stream_call(torch.rand_like, input, *args, **kwargs) + + @doc_from_base() @args_treelize @func_treelize() diff --git a/treetensor/torch/funcs/wrapper.py b/treetensor/torch/funcs/wrapper.py new file mode 100644 index 0000000000..2f442c6d99 --- /dev/null +++ b/treetensor/torch/funcs/wrapper.py @@ -0,0 +1,21 @@ +import torch + +from .base import doc_from_base, wrap_for_treelize, _is_torch_2 + +__all__ = [ + 'vmap', +] + +if _is_torch_2: + @doc_from_base() + @wrap_for_treelize() + def vmap(func, *args, **kwargs): + return torch.vmap(func, *args, **kwargs) + +else: + def vmap(func, *args, **kwargs): + """ + .. warning: + :method:`treetensor.torch.vmap` is not supported for torch 1.x. + """ + raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.')