Skip to content

Commit

Permalink
【Hackathon 7th No.36】为 Paddle 代码转换工具新增 API 转换规则(第 3 组)PART 2 (#484)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
enkilee authored Oct 12, 2024
1 parent c06c5ab commit da330a9
Show file tree
Hide file tree
Showing 10 changed files with 813 additions and 12 deletions.
126 changes: 116 additions & 10 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -4996,6 +4996,29 @@
"other": "y"
}
},
"torch.blackman_window": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"window_length",
"periodic",
"*",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"window_length": "win_length",
"periodic": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "'float32'",
"window": "'blackman'"
}
},
"torch.bmm": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.bmm",
Expand Down Expand Up @@ -7504,6 +7527,17 @@
"ndarray": "data"
}
},
"torch.frombuffer": {
"Matcher": "FromBufferMatcher",
"args_list": [
"buffer",
"*",
"dtype",
"count",
"offset",
"requires_grad"
]
},
"torch.full": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.full",
Expand Down Expand Up @@ -7609,6 +7643,12 @@
"Matcher": "GenericMatcher",
"paddle_api": "paddle.get_default_dtype"
},
"torch.get_num_interop_threads": {
"Matcher": "GetNumInteropThreadsMatcher"
},
"torch.get_num_threads": {
"Matcher": "GetNumThreadsMatcher"
},
"torch.get_rng_state": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.get_rng_state",
Expand Down Expand Up @@ -7647,6 +7687,58 @@
"out"
]
},
"torch.hamming_window": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"window_length",
"periodic",
"alpha",
"beta",
"*",
"dtype",
"layout",
"device",
"requires_grad"
],
"unsupport_args": [
"alpha",
"beta"
],
"kwargs_change": {
"window_length": "win_length",
"periodic": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "'float32'",
"window": "'hamming'"
}
},
"torch.hann_window": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.audio.functional.get_window",
"min_input_args": 1,
"args_list": [
"window_length",
"periodic",
"*",
"dtype",
"layout",
"device",
"requires_grad"
],
"kwargs_change": {
"window_length": "win_length",
"periodic": "fftbins",
"dtype": "dtype"
},
"paddle_default_kwargs": {
"dtype": "'float32'",
"window": "'hann'"
}
},
"torch.heaviside": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.heaviside",
Expand Down Expand Up @@ -9576,7 +9668,7 @@
]
},
"torch.nn.BatchNorm1d": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.BatchNorm1D",
"min_input_args": 1,
"args_list": [
Expand All @@ -9600,7 +9692,7 @@
}
},
"torch.nn.BatchNorm2d": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.BatchNorm2D",
"min_input_args": 1,
"args_list": [
Expand All @@ -9624,7 +9716,7 @@
}
},
"torch.nn.BatchNorm3d": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.BatchNorm3D",
"min_input_args": 1,
"args_list": [
Expand Down Expand Up @@ -10177,7 +10269,7 @@
]
},
"torch.nn.InstanceNorm1d": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.InstanceNorm1D",
"min_input_args": 1,
"args_list": [
Expand All @@ -10201,7 +10293,7 @@
}
},
"torch.nn.InstanceNorm2d": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.InstanceNorm2D",
"min_input_args": 1,
"args_list": [
Expand All @@ -10225,7 +10317,7 @@
}
},
"torch.nn.InstanceNorm3d": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.InstanceNorm3D",
"min_input_args": 1,
"args_list": [
Expand Down Expand Up @@ -11263,7 +11355,7 @@
"min_input_args": 0
},
"torch.nn.SyncBatchNorm": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.SyncBatchNorm",
"args_list": [
"num_features",
Expand Down Expand Up @@ -11684,7 +11776,7 @@
]
},
"torch.nn.functional.batch_norm": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.functional.batch_norm",
"args_list": [
"input",
Expand Down Expand Up @@ -12243,7 +12335,7 @@
"min_input_args": 2
},
"torch.nn.functional.instance_norm": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.functional.instance_norm",
"args_list": [
"input",
Expand Down Expand Up @@ -13265,7 +13357,7 @@
}
},
"torch.nn.modules.batchnorm._BatchNorm": {
"Matcher": "ReverseMomentumMatcher",
"Matcher": "ReverseMatcher",
"paddle_api": "paddle.nn.layer.norm._BatchNormBase",
"args_list": [
"num_features",
Expand Down Expand Up @@ -14723,6 +14815,20 @@
"mode"
]
},
"torch.set_num_interop_threads": {
"Matcher": "SetNumInteropThreadsMatcher",
"min_input_args": 1,
"args_list": [
"int"
]
},
"torch.set_num_threads": {
"Matcher": "SetNumThreadsMatcher",
"min_input_args": 1,
"args_list": [
"int"
]
},
"torch.set_printoptions": {
"Matcher": "SetPrintOptionsMatcher",
"paddle_api": "paddle.set_printoptions",
Expand Down
86 changes: 84 additions & 2 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,11 +2344,10 @@ def generate_code(self, kwargs):
return code


class ReverseMomentumMatcher(BaseMatcher):
class ReverseMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "momentum" in kwargs:
kwargs["momentum"] = f"1 - {kwargs.pop('momentum')}"

return GenericMatcher.generate_code(self, kwargs)


Expand Down Expand Up @@ -4919,3 +4918,86 @@ def generate_code(self, kwargs):
self.kwargs_to_str(kwargs_bin_edges),
)
return code


class FromBufferMatcher(BaseMatcher):
def generate_code(self, kwargs):
API_TEMPLATE = textwrap.dedent(
"""
import numpy as np
paddle.to_tensor(np.frombuffer(np.array({}), {}))
"""
)
code = API_TEMPLATE.format(kwargs["buffer"], kwargs["dtype"])

return code


class GetNumThreadsMatcher(BaseMatcher):
def generate_code(self, kwargs):
API_TEMPLATE = textwrap.dedent(
"""
import os
os.getenv("CPU_NUM",1)
"""
)
code = API_TEMPLATE.format()
return code


class GetNumInteropThreadsMatcher(BaseMatcher):
def generate_code(self, kwargs):
API_TEMPLATE = textwrap.dedent(
"""
import os
int(os.environ['OMP_NUM_THREADS'])
"""
)
code = API_TEMPLATE.format()
return code


class SetNumInteropThreadsMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
import os
def _set_num_interop_threads(int):
os.environ['OMP_NUM_THREADS'] = str(int)
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux._set_num_interop_threads({})
"""
)
code = API_TEMPLATE.format(kwargs["int"])

return code


class SetNumThreadsMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
import os
def _set_num_threads(int):
os.environ['CPU_NUM'] = str(int)
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux._set_num_threads({})
"""
)
code = API_TEMPLATE.format(kwargs["int"])

return code
Loading

0 comments on commit da330a9

Please sign in to comment.