From 16344289d2038d605747b095a690d50cfa64b3e9 Mon Sep 17 00:00:00 2001
From: kengz
Date: Sat, 26 Jun 2021 11:11:53 -0700
Subject: [PATCH 1/2] feat(dict): replace TensorTuple with tensor dict
- replace all namedtuple TensorTuple with safer dict of tensors to make module pickle-safe
- use native Python functions instead of pydash to improve performance by reducing executions in Python (all in torch)
BREAKING CHANGE: TensorTuple is replaced with dict of tensors
---
README.md | 10 +++----
setup.py | 2 +-
test/module/test_dag.py | 18 +++++--------
test/module/test_fork.py | 9 +++----
test/module/test_merge.py | 4 +--
test/test_module_builder.py | 15 +++++------
test/test_net_util.py | 5 ++--
torcharc/__init__.py | 1 -
torcharc/module/dag.py | 8 +++---
torcharc/module/fork.py | 17 +++++-------
torcharc/module/merge.py | 14 +++++-----
torcharc/module_builder.py | 53 ++++++++++++++++++-------------------
torcharc/net_util.py | 20 +++++---------
13 files changed, 76 insertions(+), 100 deletions(-)
diff --git a/README.md b/README.md
index c7492d8..9d58982 100644
--- a/README.md
+++ b/README.md
@@ -359,10 +359,8 @@ model = torcharc.build(arc)
batch_size = 16
dag_in_shape = arc['dag_in_shape']
-data = {'image': torch.rand([batch_size, *dag_in_shape['image']]), 'vector': torch.rand([batch_size, *dag_in_shape['vector']])}
-# convert from a dict of Tensors into a TensorTuple - a namedtuple
-xs = torcharc.to_namedtuple(data)
-# returns TensorTuple if output is multi-model, Tensor otherwise
+xs = {'image': torch.rand([batch_size, *dag_in_shape['image']]), 'vector': torch.rand([batch_size, *dag_in_shape['vector']])}
+# returns dict if output is multi-model, Tensor otherwise
ys = model(xs)
```
@@ -406,9 +404,9 @@ DAGNet(
-DAG module accepts a `TensorTuple` (example below) as input, and the module selects its input by matching its own name in the arc and the `in_name`, then carry forward the output together with any unconsumed inputs.
+DAG module accepts a `dict` (example below) as input, and the module selects its input by matching its own name in the arc and the `in_name`, then carry forward the output together with any unconsumed inputs.
-For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `TensorTuple(image=image_module(xs.image), vector=xs.vector)`. This is then passed through the remainder of the modules in the arc as declared.
+For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `{'image': image_module(xs.image), 'vector': xs.vector}`. This is then passed through the remainder of the modules in the arc as declared.
## Development
diff --git a/setup.py b/setup.py
index 9258551..6c6b74b 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,7 @@ def run_tests(self):
setup(
name='torcharc',
- version='0.0.6',
+ version='1.0.0',
description='Build PyTorch networks by specifying architectures.',
long_description='https://github.com/kengz/torcharc',
keywords='torcharc',
diff --git a/test/module/test_dag.py b/test/module/test_dag.py
index f71576d..596729e 100644
--- a/test/module/test_dag.py
+++ b/test/module/test_dag.py
@@ -1,6 +1,5 @@
from torcharc import arc_ref, net_util
from torcharc.module import dag
-import pydash as ps
import torch
@@ -37,7 +36,7 @@ def test_dag_reusefork():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
def test_dag_splitfork():
@@ -46,7 +45,7 @@ def test_dag_splitfork():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
def test_dag_merge_fork():
@@ -54,9 +53,8 @@ def test_dag_merge_fork():
in_shapes = arc['dag_in_shape']
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
- ys = model(xs._asdict()) # test dict input for tracing
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
def test_dag_fork_merge():
@@ -74,7 +72,7 @@ def test_dag_reuse_fork_forward():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
def test_dag_split_fork_forward():
@@ -83,7 +81,7 @@ def test_dag_split_fork_forward():
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
def test_dag_merge_forward_split():
@@ -91,9 +89,8 @@ def test_dag_merge_forward_split():
in_shapes = arc['dag_in_shape']
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
- ys = model(xs._asdict()) # test dict input for tracing
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
def test_dag_hydra():
@@ -101,6 +98,5 @@ def test_dag_hydra():
in_shapes = arc['dag_in_shape']
xs = net_util.get_rand_tensor(in_shapes)
model = dag.DAGNet(arc)
- ys = model(xs._asdict()) # test dict input for tracing
ys = model(xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
diff --git a/test/module/test_fork.py b/test/module/test_fork.py
index 3a63317..f92c2df 100644
--- a/test/module/test_fork.py
+++ b/test/module/test_fork.py
@@ -1,5 +1,4 @@
from torcharc.module.fork import Fork, ReuseFork, SplitFork
-import pydash as ps
import pytest
import torch
@@ -14,8 +13,8 @@ def test_reuse_fork(names, x):
fork = ReuseFork(names)
assert isinstance(fork, Fork)
ys = fork(x)
- assert ps.is_tuple(ys)
- assert ys._fields == tuple(names)
+ assert isinstance(ys, dict)
+ assert list(ys) == names
@pytest.mark.parametrize('shapes,x', [
@@ -28,5 +27,5 @@ def test_split_fork(shapes, x):
fork = SplitFork(shapes)
assert isinstance(fork, Fork)
ys = fork(x)
- assert ps.is_tuple(ys)
- assert ys._fields == tuple(shapes.keys())
+ assert isinstance(ys, dict)
+ assert ys.keys() == shapes.keys()
diff --git a/test/module/test_merge.py b/test/module/test_merge.py
index a8b99d1..767543b 100644
--- a/test/module/test_merge.py
+++ b/test/module/test_merge.py
@@ -16,7 +16,7 @@
def test_concat_merge(xs, out_shape):
merge = ConcatMerge()
assert isinstance(merge, Merge)
- y = merge(net_util.to_namedtuple(xs))
+ y = merge(xs)
assert y.shape == torch.Size(out_shape)
@@ -56,5 +56,5 @@ def test_film_affine_transform(feature):
def test_film_merge(names, shapes, xs):
merge = FiLMMerge(names, shapes)
assert isinstance(merge, Merge)
- y = merge(net_util.to_namedtuple(xs))
+ y = merge(xs)
assert y.shape == xs[names['feature']].shape
diff --git a/test/test_module_builder.py b/test/test_module_builder.py
index c92d7fd..3ad10aa 100644
--- a/test/test_module_builder.py
+++ b/test/test_module_builder.py
@@ -1,7 +1,6 @@
from fixture.net import CONV1D_ARC, CONV2D_ARC, CONV3D_ARC, LINEAR_ARC
from torcharc import module_builder, net_util
from torch import nn
-import pydash as ps
import pytest
import torch
@@ -27,7 +26,7 @@
])
def test_get_init_fn(init, activation):
init_fn = module_builder.get_init_fn(init, activation)
- assert ps.is_function(init_fn)
+ assert callable(init_fn)
@pytest.mark.parametrize('arc,nn_class', [
@@ -248,11 +247,11 @@ def test_carry_forward_tensor(arc, xs):
net_util.get_rand_tensor({'vector': [LINEAR_ARC['in_features']], 'image': CONV2D_ARC['in_shape']}),
)
])
-def test_carry_forward_tensor_tuple_default(arc, xs):
+def test_carry_forward_dict_default(arc, xs):
module = module_builder.build_module(arc)
- assert ps.is_tuple(xs)
+ assert isinstance(xs, dict)
ys = module_builder.carry_forward(module, xs)
- assert ps.is_tuple(ys)
+ assert isinstance(ys, dict)
@pytest.mark.parametrize('xs', [
@@ -281,8 +280,8 @@ def test_carry_forward_tensor_tuple_default(arc, xs):
['image', 'vector'],
)
])
-def test_carry_forward_tensor_tuple(arc, xs, in_names):
+def test_carry_forward_dict(arc, xs, in_names):
module = module_builder.build_module(arc)
- assert ps.is_tuple(xs)
+ assert isinstance(xs, dict)
ys = module_builder.carry_forward(module, xs, in_names)
- assert isinstance(ys, (torch.Tensor, tuple))
+ assert isinstance(ys, (torch.Tensor, dict))
diff --git a/test/test_net_util.py b/test/test_net_util.py
index d316253..c1dcea9 100644
--- a/test/test_net_util.py
+++ b/test/test_net_util.py
@@ -1,7 +1,6 @@
from fixture.net import CONV1D_ARC, CONV2D_ARC, CONV3D_ARC, LINEAR_ARC
from torcharc import module_builder, net_util
from torch import nn
-import pydash as ps
import pytest
import torch
@@ -39,7 +38,7 @@ def test_get_rand_tensor(shape, batch_size, tensor_shape):
])
def test_get_rand_tensor_dict(shapes, batch_size, tensor_shapes):
xs = net_util.get_rand_tensor(shapes, batch_size)
- assert ps.is_tuple(xs)
+ assert isinstance(xs, dict)
for name, tensor_shape in tensor_shapes.items():
- x = getattr(xs, name)
+ x = xs[name]
assert list(x.shape) == tensor_shape
diff --git a/torcharc/__init__.py b/torcharc/__init__.py
index 055cc96..6fd4fd9 100644
--- a/torcharc/__init__.py
+++ b/torcharc/__init__.py
@@ -1,5 +1,4 @@
from torcharc import module_builder
-from torcharc.net_util import to_namedtuple
from torcharc.module import dag
from torch import nn
diff --git a/torcharc/module/dag.py b/torcharc/module/dag.py
index 225aa29..28791c6 100644
--- a/torcharc/module/dag.py
+++ b/torcharc/module/dag.py
@@ -1,7 +1,7 @@
# build DAG of nn modules
from torcharc import module_builder, net_util
from torch import nn
-from typing import NamedTuple, Union
+from typing import Union
import pydash as ps
import torch
@@ -30,10 +30,8 @@ def __init__(self, arc: dict) -> None:
xs = module_builder.carry_forward(module, xs, m_arc.get('in_names'))
self.module_dict.update({name: module})
- def forward(self, xs: Union[torch.Tensor, NamedTuple]) -> Union[torch.Tensor, NamedTuple]:
- # jit.trace will spread args on encountering a namedtuple, thus xs needs to be passed as dict then converted back into namedtuple
- if ps.is_dict(xs): # guard to convert dict xs into namedtuple
- xs = net_util.to_namedtuple(xs)
+ def forward(self, xs: Union[torch.Tensor, dict]) -> Union[torch.Tensor, dict]:
+ # safe for jit.trace
for name, module in self.module_dict.items():
m_arc = self.arc[name]
xs = module_builder.carry_forward(module, xs, m_arc.get('in_names'))
diff --git a/torcharc/module/fork.py b/torcharc/module/fork.py
index 5c972db..284e942 100644
--- a/torcharc/module/fork.py
+++ b/torcharc/module/fork.py
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
-from collections import namedtuple
from torch import nn
-from typing import Dict, List, NamedTuple
+from typing import Dict, List
import pydash as ps
import torch
@@ -10,21 +9,20 @@ class Fork(ABC, nn.Module):
'''A Fork module forks one tensor into a dict of multiple tensors.'''
@abstractmethod
- def forward(self, x: torch.Tensor) -> NamedTuple: # pragma: no cover
+ def forward(self, x: torch.Tensor) -> dict: # pragma: no cover
raise NotImplementedError
class ReuseFork(Fork):
- '''Fork layer to reuse a tensor multiple times via ref in TensorTuple'''
+ '''Fork layer to reuse a tensor multiple times via ref in dict'''
def __init__(self, names: List[str]) -> None:
super().__init__()
self.names = names
self.num_reuse = len(names)
- self.TensorTuple = namedtuple('TensorTuple', names)
- def forward(self, x: torch.Tensor) -> NamedTuple:
- return self.TensorTuple(*[x] * self.num_reuse)
+ def forward(self, x: torch.Tensor) -> dict:
+ return dict(zip(self.names, [x] * self.num_reuse))
class SplitFork(Fork):
@@ -34,7 +32,6 @@ def __init__(self, shapes: Dict[str, List[int]]) -> None:
super().__init__()
self.shapes = shapes
self.split_size = ps.flatten(self.shapes.values())
- self.TensorTuple = namedtuple('TensorTuple', shapes.keys())
- def forward(self, x: torch.Tensor) -> NamedTuple:
- return self.TensorTuple(*x.split(self.split_size, dim=1))
+ def forward(self, x: torch.Tensor) -> dict:
+ return dict(zip(self.shapes, x.split(self.split_size, dim=1)))
diff --git a/torcharc/module/merge.py b/torcharc/module/merge.py
index b639f63..106647c 100644
--- a/torcharc/module/merge.py
+++ b/torcharc/module/merge.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from torch import nn
-from typing import Dict, List, NamedTuple
+from typing import Dict, List
import torch
@@ -8,15 +8,15 @@ class Merge(ABC, nn.Module):
'''A Merge module merges a dict of tensors into one tensor'''
@abstractmethod
- def forward(self, xs: NamedTuple) -> torch.Tensor: # pragma: no cover
+ def forward(self, xs: dict) -> torch.Tensor: # pragma: no cover
raise NotImplementedError
class ConcatMerge(Merge):
'''Merge layer to merge a dict of tensors by concatenating along dim=1. Reverse of Split'''
- def forward(self, xs: NamedTuple) -> torch.Tensor:
- return torch.cat(xs, dim=1)
+ def forward(self, xs: dict) -> torch.Tensor:
+ return torch.cat(list(xs.values()), dim=1)
class FiLMMerge(Merge):
@@ -43,10 +43,10 @@ def affine_transform(cls, feature: torch.Tensor, conditioner_scale: torch.Tensor
view_shape = list(conditioner_scale.shape) + [1] * (feature.dim() - conditioner_scale.dim())
return conditioner_scale.view(*view_shape) * feature + conditioner_shift.view(*view_shape)
- def forward(self, xs: NamedTuple) -> torch.Tensor:
+ def forward(self, xs: dict) -> torch.Tensor:
'''Apply FiLM affine transform on feature using conditioner'''
- feature = getattr(xs, self.feature_name)
- conditioner = getattr(xs, self.conditioner_name)
+ feature = xs[self.feature_name]
+ conditioner = xs[self.conditioner_name]
conditioner_scale = self.conditioner_scale(conditioner)
conditioner_shift = self.conditioner_shift(conditioner)
return self.affine_transform(feature, conditioner_scale, conditioner_shift)
diff --git a/torcharc/module_builder.py b/torcharc/module_builder.py
index dee25c6..84887b2 100644
--- a/torcharc/module_builder.py
+++ b/torcharc/module_builder.py
@@ -1,10 +1,9 @@
# build neural networks modularly
from torch import nn
-from torcharc import net_util
from torcharc import optim
from torcharc.module import fork, merge, sequential
from torcharc.module.transformer import pytorch_tst, tst
-from typing import Callable, List, Optional, NamedTuple, Union
+from typing import Callable, List, Optional, Union
import inspect
import pydash as ps
import torch
@@ -29,11 +28,11 @@ def get_init_fn(init: Union[str, dict], activation: Optional[str] = None) -> Cal
def init_fn(module: nn.Module) -> None:
if init is None:
return
- elif ps.is_string(init):
+ elif isinstance(init, str):
init_type = init
init_kwargs = {}
else:
- assert ps.is_dict(init)
+ assert isinstance(init, dict)
init_type = init['type']
init_kwargs = ps.omit(init, 'type')
fn = getattr(nn.init, init_type)
@@ -70,60 +69,60 @@ def build_module(arc: dict) -> nn.Module:
return module
-def infer_in_shape(arc: dict, xs: Union[torch.Tensor, NamedTuple]) -> None:
+def infer_in_shape(arc: dict, xs: Union[torch.Tensor, dict]) -> None:
'''Infer the input shape(s) for arc depending on its type and the input tensor. This updates the arc with the appropriate key.'''
nn_type = arc['type']
if nn_type == 'Linear':
- if ps.is_tuple(xs):
- in_names = arc.get('in_names', xs._fields[:1])
- xs = getattr(xs, in_names[0])
+ if isinstance(xs, dict):
+ in_names = arc.get('in_names', list(xs)[:1])
+ xs = xs[in_names[0]]
assert isinstance(xs, torch.Tensor)
assert len(xs.shape) == 2, f'xs shape {xs.shape} is not meant for {nn_type} layer'
in_features = xs.shape[1]
arc.update(in_features=in_features)
elif nn_type.startswith('Conv') or nn_type == 'transformer':
- if ps.is_tuple(xs):
- in_names = arc.get('in_names', xs._fields[:1])
- xs = getattr(xs, in_names[0])
+ if isinstance(xs, dict):
+ in_names = arc.get('in_names', list(xs)[:1])
+ xs = xs[in_names[0]]
assert isinstance(xs, torch.Tensor)
assert len(xs.shape) >= 2, f'xs shape {xs.shape} is not meant for {nn_type} layer'
in_shape = list(xs.shape)[1:]
arc.update(in_shape=in_shape)
elif nn_type == 'FiLMMerge':
- assert ps.is_tuple(xs)
+ assert isinstance(xs, dict)
assert len(arc['in_names']) == 2, 'FiLMMerge in_names should only specify 2 keys for feature and conditioner'
- shapes = {name: list(x.shape)[1:] for name, x in xs._asdict().items() if name in arc['in_names']}
+ shapes = {name: list(x.shape)[1:] for name, x in xs.items() if name in arc['in_names']}
arc.update(shapes=shapes)
else:
pass
-def carry_forward(module: nn.Module, xs: Union[torch.Tensor, NamedTuple], in_names: Optional[List[str]] = None) -> Union[torch.Tensor, NamedTuple]:
+def carry_forward(module: nn.Module, xs: Union[torch.Tensor, dict], in_names: Optional[List[str]] = None) -> Union[torch.Tensor, dict]:
'''
- Main method to call module.forward and handle tensor and namedtuple input/output
+ Main method to call module.forward and handle tensor and dict input/output
If xs and ys are tensors, forward as usual
- If xs or ys is namedtuple, then arc.in_names must specify the inputs names to be used in forward, and any unused names will be carried with the output, which will be namedtuple.
+ If xs or ys is dict, then arc.in_names must specify the inputs names to be used in forward, and any unused names will be carried with the output, which will be dict.
'''
- if ps.is_tuple(xs):
+ if isinstance(xs, dict):
if in_names is None: # use the first by default
- in_names = xs._fields[:1]
+ in_names = list(xs)[:1]
if len(in_names) == 1: # single input is tensor
- m_xs = getattr(xs, in_names[0])
- else: # multi input is namedtuple of tensors
- m_xs = net_util.to_namedtuple({name: getattr(xs, name) for name in in_names})
+ m_xs = xs[in_names[0]]
+ else: # multi input is dict of tensors
+ m_xs = {name: xs[name] for name in in_names}
ys = module(m_xs)
- # any unused_xs must be carried with the output as namedtuple
- d_xs = xs._asdict()
+ # any unused_xs must be carried with the output as dict
+ d_xs = xs
unused_d_xs = ps.omit(d_xs, in_names)
if unused_d_xs:
- if ps.is_tuple(ys):
- d_ys = {**ys._asdict(), **unused_d_xs}
- else: # when formed as namedtuple, single output will use the first of in_names
+ if isinstance(ys, dict):
+ d_ys = {**ys, **unused_d_xs}
+ else: # when formed as dict, single output will use the first of in_names
d_ys = {**{in_names[0]: ys}, **unused_d_xs}
- ys = net_util.to_namedtuple(d_ys)
+ ys = d_ys
else:
ys = module(xs)
return ys
diff --git a/torcharc/net_util.py b/torcharc/net_util.py
index f4adb77..e312fa1 100644
--- a/torcharc/net_util.py
+++ b/torcharc/net_util.py
@@ -1,7 +1,5 @@
-from collections import namedtuple
from torch import nn
-from typing import Dict, List, NamedTuple, Union
-import pydash as ps
+from typing import Dict, List, Union
import torch
@@ -18,20 +16,14 @@ def get_layer_names(nn_layers: List[nn.Module]) -> List[str]:
return [nn_layer._get_name() for nn_layer in nn_layers]
-def _get_rand_tensor(shape: Union[list, tuple], batch_size: int = 4) -> torch.Tensor:
+def _get_rand_tensor(shape: Union[list, dict], batch_size: int = 4) -> torch.Tensor:
'''Get a random tensor given a shape and a batch size'''
return torch.rand([batch_size] + list(shape))
-def get_rand_tensor(shapes: Union[List[int], Dict[str, list]], batch_size: int = 4) -> Union[torch.Tensor, NamedTuple]:
- '''Get a random tensor tuple with default batch size for a dict of shapes'''
- if ps.is_dict(shapes):
- TensorTuple = namedtuple('TensorTuple', shapes.keys())
- return TensorTuple(*[_get_rand_tensor(shape, batch_size) for shape in shapes.values()])
+def get_rand_tensor(shapes: Union[List[int], Dict[str, list]], batch_size: int = 4) -> Union[torch.Tensor, dict]:
+ '''Get a random tensor dict with default batch size for a dict of shapes'''
+ if isinstance(shapes, dict):
+ return {k: _get_rand_tensor(shape, batch_size) for k, shape in shapes.items()}
else:
return _get_rand_tensor(shapes, batch_size)
-
-
-def to_namedtuple(data: dict, name='NamedTensor') -> NamedTuple:
- '''Convert a dictionary to namedtuple.'''
- return namedtuple(name, data)(**data)
From 6fccf2a2f3895e823b27dcdd18f163bf3bdf149d Mon Sep 17 00:00:00 2001
From: kengz
Date: Sat, 26 Jun 2021 11:19:24 -0700
Subject: [PATCH 2/2] chore(ci): update github actions CI and tag-release
---
.github/workflows/ci.yml | 54 ++++++++++++++++++++++++-------
.github/workflows/tag-release.yml | 30 +++++++++++++++++
setup.py | 3 --
3 files changed, 73 insertions(+), 14 deletions(-)
create mode 100644 .github/workflows/tag-release.yml
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 5eb1080..645e497 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -4,10 +4,39 @@ on:
push:
branches: [main]
pull_request:
- branches: [main]
+ branches: ["**"]
jobs:
+ lint:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out Git repository
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
+
+ - uses: liskin/gh-problem-matcher-wrap@v1
+ with:
+ action: add
+ linters: flake8
+
+ - name: Lint with flake8
+ run: |
+ pip install flake8
+ # exit-zero treats all errors as warnings.
+ flake8 . --ignore=E501 --count --exit-zero --statistics
+
+ - uses: liskin/gh-problem-matcher-wrap@v1
+ with:
+ action: remove
+ linters: flake8
+
build:
+ needs: lint
runs-on: ubuntu-latest
steps:
@@ -35,16 +64,19 @@ jobs:
conda info
conda list
- - name: Setup flake8 annotations
- uses: rbialon/flake8-annotations@v1
- - name: Lint with flake8
- shell: bash -l {0}
- run: |
- pip install flake8
- # exit-zero treats all errors as warnings.
- flake8 . --ignore=E501 --count --exit-zero --statistics
+ - uses: liskin/gh-problem-matcher-wrap@v1
+ with:
+ action: add
+ linters: pytest
- name: Run tests
shell: bash -l {0}
- run: |
- python setup.py test
+ run: python setup.py test | tee pytest-coverage.txt
+
+ - name: Post coverage to PR comment
+ uses: coroo/pytest-coverage-commentator@v1.0.2
+
+ - uses: liskin/gh-problem-matcher-wrap@v1
+ with:
+ action: add
+ linters: pytest
diff --git a/.github/workflows/tag-release.yml b/.github/workflows/tag-release.yml
new file mode 100644
index 0000000..b187658
--- /dev/null
+++ b/.github/workflows/tag-release.yml
@@ -0,0 +1,30 @@
+# tag ref: https://github.com/marketplace/actions/github-tag#bumping
+# commit msg format for tag: https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#-git-commit-guidelines
+name: Tag and release version
+
+on:
+ push:
+ branches: [main]
+
+jobs:
+ tag_and_release:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Bump version and push tag
+ id: tag_version
+ uses: mathieudutour/github-tag-action@v5.1
+ with:
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ tag_prefix: ''
+
+ - name: Create a GitHub release
+ uses: actions/create-release@v1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ tag_name: ${{ steps.tag_version.outputs.new_tag }}
+ release_name: Release ${{ steps.tag_version.outputs.new_tag }}
+ body: ${{ steps.tag_version.outputs.changelog }}
diff --git a/setup.py b/setup.py
index 6c6b74b..5a242b8 100644
--- a/setup.py
+++ b/setup.py
@@ -11,9 +11,6 @@
'--log-file-level=INFO',
'--no-flaky-report',
'--timeout=300',
- '--cov-report=html',
- '--cov-report=term',
- '--cov-report=xml',
'--cov=torcharc',
'test',
]