Skip to content

Commit

Permalink
Merge pull request #45 from f-dangel/release
Browse files Browse the repository at this point in the history
Release BackPACK1.1.0
  • Loading branch information
fKunstner authored Feb 11, 2020
2 parents 4e7fed2 + e710df9 commit 3122de0
Show file tree
Hide file tree
Showing 178 changed files with 4,098 additions and 2,736 deletions.
14 changes: 14 additions & 0 deletions .conda_env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: backpack
channels:
- pytorch
- defaults
dependencies:
- cudatoolkit=9.2=0
- pip=19.3.1
- python=3.7.6
- pytorch=1.3.1=py3.7_cuda9.2.148_cudnn7.6.3_0
- torchvision=0.4.2=py37_cu92
- pip:
- -r requirements.txt
- -r requirements-dev.txt
- -e .
40 changes: 40 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[flake8]
select = B,C,E,F,P,W,B9
max-line-length = 80
max-complexity = 10
ignore =
# replaced by B950 (max-line-length + 10%)
E501, # max-line-length
# ignored because pytorch uses dict
C408, # use {} instead of dict()
# Not Black-compatible
E203, # whitespace before :
E231, # missing whitespace after ','
W291, # trailing whitespace
W503, # line break before binary operator
W504, # line break after binary operator
exclude = docs, docs_src, build, .git


# Differences with pytorch
#
# Smaller max-line-length
# Enabled max-complexity
# No flake8-mypy (T4 range)
#
# Set of rules ignore by pytorch, probably to get around the C
#
# F401 (import unused in __init__.py) not ignored
# F403 'from module import *' used; unable to detect undefined names
# F405 Name may be undefined, or defined from star imports: module
# F821 Undefined name name
# F841 Local variable name is assigned to but never used
#
# Pytorch ignored rules that I don't see a reason to ignore (yet?):
#
# E305 Expected 2 blank lines after end of function or class
# E402 Module level import not at top of file
# E721 Do not compare types, use 'isinstance()'
# E741 Do not use variables named 'l', 'o', or 'i'
# E302 Expected 2 blank lines, found 0
# E303 Too many blank lines (3)
38 changes: 38 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Lint

on:
push:
branches:
- pep8-style

jobs:
flake8:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Install dependencies
run: |
python -m pip install --upgrade pip
make install-lint
- name: Run flake8
run: |
make flake8
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Install dependencies
run: |
python -m pip install --upgrade pip
make install-lint
- name: Run black
run: |
make black-check
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
__pycache__/
.mypy_cache
*.egg-info/
**/*.pyc
.cache
examples/data
.idea
.coverage
dist/*
build/*
7 changes: 7 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[settings]
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88
skip_glob=docs/*,docs_src/*
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
args: [--config=black.toml]
- repo: https://gitlab.com/pycqa/flake8
rev: '3.7.9'
hooks:
- id: flake8
additional_dependencies: [
mccabe,
pycodestyle,
pyflakes,
pep8-naming,
flake8-bugbear,
flake8-comprehensions,
]
22 changes: 22 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
language: python
python:
- '3.5'
- '3.6'
- '3.7'
install:
- pip install -r requirements.txt
- pip install -r requirements/test.txt
- pip install .
- pip install pillow==6.1.0
cache:
- pip
script:
- pytest -vx --cov=backpack/ .
- python examples/run_examples.py
after_success:
- coveralls
notifications:
email: false
slack:
secure: qAK64wEVkRC57IrNMqXetPoqWLGkHId5ayhzoRYzFfuiMuTKlG+Dwaif/TixjjKwu9vdLyuX4+0gi6IVFB9UZ0+bgMBkbh4rugPiINliiqFi91Z8Kl9ns/qmhbfKnCKwYkU+vkjuUsuHhe/3dV3XUs3RgQaJBIP4iTu1ayTbIB1QIyQJDBnlC+65mKA0qxMEIuvOYZDemDsr747583UFCcx2EC4daZuANeQTwFDnDx9TVnNJheblZ8AqH0JnoOQRJo3iPLBtxo9jDpbPupew9oY3dDS5J/+FgjYw5oGDroyM7TcP8q+HkCkmUtX9DU/DgpZgqd6Ysk9jgPK3k1uqq5oOKZ1jCs64c9K+ayekaM6wJPdJXcZSH3JCUhQc9q+xkdq3hEp27dNQcZbK7YT1nkF9MzR+StaMZsMIx5nvO6n1onBU+oTZYPURfizCnB9a/jDmMNKNkhmgit/1MWbRVMnc+YgH3R5SBhu4W2HBlw3hYQ2vgpok0lZ6mzHXnw67q91yncDXqAOx/7rPzXGm0E8iH1sfhWs26IFXOgUMPBRN/lUUznWNrdG5Ht6dTF2cRVjB8cUlJPyy5Uhv1YFnePXgSMBg7fZMEG3qMJn0V3raJT842LKdZhJRXyCX9fPacpbvKQ97mrw62KfzJ4HGpsilqOBNd2kwyvXqatBQaGU=
on_success: never
10 changes: 10 additions & 0 deletions README-dev.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Basics about the development setup

|-|-|
|-|-|
| Python version | The subset of Python 3 and Pytorch (`3.5, 3.6, 3.7`) and use `3.7` for development |
| Tooling management | [`make`](https://www.gnu.org/software/make/) as an interface to the dev tools ([makefile](makefile)) |
| Testing | [`pytest`](https://docs.pytest.org) ([testing readme](test/readme.md))
| Style | [`black`](https://black.readthedocs.io) ([rules](black.toml)) for formatting and [`flake8`](http://flake8.pycqa.org/) ([rules](.flake8)) for linting |
| CI/QA | [`Travis`](https://travis-ci.org/f-dangel/backpack) ([config](.travis.yaml)) to run tests and [`Github workflows`](https://github.com/f-dangel/backpack/actions) ([config](.github/workflows)) to check formatting and linting |

20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# BACKpropagation PACKage - a backpack for `PyTorch`

| branch | tests & examples | coverage |
|--------|---------------------------|----------|
|`master` | [![Build Status](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack) | [![Coverage Status](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) |
| `development` | [![Build Status](https://travis-ci.org/f-dangel/backpack.svg?branch=development)](https://travis-ci.org/f-dangel/backpack) | [![Coverage Status](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=development)](https://coveralls.io/github/f-dangel/backpack) |

A backpack for PyTorch that extends the backward pass of feedforward networks to compute quantities beyond the gradient.

Check out the [examples](https://f-dangel.github.io/backpack/) on how to use the code.
- Check out the [cheatsheet](examples/cheatsheet.pdf) for an overview of quantities.
- Check out the [examples](https://f-dangel.github.io/backpack/) on how to use the code.

## Installation
```bash
Expand All @@ -15,3 +21,15 @@ git clone https://github.com/f-dangel/backpack.git ~/backpack
cd ~/backpack
python setup.py install
```

## How to cite
If you are using `backpack` for your research, consider citing the [paper](https://openreview.net/forum?id=BJlrF24twB)
```
@inproceedings{dangel2020backpack,
title = {Back{PACK}: Packing more into Backprop},
author = {Felix Dangel and Frederik Kunstner and Philipp Hennig},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=BJlrF24twB}
}
```
118 changes: 57 additions & 61 deletions backpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
BackPACK
"""
import torch
from .context import CTX

from . import extensions
from .context import CTX

class backpack():

class backpack:
"""
Activates the BackPACK extensions passed as arguments for the
:code:`backward` calls in the current :code:`with` block.
"""

def __init__(self, *args):
def __init__(self, *args, debug=False):
"""
Activate the Backpack extensions.
Expand Down Expand Up @@ -39,16 +41,61 @@ def __init__(self, *args):
Parameters:
args: [BackpropExtension]
The extensions to activate for the backward pass.
debug: Bool, optional (default: False)
If true, will print debug messages during the backward pass.
"""
self.args = args
self.debug = debug

def __enter__(self):
self.old_CTX = CTX.get_active_exts()
self.old_debug = CTX.get_debug()
CTX.set_active_exts(self.args)
CTX.set_debug(self.debug)

def __exit__(self, type, value, traceback):
CTX.set_active_exts(self.old_CTX)
CTX.clear()
CTX.set_debug(self.old_debug)


def hook_store_io(module, input, output):
for i in range(len(input)):
setattr(module, "input{}".format(i), input[i])
module.output = output


def hook_store_shapes(module, input, output):
"""Store dimensionality of output as buffer."""
for i in range(len(input)):
module.register_buffer(
"input{}_shape".format(i), torch.IntTensor([*input[i].size()])
)
module.register_buffer("output_shape", torch.IntTensor([*output.size()]))


def memory_cleanup(module):
if hasattr(module, "output"):
delattr(module, "output")
if hasattr(module, "output_shape"):
delattr(module, "output_shape")
i = 0
while hasattr(module, "input{}".format(i)):
delattr(module, "input{}".format(i))
i += 1
i = 0
while hasattr(module, "input{}_shape".format(i)):
delattr(module, "input{}_shape".format(i))
i += 1


def hook_run_extensions(module, g_inp, g_out):
for backpack_extension in CTX.get_active_exts():
if CTX.get_debug():
print("[DEBUG] Running extension", backpack_extension, "on", module)
backpack_extension.apply(module, g_inp, g_out)

if not CTX.is_extension_active(extensions.curvmatprod.CMP):
memory_cleanup(module)


def extend(module, debug=False):
Expand All @@ -60,7 +107,7 @@ def extend(module, debug=False):
module: torch.nn.Module
The module to extend
debug: Bool, optional (default: False)
If true, will print debug messages during the extension and backward.
If true, will print debug messages during the extension.
"""
if debug:
print("[DEBUG] Extending", module)
Expand All @@ -69,61 +116,10 @@ def extend(module, debug=False):
extend(child, debug=debug)

module_was_already_extended = getattr(module, "_backpack_extend", False)
if module_was_already_extended:
return module

def store_io(module, input, output):
for i in range(len(input)):
setattr(module, 'input{}'.format(i), input[i])
setattr(module, 'output', output)

def store_shapes(module, input, output):
"""Store dimensionality of output as buffer."""
for i in range(len(input)):
module.register_buffer(
'input{}_shape'.format(i),
torch.IntTensor([*input[i].size()])
)
module.register_buffer(
'output_shape',
torch.IntTensor([*output.size()])
)
if not module_was_already_extended:
CTX.add_hook_handle(module.register_forward_hook(hook_store_io))
CTX.add_hook_handle(module.register_forward_hook(hook_store_shapes))
CTX.add_hook_handle(module.register_backward_hook(hook_run_extensions))
module._backpack_extend = True

def memory_cleanup(module):
if hasattr(module, "output"):
delattr(module, "output")
if hasattr(module, "output_shape"):
delattr(module, "output_shape")
i = 0
while hasattr(module, "input{}".format(i)):
delattr(module, "input{}".format(i))
i += 1
i = 0
while hasattr(module, "input{}_shape".format(i)):
delattr(module, "input{}_shape".format(i))
i += 1

def run_extensions(module_, g_inp, g_out):
for backpack_extension in CTX.get_active_exts():
if debug:
print(
"[DEBUG] Running extension", backpack_extension,
"on", module
)
backpack_extension.apply(module_, g_inp, g_out)

def extension_contain_curvmatprod():
for backpack_ext in CTX.get_active_exts():
if isinstance(backpack_ext, extensions.curvmatprod.CMP):
return True
return False

if not extension_contain_curvmatprod():
memory_cleanup(module_)

CTX.add_hook_handle(module.register_forward_hook(store_io))
CTX.add_hook_handle(module.register_forward_hook(store_shapes))
CTX.add_hook_handle(module.register_backward_hook(run_extensions))

setattr(module, "_backpack_extend", True)
return module
Loading

0 comments on commit 3122de0

Please sign in to comment.