From 2a2c4bf00fbcdc59646ad8a268974a8aedcc00a4 Mon Sep 17 00:00:00 2001 From: kengz Date: Thu, 25 Jun 2020 20:53:46 -0700 Subject: [PATCH] allow arc.init to use dict for init kwargs --- README.md | 5 ++++- setup.py | 4 +++- test/test_module_builder.py | 1 + torcharc/arc_ref.py | 5 ++++- torcharc/module_builder.py | 25 ++++++++++++++++++------- 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7f3c319..c7492d8 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,10 @@ arc = { 'batch_norm': True, 'activation': 'ReLU', 'dropout': 0.2, - 'init': 'kaiming_uniform_', + 'init': { + 'type': 'normal_', + 'std': 0.01, + }, } model = torcharc.build(arc) diff --git a/setup.py b/setup.py index 3c9d5e2..c187952 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def run_tests(self): setup( name='torcharc', - version='0.0.4', + version='0.0.5', description='Build PyTorch networks by specifying architectures.', long_description='https://github.com/kengz/torcharc', keywords='torcharc', @@ -57,6 +57,8 @@ def run_tests(self): extras_require={}, classifiers=[], tests_require=[ + 'autopep8==1.5.3', + 'flake8==3.8.3', 'flaky==3.6.1', 'pytest==5.4.1', 'pytest-cov==2.8.1', diff --git a/test/test_module_builder.py b/test/test_module_builder.py index 90aad5f..c92d7fd 100644 --- a/test/test_module_builder.py +++ b/test/test_module_builder.py @@ -88,6 +88,7 @@ def test_build_module(arc, nn_class): 'kaiming_uniform_', 'kaiming_normal_', 'orthogonal_', + {'type': 'normal_', 'std': 0.01}, # when using init kwargs ]) @pytest.mark.parametrize('activation', [ 'ReLU', diff --git a/torcharc/arc_ref.py b/torcharc/arc_ref.py index 2118350..9ec9223 100644 --- a/torcharc/arc_ref.py +++ b/torcharc/arc_ref.py @@ -43,7 +43,10 @@ 'batch_norm': True, 'activation': 'ReLU', 'dropout': 0.2, - 'init': 'kaiming_uniform_', + 'init': { + 'type': 'normal_', + 'std': 0.01, + }, }, 'tstransformer': { 'type': 'TSTransformer', diff --git a/torcharc/module_builder.py b/torcharc/module_builder.py index a6de618..3f4f494 100644 --- a/torcharc/module_builder.py +++ b/torcharc/module_builder.py @@ -24,23 +24,34 @@ setattr(torch.optim, 'RAdam', optim.RAdam) -def get_init_fn(init: str, activation: Optional[str] = None) -> Callable: +def get_init_fn(init: Union[str, dict], activation: Optional[str] = None) -> Callable: '''Get init function that can be called as `module.apply(init_fn)`. Initializes weights only. Internally this also takes care of gain and nonlinearity args. Ref: https://pytorch.org/docs/stable/nn.init.html''' def init_fn(module: nn.Module) -> None: - fn = getattr(nn.init, init) + if init is None: + return + elif ps.is_string(init): + init_type = init + init_kwargs = {} + else: + assert ps.is_dict(init) + init_type = init['type'] + init_kwargs = ps.omit(init, 'type') + fn = getattr(nn.init, init_type) args = inspect.getfullargspec(fn).args try: try: # first try with gain/activation args if 'gain' in args: gain = nn.init.calculate_gain(activation) - fn(module.weight, gain=gain) + ext_init_kwargs = {'gain': gain, **init_kwargs} + fn(module.weight, **ext_init_kwargs) elif 'nonlinearity' in args: - fn(module.weight, nonlinearity=activation) + ext_init_kwargs = {'nonlinearity': activation, **init_kwargs} + fn(module.weight, **ext_init_kwargs) else: - fn(module.weight) + fn(module.weight, **init_kwargs) except Exception: # first fallback to plain init - fn(module.weight) + fn(module.weight, **init_kwargs) except Exception: # second fallback: module weight cannot be initialized, ok pass return init_fn @@ -51,7 +62,7 @@ def build_module(arc: dict) -> nn.Module: if arc.get('layers'): # if given layers, build as sequential module = sequential.build(arc) else: - kwargs = ps.omit(arc, 'type', 'in_names') + kwargs = ps.omit(arc, 'type', 'in_names', 'init') module = getattr(nn, arc['type'])(**kwargs) # initialize weights if 'init' is given if arc.get('init'):