From bb737504e3cd8692a8b6ccd03503d7c8b0894bf6 Mon Sep 17 00:00:00 2001 From: inaomIIsFarell <1344594208@qq.com> Date: Thu, 26 Sep 2024 16:42:48 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E3=80=90Hackathon=207th=20No.38=E3=80=91?= =?UTF-8?q?=E4=B8=BA=20Paddle=20=E4=BB=A3=E7=A0=81=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=96=B0=E5=A2=9E=20API=20=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E8=A7=84=E5=88=99=EF=BC=88=E7=AC=AC5=E7=BB=84=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paconvert/api_alias_mapping.json | 1 + paconvert/api_mapping.json | 165 +++++++++++++++++- paconvert/api_matcher.py | 258 +++++++++++++++++++++++++++- tests/test_Tensor_isneginf.py | 51 ++++++ tests/test_Tensor_isposinf.py | 51 ++++++ tests/test_Tensor_isreal.py | 52 ++++++ tests/test_Tensor_positive.py | 41 +++++ tests/test_Tensor_scatter_reduce.py | 71 ++++++-- tests/test_block_diag.py | 73 ++++++++ tests/test_can_cast.py | 29 +++- tests/test_cartesian_prod.py | 54 ++++++ tests/test_float_power.py | 103 +++++++++++ tests/test_isin.py | 47 ++--- tests/test_isneginf.py | 69 ++++++++ tests/test_isposinf.py | 69 ++++++++ tests/test_isreal.py | 50 ++++++ tests/test_positive.py | 16 +- tests/test_scatter_reduce.py | 71 ++++++-- 18 files changed, 1173 insertions(+), 98 deletions(-) create mode 100644 tests/test_Tensor_isneginf.py create mode 100644 tests/test_Tensor_isposinf.py create mode 100644 tests/test_Tensor_isreal.py create mode 100644 tests/test_Tensor_positive.py create mode 100644 tests/test_block_diag.py create mode 100644 tests/test_cartesian_prod.py create mode 100644 tests/test_float_power.py create mode 100644 tests/test_isneginf.py create mode 100644 tests/test_isposinf.py create mode 100644 tests/test_isreal.py diff --git a/paconvert/api_alias_mapping.json b/paconvert/api_alias_mapping.json index 563f49f17..4ab131099 100644 --- a/paconvert/api_alias_mapping.json +++ b/paconvert/api_alias_mapping.json @@ -48,6 +48,7 @@ "torch.bilinear": "torch.nn.functional.bilinear", "torch.celu_": "torch.nn.functional.celu_", "torch.channel_shuffle": "torch.nn.functional.channel_shuffle", + "torch.concatenate": "torch.cat", "torch.clip": "torch.clamp", "torch.conv1d": "torch.nn.functional.conv1d", "torch.conv2d": "torch.nn.functional.conv2d", diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 97cc4aea7..36fad504e 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -1999,9 +1999,21 @@ "paddle_api": "paddle.Tensor.isnan", "min_input_args": 0 }, - "torch.Tensor.isneginf": {}, - "torch.Tensor.isposinf": {}, - "torch.Tensor.isreal": {}, + "torch.Tensor.isneginf": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.isneginf", + "min_input_args": 0 + }, + "torch.Tensor.isposinf": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.isposinf", + "min_input_args": 0 + }, + "torch.Tensor.isreal": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.isreal", + "min_input_args": 0 + }, "torch.Tensor.istft": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.istft", @@ -2923,7 +2935,9 @@ "n" ] }, - "torch.Tensor.positive": {}, + "torch.Tensor.positive": { + "Matcher": "PositiveMatcher" + }, "torch.Tensor.pow": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.pow", @@ -3285,7 +3299,27 @@ "reduce": "'add'" } }, - "torch.Tensor.scatter_reduce": {}, + "torch.Tensor.scatter_reduce": { + "Matcher": "ScatterReduceMatcher", + "paddle_api": "paddle.Tensor.put_along_axis", + "min_input_args": 3, + "args_list": [ + "dim", + "index", + "src", + "reduce", + "*", + "include_self" + ], + "kwargs_change": { + "dim": "axis", + "index": "indices", + "src": "values" + }, + "paddle_default_kwargs": { + "broadcast": "False" + } + }, "torch.Tensor.scatter_reduce_": {}, "torch.Tensor.select": { "Matcher": "SelectMatcher", @@ -4885,6 +4919,17 @@ "other": "y" } }, + "torch.block_diag": { + "Matcher": "ScalableVarMatcher", + "paddle_api": "paddle.block_diag", + "min_input_args": 1, + "args_list": [ + "*tensors" + ], + "kwargs_change": { + "tensors": "inputs" + } + }, "torch.bmm": { "Matcher": "GenericMatcher", "paddle_api": "paddle.bmm", @@ -4949,6 +4994,24 @@ "boundaries": "sorted_sequence" } }, + "torch.can_cast": { + "Matcher": "CanCastMatcher", + "args_list": [ + "from_", + "to" + ] + }, + "torch.cartesian_prod": { + "Matcher": "CartesianProdMatcher", + "paddle_api": "paddle.cartesian_prod", + "min_input_args": 1, + "args_list": [ + "*tensors" + ], + "kwargs_change": { + "tensors": "x" + } + }, "torch.cat": { "Matcher": "GenericMatcher", "paddle_api": "paddle.concat", @@ -7270,6 +7333,16 @@ "axis": 0 } }, + "torch.float_power": { + "Matcher": "FloatPowerMatcher", + "min_input_args": 2, + "args_list": [ + "input", + "exponent", + "*", + "out" + ] + }, "torch.floor": { "Matcher": "GenericMatcher", "paddle_api": "paddle.floor", @@ -7915,6 +7988,22 @@ "input": "x" } }, + "torch.isin": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.isin", + "min_input_args": 2, + "args_list": [ + "elements", + "test_elements", + "*", + "assume_unique", + "invert" + ], + "kwargs_change": { + "elements": "x", + "test_elements": "test_x" + } + }, "torch.isinf": { "Matcher": "GenericMatcher", "paddle_api": "paddle.isinf", @@ -7937,6 +8026,43 @@ "input": "x" } }, + "torch.isneginf": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.isneginf", + "min_input_args": 1, + "args_list": [ + "input", + "*", + "out" + ], + "kwargs_change": { + "input": "x" + } + }, + "torch.isposinf": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.isposinf", + "min_input_args": 1, + "args_list": [ + "input", + "*", + "out" + ], + "kwargs_change": { + "input": "x" + } + }, + "torch.isreal": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.isreal", + "min_input_args": 1, + "args_list": [ + "input" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.istft": { "Matcher": "GenericMatcher", "paddle_api": "paddle.signal.istft", @@ -14002,6 +14128,12 @@ "out" ] }, + "torch.positive": { + "Matcher": "PositiveMatcher", + "args_list": [ + "input" + ] + }, "torch.pow": { "Matcher": "GenericMatcher", "paddle_api": "paddle.pow", @@ -14521,6 +14653,29 @@ "reduce": "'add'" } }, + "torch.scatter_reduce": { + "Matcher": "ScatterReduceMatcher", + "paddle_api": "paddle.put_along_axis", + "min_input_args": 4, + "args_list": [ + "input", + "dim", + "index", + "src", + "reduce", + "*", + "include_self" + ], + "kwargs_change": { + "input": "arr", + "dim": "axis", + "index": "indices", + "src": "values" + }, + "paddle_default_kwargs": { + "broadcast": "False" + } + }, "torch.searchsorted": { "Matcher": "SearchsortedMatcher", "paddle_api": "paddle.searchsorted", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index fa0dfb2be..b2276d358 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -1450,6 +1450,14 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) +class ScatterReduceMatcher(BaseMatcher): + def generate_code(self, kwargs): + reduce_mapping = {'"""sum"""': '"add"', '"""prod"""': '"multiply"'} + if "reduce" in kwargs and kwargs["reduce"] in reduce_mapping: + kwargs["reduce"] = reduce_mapping[kwargs["reduce"]] + return GenericMatcher.generate_code(self, kwargs) + + class SparseSoftmaxMatcher(BaseMatcher): def generate_code(self, kwargs): code = "" @@ -3138,6 +3146,16 @@ def generate_code(self, kwargs): return code +class CartesianProdMatcher(BaseMatcher): + def get_paddle_nodes(self, args, kwargs): + new_args = self.parse_args(args) + code = "paddle.cartesian_prod([ {}".format(new_args[0]) + for arg in new_args[1:]: + code = code + ", {}".format(arg) + code = code + "])" + return ast.parse(code).body + + class Chain_MatmulMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): if len(args) == 1 and isinstance(args[0], ast.Starred): @@ -4267,11 +4285,245 @@ def generate_code(self, kwargs): return "paddle_aux._CONVERT_SYMEIG({})".format(self.kwargs_to_str(kwargs)) -class FloatPowerMatcher(BaseMatcher): +class CanCastMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + def can_cast(from_, to): + can_cast_dict = { + paddle.bfloat16: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: False, + paddle.int8: False, + paddle.int16: False, + paddle.int32: False, + paddle.int64: False, + paddle.bool: False + }, + paddle.float16: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: False, + paddle.int8: False, + paddle.int16: False, + paddle.int32: False, + paddle.int64: False, + paddle.bool: False, + }, + paddle.float32: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: False, + paddle.int8: False, + paddle.int16: False, + paddle.int32: False, + paddle.int64: False, + paddle.bool: False, + }, + paddle.float64: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: False, + paddle.int8: False, + paddle.int16: False, + paddle.int32: False, + paddle.int64: False, + paddle.bool: False, + }, + paddle.complex64: { + paddle.bfloat16: False, + paddle.float16: False, + paddle.float32: False, + paddle.float64: False, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: False, + paddle.int8: False, + paddle.int16: False, + paddle.int32: False, + paddle.int64: False, + paddle.bool: False, + }, + paddle.complex128: { + paddle.bfloat16: False, + paddle.float16: False, + paddle.float32: False, + paddle.float64: False, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: False, + paddle.int8: False, + paddle.int16: False, + paddle.int32: False, + paddle.int64: False, + paddle.bool: False, + }, + paddle.uint8: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: False, + }, + paddle.int8: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: False, + }, + paddle.int16: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: False, + }, + paddle.int32: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: False, + }, + paddle.int64: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: False, + }, + paddle.bool: { + paddle.bfloat16: True, + paddle.float16: True, + paddle.float32: True, + paddle.float64: True, + paddle.complex64: True, + paddle.complex128: True, + paddle.uint8: True, + paddle.int8: True, + paddle.int16: True, + paddle.int32: True, + paddle.int64: True, + paddle.bool: True, + } + } + return can_cast_dict[from_][to] + setattr(paddle, 'can_cast', can_cast) + """ + ) + return CODE_TEMPLATE + def generate_code(self, kwargs): - return "{}.cast(paddle.float64).pow({})".format( - self.paddleClass, kwargs["exponent"] + self.write_aux_code() + _from_dtype = kwargs["from_"][3:-3] + _to_dtype = kwargs["to"][3:-3] + code = "paddle_aux.can_cast(paddle.{}, paddle.{})".format( + _from_dtype, _to_dtype + ) + return code + + +class PositiveMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + def positive({}): + if {}.dtype != paddle.bool: + return {} + else: + raise RuntimeError("boolean tensors is not supported.") + positive({}) + """ ) + if "input" not in kwargs: + code = API_TEMPLATE.format( + self.paddleClass, self.paddleClass, self.paddleClass, self.paddleClass + ) + else: + code = API_TEMPLATE.format( + kwargs["input"], kwargs["input"], kwargs["input"], kwargs["input"] + ) + return code + + +class FloatPowerMatcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + return "{}.cast(paddle.float64).pow({}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {})".format( + self.paddleClass, + kwargs["exponent"], + kwargs["exponent"], + kwargs["exponent"], + ) + else: + if "out" not in kwargs: + return "paddle.pow({}.cast(paddle.float64), {}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {})".format( + kwargs["input"], + kwargs["exponent"], + kwargs["exponent"], + kwargs["exponent"], + ) + else: + return "paddle.assign(paddle.pow({}.cast(paddle.float64), {}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {}), {})".format( + kwargs["input"], + kwargs["exponent"], + kwargs["exponent"], + kwargs["exponent"], + kwargs["out"], + ) class FloatPowerInplaceMatcher(BaseMatcher): diff --git a/tests/test_Tensor_isneginf.py b/tests/test_Tensor_isneginf.py new file mode 100644 index 000000000..006157b4c --- /dev/null +++ b/tests/test_Tensor_isneginf.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isneginf") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isneginf() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]) + result = input.isneginf() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, 6.9, 2]) + result = input.isneginf() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_isposinf.py b/tests/test_Tensor_isposinf.py new file mode 100644 index 000000000..3177792cf --- /dev/null +++ b/tests/test_Tensor_isposinf.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isposinf") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isposinf() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]) + result = input.isposinf() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, 6.9, 2]) + result = input.isposinf() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_isreal.py b/tests/test_Tensor_isreal.py new file mode 100644 index 000000000..1070e36b3 --- /dev/null +++ b/tests/test_Tensor_isreal.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isreal") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 1+1j, 2+0j]) + result = x.isreal() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-0., -2.1, 2.5]) + result = x.isreal() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([(-0.+1j), (-2.1+0.2j), (2.5-3.1j)]) + result = x.isreal() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_positive.py b/tests/test_Tensor_positive.py new file mode 100644 index 000000000..589c125ed --- /dev/null +++ b/tests/test_Tensor_positive.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.positive") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4., 1., 1., 16.]) + result = x.positive() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[4., 1., 1., 16.], [5., 1., 1., 17.]]) + result = x.positive() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_scatter_reduce.py b/tests/test_Tensor_scatter_reduce.py index d1383aba7..cf0cc3350 100644 --- a/tests/test_Tensor_scatter_reduce.py +++ b/tests/test_Tensor_scatter_reduce.py @@ -29,12 +29,7 @@ def test_case_1(): result = input.scatter_reduce(0, index, src, reduce="sum") """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -47,12 +42,7 @@ def test_case_2(): result = input.scatter_reduce(0, index, src, reduce="sum", include_self=False) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): @@ -65,9 +55,56 @@ def test_case_3(): result = input.scatter_reduce(0, index, src, reduce="amax") """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([1., 2., 3., 4., 5., 6.]) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + input = torch.tensor([1., 2., 3., 4.]) + result = input.scatter_reduce(0, index, src, reduce="amin") + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([1., 2., 3., 4., 5., 6.]) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + input = torch.tensor([1., 2., 3., 4.]) + result = input.scatter_reduce(0, index, src, reduce="prod") + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([[1., 2.],[3., 4.]]) + index = torch.tensor([[0, 0], [0, 0]]) + input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) + result = input.scatter_reduce(0, index, src, reduce="sum") + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([[1., 2.],[3., 4.]]) + index = torch.tensor([[0, 0], [0, 0]]) + input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) + result = input.scatter_reduce(0, index, src, reduce="prod", include_self=False) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_block_diag.py b/tests/test_block_diag.py new file mode 100644 index 000000000..a5a34428b --- /dev/null +++ b/tests/test_block_diag.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.block_diag") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[0, 1], [1, 0]]) + B = torch.tensor([[3, 4, 5], [6, 7, 8]]) + C = torch.tensor(7) + D = torch.tensor([1, 2, 3]) + E = torch.tensor([[4], [5], [6]]) + result = torch.block_diag(A, B, C, D, E) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[4], [3], [2]]) + B = torch.tensor([7, 6, 5]) + C = torch.tensor(1) + result = torch.block_diag(A, B, C) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[4], [3], [2]]) + B = torch.tensor([[5, 6], [9, 1]]) + C = torch.tensor([1, 2, 3]) + result = torch.block_diag(A, B, C) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[4], [3], [2]]) + B = torch.tensor([[5], [6]]) + C = torch.tensor([1, 2, 3]) + result = torch.block_diag(A, B, C) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_can_cast.py b/tests/test_can_cast.py index 1487616c1..40be1183f 100644 --- a/tests/test_can_cast.py +++ b/tests/test_can_cast.py @@ -23,14 +23,27 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([[0,1,2]]) - y = torch.tensor([[0],[1]]) - result = torch.can_cast(x, y) + result = torch.can_cast(torch.double, torch.float) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle not support this API now", + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.can_cast(torch.float, torch.int) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.can_cast(torch.complex64, torch.complex128) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_cartesian_prod.py b/tests/test_cartesian_prod.py new file mode 100644 index 000000000..b10d23f7b --- /dev/null +++ b/tests/test_cartesian_prod.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.cartesian_prod") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1, 2, 3]) + b = torch.tensor([5, 6]) + result = torch.cartesian_prod(a, b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1, 2, 3]) + result = torch.cartesian_prod(a) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor_a = torch.tensor([1, 2, 4, 5]) + tensor_b = torch.tensor([5, 6]) + result = torch.cartesian_prod(tensor_a, tensor_b) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_float_power.py b/tests/test_float_power.py new file mode 100644 index 000000000..7633aa479 --- /dev/null +++ b/tests/test_float_power.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.float_power") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4, 6, 7, 1]) + result = torch.float_power(x, 2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4, 6, 7, 1]) + out = torch.zeros([4, 1], dtype=torch.double) + result = torch.float_power(x, 2, out=out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4, 6, 7, 1]) + y = torch.tensor([2, -3, 4, -5]) + result = torch.float_power(x, y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4, 6, 7, 1]) + y = torch.tensor([2, -3, 4, -5]) + out = torch.zeros([4, 1], dtype=torch.double) + result = torch.float_power(x, y, out=out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]]) + result = torch.float_power(x, 2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., -2.], [2., 5.]]) + out = torch.zeros([2, 2], dtype=torch.double) + result = torch.float_power(x, 2, out=out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, -2], [2, 5]]) + y = torch.tensor([[-2, 3], [-1, 2]]) + out = torch.zeros([2, 2], dtype=torch.double) + result = torch.float_power(x, y, out=out) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_isin.py b/tests/test_isin.py index b993dd819..e5bba6837 100644 --- a/tests/test_isin.py +++ b/tests/test_isin.py @@ -27,70 +27,47 @@ def test_case_1(): result = torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - result = torch.isin(elements=torch.tensor([[1, 2], [3, 4]]), test_elements=torch.tensor([2, 3])) + result = torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3]), assume_unique=True) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - result = torch.isin(test_elements=torch.tensor([2, 3]), elements=torch.tensor([[1, 2], [3, 4]])) + result = torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3]), invert=True) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_4(): pytorch_code = textwrap.dedent( """ import torch - result = torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3]), assume_unique=False, invert=False) + result = torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3]), assume_unique=True, invert=True) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_5(): pytorch_code = textwrap.dedent( """ import torch - result = torch.isin(elements=torch.tensor([[1, 2], [3, 4]]), test_elements=torch.tensor([2, 3]), - assume_unique=False, invert=False) + x = torch.tensor([0., 1., 2.]*20).reshape([20, 3]) + test_x = torch.tensor([0., 1.]*20) + correct_result = torch.isin(x, test_x, assume_unique=False) + incorrect_result = torch.isin(x, test_x, assume_unique=True) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["correct_result", "incorrect_result"]) diff --git a/tests/test_isneginf.py b/tests/test_isneginf.py new file mode 100644 index 000000000..4cc944821 --- /dev/null +++ b/tests/test_isneginf.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.isneginf") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-float('inf'), float('inf'), 1.2]) + result = torch.isneginf(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + out = torch.tensor([False, False, False]) + x = torch.tensor([-float('inf'), float('inf'), 1.2]) + result = torch.isneginf(x, out = out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], + [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], + [-float('inf'), float('inf'), 1., 2., 4.]]) + result = torch.isneginf(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + out = torch.zeros(3, 5, dtype=torch.bool) + x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], + [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], + [-float('inf'), float('inf'), 1., 2., 4.]]) + result = torch.isneginf(x, out = out) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_isposinf.py b/tests/test_isposinf.py new file mode 100644 index 000000000..2f1ead5ee --- /dev/null +++ b/tests/test_isposinf.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.isposinf") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-float('inf'), float('inf'), 1.2]) + result = torch.isposinf(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + out = torch.tensor([False, False, False]) + x = torch.tensor([-float('inf'), float('inf'), 1.2]) + result = torch.isposinf(x, out = out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], + [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], + [-float('inf'), float('inf'), 1., 2., 4.]]) + result = torch.isposinf(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + out = torch.zeros(3, 5, dtype=torch.bool) + x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], + [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], + [-float('inf'), float('inf'), 1., 2., 4.]]) + result = torch.isposinf(x, out = out) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_isreal.py b/tests/test_isreal.py new file mode 100644 index 000000000..db941906f --- /dev/null +++ b/tests/test_isreal.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.isreal") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.isreal(torch.tensor([-0., -2.1, 2.5])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([(-0.+1j), (-2.1+0.2j), (2.5-3.1j)]) + result = torch.isreal(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_positive.py b/tests/test_positive.py index 277709b40..a9288d066 100644 --- a/tests/test_positive.py +++ b/tests/test_positive.py @@ -24,16 +24,11 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([4., 1., 1., 16.], ) + x = torch.tensor([4., 1., 1., 16.]) result = torch.positive(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -44,9 +39,4 @@ def test_case_2(): result = torch.positive(t) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_scatter_reduce.py b/tests/test_scatter_reduce.py index 626e0ae67..aa4dcdc6b 100644 --- a/tests/test_scatter_reduce.py +++ b/tests/test_scatter_reduce.py @@ -29,12 +29,7 @@ def test_case_1(): result = torch.scatter_reduce(input, 0, index, src, reduce="sum") """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -47,12 +42,7 @@ def test_case_2(): result = torch.scatter_reduce(input, 0, index, src, reduce="sum", include_self=False) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): @@ -65,9 +55,56 @@ def test_case_3(): result = torch.scatter_reduce(input, 0, index, src, reduce="amax") """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([1., 2., 3., 4., 5., 6.]) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + input = torch.tensor([1., 2., 3., 4.]) + result = torch.scatter_reduce(input, 0, index, src, reduce="amin") + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([1., 2., 3., 4., 5., 6.]) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + input = torch.tensor([1., 2., 3., 4.]) + result = torch.scatter_reduce(input, 0, index, src, reduce="prod") + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([[1., 2.],[3., 4.]]) + index = torch.tensor([[0, 0], [0, 0]]) + input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) + result = torch.scatter_reduce(input, 0, index, src, reduce="sum") + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([[1., 2.],[3., 4.]]) + index = torch.tensor([[0, 0], [0, 0]]) + input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) + result = torch.scatter_reduce(input, 0, index, src, reduce="prod", include_self=False) + """ ) + obj.run(pytorch_code, ["result"]) From a1d3c42019112a65e4a8580e3297856b6a12cbcf Mon Sep 17 00:00:00 2001 From: inaomIIsFarell <1344594208@qq.com> Date: Sat, 12 Oct 2024 18:45:01 +0800 Subject: [PATCH 2/3] fix --- paconvert/api_matcher.py | 431 ++++++++++++++-------------- tests/test_Tensor_scatter_reduce.py | 23 +- tests/test_block_diag.py | 4 +- tests/test_can_cast.py | 28 +- tests/test_cartesian_prod.py | 4 +- tests/test_float_power.py | 14 +- tests/test_isin.py | 18 +- tests/test_isneginf.py | 12 +- tests/test_isposinf.py | 12 +- tests/test_isreal.py | 13 +- tests/test_positive.py | 24 +- tests/test_scatter_reduce.py | 23 +- 12 files changed, 357 insertions(+), 249 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index b2276d358..78972431a 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -1450,11 +1450,28 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) -class ScatterReduceMatcher(BaseMatcher): +class ScatterReduceMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + def reduce_type(type): + map = {"sum": "add", "prod": "multiply"} + if type == "sum" or type == "prod": + type = map[type] + return type + setattr(paddle, 'reduce_type', reduce_type) + """ + ) + return CODE_TEMPLATE + def generate_code(self, kwargs): + allowed_reduce_type = ['"""sum"""', '"""prod"""', '"""amax"""', '"""amin"""', '"""mean"""'] reduce_mapping = {'"""sum"""': '"add"', '"""prod"""': '"multiply"'} if "reduce" in kwargs and kwargs["reduce"] in reduce_mapping: kwargs["reduce"] = reduce_mapping[kwargs["reduce"]] + elif "reduce" in kwargs and kwargs["reduce"] not in allowed_reduce_type: + self.write_aux_code() + kwargs["reduce"] = "paddle_aux.reduce_type({})".format(kwargs["reduce"]) return GenericMatcher.generate_code(self, kwargs) @@ -3149,10 +3166,7 @@ def generate_code(self, kwargs): class CartesianProdMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): new_args = self.parse_args(args) - code = "paddle.cartesian_prod([ {}".format(new_args[0]) - for arg in new_args[1:]: - code = code + ", {}".format(arg) - code = code + "])" + code = "paddle.cartesian_prod([{}])".format(", ".join(new_args)) return ast.parse(code).body @@ -4289,175 +4303,175 @@ class CanCastMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( """ - def can_cast(from_, to): + def can_cast(from_, to): can_cast_dict = { - paddle.bfloat16: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: False, - paddle.int8: False, - paddle.int16: False, - paddle.int32: False, - paddle.int64: False, - paddle.bool: False + 'bfloat16': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': False, + 'int8': False, + 'int16': False, + 'int32': False, + 'int64': False, + 'bool': False }, - paddle.float16: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: False, - paddle.int8: False, - paddle.int16: False, - paddle.int32: False, - paddle.int64: False, - paddle.bool: False, + 'float16': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': False, + 'int8': False, + 'int16': False, + 'int32': False, + 'int64': False, + 'bool': False, }, - paddle.float32: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: False, - paddle.int8: False, - paddle.int16: False, - paddle.int32: False, - paddle.int64: False, - paddle.bool: False, + 'float32': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': False, + 'int8': False, + 'int16': False, + 'int32': False, + 'int64': False, + 'bool': False, }, - paddle.float64: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: False, - paddle.int8: False, - paddle.int16: False, - paddle.int32: False, - paddle.int64: False, - paddle.bool: False, + 'float64': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': False, + 'int8': False, + 'int16': False, + 'int32': False, + 'int64': False, + 'bool': False, }, - paddle.complex64: { - paddle.bfloat16: False, - paddle.float16: False, - paddle.float32: False, - paddle.float64: False, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: False, - paddle.int8: False, - paddle.int16: False, - paddle.int32: False, - paddle.int64: False, - paddle.bool: False, + 'complex64': { + 'bfloat16': False, + 'float16': False, + 'float32': False, + 'float64': False, + 'complex64': True, + 'complex128': True, + 'uint8': False, + 'int8': False, + 'int16': False, + 'int32': False, + 'int64': False, + 'bool': False, }, - paddle.complex128: { - paddle.bfloat16: False, - paddle.float16: False, - paddle.float32: False, - paddle.float64: False, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: False, - paddle.int8: False, - paddle.int16: False, - paddle.int32: False, - paddle.int64: False, - paddle.bool: False, + 'complex128': { + 'bfloat16': False, + 'float16': False, + 'float32': False, + 'float64': False, + 'complex64': True, + 'complex128': True, + 'uint8': False, + 'int8': False, + 'int16': False, + 'int32': False, + 'int64': False, + 'bool': False, }, - paddle.uint8: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: True, - paddle.int8: True, - paddle.int16: True, - paddle.int32: True, - paddle.int64: True, - paddle.bool: False, + 'uint8': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': True, + 'int8': True, + 'int16': True, + 'int32': True, + 'int64': True, + 'bool': False, }, - paddle.int8: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: True, - paddle.int8: True, - paddle.int16: True, - paddle.int32: True, - paddle.int64: True, - paddle.bool: False, + 'int8': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': True, + 'int8': True, + 'int16': True, + 'int32': True, + 'int64': True, + 'bool': False, }, - paddle.int16: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: True, - paddle.int8: True, - paddle.int16: True, - paddle.int32: True, - paddle.int64: True, - paddle.bool: False, + 'int16': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': True, + 'int8': True, + 'int16': True, + 'int32': True, + 'int64': True, + 'bool': False, }, - paddle.int32: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: True, - paddle.int8: True, - paddle.int16: True, - paddle.int32: True, - paddle.int64: True, - paddle.bool: False, + 'int32': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': True, + 'int8': True, + 'int16': True, + 'int32': True, + 'int64': True, + 'bool': False, }, - paddle.int64: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: True, - paddle.int8: True, - paddle.int16: True, - paddle.int32: True, - paddle.int64: True, - paddle.bool: False, + 'int64': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': True, + 'int8': True, + 'int16': True, + 'int32': True, + 'int64': True, + 'bool': False, }, - paddle.bool: { - paddle.bfloat16: True, - paddle.float16: True, - paddle.float32: True, - paddle.float64: True, - paddle.complex64: True, - paddle.complex128: True, - paddle.uint8: True, - paddle.int8: True, - paddle.int16: True, - paddle.int32: True, - paddle.int64: True, - paddle.bool: True, + 'bool': { + 'bfloat16': True, + 'float16': True, + 'float32': True, + 'float64': True, + 'complex64': True, + 'complex128': True, + 'uint8': True, + 'int8': True, + 'int16': True, + 'int32': True, + 'int64': True, + 'bool': True, } } return can_cast_dict[from_][to] @@ -4465,66 +4479,61 @@ def can_cast(from_, to): """ ) return CODE_TEMPLATE - - def generate_code(self, kwargs): + def get_paddle_nodes(self, args, kwargs): self.write_aux_code() - _from_dtype = kwargs["from_"][3:-3] - _to_dtype = kwargs["to"][3:-3] - code = "paddle_aux.can_cast(paddle.{}, paddle.{})".format( - _from_dtype, _to_dtype - ) - return code + new_args = self.parse_args(args) + new_kwargs = self.parse_kwargs(kwargs) + can_cast_template = "paddle_aux.can_cast({}, {})" + from_type = new_kwargs.get("from_", new_args[0] if new_args else None) + to_type = new_kwargs.get("to", new_args[1] if len(new_args) > 1 else None) + code = can_cast_template.format(from_type, to_type) + return ast.parse(code).body class PositiveMatcher(BaseMatcher): - def generate_code(self, kwargs): - API_TEMPLATE = textwrap.dedent( - """ - def positive({}): - if {}.dtype != paddle.bool: - return {} - else: - raise RuntimeError("boolean tensors is not supported.") - positive({}) - """ + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + def positive(x): + if x.dtype != paddle.bool: + return x + else: + raise RuntimeError("boolean tensors is not supported.") + setattr(paddle, "positive", positive) + """ ) - if "input" not in kwargs: - code = API_TEMPLATE.format( - self.paddleClass, self.paddleClass, self.paddleClass, self.paddleClass - ) - else: - code = API_TEMPLATE.format( - kwargs["input"], kwargs["input"], kwargs["input"], kwargs["input"] - ) + return CODE_TEMPLATE + def generate_code(self, kwargs): + self.write_aux_code() + if "input" in kwargs and kwargs["input"] is not None: + code = "paddle_aux.positive({})".format(kwargs["input"]) + else : + code = "paddle_aux.positive({})".format(self.paddleClass) return code class FloatPowerMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + def get_exponent(exponent): + return exponent.cast(paddle.float64) if isinstance(exponent, paddle.Tensor) else exponent + setattr(paddle, "get_exponent", get_exponent) + """ + ) + return CODE_TEMPLATE + def generate_code(self, kwargs): - if "input" not in kwargs: - return "{}.cast(paddle.float64).pow({}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {})".format( - self.paddleClass, - kwargs["exponent"], - kwargs["exponent"], - kwargs["exponent"], - ) + self.write_aux_code() + pow_expression = "paddle.pow({}.cast(paddle.float64), paddle_aux.get_exponent({}))".format( + kwargs["input"], + kwargs["exponent"] + ) + if "out" in kwargs and kwargs["out"] is not None: + code = "paddle.assign({}, {})".format(pow_expression, kwargs["out"]) else: - if "out" not in kwargs: - return "paddle.pow({}.cast(paddle.float64), {}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {})".format( - kwargs["input"], - kwargs["exponent"], - kwargs["exponent"], - kwargs["exponent"], - ) - else: - return "paddle.assign(paddle.pow({}.cast(paddle.float64), {}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {}), {})".format( - kwargs["input"], - kwargs["exponent"], - kwargs["exponent"], - kwargs["exponent"], - kwargs["out"], - ) - + code = pow_expression + return code class FloatPowerInplaceMatcher(BaseMatcher): def generate_code(self, kwargs): diff --git a/tests/test_Tensor_scatter_reduce.py b/tests/test_Tensor_scatter_reduce.py index cf0cc3350..8fe9aa143 100644 --- a/tests/test_Tensor_scatter_reduce.py +++ b/tests/test_Tensor_scatter_reduce.py @@ -26,7 +26,8 @@ def test_case_1(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = input.scatter_reduce(0, index, src, reduce="sum") + type = "sum" + result = input.scatter_reduce(0, index, src, reduce=type) """ ) obj.run(pytorch_code, ["result"]) @@ -39,7 +40,8 @@ def test_case_2(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = input.scatter_reduce(0, index, src, reduce="sum", include_self=False) + re_type = "sum" + result = input.scatter_reduce(dim=0, index=index, src=src, reduce=re_type, include_self=False) """ ) obj.run(pytorch_code, ["result"]) @@ -52,7 +54,7 @@ def test_case_3(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = input.scatter_reduce(0, index, src, reduce="amax") + result = input.scatter_reduce(0, index, src, "amax") """ ) obj.run(pytorch_code, ["result"]) @@ -91,7 +93,7 @@ def test_case_6(): src = torch.tensor([[1., 2.],[3., 4.]]) index = torch.tensor([[0, 0], [0, 0]]) input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) - result = input.scatter_reduce(0, index, src, reduce="sum") + result = input.scatter_reduce(index=index, src=src, reduce="sum", dim=0) """ ) obj.run(pytorch_code, ["result"]) @@ -108,3 +110,16 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"]) + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([[1., 2.],[3., 4.]]) + index = torch.tensor([[0, 0], [0, 0]]) + input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) + re_type = "prod" + result = input.scatter_reduce(0, index, src, re_type) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_block_diag.py b/tests/test_block_diag.py index a5a34428b..463e2558f 100644 --- a/tests/test_block_diag.py +++ b/tests/test_block_diag.py @@ -41,7 +41,9 @@ def test_case_2(): A = torch.tensor([[4], [3], [2]]) B = torch.tensor([7, 6, 5]) C = torch.tensor(1) - result = torch.block_diag(A, B, C) + result = torch.block_diag(torch.tensor([[4], [3], [2]]), + torch.tensor([7, 6, 5]), + torch.tensor(1)) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_can_cast.py b/tests/test_can_cast.py index 40be1183f..d08a0396a 100644 --- a/tests/test_can_cast.py +++ b/tests/test_can_cast.py @@ -33,7 +33,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - result = torch.can_cast(torch.float, torch.int) + result = torch.can_cast(from_=torch.complex64, to=torch.complex128) """ ) obj.run(pytorch_code, ["result"]) @@ -43,7 +43,31 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - result = torch.can_cast(torch.complex64, torch.complex128) + from_type = torch.float + to_type = torch.int + result = torch.can_cast(from_type, to=to_type) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + from_type = torch.int + to_type = torch.bool + result = torch.can_cast(to=to_type, from_=from_type) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.can_cast(to=torch.bool, from_=torch.int) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_cartesian_prod.py b/tests/test_cartesian_prod.py index b10d23f7b..03e8bd636 100644 --- a/tests/test_cartesian_prod.py +++ b/tests/test_cartesian_prod.py @@ -46,9 +46,7 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - tensor_a = torch.tensor([1, 2, 4, 5]) - tensor_b = torch.tensor([5, 6]) - result = torch.cartesian_prod(tensor_a, tensor_b) + result = torch.cartesian_prod(torch.tensor([1, 2, 4, 5]), torch.tensor([5, 6]), torch.tensor([7])) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_float_power.py b/tests/test_float_power.py index 7633aa479..d5a63c31b 100644 --- a/tests/test_float_power.py +++ b/tests/test_float_power.py @@ -35,7 +35,7 @@ def test_case_2(): """ import torch x = torch.tensor([4, 6, 7, 1]) - out = torch.zeros([4, 1], dtype=torch.double) + out = torch.zeros([4], dtype=torch.double) result = torch.float_power(x, 2, out=out) """ ) @@ -58,10 +58,8 @@ def test_case_4(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([4, 6, 7, 1]) - y = torch.tensor([2, -3, 4, -5]) - out = torch.zeros([4, 1], dtype=torch.double) - result = torch.float_power(x, y, out=out) + out = torch.zeros([4], dtype=torch.double) + result = torch.float_power(torch.tensor([4, 6, 7, 1]), torch.tensor([2, -3, 4, -5]), out=out) """ ) obj.run(pytorch_code, ["result"]) @@ -72,7 +70,7 @@ def test_case_5(): """ import torch x = torch.tensor([[1, -2], [2, 5]]) - result = torch.float_power(x, 2) + result = torch.float_power(x, exponent=2) """ ) obj.run(pytorch_code, ["result"]) @@ -84,7 +82,7 @@ def test_case_6(): import torch x = torch.tensor([[1., -2.], [2., 5.]]) out = torch.zeros([2, 2], dtype=torch.double) - result = torch.float_power(x, 2, out=out) + result = torch.float_power(input=x, exponent=2, out=out) """ ) obj.run(pytorch_code, ["result"]) @@ -97,7 +95,7 @@ def test_case_7(): x = torch.tensor([[1, -2], [2, 5]]) y = torch.tensor([[-2, 3], [-1, 2]]) out = torch.zeros([2, 2], dtype=torch.double) - result = torch.float_power(x, y, out=out) + result = torch.float_power(out=out, exponent=y, input=x) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_isin.py b/tests/test_isin.py index e5bba6837..395c3f74b 100644 --- a/tests/test_isin.py +++ b/tests/test_isin.py @@ -54,13 +54,25 @@ def test_case_4(): pytorch_code = textwrap.dedent( """ import torch - result = torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3]), assume_unique=True, invert=True) + result = torch.isin(elements=torch.tensor([[1, 2], [3, 4]]), test_elements=torch.tensor([2, 3]), assume_unique=True, invert=True) """ ) obj.run(pytorch_code, ["result"]) def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + elemnts = torch.tensor([[1, 2], [3, 4]]) + test_elemnts = torch.tensor([2, 3]) + result = torch.isin(assume_unique=True, invert=True, test_elements=test_elemnts, elements=elemnts) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): pytorch_code = textwrap.dedent( """ import torch @@ -71,3 +83,7 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["correct_result", "incorrect_result"]) + + +if __name__ == "__main__": + test_case_5() \ No newline at end of file diff --git a/tests/test_isneginf.py b/tests/test_isneginf.py index 4cc944821..724af29e8 100644 --- a/tests/test_isneginf.py +++ b/tests/test_isneginf.py @@ -36,7 +36,7 @@ def test_case_2(): import torch out = torch.tensor([False, False, False]) x = torch.tensor([-float('inf'), float('inf'), 1.2]) - result = torch.isneginf(x, out = out) + result = torch.isneginf(input=x, out = out) """ ) obj.run(pytorch_code, ["result"]) @@ -46,10 +46,10 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], - [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], - [-float('inf'), float('inf'), 1., 2., 4.]]) - result = torch.isneginf(x) + result = torch.isneginf(torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], + [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], + [-float('inf'), float('inf'), 1., 2., 4.]]) + ) """ ) obj.run(pytorch_code, ["result"]) @@ -63,7 +63,7 @@ def test_case_4(): x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], [-float('inf'), float('inf'), 1., 2., 4.]]) - result = torch.isneginf(x, out = out) + result = torch.isneginf(out = out, input=x) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_isposinf.py b/tests/test_isposinf.py index 2f1ead5ee..173bff9eb 100644 --- a/tests/test_isposinf.py +++ b/tests/test_isposinf.py @@ -36,7 +36,7 @@ def test_case_2(): import torch out = torch.tensor([False, False, False]) x = torch.tensor([-float('inf'), float('inf'), 1.2]) - result = torch.isposinf(x, out = out) + result = torch.isposinf(input=x, out = out) """ ) obj.run(pytorch_code, ["result"]) @@ -46,10 +46,10 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], - [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], - [-float('inf'), float('inf'), 1., 2., 4.]]) - result = torch.isposinf(x) + result = torch.isposinf(torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], + [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], + [-float('inf'), float('inf'), 1., 2., 4.]]) + ) """ ) obj.run(pytorch_code, ["result"]) @@ -63,7 +63,7 @@ def test_case_4(): x = torch.tensor([[-float('inf'), float('inf'), 1.2, 0., 2.5], [-1.35 , -float('inf') , 0.18, -0.33, float('inf')], [-float('inf'), float('inf'), 1., 2., 4.]]) - result = torch.isposinf(x, out = out) + result = torch.isposinf(out = out, input=x) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_isreal.py b/tests/test_isreal.py index db941906f..596481524 100644 --- a/tests/test_isreal.py +++ b/tests/test_isreal.py @@ -33,7 +33,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - result = torch.isreal(torch.tensor([-0., -2.1, 2.5])) + result = torch.isreal(input=torch.tensor([-0., -2.1, 2.5])) """ ) obj.run(pytorch_code, ["result"]) @@ -48,3 +48,14 @@ def test_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([(-0.+1j), (-2.1+0.2j), (2.5-3.1j)]) + result = torch.isreal(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) \ No newline at end of file diff --git a/tests/test_positive.py b/tests/test_positive.py index a9288d066..a0674b5ac 100644 --- a/tests/test_positive.py +++ b/tests/test_positive.py @@ -35,8 +35,28 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) - result = torch.positive(t) + result = torch.positive(torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[-4., 1., 1., 16.]]) + result = torch.positive(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.positive(input=torch.tensor([[-4., 1., 1., 16.]])) """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_scatter_reduce.py b/tests/test_scatter_reduce.py index aa4dcdc6b..c252806c3 100644 --- a/tests/test_scatter_reduce.py +++ b/tests/test_scatter_reduce.py @@ -26,7 +26,8 @@ def test_case_1(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = torch.scatter_reduce(input, 0, index, src, reduce="sum") + type = "sum" + result = torch.scatter_reduce(input, 0, index, src, reduce=type) """ ) obj.run(pytorch_code, ["result"]) @@ -39,7 +40,7 @@ def test_case_2(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = torch.scatter_reduce(input, 0, index, src, reduce="sum", include_self=False) + result = torch.scatter_reduce(input=input, dim=0, index=index, src=src, reduce="sum", include_self=False) """ ) obj.run(pytorch_code, ["result"]) @@ -52,7 +53,7 @@ def test_case_3(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = torch.scatter_reduce(input, 0, index, src, reduce="amax") + result = torch.scatter_reduce(input, 0, index, src, "amax") """ ) obj.run(pytorch_code, ["result"]) @@ -78,7 +79,8 @@ def test_case_5(): src = torch.tensor([1., 2., 3., 4., 5., 6.]) index = torch.tensor([0, 1, 0, 1, 2, 1]) input = torch.tensor([1., 2., 3., 4.]) - result = torch.scatter_reduce(input, 0, index, src, reduce="prod") + re_type = "prod" + result = torch.scatter_reduce(input, 0, index, src, reduce=re_type) """ ) obj.run(pytorch_code, ["result"]) @@ -108,3 +110,16 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"]) + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.tensor([[1., 2.],[3., 4.]]) + index = torch.tensor([[0, 0], [0, 0]]) + input = torch.tensor([[10., 30., 20.], [60., 40., 50.]]) + re_type = "prod" + result = torch.scatter_reduce(input, 0, index, src, re_type) + """ + ) + obj.run(pytorch_code, ["result"]) From 6789270c70be74221ca87f41471f0be22b08483e Mon Sep 17 00:00:00 2001 From: inaomIIsFarell <1344594208@qq.com> Date: Tue, 15 Oct 2024 13:24:09 +0800 Subject: [PATCH 3/3] fix --- paconvert/api_matcher.py | 262 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 244 insertions(+), 18 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index ec3d680f8..cab07045e 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -807,6 +807,74 @@ def generate_code(self, kwargs): return code +class AssertMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + assert {}, '{}' + """ + ) + code = API_TEMPLATE.format( + kwargs["condition"], + kwargs["message"], + ) + return code + + +class MakeTMatcher(BaseMatcher): + def get_paddle_nodes(self, args, kwargs): + kwargs = self.parse_kwargs(kwargs) + if "shape" not in kwargs: + if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)): + shape = self.parse_args(args) + elif isinstance(args[0], ast.Starred): + shape = astor.to_source(args[0].value).strip("\n") + else: + shape = self.parse_args(args)[0] + kwargs = {"shape": str(shape).replace("'", ""), **kwargs} + + if "dtype" not in kwargs: + kwargs["dtype"] = "float32" + + if "low" not in kwargs: + kwargs["low"] = 0 + + if "high" not in kwargs: + kwargs["high"] = 1 + + if "requires_grad" not in kwargs.keys(): + API_TEMPLATE = textwrap.dedent( + """ + paddle.uniform({}, dtype={}, min={}, max={}).to({}) + """ + ) + code = API_TEMPLATE.format( + kwargs["shape"], + kwargs["dtype"], + kwargs["low"], + kwargs["high"], + kwargs["device"], + ) + else: + API_TEMPLATE = textwrap.dedent( + """ + out = paddle.uniform({}, dtype={}, min={}, max={}).to({}) + out.stop_gradient = not {} + out + """ + ) + code = API_TEMPLATE.format( + kwargs["shape"], + kwargs["dtype"], + kwargs["low"], + kwargs["high"], + kwargs["device"], + kwargs["requires_grad"], + ) + + return ast.parse(code).body + + class CreateMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): kwargs = self.parse_kwargs(kwargs) @@ -1450,7 +1518,7 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) -class ScatterReduceMatcher(BaseMatcher): +class ScatterReduceMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( """ @@ -1465,7 +1533,13 @@ def reduce_type(type): return CODE_TEMPLATE def generate_code(self, kwargs): - allowed_reduce_type = ['"""sum"""', '"""prod"""', '"""amax"""', '"""amin"""', '"""mean"""'] + allowed_reduce_type = [ + '"""sum"""', + '"""prod"""', + '"""amax"""', + '"""amin"""', + '"""mean"""', + ] reduce_mapping = {'"""sum"""': '"add"', '"""prod"""': '"multiply"'} if "reduce" in kwargs and kwargs["reduce"] in reduce_mapping: kwargs["reduce"] = reduce_mapping[kwargs["reduce"]] @@ -2369,11 +2443,10 @@ def generate_code(self, kwargs): return code -class ReverseMomentumMatcher(BaseMatcher): +class ReverseMatcher(BaseMatcher): def generate_code(self, kwargs): if "momentum" in kwargs: kwargs["momentum"] = f"1 - {kwargs.pop('momentum')}" - return GenericMatcher.generate_code(self, kwargs) @@ -3511,6 +3584,72 @@ def generate_code(self, kwargs): return code +class SpecialNdtrMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + (paddle.erf({}/paddle.sqrt(paddle.to_tensor(2.)))-paddle.erf(paddle.to_tensor(-float('inf'))))/2 + """ + ) + code = API_TEMPLATE.format(kwargs["input"]) + if "out" in kwargs and kwargs["out"] != "None": + code = "paddle.assign({}, output={})".format(code, kwargs["out"]) + + return code + + +class LinalgInvExMatcher(BaseMatcher): + def generate_code(self, kwargs): + if "out" in kwargs and kwargs["out"] != "None": + out_v = kwargs["out"] + API_TEMPLATE = textwrap.dedent( + """ + out1 = paddle.linalg.inv({}) + out2 = paddle.zeros({}.shape[:-2], dtype='int32') + paddle.assign(out1, output={}[0]), paddle.assign(out2, output={}[1]) + """ + ) + code = API_TEMPLATE.format(kwargs["A"], kwargs["A"], out_v, out_v) + return code + else: + API_TEMPLATE = textwrap.dedent( + """ + (paddle.linalg.inv({}), paddle.zeros({}.shape[:-2], dtype='int32')) + """ + ) + code = API_TEMPLATE.format(kwargs["A"], kwargs["A"]) + return code + + +class LinalgCholeskyExMatcher(BaseMatcher): + def generate_code(self, kwargs): + if "upper" not in kwargs: + kwargs["upper"] = False + if "out" in kwargs and kwargs["out"] != "None": + out_v = kwargs["out"] + API_TEMPLATE = textwrap.dedent( + """ + out1 = paddle.linalg.cholesky(x={}, upper={}) + out2 = paddle.zeros({}.shape[:-2], dtype='int32') + paddle.assign(out1, output={}[0]), paddle.assign(out2, output={}[1]) + """ + ) + code = API_TEMPLATE.format( + kwargs["input"], kwargs["upper"], kwargs["input"], out_v, out_v + ) + return code + else: + API_TEMPLATE = textwrap.dedent( + """ + (paddle.linalg.cholesky(x={}, upper={}), paddle.zeros({}.shape[:-2], dtype='int32')) + """ + ) + code = API_TEMPLATE.format( + kwargs["input"], kwargs["upper"], kwargs["input"] + ) + return code + + class AdjointMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: @@ -4430,7 +4569,7 @@ class CanCastMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( """ - def can_cast(from_, to): + def can_cast(from_, to): can_cast_dict = { 'bfloat16': { 'bfloat16': True, @@ -4606,6 +4745,7 @@ def can_cast(from_, to): """ ) return CODE_TEMPLATE + def get_paddle_nodes(self, args, kwargs): self.write_aux_code() new_args = self.parse_args(args) @@ -4615,12 +4755,12 @@ def get_paddle_nodes(self, args, kwargs): to_type = new_kwargs.get("to", new_args[1] if len(new_args) > 1 else None) code = can_cast_template.format(from_type, to_type) return ast.parse(code).body - + class PositiveMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( - """ + """ def positive(x): if x.dtype != paddle.bool: return x @@ -4630,31 +4770,33 @@ def positive(x): """ ) return CODE_TEMPLATE + def generate_code(self, kwargs): self.write_aux_code() if "input" in kwargs and kwargs["input"] is not None: code = "paddle_aux.positive({})".format(kwargs["input"]) - else : - code = "paddle_aux.positive({})".format(self.paddleClass) + else: + code = "paddle_aux.positive({})".format(self.paddleClass) return code class FloatPowerMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( - """ - def get_exponent(exponent): - return exponent.cast(paddle.float64) if isinstance(exponent, paddle.Tensor) else exponent - setattr(paddle, "get_exponent", get_exponent) - """ + """ + def get_exponent(exponent): + return exponent.cast(paddle.float64) if isinstance(exponent, paddle.Tensor) else exponent + setattr(paddle, "get_exponent", get_exponent) + """ ) return CODE_TEMPLATE - + def generate_code(self, kwargs): self.write_aux_code() - pow_expression = "paddle.pow({}.cast(paddle.float64), paddle_aux.get_exponent({}))".format( - kwargs["input"], - kwargs["exponent"] + pow_expression = ( + "paddle.pow({}.cast(paddle.float64), paddle_aux.get_exponent({}))".format( + kwargs["input"], kwargs["exponent"] + ) ) if "out" in kwargs and kwargs["out"] is not None: code = "paddle.assign({}, {})".format(pow_expression, kwargs["out"]) @@ -4662,6 +4804,7 @@ def generate_code(self, kwargs): code = pow_expression return code + class FloatPowerInplaceMatcher(BaseMatcher): def generate_code(self, kwargs): return "{}.cast_(paddle.float64).pow_({})".format( @@ -5180,3 +5323,86 @@ def generate_code(self, kwargs): self.kwargs_to_str(kwargs_bin_edges), ) return code + + +class FromBufferMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + import numpy as np + paddle.to_tensor(np.frombuffer(np.array({}), {})) + """ + ) + code = API_TEMPLATE.format(kwargs["buffer"], kwargs["dtype"]) + + return code + + +class GetNumThreadsMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + import os + os.getenv("CPU_NUM",1) + """ + ) + code = API_TEMPLATE.format() + return code + + +class GetNumInteropThreadsMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + import os + int(os.environ['OMP_NUM_THREADS']) + """ + ) + code = API_TEMPLATE.format() + return code + + +class SetNumInteropThreadsMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + import os + def _set_num_interop_threads(int): + os.environ['OMP_NUM_THREADS'] = str(int) + """ + ) + return CODE_TEMPLATE + + def generate_code(self, kwargs): + self.write_aux_code() + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux._set_num_interop_threads({}) + """ + ) + code = API_TEMPLATE.format(kwargs["int"]) + + return code + + +class SetNumThreadsMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + import os + def _set_num_threads(int): + os.environ['CPU_NUM'] = str(int) + """ + ) + return CODE_TEMPLATE + + def generate_code(self, kwargs): + self.write_aux_code() + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux._set_num_threads({}) + """ + ) + code = API_TEMPLATE.format(kwargs["int"]) + + return code