Skip to content

Commit

Permalink
Merge pull request #5 from kengz/scalartst
Browse files Browse the repository at this point in the history
Support scalar output for transformers
  • Loading branch information
kengz authored Dec 30, 2020
2 parents f01b6cb + 5170718 commit 0be2905
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 43 deletions.
74 changes: 40 additions & 34 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,49 @@ name: CI

on:
push:
branches: [ master ]
branches: [main]
pull_request:
branches: [ master ]
branches: [main]

jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Cache Conda
uses: actions/cache@v1
with:
path: /usr/share/miniconda/envs/torcharc
key: ${{ runner.os }}-conda-${{ hashFiles('environment.yml') }}
restore-keys: |
${{ runner.os }}-conda-
- name: Setup Conda dependencies
uses: goanpeca/setup-miniconda@v1
with:
activate-environment: torcharc
environment-file: environment.yml
python-version: 3.8
auto-activate-base: false
- name: Conda info
shell: bash -l {0}
run: |
conda info
conda list
- name: Setup flake8 annotations
uses: rbialon/flake8-annotations@v1
- name: Lint with flake8
shell: bash -l {0}
run: |
pip install flake8
# exit-zero treats all errors as warnings.
flake8 . --ignore=E501 --count --exit-zero --statistics
- name: Run tests
shell: bash -l {0}
run: |
python setup.py test
- uses: actions/checkout@v2

- name: Cache Conda
uses: actions/cache@v2
with:
path: /usr/share/miniconda/envs/torcharc
key: ${{ runner.os }}-conda-${{ hashFiles('environment.yml') }}
restore-keys: |
${{ runner.os }}-conda-
- name: Setup Conda dependencies
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: torcharc
environment-file: environment.yml
python-version: 3.8
auto-activate-base: false

- name: Conda info
shell: bash -l {0}
run: |
conda info
conda list
- name: Setup flake8 annotations
uses: rbialon/flake8-annotations@v1
- name: Lint with flake8
shell: bash -l {0}
run: |
pip install flake8
# exit-zero treats all errors as warnings.
flake8 . --ignore=E501 --count --exit-zero --statistics
- name: Run tests
shell: bash -l {0}
run: |
python setup.py test
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run_tests(self):

setup(
name='torcharc',
version='0.0.5',
version='0.0.6',
description='Build PyTorch networks by specifying architectures.',
long_description='https://github.com/kengz/torcharc',
keywords='torcharc',
Expand Down
12 changes: 9 additions & 3 deletions torcharc/module/transformer/pytorch_tst.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
q: int = 8, # Dimension of queries and keys.
v: int = 8, # Dimension of values.
chunk_mode: bool = 'chunk',
scalar_output: bool = False,
) -> None:
super().__init__()

Expand All @@ -66,6 +67,7 @@ def __init__(
self.decoders = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

self.out_linear = nn.Linear(d_model, out_channels)
self.scalar_output = scalar_output

def forward(self, x: torch.Tensor) -> torch.Tensor:
'''
Expand All @@ -78,9 +80,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
encoding = self.encoders(encoding.transpose(0, 1))

decoding = encoding
if self.pe is not None: # position encoding
decoding = self.pe(decoding.transpose(0, 1)).transpose(0, 1)
decoding = self.decoders(decoding, encoding).transpose(0, 1)
if len(self.decoders.layers):
if self.pe is not None: # position encoding
decoding = self.pe(decoding.transpose(0, 1)).transpose(0, 1)
decoding = self.decoders(decoding, encoding).transpose(0, 1)

if self.scalar_output: # if want scalar instead of seq output, take the first index from seq
decoding = decoding[:, 0, :]

output = self.out_linear(decoding)
return output
14 changes: 10 additions & 4 deletions torcharc/module/transformer/tst.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
q: int = 8, # Dimension of queries and keys.
v: int = 8, # Dimension of values.
chunk_mode: bool = 'chunk',
scalar_output: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
activation=activation,
chunk_mode=chunk_mode) for _ in range(num_decoder_layers)])
self.out_linear = nn.Linear(d_model, out_channels)
self.scalar_output = scalar_output

def forward(self, x: torch.Tensor) -> torch.Tensor:
'''
Expand All @@ -89,10 +91,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
encoding = self.encoders(encoding)

decoding = encoding
if self.pe is not None: # position encoding
decoding = self.pe(decoding)
for layer in self.decoders:
decoding = layer(decoding, encoding)
if len(self.decoders):
if self.pe is not None: # position encoding
decoding = self.pe(decoding)
for layer in self.decoders:
decoding = layer(decoding, encoding)

if self.scalar_output: # if want scalar instead of seq output, take the first index from seq
decoding = decoding[:, 0, :]

output = self.out_linear(decoding)
return output
2 changes: 1 addition & 1 deletion torcharc/module_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def infer_in_shape(arc: dict, xs: Union[torch.Tensor, NamedTuple]) -> None:
arc.update(in_shape=in_shape)
elif nn_type == 'FiLMMerge':
assert ps.is_tuple(xs)
assert len(arc['in_names']) == 2, f'FiLMMerge in_names should only specify 2 keys for feature and conditioner'
assert len(arc['in_names']) == 2, 'FiLMMerge in_names should only specify 2 keys for feature and conditioner'
shapes = {name: list(x.shape)[1:] for name, x in xs._asdict().items() if name in arc['in_names']}
arc.update(shapes=shapes)
else:
Expand Down

0 comments on commit 0be2905

Please sign in to comment.