Skip to content

Commit

Permalink
Merge pull request #4 from kengz/initfn
Browse files Browse the repository at this point in the history
allow arc.init to use dict for init kwargs
  • Loading branch information
kengz authored Jun 26, 2020
2 parents 895a42e + 2a2c4bf commit f01b6cb
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ arc = {
'batch_norm': True,
'activation': 'ReLU',
'dropout': 0.2,
'init': 'kaiming_uniform_',
'init': {
'type': 'normal_',
'std': 0.01,
},
}
model = torcharc.build(arc)

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run_tests(self):

setup(
name='torcharc',
version='0.0.4',
version='0.0.5',
description='Build PyTorch networks by specifying architectures.',
long_description='https://github.com/kengz/torcharc',
keywords='torcharc',
Expand All @@ -57,6 +57,8 @@ def run_tests(self):
extras_require={},
classifiers=[],
tests_require=[
'autopep8==1.5.3',
'flake8==3.8.3',
'flaky==3.6.1',
'pytest==5.4.1',
'pytest-cov==2.8.1',
Expand Down
1 change: 1 addition & 0 deletions test/test_module_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_build_module(arc, nn_class):
'kaiming_uniform_',
'kaiming_normal_',
'orthogonal_',
{'type': 'normal_', 'std': 0.01}, # when using init kwargs
])
@pytest.mark.parametrize('activation', [
'ReLU',
Expand Down
5 changes: 4 additions & 1 deletion torcharc/arc_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
'batch_norm': True,
'activation': 'ReLU',
'dropout': 0.2,
'init': 'kaiming_uniform_',
'init': {
'type': 'normal_',
'std': 0.01,
},
},
'tstransformer': {
'type': 'TSTransformer',
Expand Down
25 changes: 18 additions & 7 deletions torcharc/module_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,34 @@
setattr(torch.optim, 'RAdam', optim.RAdam)


def get_init_fn(init: str, activation: Optional[str] = None) -> Callable:
def get_init_fn(init: Union[str, dict], activation: Optional[str] = None) -> Callable:
'''Get init function that can be called as `module.apply(init_fn)`. Initializes weights only. Internally this also takes care of gain and nonlinearity args. Ref: https://pytorch.org/docs/stable/nn.init.html'''
def init_fn(module: nn.Module) -> None:
fn = getattr(nn.init, init)
if init is None:
return
elif ps.is_string(init):
init_type = init
init_kwargs = {}
else:
assert ps.is_dict(init)
init_type = init['type']
init_kwargs = ps.omit(init, 'type')
fn = getattr(nn.init, init_type)
args = inspect.getfullargspec(fn).args
try:
try:
# first try with gain/activation args
if 'gain' in args:
gain = nn.init.calculate_gain(activation)
fn(module.weight, gain=gain)
ext_init_kwargs = {'gain': gain, **init_kwargs}
fn(module.weight, **ext_init_kwargs)
elif 'nonlinearity' in args:
fn(module.weight, nonlinearity=activation)
ext_init_kwargs = {'nonlinearity': activation, **init_kwargs}
fn(module.weight, **ext_init_kwargs)
else:
fn(module.weight)
fn(module.weight, **init_kwargs)
except Exception: # first fallback to plain init
fn(module.weight)
fn(module.weight, **init_kwargs)
except Exception: # second fallback: module weight cannot be initialized, ok
pass
return init_fn
Expand All @@ -51,7 +62,7 @@ def build_module(arc: dict) -> nn.Module:
if arc.get('layers'): # if given layers, build as sequential
module = sequential.build(arc)
else:
kwargs = ps.omit(arc, 'type', 'in_names')
kwargs = ps.omit(arc, 'type', 'in_names', 'init')
module = getattr(nn, arc['type'])(**kwargs)
# initialize weights if 'init' is given
if arc.get('init'):
Expand Down

0 comments on commit f01b6cb

Please sign in to comment.