diff --git a/.conda_env.yml b/.conda_env.yml new file mode 100644 index 00000000..f812dc4f --- /dev/null +++ b/.conda_env.yml @@ -0,0 +1,14 @@ +name: backpack +channels: + - pytorch + - defaults +dependencies: + - cudatoolkit=9.2=0 + - pip=19.3.1 + - python=3.7.6 + - pytorch=1.3.1=py3.7_cuda9.2.148_cudnn7.6.3_0 + - torchvision=0.4.2=py37_cu92 + - pip: + - -r requirements.txt + - -r requirements-dev.txt + - -e . diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..6b377e5a --- /dev/null +++ b/.flake8 @@ -0,0 +1,40 @@ +[flake8] +select = B,C,E,F,P,W,B9 +max-line-length = 80 +max-complexity = 10 +ignore = + # replaced by B950 (max-line-length + 10%) + E501, # max-line-length + # ignored because pytorch uses dict + C408, # use {} instead of dict() + # Not Black-compatible + E203, # whitespace before : + E231, # missing whitespace after ',' + W291, # trailing whitespace + W503, # line break before binary operator + W504, # line break after binary operator +exclude = docs, docs_src, build, .git + + +# Differences with pytorch +# +# Smaller max-line-length +# Enabled max-complexity +# No flake8-mypy (T4 range) +# +# Set of rules ignore by pytorch, probably to get around the C +# +# F401 (import unused in __init__.py) not ignored +# F403 'from module import *' used; unable to detect undefined names +# F405 Name may be undefined, or defined from star imports: module +# F821 Undefined name name +# F841 Local variable name is assigned to but never used +# +# Pytorch ignored rules that I don't see a reason to ignore (yet?): +# +# E305 Expected 2 blank lines after end of function or class +# E402 Module level import not at top of file +# E721 Do not compare types, use 'isinstance()' +# E741 Do not use variables named 'l', 'o', or 'i' +# E302 Expected 2 blank lines, found 0 +# E303 Too many blank lines (3) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..27f20ca6 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,38 @@ +name: Lint + +on: + push: + branches: + - pep8-style + +jobs: + flake8: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + make install-lint + - name: Run flake8 + run: | + make flake8 + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + make install-lint + - name: Run black + run: | + make black-check diff --git a/.gitignore b/.gitignore index aa057f3d..30acdf7f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ __pycache__/ +.mypy_cache *.egg-info/ **/*.pyc .cache examples/data .idea +.coverage dist/* build/* diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..be0f1bf2 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,7 @@ +[settings] +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88 +skip_glob=docs/*,docs_src/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..76171893 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: https://github.com/psf/black + rev: stable + hooks: + - id: black + args: [--config=black.toml] +- repo: https://gitlab.com/pycqa/flake8 + rev: '3.7.9' + hooks: + - id: flake8 + additional_dependencies: [ + mccabe, + pycodestyle, + pyflakes, + pep8-naming, + flake8-bugbear, + flake8-comprehensions, + ] diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..6301bfe0 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,22 @@ +language: python +python: +- '3.5' +- '3.6' +- '3.7' +install: +- pip install -r requirements.txt +- pip install -r requirements/test.txt +- pip install . +- pip install pillow==6.1.0 +cache: +- pip +script: +- pytest -vx --cov=backpack/ . +- python examples/run_examples.py +after_success: +- coveralls +notifications: + email: false + slack: + secure: qAK64wEVkRC57IrNMqXetPoqWLGkHId5ayhzoRYzFfuiMuTKlG+Dwaif/TixjjKwu9vdLyuX4+0gi6IVFB9UZ0+bgMBkbh4rugPiINliiqFi91Z8Kl9ns/qmhbfKnCKwYkU+vkjuUsuHhe/3dV3XUs3RgQaJBIP4iTu1ayTbIB1QIyQJDBnlC+65mKA0qxMEIuvOYZDemDsr747583UFCcx2EC4daZuANeQTwFDnDx9TVnNJheblZ8AqH0JnoOQRJo3iPLBtxo9jDpbPupew9oY3dDS5J/+FgjYw5oGDroyM7TcP8q+HkCkmUtX9DU/DgpZgqd6Ysk9jgPK3k1uqq5oOKZ1jCs64c9K+ayekaM6wJPdJXcZSH3JCUhQc9q+xkdq3hEp27dNQcZbK7YT1nkF9MzR+StaMZsMIx5nvO6n1onBU+oTZYPURfizCnB9a/jDmMNKNkhmgit/1MWbRVMnc+YgH3R5SBhu4W2HBlw3hYQ2vgpok0lZ6mzHXnw67q91yncDXqAOx/7rPzXGm0E8iH1sfhWs26IFXOgUMPBRN/lUUznWNrdG5Ht6dTF2cRVjB8cUlJPyy5Uhv1YFnePXgSMBg7fZMEG3qMJn0V3raJT842LKdZhJRXyCX9fPacpbvKQ97mrw62KfzJ4HGpsilqOBNd2kwyvXqatBQaGU= + on_success: never diff --git a/README-dev.md b/README-dev.md new file mode 100644 index 00000000..59b79145 --- /dev/null +++ b/README-dev.md @@ -0,0 +1,10 @@ +Basics about the development setup + +|-|-| +|-|-| +| Python version | The subset of Python 3 and Pytorch (`3.5, 3.6, 3.7`) and use `3.7` for development | +| Tooling management | [`make`](https://www.gnu.org/software/make/) as an interface to the dev tools ([makefile](makefile)) | +| Testing | [`pytest`](https://docs.pytest.org) ([testing readme](test/readme.md)) +| Style | [`black`](https://black.readthedocs.io) ([rules](black.toml)) for formatting and [`flake8`](http://flake8.pycqa.org/) ([rules](.flake8)) for linting | +| CI/QA | [`Travis`](https://travis-ci.org/f-dangel/backpack) ([config](.travis.yaml)) to run tests and [`Github workflows`](https://github.com/f-dangel/backpack/actions) ([config](.github/workflows)) to check formatting and linting | + diff --git a/README.md b/README.md index 11f1da9f..74528a26 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,14 @@ # BACKpropagation PACKage - a backpack for `PyTorch` +| branch | tests & examples | coverage | +|--------|---------------------------|----------| +|`master` | [![Build Status](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack) | [![Coverage Status](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) | +| `development` | [![Build Status](https://travis-ci.org/f-dangel/backpack.svg?branch=development)](https://travis-ci.org/f-dangel/backpack) | [![Coverage Status](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=development)](https://coveralls.io/github/f-dangel/backpack) | + A backpack for PyTorch that extends the backward pass of feedforward networks to compute quantities beyond the gradient. -Check out the [examples](https://f-dangel.github.io/backpack/) on how to use the code. +- Check out the [cheatsheet](examples/cheatsheet.pdf) for an overview of quantities. +- Check out the [examples](https://f-dangel.github.io/backpack/) on how to use the code. ## Installation ```bash @@ -15,3 +21,15 @@ git clone https://github.com/f-dangel/backpack.git ~/backpack cd ~/backpack python setup.py install ``` + +## How to cite +If you are using `backpack` for your research, consider citing the [paper](https://openreview.net/forum?id=BJlrF24twB) +``` +@inproceedings{dangel2020backpack, + title = {Back{PACK}: Packing more into Backprop}, + author = {Felix Dangel and Frederik Kunstner and Philipp Hennig}, + booktitle = {International Conference on Learning Representations}, + year = {2020}, + url = {https://openreview.net/forum?id=BJlrF24twB} +} +``` diff --git a/backpack/__init__.py b/backpack/__init__.py index 1b6a67bb..dfa11c3a 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -2,16 +2,18 @@ BackPACK """ import torch -from .context import CTX + from . import extensions +from .context import CTX -class backpack(): + +class backpack: """ Activates the BackPACK extensions passed as arguments for the :code:`backward` calls in the current :code:`with` block. """ - def __init__(self, *args): + def __init__(self, *args, debug=False): """ Activate the Backpack extensions. @@ -39,16 +41,61 @@ def __init__(self, *args): Parameters: args: [BackpropExtension] The extensions to activate for the backward pass. + debug: Bool, optional (default: False) + If true, will print debug messages during the backward pass. """ self.args = args + self.debug = debug def __enter__(self): self.old_CTX = CTX.get_active_exts() + self.old_debug = CTX.get_debug() CTX.set_active_exts(self.args) + CTX.set_debug(self.debug) def __exit__(self, type, value, traceback): CTX.set_active_exts(self.old_CTX) - CTX.clear() + CTX.set_debug(self.old_debug) + + +def hook_store_io(module, input, output): + for i in range(len(input)): + setattr(module, "input{}".format(i), input[i]) + module.output = output + + +def hook_store_shapes(module, input, output): + """Store dimensionality of output as buffer.""" + for i in range(len(input)): + module.register_buffer( + "input{}_shape".format(i), torch.IntTensor([*input[i].size()]) + ) + module.register_buffer("output_shape", torch.IntTensor([*output.size()])) + + +def memory_cleanup(module): + if hasattr(module, "output"): + delattr(module, "output") + if hasattr(module, "output_shape"): + delattr(module, "output_shape") + i = 0 + while hasattr(module, "input{}".format(i)): + delattr(module, "input{}".format(i)) + i += 1 + i = 0 + while hasattr(module, "input{}_shape".format(i)): + delattr(module, "input{}_shape".format(i)) + i += 1 + + +def hook_run_extensions(module, g_inp, g_out): + for backpack_extension in CTX.get_active_exts(): + if CTX.get_debug(): + print("[DEBUG] Running extension", backpack_extension, "on", module) + backpack_extension.apply(module, g_inp, g_out) + + if not CTX.is_extension_active(extensions.curvmatprod.CMP): + memory_cleanup(module) def extend(module, debug=False): @@ -60,7 +107,7 @@ def extend(module, debug=False): module: torch.nn.Module The module to extend debug: Bool, optional (default: False) - If true, will print debug messages during the extension and backward. + If true, will print debug messages during the extension. """ if debug: print("[DEBUG] Extending", module) @@ -69,61 +116,10 @@ def extend(module, debug=False): extend(child, debug=debug) module_was_already_extended = getattr(module, "_backpack_extend", False) - if module_was_already_extended: - return module - - def store_io(module, input, output): - for i in range(len(input)): - setattr(module, 'input{}'.format(i), input[i]) - setattr(module, 'output', output) - - def store_shapes(module, input, output): - """Store dimensionality of output as buffer.""" - for i in range(len(input)): - module.register_buffer( - 'input{}_shape'.format(i), - torch.IntTensor([*input[i].size()]) - ) - module.register_buffer( - 'output_shape', - torch.IntTensor([*output.size()]) - ) + if not module_was_already_extended: + CTX.add_hook_handle(module.register_forward_hook(hook_store_io)) + CTX.add_hook_handle(module.register_forward_hook(hook_store_shapes)) + CTX.add_hook_handle(module.register_backward_hook(hook_run_extensions)) + module._backpack_extend = True - def memory_cleanup(module): - if hasattr(module, "output"): - delattr(module, "output") - if hasattr(module, "output_shape"): - delattr(module, "output_shape") - i = 0 - while hasattr(module, "input{}".format(i)): - delattr(module, "input{}".format(i)) - i += 1 - i = 0 - while hasattr(module, "input{}_shape".format(i)): - delattr(module, "input{}_shape".format(i)) - i += 1 - - def run_extensions(module_, g_inp, g_out): - for backpack_extension in CTX.get_active_exts(): - if debug: - print( - "[DEBUG] Running extension", backpack_extension, - "on", module - ) - backpack_extension.apply(module_, g_inp, g_out) - - def extension_contain_curvmatprod(): - for backpack_ext in CTX.get_active_exts(): - if isinstance(backpack_ext, extensions.curvmatprod.CMP): - return True - return False - - if not extension_contain_curvmatprod(): - memory_cleanup(module_) - - CTX.add_hook_handle(module.register_forward_hook(store_io)) - CTX.add_hook_handle(module.register_forward_hook(store_shapes)) - CTX.add_hook_handle(module.register_backward_hook(run_extensions)) - - setattr(module, "_backpack_extend", True) return module diff --git a/backpack/context.py b/backpack/context.py index d7dc9bdb..15512e44 100644 --- a/backpack/context.py +++ b/backpack/context.py @@ -1,23 +1,10 @@ -import warnings - - -def get_from_ctx(name): - value = CTX.backproped_quantities.get(name, None) - if value is None: - warnings.warn("The attribute {} does not exist in CTX".format(name)) - return value - - -def set_in_ctx(name, value): - CTX.backproped_quantities[name] = value - - class CTX: """ Global Class holding the configuration of the backward pass """ + active_exts = tuple() - backproped_quantities = {} + debug = False @staticmethod def set_active_exts(active_exts): @@ -42,6 +29,16 @@ def remove_hooks(): CTX.hook_handles = [] @staticmethod - def clear(): - del CTX.backproped_quantities - CTX.backproped_quantities = {} + def is_extension_active(extension_class): + for backpack_ext in CTX.get_active_exts(): + if isinstance(backpack_ext, extension_class): + return True + return False + + @staticmethod + def get_debug(): + return CTX.debug + + @staticmethod + def set_debug(debug): + CTX.debug = debug diff --git a/backpack/core/derivatives/__init__.py b/backpack/core/derivatives/__init__.py index c379dece..6feb2f6b 100644 --- a/backpack/core/derivatives/__init__.py +++ b/backpack/core/derivatives/__init__.py @@ -1,26 +1,33 @@ -from torch.nn import Sigmoid, Tanh, ReLU, Dropout, ZeroPad2d, MaxPool2d, Linear, AvgPool2d, Conv2d -from backpack.core.layers import LinearConcat, Conv2dConcat -from .linear import LinearDerivatives, LinearConcatDerivatives -from .conv2d import Conv2DDerivatives, Conv2DConcatDerivatives +from torch.nn import ( + AvgPool2d, + Conv2d, + Dropout, + Linear, + MaxPool2d, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + from .avgpool2d import AvgPool2DDerivatives -from .maxpool2d import MaxPool2DDerivatives -from .zeropad2d import ZeroPad2dDerivatives +from .conv2d import Conv2DDerivatives from .dropout import DropoutDerivatives +from .linear import LinearDerivatives +from .maxpool2d import MaxPool2DDerivatives from .relu import ReLUDerivatives from .sigmoid import SigmoidDerivatives from .tanh import TanhDerivatives - +from .zeropad2d import ZeroPad2dDerivatives derivatives_for = { Linear: LinearDerivatives, - LinearConcat: LinearConcatDerivatives, Conv2d: Conv2DDerivatives, - Conv2dConcat: Conv2DConcatDerivatives, AvgPool2d: AvgPool2DDerivatives, MaxPool2d: MaxPool2DDerivatives, ZeroPad2d: ZeroPad2dDerivatives, Dropout: DropoutDerivatives, ReLU: ReLUDerivatives, Tanh: TanhDerivatives, - Sigmoid: SigmoidDerivatives -} \ No newline at end of file + Sigmoid: SigmoidDerivatives, +} diff --git a/backpack/core/derivatives/avgpool2d.py b/backpack/core/derivatives/avgpool2d.py index e499a98f..8b4c2554 100644 --- a/backpack/core/derivatives/avgpool2d.py +++ b/backpack/core/derivatives/avgpool2d.py @@ -1,15 +1,11 @@ """The code relies on the insight that average pooling can be understood as convolution over single channels with a constant kernel.""" -import warnings - import torch.nn from torch.nn import AvgPool2d, Conv2d, ConvTranspose2d -from ...utils import conv as convUtils -from ...utils.utils import einsum, random_psd_matrix -from .basederivatives import BaseDerivatives -from .utils import jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.utils.ein import eingroup, einsum class AvgPool2DDerivatives(BaseDerivatives): @@ -29,7 +25,7 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): # 1) apply conv_transpose to multiply with W^T result = mat.view(channels, out_x, out_y, out_features) - result = einsum('cxyf->fcxy', (result, )).contiguous() + result = einsum("cxyf->fcxy", (result,)).contiguous() result = result.view(out_features * channels, 1, out_x, out_y) # result: W^T mat result = self.__apply_jacobian_t_of(module, result) @@ -45,52 +41,38 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): # 4) transpose to obtain W^T mat W return result.view(in_features, in_features).t() - # Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_mat_prod(self, module, g_inp, g_out, mat): - assert module.count_include_pad, "Might now work for exotic hyperparameters of AvgPool2d, like count_include_pad=False" + def check_exotic_parameters(self, module): + assert module.count_include_pad, ( + "Might not work for exotic hyperparameters of AvgPool2d, " + + "like count_include_pad=False" + ) - convUtils.check_sizes_input_jac(mat, module) - mat_as_pool = self.__reshape_for_conv(mat, module) - jmp_as_pool = self.__apply_jacobian_of(module, mat_as_pool) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + self.check_exotic_parameters(module) - batch, channels, out_x, out_y = module.output_shape - num_classes = mat.size(2) - assert jmp_as_pool.size(0) == num_classes * batch * channels - assert jmp_as_pool.size(1) == 1 - assert jmp_as_pool.size(2) == out_x - assert jmp_as_pool.size(3) == out_y + mat_as_pool = self.__make_single_channel(mat, module) + jmp_as_pool = self.__apply_jacobian_of(module, mat_as_pool) + self.__check_jmp_out_as_pool(mat, jmp_as_pool, module) - return self.__reshape_for_matmul(jmp_as_pool, module) + return self.view_like_output(jmp_as_pool, module) + # return self.__view_as_output(jmp_as_pool, module) - def __reshape_for_conv(self, mat, module): + def __make_single_channel(self, mat, module): """Create fake single-channel images, grouping batch, class and channel dimension.""" - batch, in_channels, in_x, in_y = module.input0.size() - num_columns = mat.size(-1) - - # 'fake' image for convolution - # (batch * class * channel, 1, out_x, out_y) - return einsum('bic->bci', - mat).contiguous().view(batch * num_columns * in_channels, - 1, in_x, in_y) - - def __reshape_for_matmul(self, mat, module): - """Ungroup dimensions after application of Jacobian.""" - batch, channels, out_x, out_y = module.output_shape - features = channels * out_x * out_y - # mat is of shape (batch * class * channel, 1, out_x, out_y) - # move class dimension to last - mat_view = mat.view(batch, -1, features) - return einsum('bci->bic', mat_view).contiguous() + result = eingroup("v,n,c,w,h->vnc,w,h", mat) + C_axis = 1 + return result.unsqueeze(C_axis) def __apply_jacobian_of(self, module, mat): - conv2d = Conv2d(in_channels=1, - out_channels=1, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - bias=False).to(module.input0.device) + conv2d = Conv2d( + in_channels=1, + out_channels=1, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + bias=False, + ).to(module.input0.device) conv2d.weight.requires_grad = False avg_kernel = torch.ones_like(conv2d.weight) / conv2d.weight.numel() @@ -98,58 +80,43 @@ def __apply_jacobian_of(self, module, mat): return conv2d(mat) - # Transpose Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): + def __check_jmp_out_as_pool(self, mat, jmp_as_pool, module): + V = mat.size(0) + N, C_out, H_out, W_out = module.output_shape + assert jmp_as_pool.shape == (V * N * C_out, 1, H_out, W_out) - assert module.count_include_pad, "Might now work for exotic hyperparameters of AvgPool2d, like count_include_pad=False" + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + self.check_exotic_parameters(module) - convUtils.check_sizes_input_jac_t(mat, module) - mat_as_pool = self.__reshape_for_conv_t(mat, module) + mat_as_pool = self.__make_single_channel(mat, module) jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool) + self.__check_jmp_in_as_pool(mat, jmp_as_pool, module) - batch, channels, in_x, in_y = module.input0.size() - num_classes = mat.size(2) - assert jmp_as_pool.size(0) == num_classes * batch * channels - assert jmp_as_pool.size(1) == 1 - assert jmp_as_pool.size(2) == in_x - assert jmp_as_pool.size(3) == in_y - - return self.__reshape_for_matmul_t(jmp_as_pool, module) - - def __reshape_for_conv_t(self, mat, module): - """Create fake single-channel images, grouping batch, - class and channel dimension.""" - batch, out_channels, out_x, out_y = module.output_shape - num_classes = mat.size(-1) - - # 'fake' image for convolution - # (batch * class * channel, 1, out_x, out_y) - return einsum('bic->bci', mat).contiguous().view( - batch * num_classes * out_channels, 1, out_x, out_y) - - def __reshape_for_matmul_t(self, mat, module): - """Ungroup dimensions after application of Jacobian.""" - batch, channels, in_x, in_y = module.input0.size() - features = channels * in_x * in_y - # mat is of shape (batch * class * channel, 1, in_x, in_y) - # move class dimension to last - mat_view = mat.view(batch, -1, features) - return einsum('bci->bic', mat_view).contiguous() + return self.view_like_input(jmp_as_pool, module) def __apply_jacobian_t_of(self, module, mat): - _, _, in_x, in_y = module.input0.size() - output_size = (mat.size(0), 1, in_x, in_y) + C_for_conv_t = 1 - conv2d_t = ConvTranspose2d(in_channels=1, - out_channels=1, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - bias=False).to(module.input0.device) + conv2d_t = ConvTranspose2d( + in_channels=C_for_conv_t, + out_channels=C_for_conv_t, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + bias=False, + ).to(module.input0.device) conv2d_t.weight.requires_grad = False avg_kernel = torch.ones_like(conv2d_t.weight) / conv2d_t.weight.numel() conv2d_t.weight.data = avg_kernel + V_N_C_in = mat.size(0) + _, _, H_in, W_in = module.input0.size() + output_size = (V_N_C_in, C_for_conv_t, H_in, W_in) + return conv2d_t(mat, output_size=output_size) + + def __check_jmp_in_as_pool(self, mat, jmp_as_pool, module): + V = mat.size(0) + N, C_in, H_in, W_in = module.input0_shape + assert jmp_as_pool.shape == (V * N * C_in, 1, H_in, W_in) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index ab832896..ef75bc15 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -1,13 +1,90 @@ -class BaseDerivatives(): +from backpack.core.derivatives import shape_check +from backpack.utils.ein import try_view - MC_SAMPLES = 1 +class BaseDerivatives: + """First- and second-order partial derivatives of a module. + + Shape conventions: + ------------------ + * Batch size: N + * Free dimension for vectorization: V + + For vector-processing layers (2d input): + * input [N, C_in], output [N, C_out] + + For image-processing layers (4d input) + * Input/output channels: C_in/C_out + * Input/output height: H_in/H_out + * Input/output width: W_in/W_out + * input [N, C_in, H_in, W_in], output [N, C_out, H_in, W_in] + + + Definitions: + ------------ + * The Jacobian J is defined as + J[n, c, w, ..., ̃n, ̃c, ̃w, ...] + = 𝜕output[n, c, w, ...] / 𝜕input[̃n, ̃c, ̃w, ...] + * The transposed Jacobian Jᵀ is defined as + Jᵀ[̃n, ̃c, ̃w, ..., n, c, w, ...] + = 𝜕output[n, c, w, ...] / 𝜕input[̃n, ̃c, ̃w, ...] + """ + + @shape_check.jac_mat_prod_accept_vectors + @shape_check.jac_mat_prod_check_shapes def jac_mat_prod(self, module, g_inp, g_out, mat): + """Apply Jacobian of the output w.r.t. input to a matrix. + + Implicit application of J: + result[v, n, c, w, ...] + = ∑_{̃n, ̃c, ̃w} J[n, c, w, ..., ̃n, ̃c, ̃w, ...] mat[̃n, ̃c, ̃w, ...]. + Parameters: + ----------- + mat: torch.Tensor + Matrix the Jacobian will be applied to. + Must have shape [V, N, C_in, H_in, ...]. + + Returns: + -------- + result: torch.Tensor + Jacobian-matrix product. + Has shape [V, N, C_out, H_out, ...]. + """ + return self._jac_mat_prod(module, g_inp, g_out, mat) + + def _jac_mat_prod(self, module, g_inp, g_out, mat): + """Internal implementation of the Jacobian.""" raise NotImplementedError + @shape_check.jac_t_mat_prod_accept_vectors + @shape_check.jac_t_mat_prod_check_shapes def jac_t_mat_prod(self, module, g_inp, g_out, mat): + """Apply transposed Jacobian of module output w.r.t. input to a matrix. + + Implicit application of Jᵀ: + result[v, ̃n, ̃c, ̃w, ...] + = ∑_{n, c, w} Jᵀ[̃n, ̃c, ̃w, ..., n, c, w, ...] mat[v, n, c, w, ...]. + + Parameters: + ----------- + mat: torch.Tensor + Matrix the transposed Jacobian will be applied to. + Must have shape [V, N, C_out, H_out, ...]. + + Returns: + -------- + result: torch.Tensor + Transposed Jacobian-matrix product. + Has shape [V, N, C_in, H_in, ...]. + """ + return self._jac_t_mat_prod(module, g_inp, g_out, mat) + + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + """Internal implementation of transposed Jacobian.""" raise NotImplementedError + # TODO Add shape check + # TODO Use new convention def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): raise NotImplementedError @@ -23,17 +100,208 @@ def hessian_diagonal(self): def hessian_is_psd(self): raise NotImplementedError + # TODO make accept vectors + # TODO add shape check + def make_residual_mat_prod(self, module, g_inp, g_out): + """Return multiplication routine with the residual term. + + The function performs the mapping: mat → [∑_{k} Hz_k(x) 𝛿z_k] mat. + (required for extension `curvmatprod`) + + Note: + ----- + This function only has to be implemented if the residual is not + zero and not diagonal (for instance, `BatchNorm`). + """ + raise NotImplementedError + + # TODO Refactor and remove def batch_flat(self, tensor): batch = tensor.size(0) - # TODO: Removing the clone().detach() will destroy the computation graph + # TODO Removing the clone().detach() will destroy the computation graph # Tests will fail return batch, tensor.clone().detach().view(batch, -1) + # TODO Refactor and remove def get_batch(self, module): - return self.get_input(module).size(0) - - def get_input(self, module): - return module.input0 + return module.input0.size(0) + # TODO Refactor and remove def get_output(self, module): return module.output + + @staticmethod + def _view_like(mat, like): + """View as like with trailing and additional 0th dimension. + + If like is [N, C, H, ...], returns shape [-1, N, C, H, ...] + """ + V = -1 + shape = (V, *like.shape) + return try_view(mat, shape) + + @classmethod + def view_like_input(cls, mat, module): + return cls._view_like(mat, module.input0) + + @classmethod + def view_like_output(cls, mat, module): + return cls._view_like(mat, module.output) + + +class BaseParameterDerivatives(BaseDerivatives): + """First- and second order partial derivatives of a module with parameters. + + Assumptions (true for `nn.Linear`, `nn.Conv(Transpose)Nd`, `nn.BatchNormNd`): + - Parameters are saved as `.weight` and `.bias` fields in a module + - The output is linear in the model parameters + + Shape conventions: + ------------------ + Weight [C_w, H_w, W_w, ...] (usually 1d, 2d, 4d) + Bias [C_b, ...] (usually 1d) + + For most layers, these shapes correspond to shapes of the module input or output. + """ + + @shape_check.bias_jac_mat_prod_accept_vectors + @shape_check.bias_jac_mat_prod_check_shapes + def bias_jac_mat_prod(self, module, g_inp, g_out, mat): + """Apply Jacobian of the output w.r.t. bias to a matrix. + + Parameters: + ----------- + mat: torch.Tensor + Matrix the Jacobian will be applied to. + Must have shape [V, C_b, ...]. + + Returns: + -------- + result: torch.Tensor + Jacobian-matrix product. + Has shape [V, N, C_out, H_out, ...]. + """ + return self._bias_jac_mat_prod(module, g_inp, g_out, mat) + + def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): + """Internal implementation of the bias Jacobian.""" + raise NotImplementedError + + @shape_check.bias_jac_t_mat_prod_accept_vectors + @shape_check.bias_jac_t_mat_prod_check_shapes + def bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Apply transposed Jacobian of the output w.r.t. bias to a matrix. + + Parameters: + ----------- + mat: torch.Tensor + Matrix the transposed Jacobian will be applied to. + Must have shape [V, N, C_out, H_out, ...]. + sum_batch: bool + Whether to sum over the batch dimension on the fly. + + Returns: + -------- + result: torch.Tensor + Jacobian-matrix product. + Has shape [V, N, C_b, ...] if `sum_batch == False`. + Has shape [V, C_b, ...] if `sum_batch == True`. + """ + return self._bias_jac_t_mat_prod(module, g_inp, g_out, mat, sum_batch=sum_batch) + + def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Internal implementation of the transposed bias Jacobian.""" + raise NotImplementedError + + @shape_check.weight_jac_mat_prod_accept_vectors + @shape_check.weight_jac_mat_prod_check_shapes + def weight_jac_mat_prod(self, module, g_inp, g_out, mat): + """Apply Jacobian of the output w.r.t. weight to a matrix. + + Parameters: + ----------- + mat: torch.Tensor + Matrix the Jacobian will be applied to. + Must have shape [V, C_w, H_w, ...]. + + Returns: + -------- + result: torch.Tensor + Jacobian-matrix product. + Has shape [V, N, C_out, H_out, ...]. + """ + return self._weight_jac_mat_prod(module, g_inp, g_out, mat) + + def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): + """Internal implementation of weight Jacobian.""" + raise NotImplementedError + + @shape_check.weight_jac_t_mat_prod_accept_vectors + @shape_check.weight_jac_t_mat_prod_check_shapes + def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Apply transposed Jacobian of the output w.r.t. weight to a matrix. + + Parameters: + ----------- + mat: torch.Tensor + Matrix the transposed Jacobian will be applied to. + Must have shape [V, N, C_out, H_out, ...]. + sum_batch: bool + Whether to sum over the batch dimension on the fly. + + Returns: + -------- + result: torch.Tensor + Jacobian-matrix product. + Has shape [V, N, C_w, H_w, ...] if `sum_batch == False`. + Has shape [V, C_w, H_w, ...] if `sum_batch == True`. + """ + return self._weight_jac_t_mat_prod( + module, g_inp, g_out, mat, sum_batch=sum_batch + ) + + def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Internal implementation of transposed weight Jacobian.""" + raise NotImplementedError + + +class BaseLossDerivatives(BaseDerivatives): + """Second- order partial derivatives of loss functions. + + """ + + # TODO Add shape check + def sqrt_hessian(self, module, g_inp, g_out): + """Symmetric factorization ('sqrt') of the loss Hessian.""" + return self._sqrt_hessian(module, g_inp, g_out) + + def _sqrt_hessian(self, module, g_inp, g_out): + raise NotImplementedError + + # TODO Add shape check + def sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): + """Monte-Carlo sampled symmetric factorization of the loss Hessian.""" + return self._sqrt_hessian_sampled(module, g_inp, g_out, mc_samples=mc_samples) + + def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): + raise NotImplementedError + + @shape_check.make_hessian_mat_prod_accept_vectors + @shape_check.make_hessian_mat_prod_check_shapes + def make_hessian_mat_prod(self, module, g_inp, g_out): + """Multiplication of the input Hessian with a matrix. + + Return a function that maps mat to H * mat. + """ + return self._make_hessian_mat_prod(module, g_inp, g_out) + + def _make_hessian_mat_prod(self, module, g_inp, g_out): + raise NotImplementedError + + # TODO Add shape check + def sum_hessian(self, module, g_inp, g_out): + """Loss Hessians, summed over the batch dimension.""" + return self._sum_hessian(module, g_inp, g_out) + + def _sum_hessian(self, module, g_inp, g_out): + raise NotImplementedError diff --git a/backpack/core/derivatives/batchnorm1d.py b/backpack/core/derivatives/batchnorm1d.py index 5c673bc4..b9310db3 100644 --- a/backpack/core/derivatives/batchnorm1d.py +++ b/backpack/core/derivatives/batchnorm1d.py @@ -1,13 +1,16 @@ -import torch -import torch.nn +from warnings import warn + from torch.nn import BatchNorm1d -from ...utils.utils import einsum -from .basederivatives import BaseDerivatives -from .utils import jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.core.derivatives.shape_check import ( + R_mat_prod_accept_vectors, + R_mat_prod_check_shapes, +) +from backpack.utils.ein import einsum -class BatchNorm1dDerivatives(BaseDerivatives): +class BatchNorm1dDerivatives(BaseParameterDerivatives): def get_module(self): return BatchNorm1d @@ -17,14 +20,10 @@ def hessian_is_zero(self): def hessian_is_diagonal(self): return False - # Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_mat_prod(self, module, g_inp, g_out, mat): - return self.jac_t_mat_prod(module, g_inp, g_out, mat) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + return self._jac_t_mat_prod(module, g_inp, g_out, mat) - # Transpose Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): """ Note: ----- @@ -41,46 +40,66 @@ def jac_t_mat_prod(self, module, g_inp, g_out, mat): """ assert module.affine is True - batch = self.get_batch(module) + N = self.get_batch(module) x_hat, var = self.get_normalized_input_and_var(module) - ivar = 1. / (var + module.eps).sqrt() + ivar = 1.0 / (var + module.eps).sqrt() - dx_hat = einsum('bic,i->bic', (mat, module.weight)) + dx_hat = einsum("vni,i->vni", (mat, module.weight)) - jac_t_mat = batch * dx_hat - jac_t_mat -= dx_hat.sum(0).unsqueeze(0).expand_as(jac_t_mat) - jac_t_mat -= einsum('bi,sic,si->bic', (x_hat, dx_hat, x_hat)) - jac_t_mat = einsum('bic,i->bic', (jac_t_mat, ivar / batch)) + jac_t_mat = N * dx_hat + jac_t_mat -= dx_hat.sum(1).unsqueeze(1).expand_as(jac_t_mat) + jac_t_mat -= einsum("ni,vsi,si->vni", (x_hat, dx_hat, x_hat)) + jac_t_mat = einsum("vni,i->vni", (jac_t_mat, ivar / N)) return jac_t_mat def get_normalized_input_and_var(self, module): - input = self.get_input(module) + input = module.input0 mean = input.mean(dim=0) var = input.var(dim=0, unbiased=False) return (input - mean) / (var + module.eps).sqrt(), var - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def weight_jac_mat_prod(self, module, g_inp, g_out, mat): - batch = self.get_batch(module) - x_hat, _ = self.get_normalized_input_and_var(module) - return einsum('bi,ic->bic', (x_hat, mat)) + @R_mat_prod_accept_vectors + @R_mat_prod_check_shapes + def make_residual_mat_prod(self, module, g_inp, g_out): + # TODO: Implement R_mat_prod for BatchNorm + def R_mat_prod(mat): + """Multiply with the residual: mat → [∑_{k} Hz_k(x) 𝛿z_k] mat. + + Second term of the module input Hessian backpropagation equation. + """ + raise NotImplementedError + + # TODO: Enable tests in test/automated_bn_test.py + raise NotImplementedError + return R_mat_prod - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): x_hat, _ = self.get_normalized_input_and_var(module) - equation = 'bic,bi->{}ic'.format('' if sum_batch is True else 'b') + return einsum("ni,vi->vni", (x_hat, mat)) + + def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch): + if not sum_batch: + warn( + "BatchNorm batch summation disabled." + "This may not compute meaningful quantities" + ) + x_hat, _ = self.get_normalized_input_and_var(module) + equation = "vni,ni->v{}i".format("" if sum_batch is True else "n") operands = [mat, x_hat] return einsum(equation, operands) - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def bias_jac_mat_prod(self, module, g_inp, g_out, mat): - batch = self.get_batch(module) - return mat.unsqueeze(0).repeat(batch, 1, 1) + def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): + N = self.get_batch(module) + return mat.unsqueeze(1).repeat(1, N, 1) - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - if sum_batch is True: - return mat.sum(0) - else: + def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + if not sum_batch: + warn( + "BatchNorm batch summation disabled." + "This may not compute meaningful quantities" + ) return mat + else: + N_axis = 1 + return mat.sum(N_axis) diff --git a/backpack/core/derivatives/conv2d.py b/backpack/core/derivatives/conv2d.py index 6fdd536e..9c572690 100644 --- a/backpack/core/derivatives/conv2d.py +++ b/backpack/core/derivatives/conv2d.py @@ -1,31 +1,20 @@ -import warnings - -import torch from torch.nn import Conv2d from torch.nn.functional import conv2d, conv_transpose2d -from ...core.layers import Conv2dConcat -from ...utils import conv as convUtils -from ...utils.utils import einsum, random_psd_matrix -from .basederivatives import BaseDerivatives -from .utils import jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils import conv as convUtils +from backpack.utils.ein import eingroup, einsum -class Conv2DDerivatives(BaseDerivatives): +class Conv2DDerivatives(BaseParameterDerivatives): def get_module(self): return Conv2d def hessian_is_zero(self): return True - def get_weight_data(self, module): - return module.weight.data - - def get_input(self, module): - return module.input0 - def get_unfolded_input(self, module): - return convUtils.unfold_func(module)(self.get_input(module)) + return convUtils.unfold_func(module)(module.input0) # TODO: Require tests def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): @@ -36,10 +25,11 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): # 1) apply conv_transpose to multiply with W^T result = mat.view(out_c, out_x, out_y, out_features) - result = einsum('cxyf->fcxy', (result, )) + result = einsum("cxyf->fcxy", (result,)) # result: W^T mat result = self.__apply_jacobian_t_of(module, result).view( - out_features, in_features) + out_features, in_features + ) # 2) transpose: mat^T W result = result.t() @@ -51,222 +41,86 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): # 4) transpose to obtain W^T mat W return result.view(in_features, in_features).t() - # Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def jac_mat_prod(self, module, g_inp, g_out, mat): - convUtils.check_sizes_input_jac(mat, module) - mat_as_conv = self.__reshape_for_conv_in(mat, module) - jmp_as_conv = self.__apply_jacobian_of(module, mat_as_conv) - convUtils.check_sizes_output_jac(jmp_as_conv, module) - - return self.__reshape_for_matmul(jmp_as_conv, module) - - def __apply_jacobian_of(self, module, mat): - return conv2d(mat, - self.get_weight_data(module), - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups) - - def __reshape_for_conv_in(self, bmat, module): - batch, in_channels, in_x, in_y = module.input0.size() - num_classes = bmat.size(2) - bmat = einsum('boc->cbo', (bmat, )).contiguous() - bmat = bmat.view(num_classes * batch, in_channels, in_x, in_y) - return bmat - - def __reshape_for_matmul(self, bconv, module): - batch = module.output_shape[0] - out_features = torch.prod(module.output_shape) / batch - bconv = bconv.view(-1, batch, out_features) - bconv = einsum('cbi->bic', (bconv, )) - return bconv - - # Transposed Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): - convUtils.check_sizes_input_jac_t(mat, module) - mat_as_conv = self.__reshape_for_conv_out(mat, module) - jmp_as_conv = self.__apply_jacobian_t_of(module, mat_as_conv) - convUtils.check_sizes_output_jac_t(jmp_as_conv, module) - - return self.__reshape_for_matmul_t(jmp_as_conv, module) - - def __reshape_for_conv_out(self, bmat, module): - batch, out_channels, out_x, out_y = module.output_shape - num_classes = bmat.size(2) - - bmat = einsum('boc->cbo', (bmat, )).contiguous() - bmat = bmat.view(num_classes * batch, out_channels, out_x, out_y) - return bmat - - def __reshape_for_matmul_t(self, bconv, module): - batch = module.output_shape[0] - in_features = module.input0.numel() / batch - bconv = bconv.view(-1, batch, in_features) - bconv = einsum('cbi->bic', (bconv, )) - return bconv - - def __apply_jacobian_t_of(self, module, mat): - return conv_transpose2d(mat, - self.get_weight_data(module), - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups) - - # TODO: Improve performance - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def bias_jac_mat_prod(self, module, g_inp, g_out, mat): - batch, out_channels, out_x, out_y = module.output_shape - num_cols = mat.size(1) - # mat has shape (out_channels, num_cols) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + mat_as_conv = eingroup("v,n,c,h,w->vn,c,h,w", mat) + jmp_as_conv = conv2d( + mat_as_conv, + module.weight.data, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + return self.view_like_output(jmp_as_conv, module) + + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + mat_as_conv = eingroup("v,n,c,h,w->vn,c,h,w", mat) + jmp_as_conv = conv_transpose2d( + mat_as_conv, + module.weight.data, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + return self.view_like_input(jmp_as_conv, module) + + def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): + """mat has shape [V, C_out]""" # expand for each batch and for each channel - jac_mat = mat.view(1, out_channels, 1, 1, num_cols) - jac_mat = jac_mat.expand(batch, -1, out_x, out_y, -1).contiguous() - return jac_mat.view(batch, -1, num_cols) - - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - batch, out_channels, out_x, out_y = module.output_shape - num_cols = mat.size(2) - shape = (batch, out_channels, out_x * out_y, num_cols) - # mat has shape (batch, out_features, num_cols) - # sum back over the pixels and batch dimensions - sum_dims = [0, 2] if sum_batch is True else [2] - return mat.view(shape).sum(sum_dims) - - # TODO: Improve performance, get rid of unfold - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def weight_jac_mat_prod(self, module, g_inp, g_out, mat): - batch, out_channels, out_x, out_y = module.output_shape - out_features = out_channels * out_x * out_y - num_cols = mat.size(1) - jac_mat = mat.view(1, out_channels, -1, num_cols) - jac_mat = jac_mat.expand(batch, out_channels, -1, -1) + N_axis, H_axis, W_axis = 1, 3, 4 + jac_mat = mat.unsqueeze(N_axis).unsqueeze(H_axis).unsqueeze(W_axis) - X = self.get_unfolded_input(module) - jac_mat = einsum('bij,bkic->bkjc', (X, jac_mat)).contiguous() - jac_mat = jac_mat.view(batch, out_features, num_cols) - return jac_mat + N, _, H_out, W_out = module.output_shape + return jac_mat.expand(-1, N, -1, H_out, W_out) - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def __weight_jac_t_mat_prod2(self, - module, - g_inp, - g_out, - mat, - sum_batch=True): - """Intuitive, using unfold operation.""" - batch, out_channels, out_x, out_y = module.output_shape - _, in_channels, in_x, in_y = module.input0.shape - num_cols = mat.shape[-1] + def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + N_axis, H_axis, W_axis = 1, 3, 4 + axes = [H_axis, W_axis] + if sum_batch: + axes = [N_axis] + axes - jac_t_mat = mat.view(batch, out_channels, -1, num_cols) + return mat.sum(axes) - equation = 'bij,bkjc->kic' if sum_batch is True else 'bij,bkjc->bkic' + # TODO: Improve performance by using conv instead of unfold + def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): + jac_mat = eingroup("v,o,i,h,w->v,o,ihw", mat) X = self.get_unfolded_input(module) - jac_t_mat = einsum(equation, (X, jac_t_mat)).contiguous() - - sum_shape = [module.weight.numel(), num_cols] - shape = sum_shape if sum_batch is True else [batch] + sum_shape - - jac_t_mat = jac_t_mat.view(shape) - return jac_t_mat - - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - """Unintuitive, but faster due to conv operation.""" - batch, out_channels, out_x, out_y = module.output_shape - _, in_channels, in_x, in_y = module.input0.shape - k_x, k_y = module.kernel_size - num_cols = mat.shape[-1] - - mat = mat.view(batch, out_channels, out_x, out_y, num_cols) - mat = einsum('boxyc->cboxy', - (mat, )).contiguous().view(num_cols * batch, out_channels, - out_x, out_y) - - mat = mat.repeat(1, in_channels, 1, 1) - mat = mat.view(num_cols * batch * out_channels * in_channels, 1, out_x, - out_y) - - input = module.input0.view(1, -1, in_x, in_y).repeat(1, num_cols, 1, 1) - - grad_weight = conv2d(input, mat, None, module.dilation, module.padding, - module.stride, in_channels * batch * num_cols) - - grad_weight = grad_weight.view(num_cols, batch, - out_channels * in_channels, k_x, k_y) - if sum_batch is True: - grad_weight = grad_weight.sum(1) - batch = 1 - - grad_weight = grad_weight.view(num_cols, batch, in_channels, - out_channels, k_x, k_y) - grad_weight = einsum('cbmnxy->bnmxyc', grad_weight).contiguous() - - grad_weight = grad_weight.view(batch, - in_channels * out_channels * k_x * k_y, - num_cols) - - if sum_batch is True: - grad_weight = grad_weight.squeeze(0) - - return grad_weight - - -class Conv2DConcatDerivatives(Conv2DDerivatives): - # override - def get_module(self): - return Conv2dConcat - - # override - def get_unfolded_input(self, module): - """Return homogeneous input, if bias exists """ - X = convUtils.unfold_func(module)(self.get_input(module)) - if module.has_bias(): - return module.append_ones(X) - else: - return X - - # override - def get_weight_data(self, module): - return module._slice_weight().data - - # override - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): - weight_part = super().weight_jac_t_mat_prod(module, - g_inp, - g_out, - mat, - sum_batch=sum_batch) - - if not module.has_bias(): - return weight_part - - else: - bias_part = super().bias_jac_t_mat_prod(module, - g_inp, - g_out, - mat, - sum_batch=sum_batch) - - batch = 1 if sum_batch is True else self.get_batch(module) - num_cols = mat.size(2) - w_for_cat = [batch, module.out_channels, -1, num_cols] - b_for_cat = [batch, module.out_channels, 1, num_cols] - - weight_part = weight_part.view(w_for_cat) - bias_part = bias_part.view(b_for_cat) - - wb_part = torch.cat([weight_part, bias_part], dim=2) - wb_part = wb_part.view(batch, -1, num_cols) - - if sum_batch is True: - wb_part = wb_part.squeeze(0) - return wb_part + jac_mat = einsum("nij,vki->vnkj", (X, jac_mat)) + return self.view_like_output(jac_mat, module) + + def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Unintuitive, but faster due to convolution.""" + V = mat.shape[0] + N, C_out, _, _ = module.output_shape + _, C_in, _, _ = module.input0_shape + + mat = eingroup("v,n,c,w,h->vn,c,w,h", mat).repeat(1, C_in, 1, 1) + C_in_axis = 1 + # a,b represent the combined/repeated dimensions + mat = eingroup("a,b,w,h->ab,w,h", mat).unsqueeze(C_in_axis) + + N_axis = 0 + input = eingroup("n,c,h,w->nc,h,w", module.input0).unsqueeze(N_axis) + input = input.repeat(1, V, 1, 1) + + grad_weight = conv2d( + input, + mat, + bias=None, + stride=module.dilation, + padding=module.padding, + dilation=module.stride, + groups=C_in * N * V, + ).squeeze(0) + + K_H_axis, K_W_axis = 1, 2 + _, _, K_H, K_W = module.weight.shape + grad_weight = grad_weight.narrow(K_H_axis, 0, K_H).narrow(K_W_axis, 0, K_W) + + eingroup_eq = "vnio,x,y->v,{}o,i,x,y".format("" if sum_batch else "n,") + return eingroup( + eingroup_eq, grad_weight, dim={"v": V, "n": N, "i": C_in, "o": C_out} + ) diff --git a/backpack/core/derivatives/crossentropyloss.py b/backpack/core/derivatives/crossentropyloss.py index 1e1f2212..f8587f66 100644 --- a/backpack/core/derivatives/crossentropyloss.py +++ b/backpack/core/derivatives/crossentropyloss.py @@ -1,84 +1,78 @@ from math import sqrt -from warnings import warn -from torch import diag, diag_embed, multinomial, ones_like, randn, softmax +from torch import diag, diag_embed, multinomial, ones_like, softmax from torch import sqrt as torchsqrt from torch.nn import CrossEntropyLoss from torch.nn.functional import one_hot -from ...utils.utils import einsum -from .basederivatives import BaseDerivatives -from .utils import hmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseLossDerivatives +from backpack.utils.ein import einsum -class CrossEntropyLossDerivatives(BaseDerivatives): +class CrossEntropyLossDerivatives(BaseLossDerivatives): def get_module(self): return CrossEntropyLoss - def sqrt_hessian(self, module, g_inp, g_out): + def _sqrt_hessian(self, module, g_inp, g_out): probs = self.get_probs(module) tau = torchsqrt(probs) - Id = diag_embed(ones_like(probs)) - Id_tautau = Id - einsum('ni,nj->nij', tau, tau) - sqrt_H = einsum('ni,nij->nij', tau, Id_tautau) + V_dim, C_dim = 0, 2 + Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim) + Id_tautau = Id - einsum("nv,nc->vnc", tau, tau) + sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau) - if module.reduction is "mean": - sqrt_H /= sqrt(module.input0.shape[0]) + if module.reduction == "mean": + N = module.input0.shape[0] + sqrt_H /= sqrt(N) return sqrt_H - def sqrt_hessian_sampled(self, module, g_inp, g_out): - M = self.MC_SAMPLES + def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1): + M = mc_samples C = module.input0.shape[1] - probs = self.get_probs(module).unsqueeze(-1).repeat(1, 1, M) - - # HOTFIX (torch bug): multinomial not working with CUDA - original_dev = probs.device - if probs.is_cuda: - probs = probs.cpu() - - classes = one_hot(multinomial(probs, M, replacement=True), - num_classes=C) - - probs = probs.to(original_dev) - classes = classes.to(original_dev) - # END + probs = self.get_probs(module) + V_dim = 0 + probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1) - classes = classes.transpose(1, 2).float() + multi = multinomial(probs, M, replacement=True) + classes = one_hot(multi, num_classes=C) + classes = einsum("nvc->vnc", classes).float() - sqrt_mc_h = (probs - classes) / sqrt(M) + sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M) - if module.reduction is "mean": - sqrt_mc_h /= sqrt(module.input0.shape[0]) + if module.reduction == "mean": + N = module.input0.shape[0] + sqrt_mc_h /= sqrt(N) return sqrt_mc_h - def sum_hessian(self, module, g_inp, g_out): + def _sum_hessian(self, module, g_inp, g_out): probs = self.get_probs(module) - sum_H = diag(probs.sum(0)) - einsum('bi,bj->ij', (probs, probs)) + sum_H = diag(probs.sum(0)) - einsum("bi,bj->ij", (probs, probs)) - if module.reduction is "mean": - sum_H /= module.input0.shape[0] + if module.reduction == "mean": + N = module.input0.shape[0] + sum_H /= N return sum_H - def hessian_matrix_product(self, module, g_inp, g_out): + def _make_hessian_mat_prod(self, module, g_inp, g_out): """Multiplication of the input Hessian with a matrix.""" probs = self.get_probs(module) - @hmp_unsqueeze_if_missing_dim(mat_dim=3) - def hmp(mat): - Hmat = einsum('bi,bic->bic', - (probs, mat)) - einsum('bi,bj,bjc->bic', - (probs, probs, mat)) + def hessian_mat_prod(mat): + Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum( + "bi,bj,cbj->cbi", (probs, probs, mat) + ) - if module.reduction is "mean": - Hmat /= module.input0.shape[0] + if module.reduction == "mean": + N = module.input0.shape[0] + Hmat /= N return Hmat - return hmp + return hessian_mat_prod def hessian_is_psd(self): return True diff --git a/backpack/core/derivatives/dropout.py b/backpack/core/derivatives/dropout.py index 0cbf59aa..c97cc62e 100644 --- a/backpack/core/derivatives/dropout.py +++ b/backpack/core/derivatives/dropout.py @@ -1,6 +1,7 @@ from torch import eq from torch.nn import Dropout -from .elementwise import ElementwiseDerivatives + +from backpack.core.derivatives.elementwise import ElementwiseDerivatives class DropoutDerivatives(ElementwiseDerivatives): @@ -12,5 +13,5 @@ def hessian_is_zero(self): def df(self, module, g_inp, g_out): scaling = 1 / (1 - module.p) - mask = 1 - eq(module.output, 0.).float() + mask = 1 - eq(module.output, 0.0).float() return mask * scaling diff --git a/backpack/core/derivatives/elementwise.py b/backpack/core/derivatives/elementwise.py index 47ad4c9d..1ddc2cec 100644 --- a/backpack/core/derivatives/elementwise.py +++ b/backpack/core/derivatives/elementwise.py @@ -1,27 +1,21 @@ -from ...utils.utils import einsum -from .basederivatives import BaseDerivatives -from .utils import hmp_unsqueeze_if_missing_dim, jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.utils.ein import einsum class ElementwiseDerivatives(BaseDerivatives): - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): - _, df_flat = self.batch_flat(self.df(module, g_inp, g_out)) - return einsum('bi,bic->bic', (df_flat, mat)) + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + df_elementwise = self.df(module, g_inp, g_out) + return einsum("...,v...->v...", (df_elementwise, mat)) - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_mat_prod(self, module, g_inp, g_out, mat): - _, df_flat = self.batch_flat(self.df(module, g_inp, g_out)) - return einsum('bi,bic->bic', (df_flat, mat)) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + return self.jac_t_mat_prod(module, g_inp, g_out, mat) def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): batch, df_flat = self.batch_flat(self.df(module, g_inp, g_out)) - return einsum('bi,bj,ij->ij', (df_flat, df_flat, mat)) / batch + return einsum("ni,nj,ij->ij", (df_flat, df_flat, mat)) / batch def hessian_diagonal(self, module, g_inp, g_out): - _, d2f_flat = self.batch_flat(self.d2f(module, g_inp, g_out)) - _, g_out_flat = self.batch_flat(g_out[0]) - return d2f_flat * g_out_flat + return self.d2f(module, g_inp, g_out) * g_out[0] def df(self, module, g_inp, g_out): raise NotImplementedError("First derivatives not implemented") diff --git a/backpack/core/derivatives/flatten.py b/backpack/core/derivatives/flatten.py new file mode 100644 index 00000000..1fe51a09 --- /dev/null +++ b/backpack/core/derivatives/flatten.py @@ -0,0 +1,30 @@ +from torch.nn import Flatten + +from backpack.core.derivatives.basederivatives import BaseDerivatives + + +class FlattenDerivatives(BaseDerivatives): + def get_module(self): + return Flatten + + def hessian_is_zero(self): + return True + + def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): + return mat + + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + return self.view_like_input(mat, module) + + def _jac_mat_prod(self, module, g_inp, g_out, mat): + return self.view_like_output(mat, module) + + def is_no_op(self, module): + """Does flatten add an operation to the computational graph. + + If the input is already flattened, no operation will be added for + the `Flatten` layer. This can lead to an intuitive order of backward + hook execution, see the discussion at https://discuss.pytorch.org/t/ + backward-hooks-changing-order-of-execution-in-nn-sequential/12447/4 . + """ + return tuple(module.input0_shape) == tuple(module.output_shape) diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index 56c6d0f8..19101341 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -1,101 +1,60 @@ -import torch -from ...utils.utils import einsum from torch.nn import Linear -from .basederivatives import BaseDerivatives -from ..layers import LinearConcat -from .utils import jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils.ein import einsum -class LinearDerivatives(BaseDerivatives): +class LinearDerivatives(BaseParameterDerivatives): + """Partial derivatives for the Linear layer. + + Index conventions: + ------------------ + * v: Free dimension + * n: Batch dimension + * o: Output dimension + * i: Input dimension + """ + def get_module(self): return Linear - def get_input(self, module): - return module.input0 - def hessian_is_zero(self): return True - def get_weight_data(self, module): - return module.weight.data + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + """Apply transposed Jacobian of the output w.r.t. the input.""" + d_input = module.weight.data + return einsum("oi,vno->vni", (d_input, mat)) - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): - d_linear = self.get_weight_data(module) - return einsum('ij,bic->bjc', (d_linear, mat)) - - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_mat_prod(self, module, g_inp, g_out, mat): - d_linear = self.get_weight_data(module) - return einsum('ij,bjc->bic', (d_linear, mat)) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + """Apply Jacobian of the output w.r.t. the input.""" + d_input = module.weight.data + return einsum("oi,vni->vno", (d_input, mat)) def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): - jac = self.get_weight_data(module) - return einsum('ik,ij,jl->kl', (jac, mat, jac)) - - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def weight_jac_mat_prod(self, module, g_inp, g_out, mat): - batch = self.get_batch(module) - num_cols = mat.size(1) - shape = tuple(module.weight.size()) + (num_cols, ) - - jac_mat = einsum('bj,ijc->bic', - (self.get_input(module), mat.view(shape))) - return jac_mat - - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def weight_jac_t_mat_prod(self, - module, - g_inp, - g_out, - mat, - sum_batch=True): - batch = self.get_batch(module) - num_cols = mat.size(2) - - equation = 'bjc,bi->jic' if sum_batch is True else 'bjc,bi->bjic' - - jac_t_mat = einsum(equation, - (mat, self.get_input(module))).contiguous() - - sum_shape = [module.weight.numel(), num_cols] - shape = sum_shape if sum_batch is True else [batch] + sum_shape - - return jac_t_mat.view(*shape) - - @jmp_unsqueeze_if_missing_dim(mat_dim=2) - def bias_jac_mat_prod(self, module, g_inp, g_out, mat): - batch = self.get_batch(module) - return mat.unsqueeze(0).expand(batch, -1, -1) - - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def bias_jac_t_mat_prod(self, - module, - g_inp, - g_out, - mat, - sum_batch=True): - if sum_batch is True: - return mat.sum(0) + jac = module.weight.data + return einsum("ik,ij,jl->kl", (jac, mat, jac)) + + def _weight_jac_mat_prod(self, module, g_inp, g_out, mat): + """Apply Jacobian of the output w.r.t. the weight.""" + d_weight = module.input0 + return einsum("ni,voi->vno", (d_weight, mat)) + + def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Apply transposed Jacobian of the output w.r.t. the weight.""" + d_weight = module.input0 + contract = "vno,ni->voi" if sum_batch else "vno,ni->vnoi" + return einsum(contract, (mat, d_weight)) + + def _bias_jac_mat_prod(self, module, g_inp, g_out, mat): + """Apply Jacobian of the output w.r.t. the bias.""" + N = self.get_batch(module) + return mat.unsqueeze(1).expand(-1, N, -1) + + def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True): + """Apply transposed Jacobian of the output w.r.t. the bias.""" + if sum_batch: + N_axis = 1 + return mat.sum(N_axis) else: return mat - - -class LinearConcatDerivatives(LinearDerivatives): - # override - def get_module(self): - return LinearConcat - - # override - def get_input(self, module): - """Return homogeneous input, if bias exists """ - input = super().get_input(module) - if module.has_bias(): - return module.append_ones(input) - else: - return input - - # override - def get_weight_data(self, module): - return module._slice_weight().data diff --git a/backpack/core/derivatives/maxpool2d.py b/backpack/core/derivatives/maxpool2d.py index f2cf5f92..89358797 100644 --- a/backpack/core/derivatives/maxpool2d.py +++ b/backpack/core/derivatives/maxpool2d.py @@ -1,28 +1,26 @@ -import warnings - -from torch import prod, scatter_add, zeros +from torch import zeros from torch.nn import MaxPool2d from torch.nn.functional import max_pool2d -from ...utils import conv as convUtils -from ...utils.utils import random_psd_matrix -from .basederivatives import BaseDerivatives -from .utils import jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.utils.ein import eingroup class MaxPool2DDerivatives(BaseDerivatives): def get_module(self): return MaxPool2d + # TODO: Do not recompute but get from forward pass of module def get_pooling_idx(self, module): - # TODO: Do not recompute but get from forward pass of module - _, pool_idx = max_pool2d(module.input0, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - return_indices=True, - ceil_mode=module.ceil_mode) + _, pool_idx = max_pool2d( + module.input0, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + return_indices=True, + ceil_mode=module.ceil_mode, + ) return pool_idx def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): @@ -32,86 +30,62 @@ def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): in terms of the approximation and memory costs. """ device = mat.device - batch, channels, in_x, in_y = module.input0.size() - in_features = channels * in_x * in_y - _, _, out_x, out_y = module.output.size() - out_features = channels * out_x * out_y + N, channels, H_in, W_in = module.input0.size() + in_features = channels * H_in * W_in + _, _, H_out, W_out = module.output.size() + out_features = channels * H_out * W_out - pool_idx = self.get_pooling_idx(module).view(batch, channels, - out_x * out_y) + pool_idx = self.get_pooling_idx(module).view(N, channels, H_out * W_out) result = zeros(in_features, in_features, device=device) - for b in range(batch): + for b in range(N): idx = pool_idx[b, :, :] temp = zeros(in_features, out_features, device=device) temp.scatter_add_(1, idx, mat) result.scatter_add_(0, idx.t(), temp) - return result / batch + return result / N def hessian_is_zero(self): return True - # Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_mat_prod(self, module, g_inp, g_out, mat): - convUtils.check_sizes_input_jac(mat, module) - mat_as_pool = self.__reshape_for_pooling_in(mat, module) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + mat_as_pool = eingroup("v,n,c,h,w->v,n,c,hw", mat) jmp_as_pool = self.__apply_jacobian_of(module, mat_as_pool) - return self.__reshape_for_matmul(jmp_as_pool, module) - - def __reshape_for_pooling_in(self, mat, module): - num_classes = mat.size(-1) - batch, channels, in_x, in_y = module.input0.size() - return mat.view(batch, channels, in_x * in_y, num_classes) - - def __reshape_for_matmul(self, mat, module): - batch = module.output_shape[0] - out_features = prod(module.output_shape) / batch - num_classes = mat.size(-1) - return mat.view(batch, out_features, num_classes) + return self.view_like_output(jmp_as_pool, module) def __apply_jacobian_of(self, module, mat): - batch, channels, out_x, out_y = module.output_shape - num_classes = mat.shape[-1] + V, HW_axis = mat.shape[0], 3 + pool_idx = self.__pool_idx_for_jac(module, V) + return mat.gather(HW_axis, pool_idx) - pool_idx = self.get_pooling_idx(module) - pool_idx = pool_idx.view(batch, channels, out_x * out_y) - pool_idx = pool_idx.unsqueeze(-1).expand(-1, -1, -1, num_classes) - - return mat.gather(2, pool_idx) + def __pool_idx_for_jac(self, module, V): + """Manipulated pooling indices ready-to-use in jac(t).""" - # Transposed Jacobian-matrix product - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): - convUtils.check_sizes_input_jac_t(mat, module) - mat_as_pool = self.__reshape_for_pooling_out(mat, module) + pool_idx = self.get_pooling_idx(module) + V_axis = 0 + return ( + eingroup("n,c,h,w->n,c,hw", pool_idx) + .unsqueeze(V_axis) + .expand(V, -1, -1, -1) + ) + + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + mat_as_pool = eingroup("v,n,c,h,w->v,n,c,hw", mat) jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool) - return self.__reshape_for_matmul_t(jmp_as_pool, module) - - def __reshape_for_pooling_out(self, mat, module): - num_classes = mat.size(-1) - batch, channels, out_x, out_y = module.output_shape - return mat.view(batch, channels, out_x * out_y, num_classes) - - def __reshape_for_matmul_t(self, mat, module): - batch = module.output_shape[0] - in_features = module.input0.numel() / batch - num_classes = mat.size(-1) - return mat.view(batch, in_features, num_classes) + return self.view_like_input(jmp_as_pool, module) def __apply_jacobian_t_of(self, module, mat): - batch, channels, out_x, out_y = module.output_shape - _, _, in_x, in_y = module.input0.size() - num_classes = mat.shape[-1] - - result = zeros(batch, - channels, - in_x * in_y, - num_classes, - device=mat.device) + V = mat.shape[0] + result = self.__zero_for_jac_t(module, V, mat.device) + pool_idx = self.__pool_idx_for_jac(module, V) - pool_idx = self.get_pooling_idx(module) - pool_idx = pool_idx.view(batch, channels, out_x * out_y) - pool_idx = pool_idx.unsqueeze(-1).expand(-1, -1, -1, num_classes) - result.scatter_add_(2, pool_idx, mat) + HW_axis = 3 + result.scatter_add_(HW_axis, pool_idx, mat) return result + + def __zero_for_jac_t(self, module, V, device): + N, C_out, _, _ = module.output_shape + _, _, H_in, W_in = module.input0.size() + + shape = (V, N, C_out, H_in * W_in) + return zeros(shape, device=device) diff --git a/backpack/core/derivatives/mseloss.py b/backpack/core/derivatives/mseloss.py index 71f069fa..2fb655a5 100644 --- a/backpack/core/derivatives/mseloss.py +++ b/backpack/core/derivatives/mseloss.py @@ -1,63 +1,79 @@ -from warnings import warn from math import sqrt -from torch import diag_embed, ones_like, randn, diag, ones +from warnings import warn + +from torch import diag, diag_embed, ones, ones_like from torch.nn import MSELoss -from .basederivatives import BaseDerivatives -from .utils import hmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseLossDerivatives -class MSELossDerivatives(BaseDerivatives): +class MSELossDerivatives(BaseLossDerivatives): def get_module(self): return MSELoss - def sqrt_hessian(self, module, g_inp, g_out): + def _sqrt_hessian(self, module, g_inp, g_out): self.check_input_dims(module) - sqrt_H = diag_embed(sqrt(2) * ones_like(module.input0)) + V_dim, C_dim = 0, 2 + diag = sqrt(2) * ones_like(module.input0) + sqrt_H = diag_embed(diag, dim1=V_dim, dim2=C_dim) - if module.reduction is "mean": - sqrt_H /= sqrt(module.input0.shape[0]) + if module.reduction == "mean": + N = module.input0.shape[0] + sqrt_H /= sqrt(N) return sqrt_H - def sqrt_hessian_sampled(self, module, g_inp, g_out): - warn("[MC Sampling Hessian of CrossEntropy] " + - "Returning the symmetric factorization of the full Hessian " + - "(same computation cost)") + def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=None): + """ + Note: + ----- + The parameter `mc_samples` is ignored. + The method always returns the full square root. + + The computational cost between the sampled and full version is the same, + so the method always return the more accurate version. + + The cost is the same because the hessian of the loss w.r.t. its inputs + for a single sample is one-dimensional. + """ + warn( + "[MC Sampling Hessian of MSE loss] " + + "Returning the symmetric factorization of the full Hessian " + + "(same computation cost)", + UserWarning, + ) return self.sqrt_hessian(module, g_inp, g_out) - def sum_hessian(self, module, g_inp, g_out): + def _sum_hessian(self, module, g_inp, g_out): self.check_input_dims(module) - batch = module.input0_shape[0] - num_features = module.input0.numel() // batch - sum_H = 2 * batch * diag( - ones(num_features, device=module.input0.device)) + N = module.input0_shape[0] + num_features = module.input0.numel() // N + sum_H = 2 * N * diag(ones(num_features, device=module.input0.device)) + + if module.reduction == "mean": + sum_H /= N - if module.reduction is "mean": - sum_H /= module.input0.shape[0] - print("sum H ", sum_H.shape) return sum_H - def hessian_matrix_product(self, module, g_inp, g_out): + def _make_hessian_mat_prod(self, module, g_inp, g_out): """Multiplication of the input Hessian with a matrix.""" - @hmp_unsqueeze_if_missing_dim(mat_dim=3) - def hmp(mat): + def hessian_mat_prod(mat): Hmat = 2 * mat - if module.reduction is "mean": - Hmat /= module.input0.shape[0] + if module.reduction == "mean": + N = module.input0.shape[0] + Hmat /= N return Hmat - return hmp + return hessian_mat_prod def check_input_dims(self, module): if not len(module.input0.shape) == 2: - raise ValueError( - "Only 2D inputs are currently supported for MSELoss.") + raise ValueError("Only 2D inputs are currently supported for MSELoss.") def hessian_is_psd(self): return True diff --git a/backpack/core/derivatives/relu.py b/backpack/core/derivatives/relu.py index 379cb524..cb646baa 100644 --- a/backpack/core/derivatives/relu.py +++ b/backpack/core/derivatives/relu.py @@ -1,6 +1,7 @@ from torch import gt from torch.nn import ReLU -from .elementwise import ElementwiseDerivatives + +from backpack.core.derivatives.elementwise import ElementwiseDerivatives class ReLUDerivatives(ElementwiseDerivatives): diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py new file mode 100644 index 00000000..d92fbf58 --- /dev/null +++ b/backpack/core/derivatives/shape_check.py @@ -0,0 +1,251 @@ +""" +Helpers to support application of Jacobians to vectors +Helpers to check input and output sizes of Jacobian-matrix products. +""" +import functools + + +############################################################################### +# Utility functions # +############################################################################### +def add_V_dim(mat): + return mat.unsqueeze(0) + + +def remove_V_dim(mat): + if mat.shape[0] != 1: + raise RuntimeError( + "Cannot unsqueeze dimension 0. ", "Got tensor of shape {}".format(mat.shape) + ) + return mat.squeeze(0) + + +def check_shape(mat, like, diff=1): + """Compare dimension diff,diff+1, ... with dimension 0,1,...""" + mat_shape = [int(dim) for dim in mat.shape] + like_shape = [int(dim) for dim in like.shape] + + if len(mat_shape) - len(like_shape) != diff: + raise RuntimeError( + "Difference in dimension must be {}.".format(diff), + " Got {} and {}".format(mat_shape, like_shape), + ) + if mat_shape[diff:] != like_shape: + raise RuntimeError( + "Compared shapes {} and {} do not match. ".format( + mat_shape[diff:], like_shape + ), + "Got {} and {}".format(mat_shape, like_shape), + ) + + +def check_same_V_dim(mat1, mat2): + V1, V2 = mat1.shape[0], mat2.shape[0] + if V1 != V2: + raise RuntimeError("Number of vectors changed. Got {} and {}".format(V1, V2)) + + +def check_like(mat, module, name, diff=1, *args, **kwargs): + return check_shape(mat, getattr(module, name), diff=diff) + + +def check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs): + diff = 1 if sum_batch else 2 + return check_shape(mat, getattr(module, name), diff=diff) + + +def same_dim_as(mat, module, name, *args, **kwargs): + return len(mat.shape) == len(getattr(module, name).shape) + + +############################################################################### +# Decorators for handling vectors as matrix special case # +############################################################################### +def mat_prod_accept_vectors(mat_prod, vec_criterion): + """Add support for vectors to matrix products. + + vec_criterion(mat, module) returns if mat is a vector. + """ + + @functools.wraps(mat_prod) + def wrapped_mat_prod_accept_vectors( + self, module, g_inp, g_out, mat, *args, **kwargs + ): + is_vec = vec_criterion(mat, module, *args, **kwargs) + mat_in = mat if not is_vec else add_V_dim(mat) + mat_out = mat_prod(self, module, g_inp, g_out, mat_in, *args, **kwargs) + mat_out = mat_out if not is_vec else remove_V_dim(mat_out) + + return mat_out + + return wrapped_mat_prod_accept_vectors + + +# vec criteria +same_dim_as_output = functools.partial(same_dim_as, name="output") +same_dim_as_input = functools.partial(same_dim_as, name="input0") +same_dim_as_weight = functools.partial(same_dim_as, name="weight") +same_dim_as_bias = functools.partial(same_dim_as, name="bias") + +# decorators for handling vectors +jac_t_mat_prod_accept_vectors = functools.partial( + mat_prod_accept_vectors, vec_criterion=same_dim_as_output, +) + +weight_jac_t_mat_prod_accept_vectors = functools.partial( + mat_prod_accept_vectors, vec_criterion=same_dim_as_output, +) +bias_jac_t_mat_prod_accept_vectors = functools.partial( + mat_prod_accept_vectors, vec_criterion=same_dim_as_output, +) + +jac_mat_prod_accept_vectors = functools.partial( + mat_prod_accept_vectors, vec_criterion=same_dim_as_input, +) + +weight_jac_mat_prod_accept_vectors = functools.partial( + mat_prod_accept_vectors, vec_criterion=same_dim_as_weight, +) + +bias_jac_mat_prod_accept_vectors = functools.partial( + mat_prod_accept_vectors, vec_criterion=same_dim_as_bias, +) + + +############################################################################### +# Decorators for checking inputs and outputs of mat_prod routines # +############################################################################### +def mat_prod_check_shapes(mat_prod, in_check, out_check): + """Check that input and output have correct shapes.""" + + @functools.wraps(mat_prod) + def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwargs): + in_check(mat, module, *args, **kwargs) + mat_out = mat_prod(self, module, g_inp, g_out, mat, *args, **kwargs) + out_check(mat_out, module, *args, **kwargs) + check_same_V_dim(mat_out, mat) + + return mat_out + + return wrapped_mat_prod_check_shapes + + +# input/output checker +shape_like_output = functools.partial(check_like, name="output") +shape_like_input = functools.partial(check_like, name="input0") +shape_like_weight = functools.partial(check_like, name="weight") +shape_like_bias = functools.partial(check_like, name="bias") +shape_like_weight_with_sum_batch = functools.partial( + check_like_with_sum_batch, name="weight" +) +shape_like_bias_with_sum_batch = functools.partial( + check_like_with_sum_batch, name="bias" +) + +# decorators for shape checking +jac_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, in_check=shape_like_input, out_check=shape_like_output +) + +weight_jac_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, in_check=shape_like_weight, out_check=shape_like_output +) + +bias_jac_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, in_check=shape_like_bias, out_check=shape_like_output +) + +jac_t_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, in_check=shape_like_output, out_check=shape_like_input +) + + +weight_jac_t_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, + in_check=shape_like_output, + out_check=shape_like_weight_with_sum_batch, +) +bias_jac_t_mat_prod_check_shapes = functools.partial( + mat_prod_check_shapes, + in_check=shape_like_output, + out_check=shape_like_bias_with_sum_batch, +) + +############################################################################### +# Wrapper for second-order extensions # +############################################################################### + +# TODO Refactor using partials + + +def R_mat_prod_check_shapes(make_R_mat_prod): + """Check that input and output have correct shapes.""" + + @functools.wraps(make_R_mat_prod) + def wrapped_make_R_mat_prod(self, module, g_inp, g_out): + def checked_R_mat_prod(mat): + check_like(mat, module, "input0") + mat_out = make_R_mat_prod(self, module, g_inp, g_out)(mat) + check_like(mat_out, module, "input0") + check_same_V_dim(mat, mat_out) + + return mat_out + + return checked_R_mat_prod + + return wrapped_make_R_mat_prod + + +def R_mat_prod_accept_vectors(make_R_mat_prod): + """Add support for vectors to Residual-matrix products.""" + + @functools.wraps(make_R_mat_prod) + def wrapped_make_R_mat_prod(self, module, g_inp, g_out): + def new_R_mat_prod(mat): + is_vec = same_dim_as(mat, module, "input0") + mat_in = mat if not is_vec else add_V_dim(mat) + mat_out = make_R_mat_prod(self, module, g_inp, g_out)(mat_in) + mat_out = mat_out if not is_vec else remove_V_dim(mat_out) + + return mat_out + + return new_R_mat_prod + + return wrapped_make_R_mat_prod + + +def make_hessian_mat_prod_accept_vectors(make_hessian_mat_prod): + @functools.wraps(make_hessian_mat_prod) + def wrapped_make_hessian_mat_prod(self, module, g_inp, g_out): + + hessian_mat_prod = make_hessian_mat_prod(self, module, g_inp, g_out) + + def new_hessian_mat_prod(mat): + is_vec = same_dim_as(mat, module, "input0") + mat_in = mat if not is_vec else add_V_dim(mat) + mat_out = hessian_mat_prod(mat_in) + mat_out = mat_out if not is_vec else remove_V_dim(mat_out) + + return mat_out + + return new_hessian_mat_prod + + return wrapped_make_hessian_mat_prod + + +def make_hessian_mat_prod_check_shapes(make_hessian_mat_prod): + @functools.wraps(make_hessian_mat_prod) + def wrapped_make_hessian_mat_prod(self, module, g_inp, g_out): + + hessian_mat_prod = make_hessian_mat_prod(self, module, g_inp, g_out) + + def new_hessian_mat_prod(mat): + check_like(mat, module, "input0") + result = hessian_mat_prod(mat) + check_like(result, module, "input0") + + return result + + return new_hessian_mat_prod + + return wrapped_make_hessian_mat_prod diff --git a/backpack/core/derivatives/sigmoid.py b/backpack/core/derivatives/sigmoid.py index f3214f53..796f1b46 100644 --- a/backpack/core/derivatives/sigmoid.py +++ b/backpack/core/derivatives/sigmoid.py @@ -1,5 +1,6 @@ from torch.nn import Sigmoid -from .elementwise import ElementwiseDerivatives + +from backpack.core.derivatives.elementwise import ElementwiseDerivatives class SigmoidDerivatives(ElementwiseDerivatives): @@ -13,7 +14,7 @@ def hessian_is_diagonal(self): return True def df(self, module, g_inp, g_out): - return module.output * (1. - module.output) + return module.output * (1.0 - module.output) def d2f(self, module, g_inp, g_out): return module.output * (1 - module.output) * (1 - 2 * module.output) diff --git a/backpack/core/derivatives/tanh.py b/backpack/core/derivatives/tanh.py index bd45f9fe..fc728155 100644 --- a/backpack/core/derivatives/tanh.py +++ b/backpack/core/derivatives/tanh.py @@ -1,5 +1,6 @@ from torch.nn import Tanh -from .elementwise import ElementwiseDerivatives + +from backpack.core.derivatives.elementwise import ElementwiseDerivatives class TanhDerivatives(ElementwiseDerivatives): @@ -13,7 +14,7 @@ def hessian_is_diagonal(self): return True def df(self, module, g_inp, g_out): - return 1. - module.output**2 + return 1.0 - module.output ** 2 def d2f(self, module, g_inp, g_out): - return (-2. * module.output * (1. - module.output**2)) + return -2.0 * module.output * (1.0 - module.output ** 2) diff --git a/backpack/core/derivatives/zeropad2d.py b/backpack/core/derivatives/zeropad2d.py index 3fcb9ffe..1c5dad4e 100644 --- a/backpack/core/derivatives/zeropad2d.py +++ b/backpack/core/derivatives/zeropad2d.py @@ -1,11 +1,7 @@ -import warnings +from torch.nn import ZeroPad2d, functional -from torch.nn import ZeroPad2d -from torch.nn.functional import pad - -from ...utils.utils import einsum, random_psd_matrix -from .basederivatives import BaseDerivatives -from .utils import jmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.basederivatives import BaseDerivatives +from backpack.utils.ein import eingroup class ZeroPad2dDerivatives(BaseDerivatives): @@ -17,62 +13,33 @@ def hessian_is_zero(self): # TODO: Require tests def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): - _, out_c, out_x, out_y = module.output_shape + _, C_out, H_out, W_out = module.output_shape _, in_c, in_x, in_y = module.input0_shape in_features = in_c * in_x * in_y - # slicing indices - pad_left, pad_right, pad_top, pad_bottom = module.padding - idx_left, idx_right = pad_left, out_y - pad_right - idx_top, idx_bottom = pad_top, out_x - pad_bottom - - result = mat.view(out_c, out_x, out_y, out_c, out_x, out_y) + result = mat.view(C_out, H_out, W_out, C_out, H_out, W_out) - result = result[:, idx_top:idx_bottom, idx_left:idx_right, :, idx_top: - idx_bottom, idx_left:idx_right].contiguous() + (W_top, W_bottom), (H_bottom, H_top) = self.__unpad_indices(module) + result = result[ + :, W_top:W_bottom, H_bottom:H_top, :, W_top:W_bottom, H_bottom:H_top, + ].contiguous() return result.view(in_features, in_features) - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_t_mat_prod(self, module, g_inp, g_out, mat): - # reshape feature dimension as output image - batch, out_features, num_cols = mat.size() - _, out_channels, out_x, out_y = module.output_shape - assert out_features == out_channels * out_x * out_y - mat = mat.view(batch, out_channels, out_x, out_y, num_cols) + def _jac_t_mat_prod(self, module, g_inp, g_out, mat): + (W_top, W_bottom), (H_bottom, H_top) = self.__unpad_indices(module) + return mat[:, :, :, W_top:W_bottom, H_bottom:H_top] - # remove padding by slicing + def __unpad_indices(self, module): + _, _, H_out, W_out = module.output_shape pad_left, pad_right, pad_top, pad_bottom = module.padding - idx_left, idx_right = pad_left, out_y - pad_right - idx_top, idx_bottom = pad_top, out_x - pad_bottom - mat_unpad = mat[:, :, idx_top:idx_bottom, idx_left: - idx_right, :].contiguous() - - # group in features - _, in_channels, in_x, in_y = module.input0_shape - in_features = in_channels * in_x * in_y - return mat_unpad.view(batch, in_features, num_cols) - - @jmp_unsqueeze_if_missing_dim(mat_dim=3) - def jac_mat_prod(self, module, g_inp, g_out, mat): - # group batch and column dimension of the matrix - batch, in_features, num_cols = mat.size() - mat = einsum('bic->bci', (mat)).contiguous() - - # reshape feature dimension as input image - _, in_channels, in_x, in_y = module.input0_shape - mat = mat.view(batch * num_cols, in_channels, in_x, in_y) - - # apply padding - pad_mat = self.apply_padding(module, mat) - # ungroup batch and column dimension - _, out_channels, out_x, out_y = module.output_shape - out_features = out_channels * out_x * out_y + H_bottom, H_top = pad_left, W_out - pad_right + W_top, W_bottom = pad_top, H_out - pad_bottom - pad_mat = pad_mat.view(batch, num_cols, out_features) - return einsum('bci->bic', (pad_mat)).contiguous() + return (W_top, W_bottom), (H_bottom, H_top) - @staticmethod - def apply_padding(module, input): - return pad(input, module.padding, 'constant', module.value) + def _jac_mat_prod(self, module, g_inp, g_out, mat): + mat = eingroup("v,n,c,h,w->vn,c,h,w", mat) + pad_mat = functional.pad(mat, module.padding, "constant", module.value) + return self.view_like_output(pad_mat, module) diff --git a/backpack/core/layers/__init__.py b/backpack/core/layers/__init__.py deleted file mode 100644 index c050a574..00000000 --- a/backpack/core/layers/__init__.py +++ /dev/null @@ -1,166 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.nn import Module, Linear, Parameter, Conv2d -from torch import flatten, cat, Tensor, empty -from ...utils.conv import unfold_func - - -class Flatten(Module): - """ - NN module version of torch.nn.functional.flatten - """ - def __init__(self): - super().__init__() - - def forward(self, input): - return flatten(input, start_dim=1, end_dim=-1) - - -class SkipConnection(Module): - pass - - -class LinearConcat(Module): - """ - Drop-in replacement for torch.nn.Linear with only one parameter. - """ - def __init__(self, in_features, out_features, bias=True): - super().__init__() - - lin = Linear(in_features, out_features, bias=bias) - - if bias: - self.weight = Parameter( - empty(size=(out_features, in_features + 1))) - self.weight.data = cat( - [lin.weight.data, lin.bias.data.unsqueeze(-1)], dim=1) - else: - self.weight = Parameter(empty(size=(out_features, in_features))) - self.weight.data = lin.weight.data - - self.input_features = in_features - self.output_features = out_features - self.__bias = bias - - def forward(self, input): - return F.linear(input, self._slice_weight(), self._slice_bias()) - - def has_bias(self): - return self.__bias is True - - def homogeneous_input(self): - input = self.input0 - if self.has_bias(): - input = self.append_ones(input) - return input - - @staticmethod - def append_ones(input): - batch = input.shape[0] - ones = torch.ones(batch, 1, device=input.device) - return torch.cat([input, ones], dim=1) - - def _slice_weight(self): - return self.weight.narrow(1, 0, self.input_features) - - def _slice_bias(self): - if not self.has_bias(): - return None - else: - return self.weight.narrow(1, self.input_features, 1).squeeze(-1) - - -class Conv2dConcat(Module): - """ - Drop-in replacement for torch.nn.Conv2d with only one parameter. - """ - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode="zeros"): - assert padding_mode is "zeros" - assert groups == 1 - - super().__init__() - - conv = Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode) - - self._KERNEL_SHAPE = conv.weight.shape - - kernel_mat_shape = [out_channels, conv.weight.numel() // out_channels] - kernel_mat = conv.weight.data.view(kernel_mat_shape) - - if bias: - kernel_mat_shape[1] += 1 - kernel_mat = cat([kernel_mat, conv.bias.data.unsqueeze(-1)], dim=1) - - self.weight = Parameter(empty(size=kernel_mat_shape)) - self.weight.data = kernel_mat - - else: - self.weight = Parameter(empty(size=kernel_mat_shape)) - self.weight.data = kernel_mat - - self.in_channels = conv.in_channels - self.out_channels = conv.out_channels - self.kernel_size = conv.kernel_size - self.stride = conv.stride - self.padding = conv.padding - self.dilation = conv.dilation - self.transposed = conv.transposed - self.output_padding = conv.output_padding - self.groups = conv.groups - self.padding_mode = padding_mode - self.__bias = bias - - def forward(self, input): - return F.conv2d( - input, - self._slice_weight(), - bias=self._slice_bias(), - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - groups=self.groups) - - def has_bias(self): - return self.__bias is True - - def homogeneous_unfolded_input(self): - unfolded_input = unfold_func(self)(self.input0) - if self.has_bias(): - unfolded_input = self.append_ones(unfolded_input) - return unfolded_input - - @staticmethod - def append_ones(input): - batch, _, cols = input.shape - ones = torch.ones(batch, 1, cols, device=input.device) - return torch.cat([input, ones], dim=1) - - def _slice_weight(self): - return self.weight.narrow(1, 0, - self.weight.size(1) - 1).view( - self._KERNEL_SHAPE) - - def _slice_bias(self): - if not self.has_bias(): - return None - else: - return self.weight.narrow(1, - self.weight.size(1) - 1, 1).squeeze(-1) diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index 0fba2300..e86ec8e4 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -4,5 +4,29 @@ from .curvmatprod import CMP from .firstorder import BatchGrad, BatchL2Grad, SumGradSquared, Variance -from .secondorder import (HBP, KFAC, KFLR, KFRA, DiagGGN, DiagGGNExact, - DiagGGNMC, DiagHessian) +from .secondorder import ( + HBP, + KFAC, + KFLR, + KFRA, + DiagGGN, + DiagGGNExact, + DiagGGNMC, + DiagHessian, +) + +__all__ = [ + "CMP", + "BatchL2Grad", + "BatchGrad", + "SumGradSquared", + "Variance", + "HBP", + "KFAC", + "KFLR", + "KFRA", + "DiagGGN", + "DiagGGNExact", + "DiagGGNMC", + "DiagHessian", +] diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index c5a6d365..71f45adb 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -24,6 +24,7 @@ class BackpropExtension: print(p.myfield) ``` """ + __external_module_extensions = {} def __init__(self, savefield, module_exts, fail_mode=FAIL_ERROR): @@ -55,7 +56,9 @@ def add_module_extension(cls, module, extension): def __get_module_extension(self, module): module_extension = self.__module_extensions.get(module.__class__) - no_op = lambda *args: None + + def no_op(*args): + return None if module_extension is None: @@ -64,14 +67,16 @@ def __get_module_extension(self, module): if self.__fail_mode is FAIL_ERROR: raise NotImplementedError( - "Extension saving to {} ".format(self.savefield) + - "does not have an extension for " + - "Module {}".format(module.__class__)) + "Extension saving to {} ".format(self.savefield) + + "does not have an extension for " + + "Module {}".format(module.__class__) + ) elif self.__fail_mode == FAIL_WARN: warnings.warn( - "Extension saving to {} ".format(self.savefield) + - "does not have an extension for " + - "Module {}".format(module.__class__)) + "Extension saving to {} ".format(self.savefield) + + "does not have an extension for " + + "Module {}".format(module.__class__) + ) return no_op diff --git a/backpack/extensions/curvature.py b/backpack/extensions/curvature.py index 51481534..78b474ac 100644 --- a/backpack/extensions/curvature.py +++ b/backpack/extensions/curvature.py @@ -1,7 +1,14 @@ -import torch +"""Modification of second-order module effects during Hessian backpropagation. +The residual term is tweaked to give rise to the following curvatures: +- No modification: Exact Hessian +- Neglect module second order information: Generalized Gauss-Newton matrix +- Cast negative residual eigenvalue to their absolute value: PCH-abs +- Set negative residual eigenvalues to zero: PCH-clip +""" -class ResidualModifications(): + +class ResidualModifications: @staticmethod def nothing(res): return res @@ -18,28 +25,18 @@ def remove_negative_values(res): def to_abs(res): return res.abs() - @staticmethod - def to_med(res, if_negative_return=None): - median = res.median() - if median < 0: - return None - else: - return median * torch.ones_like(res) - -class Curvature(): - HESSIAN = 'hessian' - GGN = 'ggn' - PCH_ABS = 'pch-abs' - PCH_CLIP = 'pch-clip' - PCH_MED = 'pch-med' +class Curvature: + HESSIAN = "hessian" + GGN = "ggn" + PCH_ABS = "pch-abs" + PCH_CLIP = "pch-clip" CHOICES = [ HESSIAN, GGN, PCH_CLIP, PCH_ABS, - PCH_MED, ] REQUIRE_PSD_LOSS_HESSIAN = { @@ -47,7 +44,6 @@ class Curvature(): GGN: True, PCH_ABS: True, PCH_CLIP: True, - PCH_MED: True, } REQUIRE_RESIDUAL = { @@ -55,7 +51,6 @@ class Curvature(): GGN: False, PCH_ABS: True, PCH_CLIP: True, - PCH_MED: True, } RESIDUAL_MODS = { @@ -63,21 +58,27 @@ class Curvature(): GGN: ResidualModifications.to_zero, PCH_ABS: ResidualModifications.to_abs, PCH_CLIP: ResidualModifications.remove_negative_values, - PCH_MED: ResidualModifications.to_med, } @classmethod def __check_exists(cls, which): - if not which in cls.CHOICES: + if which not in cls.CHOICES: raise AttributeError( - "Unknown curvature: {}. Expecting one of {}".format( - which, cls.CHOICES)) + "Unknown curvature: {}. Expecting one of {}".format(which, cls.CHOICES) + ) @classmethod def require_residual(cls, curv_type): cls.__check_exists(curv_type) return cls.REQUIRE_RESIDUAL[curv_type] + @classmethod + def is_pch(cls, curv_type): + """Is `curv_type` one of the PCHs proposed by Chen et al.""" + cls.__check_exists(curv_type) + PCH = [cls.PCH_ABS, cls.PCH_CLIP] + return curv_type in PCH + @classmethod def modify_residual(cls, residual, curv_type): # None if zero or curvature neglects 2nd-order module effects @@ -95,5 +96,7 @@ def check_loss_hessian(cls, loss_hessian_is_psd, curv_type): if require_psd and not loss_hessian_is_psd: raise ValueError( - 'Loss Hessian PSD = {}, but {} requires PSD = {}'.format( - loss_hessian_is_psd, curv_type, require_psd)) + "Loss Hessian PSD = {}, but {} requires PSD = {}".format( + loss_hessian_is_psd, curv_type, require_psd + ) + ) diff --git a/backpack/extensions/curvmatprod/__init__.py b/backpack/extensions/curvmatprod/__init__.py index 9507608f..892fe28d 100644 --- a/backpack/extensions/curvmatprod/__init__.py +++ b/backpack/extensions/curvmatprod/__init__.py @@ -9,39 +9,60 @@ of the curvature, such as the Block-diagonal Generalized Gauss-Newton """ -from backpack.core.layers import Conv2dConcat, Flatten, LinearConcat +from torch.nn import ( + AvgPool2d, + BatchNorm1d, + Conv2d, + CrossEntropyLoss, + Dropout, + Flatten, + Linear, + MaxPool2d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + from backpack.extensions.backprop_extension import BackpropExtension -from torch.nn import (AvgPool2d, BatchNorm1d, Conv2d, CrossEntropyLoss, - Dropout, Linear, MaxPool2d, MSELoss, ReLU, Sigmoid, Tanh, - ZeroPad2d) -from . import (activations, batchnorm1d, conv2d, dropout, flatten, linear, - losses, padding, pooling) +from . import ( + activations, + batchnorm1d, + conv2d, + dropout, + flatten, + linear, + losses, + padding, + pooling, +) class CMP(BackpropExtension): def __init__(self, curv_type, savefield="cmp"): self.curv_type = curv_type - super().__init__(savefield=savefield, - fail_mode="ERROR", - module_exts={ - MSELoss: losses.CMPMSELoss(), - CrossEntropyLoss: losses.CMPCrossEntropyLoss(), - Linear: linear.CMPLinear(), - LinearConcat: linear.CMPLinearConcat(), - MaxPool2d: pooling.CMPMaxpool2d(), - AvgPool2d: pooling.CMPAvgPool2d(), - ZeroPad2d: padding.CMPZeroPad2d(), - Conv2d: conv2d.CMPConv2d(), - Conv2dConcat: conv2d.CMPConv2dConcat(), - Dropout: dropout.CMPDropout(), - Flatten: flatten.CMPFlatten(), - ReLU: activations.CMPReLU(), - Sigmoid: activations.CMPSigmoid(), - Tanh: activations.CMPTanh(), - BatchNorm1d: batchnorm1d.CMPBatchNorm1d(), - }) + super().__init__( + savefield=savefield, + fail_mode="ERROR", + module_exts={ + MSELoss: losses.CMPMSELoss(), + CrossEntropyLoss: losses.CMPCrossEntropyLoss(), + Linear: linear.CMPLinear(), + MaxPool2d: pooling.CMPMaxpool2d(), + AvgPool2d: pooling.CMPAvgPool2d(), + ZeroPad2d: padding.CMPZeroPad2d(), + Conv2d: conv2d.CMPConv2d(), + Dropout: dropout.CMPDropout(), + Flatten: flatten.CMPFlatten(), + ReLU: activations.CMPReLU(), + Sigmoid: activations.CMPSigmoid(), + Tanh: activations.CMPTanh(), + BatchNorm1d: batchnorm1d.CMPBatchNorm1d(), + }, + ) def get_curv_type(self): return self.curv_type diff --git a/backpack/extensions/curvmatprod/activations.py b/backpack/extensions/curvmatprod/activations.py index 3b4be167..43e04b6b 100644 --- a/backpack/extensions/curvmatprod/activations.py +++ b/backpack/extensions/curvmatprod/activations.py @@ -1,6 +1,7 @@ from backpack.core.derivatives.relu import ReLUDerivatives -from backpack.core.derivatives.tanh import TanhDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives +from backpack.core.derivatives.tanh import TanhDerivatives + from .cmpbase import CMPBase diff --git a/backpack/extensions/curvmatprod/batchnorm1d.py b/backpack/extensions/curvmatprod/batchnorm1d.py index 71e1c4e9..436441c2 100644 --- a/backpack/extensions/curvmatprod/batchnorm1d.py +++ b/backpack/extensions/curvmatprod/batchnorm1d.py @@ -1,22 +1,22 @@ from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives - -from .cmpbase import CMPBase +from backpack.extensions.curvmatprod.cmpbase import CMPBase class CMPBatchNorm1d(CMPBase): def __init__(self): - super().__init__(derivatives=BatchNorm1dDerivatives(), - params=["weight", "bias"]) + super().__init__( + derivatives=BatchNorm1dDerivatives(), params=["weight", "bias"] + ) def weight(self, ext, module, g_inp, g_out, backproped): CMP_out = backproped def weight_cmp(mat): - Jmat = self.derivatives.weight_jac_mat_prod( - module, g_inp, g_out, mat) + Jmat = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) JTCJmat = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, CJmat) + module, g_inp, g_out, CJmat + ) return JTCJmat return weight_cmp @@ -25,11 +25,9 @@ def bias(self, ext, module, g_inp, g_out, backproped): CMP_out = backproped def bias_cmp(mat): - Jmat = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, - mat) + Jmat = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) - JTCJmat = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, CJmat) + JTCJmat = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, CJmat) return JTCJmat return bias_cmp diff --git a/backpack/extensions/curvmatprod/cmpbase.py b/backpack/extensions/curvmatprod/cmpbase.py index 51b5d8c6..2ec37817 100644 --- a/backpack/extensions/curvmatprod/cmpbase.py +++ b/backpack/extensions/curvmatprod/cmpbase.py @@ -1,55 +1,114 @@ -from backpack.core.derivatives.utils import hmp_unsqueeze_if_missing_dim +from backpack.core.derivatives.shape_check import ( + R_mat_prod_accept_vectors, + R_mat_prod_check_shapes, +) from backpack.extensions.curvature import Curvature from backpack.extensions.module_extension import ModuleExtension -from backpack.utils.utils import einsum +from backpack.utils.ein import einsum class CMPBase(ModuleExtension): """ Given matrix-vector product routine `MVP(A)` backpropagate to `MVP(J^T A J)`. """ + def __init__(self, derivatives, params=None): super().__init__(params=params) self.derivatives = derivatives def backpropagate(self, ext, module, g_inp, g_out, backproped): - CMP_out = backproped + """Backpropagate Hessian multiplication routines. + + Given mat → ℋz(x) mat, backpropagate mat → ℋx mat. + """ + GGN_mat_prod = self._make_GGN_mat_prod(ext, module, g_inp, g_out, backproped) - residual = self._second_order_module_effects(module, ext, g_inp, g_out) - residual_mod = self._modify_residual(ext, residual) + R_required = self._require_residual(ext, module, g_inp, g_out, backproped) + if R_required: + R_mat_prod = self._make_R_mat_prod(ext, module, g_inp, g_out, backproped) - @hmp_unsqueeze_if_missing_dim(mat_dim=3) def CMP_in(mat): - """Multiplication of curvature matrix with matrix `mat`. + """Multiplication with curvature matrix w.r.t. the module input. Parameters: ----------- mat : torch.Tensor Matrix that will be multiplied. """ + out = GGN_mat_prod(mat) + + if R_required: + out.add_(R_mat_prod(mat)) + + return out + + return CMP_in + + def _make_GGN_mat_prod(self, ext, module, g_inp, g_out, backproped): + """Return multiplication routine with the first HBP term.""" + CMP_out = backproped + + def GGN_mat_prod(mat): + """Multiply with the GGN term: mat → [𝒟z(x)ᵀ ℋz 𝒟z(x)] mat. + + First term of the module input Hessian backpropagation equation. + """ Jmat = self.derivatives.jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) - JTCJmat = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, - CJmat) - - if residual_mod is not None: - JTCJmat.add_(einsum('bi,bic->bic', (residual_mod, mat))) + JTCJmat = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, CJmat) return JTCJmat - return CMP_in + return GGN_mat_prod + + def _require_residual(self, ext, module, g_inp, g_out, backproped): + """Is the residual term required for multiply with the curvature?""" + vanishes = self.derivatives.hessian_is_zero() + neglected = not Curvature.require_residual(ext.get_curv_type()) + + return not (vanishes or neglected) + + def _make_R_mat_prod(self, ext, module, g_inp, g_out, backproped): + """Return multiplication routine with the second HBP term.""" + if self.derivatives.hessian_is_diagonal(): + R_mat_prod = self.__make_diagonal_R_mat_prod( + ext, module, g_inp, g_out, backproped + ) + else: + R_mat_prod = self.__make_nondiagonal_R_mat_prod( + ext, module, g_inp, g_out, backproped + ) + + return R_mat_prod + + def __make_diagonal_R_mat_prod(self, ext, module, g_inp, g_out, backproped): + # TODO Refactor core: hessian_diagonal -> residual_diagonal + R = self.derivatives.hessian_diagonal(module, g_inp, g_out) + R_mod = Curvature.modify_residual(R, ext.get_curv_type()) + + @R_mat_prod_accept_vectors + @R_mat_prod_check_shapes + def make_residual_mat_prod(self, module, g_inp, g_out): + def R_mat_prod(mat): + """Multiply with the residual: mat → [∑_{k} Hz_k(x) 𝛿z_k] mat. + + Second term of the module input Hessian backpropagation equation. + """ + return einsum("n...,vn...->vn...", (R_mod, mat)) - def _second_order_module_effects(self, module, ext, g_inp, g_out): - if self.derivatives.hessian_is_zero(): - return None - if not Curvature.require_residual(ext.get_curv_type()): - return None + return R_mat_prod - if not self.derivatives.hessian_is_diagonal(): - raise NotImplementedError( - "Residual terms are only supported for elementwise functions") + return make_residual_mat_prod(self, module, g_inp, g_out) - return self.derivatives.hessian_diagonal(module, g_inp, g_out) + def __make_nondiagonal_R_mat_prod(self, ext, module, g_inp, g_out, backproped): + curv_type = ext.get_curv_type() + if not Curvature.is_pch(curv_type): + R_mat_prod = self.derivatives.make_residual_mat_prod(module, g_inp, g_out) + else: + raise ValueError( + "{} not supported for {}. Residual cannot be cast PSD.".format( + curv_type, module + ) + ) - def _modify_residual(self, ext, residual): - return Curvature.modify_residual(residual, ext.get_curv_type()) + return R_mat_prod diff --git a/backpack/extensions/curvmatprod/conv2d.py b/backpack/extensions/curvmatprod/conv2d.py index 9cbbceb1..2d856bcd 100644 --- a/backpack/extensions/curvmatprod/conv2d.py +++ b/backpack/extensions/curvmatprod/conv2d.py @@ -1,21 +1,16 @@ -from backpack.core.derivatives.conv2d import Conv2DDerivatives, Conv2DConcatDerivatives -from .cmpbase import CMPBase +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.extensions.curvmatprod.cmpbase import CMPBase class CMPConv2d(CMPBase): def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - params=["weight", "bias"] - ) + super().__init__(derivatives=Conv2DDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): CMP_out = backproped def weight_cmp(mat): - Jmat = self.derivatives.weight_jac_mat_prod( - module, g_inp, g_out, mat - ) + Jmat = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) JTCJmat = self.derivatives.weight_jac_t_mat_prod( module, g_inp, g_out, CJmat @@ -28,36 +23,9 @@ def bias(self, ext, module, g_inp, g_out, backproped): CMP_out = backproped def bias_cmp(mat): - Jmat = self.derivatives.bias_jac_mat_prod( - module, g_inp, g_out, mat - ) + Jmat = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) - JTCJmat = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, CJmat - ) + JTCJmat = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, CJmat) return JTCJmat return bias_cmp - - -class CMPConv2dConcat(CMPBase): - def __init__(self): - super().__init__( - derivatives=Conv2DConcatDerivatives(), - params=["weight"] - ) - - def weight(self, ext, module, g_inp, g_out, backproped): - CMP_out = backproped - - def weight_cmp(mat): - Jmat = self.derivatives.weight_jac_mat_prod( - module, g_inp, g_out, mat - ) - CJmat = CMP_out(Jmat) - JTCJmat = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, CJmat - ) - return JTCJmat - - return weight_cmp diff --git a/backpack/extensions/curvmatprod/dropout.py b/backpack/extensions/curvmatprod/dropout.py index 734285bd..07ef1718 100644 --- a/backpack/extensions/curvmatprod/dropout.py +++ b/backpack/extensions/curvmatprod/dropout.py @@ -1,4 +1,5 @@ from backpack.core.derivatives.dropout import DropoutDerivatives + from .cmpbase import CMPBase diff --git a/backpack/extensions/curvmatprod/flatten.py b/backpack/extensions/curvmatprod/flatten.py index 93f13803..84a3c6ff 100644 --- a/backpack/extensions/curvmatprod/flatten.py +++ b/backpack/extensions/curvmatprod/flatten.py @@ -1,9 +1,14 @@ +from backpack.core.derivatives.flatten import FlattenDerivatives + from .cmpbase import CMPBase class CMPFlatten(CMPBase): def __init__(self): - super().__init__(derivatives=None) + super().__init__(derivatives=FlattenDerivatives()) - def backpropagate(self, ext, module, g_inp, g_out, backproped): - return backproped + def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/curvmatprod/linear.py b/backpack/extensions/curvmatprod/linear.py index fb67ac21..434c0b00 100644 --- a/backpack/extensions/curvmatprod/linear.py +++ b/backpack/extensions/curvmatprod/linear.py @@ -1,23 +1,20 @@ -from backpack.core.derivatives.linear import (LinearConcatDerivatives, - LinearDerivatives) - -from .cmpbase import CMPBase +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.curvmatprod.cmpbase import CMPBase class CMPLinear(CMPBase): def __init__(self): - super().__init__(derivatives=LinearDerivatives(), - params=["weight", "bias"]) + super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): CMP_out = backproped def weight_cmp(mat): - Jmat = self.derivatives.weight_jac_mat_prod( - module, g_inp, g_out, mat) + Jmat = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) JTCJmat = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, CJmat) + module, g_inp, g_out, CJmat + ) return JTCJmat return weight_cmp @@ -26,30 +23,9 @@ def bias(self, ext, module, g_inp, g_out, backproped): CMP_out = backproped def bias_cmp(mat): - Jmat = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, - mat) + Jmat = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) CJmat = CMP_out(Jmat) - JTCJmat = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, CJmat) + JTCJmat = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, CJmat) return JTCJmat return bias_cmp - - -class CMPLinearConcat(CMPBase): - def __init__(self): - super().__init__(derivatives=LinearConcatDerivatives(), - params=["weight"]) - - def weight(self, ext, module, g_inp, g_out, backproped): - CMP_out = backproped - - def weight_cmp(mat): - Jmat = self.derivatives.weight_jac_mat_prod( - module, g_inp, g_out, mat) - CJmat = CMP_out(Jmat) - JTCJmat = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, CJmat) - return JTCJmat - - return weight_cmp diff --git a/backpack/extensions/curvmatprod/losses.py b/backpack/extensions/curvmatprod/losses.py index 0969cacc..ba8b8bfe 100644 --- a/backpack/extensions/curvmatprod/losses.py +++ b/backpack/extensions/curvmatprod/losses.py @@ -1,17 +1,16 @@ -from backpack.core.derivatives.mseloss import MSELossDerivatives from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives -from .cmpbase import CMPBase +from backpack.core.derivatives.mseloss import MSELossDerivatives from backpack.extensions.curvature import Curvature +from backpack.extensions.curvmatprod.cmpbase import CMPBase class CMPLoss(CMPBase): def backpropagate(self, ext, module, g_inp, g_out, backproped): Curvature.check_loss_hessian( - self.derivatives.hessian_is_psd(), - curv_type=ext.get_curv_type() + self.derivatives.hessian_is_psd(), curv_type=ext.get_curv_type() ) - CMP = self.derivatives.hessian_matrix_product(module, g_inp, g_out) + CMP = self.derivatives.make_hessian_mat_prod(module, g_inp, g_out) return CMP diff --git a/backpack/extensions/curvmatprod/padding.py b/backpack/extensions/curvmatprod/padding.py index 9ea47634..df772242 100644 --- a/backpack/extensions/curvmatprod/padding.py +++ b/backpack/extensions/curvmatprod/padding.py @@ -1,4 +1,5 @@ from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives + from .cmpbase import CMPBase diff --git a/backpack/extensions/curvmatprod/pooling.py b/backpack/extensions/curvmatprod/pooling.py index 6f9de2b6..c88cb2c6 100644 --- a/backpack/extensions/curvmatprod/pooling.py +++ b/backpack/extensions/curvmatprod/pooling.py @@ -1,5 +1,6 @@ from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives + from .cmpbase import CMPBase @@ -11,4 +12,3 @@ def __init__(self): class CMPMaxpool2d(CMPBase, MaxPool2DDerivatives): def __init__(self): super().__init__(derivatives=MaxPool2DDerivatives()) - diff --git a/backpack/extensions/firstorder/__init__.py b/backpack/extensions/firstorder/__init__.py index ebcf4199..2e876f3f 100644 --- a/backpack/extensions/firstorder/__init__.py +++ b/backpack/extensions/firstorder/__init__.py @@ -12,9 +12,9 @@ - `BatchL2Grad`: The L2 norm of the individual gradients """ -from .batch_l2_grad import BatchL2Grad from .batch_grad import BatchGrad +from .batch_l2_grad import BatchL2Grad from .sum_grad_squared import SumGradSquared from .variance import Variance - +__all__ = ["BatchL2Grad", "BatchGrad", "SumGradSquared", "Variance"] diff --git a/backpack/extensions/firstorder/base.py b/backpack/extensions/firstorder/base.py index ce0c59e3..e3c62d08 100644 --- a/backpack/extensions/firstorder/base.py +++ b/backpack/extensions/firstorder/base.py @@ -2,6 +2,5 @@ class FirstOrderModuleExtension(ModuleExtension): - def backpropagate(self, ext, module, g_inp, g_out, bpQuantities): return None diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index 5659dd2f..0e14aec1 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -1,7 +1,7 @@ -from backpack.core.layers import Conv2dConcat, LinearConcat -from backpack.extensions.backprop_extension import BackpropExtension from torch.nn import BatchNorm1d, Conv2d, Linear +from backpack.extensions.backprop_extension import BackpropExtension + from . import batchnorm1d, conv2d, linear @@ -14,13 +14,14 @@ class BatchGrad(BackpropExtension): where :code:`N` is the size of the minibatch and :code:`...` is the shape of the gradient. """ + def __init__(self): - super().__init__(savefield="grad_batch", - fail_mode="WARNING", - module_exts={ - Linear: linear.BatchGradLinear(), - LinearConcat: linear.BatchGradLinearConcat(), - Conv2d: conv2d.BatchGradConv2d(), - Conv2dConcat: conv2d.BatchGradConv2dConcat(), - BatchNorm1d: batchnorm1d.BatchGradBatchNorm1d(), - }) + super().__init__( + savefield="grad_batch", + fail_mode="WARNING", + module_exts={ + Linear: linear.BatchGradLinear(), + Conv2d: conv2d.BatchGradConv2d(), + BatchNorm1d: batchnorm1d.BatchGradBatchNorm1d(), + }, + ) diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index d7faaef0..1e25de41 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -2,27 +2,16 @@ class BatchGradBase(FirstOrderModuleExtension): - def __init__(self, derivatives, params=None): self.derivatives = derivatives super().__init__(params=params) def bias(self, ext, module, g_inp, g_out, bpQuantities): - batch = g_out[0].shape[0] - grad_out_vec = g_out[0].contiguous().view(batch, -1) - - bias_grad = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, grad_out_vec, sum_batch=False + return self.derivatives.bias_jac_t_mat_prod( + module, g_inp, g_out, g_out[0], sum_batch=False ) - return bias_grad.view((batch,) + module.bias.shape) - def weight(self, ext, module, g_inp, g_out, bpQuantities): - batch = g_out[0].shape[0] - grad_out_vec = g_out[0].contiguous().view(batch, -1) - - weight_grad = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, grad_out_vec, sum_batch=False + return self.derivatives.weight_jac_t_mat_prod( + module, g_inp, g_out, g_out[0], sum_batch=False ) - - return weight_grad.view((batch,) + module.weight.shape) diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py b/backpack/extensions/firstorder/batch_grad/batchnorm1d.py index 54b7f777..74d8737d 100644 --- a/backpack/extensions/firstorder/batch_grad/batchnorm1d.py +++ b/backpack/extensions/firstorder/batch_grad/batchnorm1d.py @@ -1,9 +1,9 @@ from backpack.core.derivatives.batchnorm1d import BatchNorm1dDerivatives -from backpack.extensions.firstorder.batch_grad.batch_grad_base import \ - BatchGradBase +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase class BatchGradBatchNorm1d(BatchGradBase): def __init__(self): - super().__init__(derivatives=BatchNorm1dDerivatives(), - params=["bias", "weight"]) + super().__init__( + derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"] + ) diff --git a/backpack/extensions/firstorder/batch_grad/conv2d.py b/backpack/extensions/firstorder/batch_grad/conv2d.py index 51e16a28..5ae372a2 100644 --- a/backpack/extensions/firstorder/batch_grad/conv2d.py +++ b/backpack/extensions/firstorder/batch_grad/conv2d.py @@ -1,20 +1,7 @@ -from backpack.core.derivatives.conv2d import (Conv2DDerivatives, - Conv2DConcatDerivatives) - +from backpack.core.derivatives.conv2d import Conv2DDerivatives from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase class BatchGradConv2d(BatchGradBase): def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - params=["bias", "weight"] - ) - - -class BatchGradConv2dConcat(BatchGradBase): - def __init__(self): - super().__init__( - derivatives=Conv2DConcatDerivatives(), - params=["weight"] - ) + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_grad/linear.py b/backpack/extensions/firstorder/batch_grad/linear.py index 1c6b27d1..6b58a0b7 100644 --- a/backpack/extensions/firstorder/batch_grad/linear.py +++ b/backpack/extensions/firstorder/batch_grad/linear.py @@ -1,20 +1,7 @@ -from backpack.core.derivatives.linear import (LinearDerivatives, - LinearConcatDerivatives) - +from backpack.core.derivatives.linear import LinearDerivatives from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase class BatchGradLinear(BatchGradBase): def __init__(self): - super().__init__( - derivatives=LinearDerivatives(), - params=["bias", "weight"] - ) - - -class BatchGradLinearConcat(BatchGradBase): - def __init__(self): - super().__init__( - derivatives=LinearConcatDerivatives(), - params=["weight"] - ) + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 3b32801a..a275725e 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -1,8 +1,8 @@ +from torch.nn import Conv2d, Linear + from backpack.extensions.backprop_extension import BackpropExtension -from backpack.core.layers import Conv2dConcat, LinearConcat -from torch.nn import Linear, Conv2d -from . import linear, conv2d +from . import conv2d, linear class BatchL2Grad(BackpropExtension): @@ -13,15 +13,13 @@ class BatchL2Grad(BackpropExtension): Stores the output in :code:`batch_l2` as a vector of the size as the minibatch. """ + def __init__(self): super().__init__( savefield="batch_l2", fail_mode="WARNING", module_exts={ Linear: linear.BatchL2Linear(), - LinearConcat: linear.BatchL2LinearConcat(), Conv2d: conv2d.BatchL2Conv2d(), - Conv2dConcat: conv2d.BatchL2Conv2dConcat(), - } + }, ) - diff --git a/backpack/extensions/firstorder/batch_l2_grad/conv2d.py b/backpack/extensions/firstorder/batch_l2_grad/conv2d.py index bd6018eb..4f18fe16 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/conv2d.py +++ b/backpack/extensions/firstorder/batch_l2_grad/conv2d.py @@ -1,6 +1,6 @@ -from backpack.utils.utils import einsum -from backpack.utils import conv as convUtils from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.utils import conv as convUtils +from backpack.utils.ein import einsum class BatchL2Conv2d(FirstOrderModuleExtension): @@ -8,23 +8,11 @@ def __init__(self): super().__init__(params=["bias", "weight"]) def bias(self, ext, module, g_inp, g_out, backproped): - return (g_out[0].sum(3).sum(2)**2).sum(1) - - def weight(self, ext, module, g_inp, g_out, backproped): - X, dE_dY = convUtils.get_weight_gradient_factors( - module.input0, g_out[0], module) - return einsum('bml,bkl,bmi,bki->b', (dE_dY, X, dE_dY, X)) - - -class BatchL2Conv2dConcat(FirstOrderModuleExtension): - def __init__(self): - super().__init__(params=["weight"]) + C_axis = 1 + return (einsum("nchw->nc", g_out[0]) ** 2).sum(C_axis) def weight(self, ext, module, g_inp, g_out, backproped): X, dE_dY = convUtils.get_weight_gradient_factors( - module.input0, g_out[0], module) - - if module.has_bias(): - X = module.append_ones(X) - - return einsum('bml,bkl,bmi,bki->b', (dE_dY, X, dE_dY, X)) + module.input0, g_out[0], module + ) + return einsum("nml,nkl,nmi,nki->n", (dE_dY, X, dE_dY, X)) diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index b0c90ffb..85b55f50 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -1,5 +1,5 @@ -from backpack.utils.utils import einsum from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.utils.ein import einsum class BatchL2Linear(FirstOrderModuleExtension): @@ -7,16 +7,8 @@ def __init__(self): super().__init__(params=["bias", "weight"]) def bias(self, ext, module, g_inp, g_out, backproped): - return (g_out[0] ** 2).sum(1) + C_axis = 1 + return (g_out[0] ** 2).sum(C_axis) def weight(self, ext, module, g_inp, g_out, backproped): - return einsum('bi,bj->b', (g_out[0] ** 2, module.input0 ** 2)) - - -class BatchL2LinearConcat(FirstOrderModuleExtension): - def __init__(self): - super().__init__(params=["weight"]) - - def weight(self, ext, module, g_inp, g_out, backproped): - input = module.homogeneous_input() - return einsum('bi,bj->b', (g_out[0] ** 2, input ** 2)) + return einsum("ni,nj->n", (g_out[0] ** 2, module.input0 ** 2)) diff --git a/backpack/extensions/firstorder/gradient/__init__.py b/backpack/extensions/firstorder/gradient/__init__.py index 8b137891..7a522883 100644 --- a/backpack/extensions/firstorder/gradient/__init__.py +++ b/backpack/extensions/firstorder/gradient/__init__.py @@ -1 +1 @@ - +# TODO: Rewrite variance to not need this extension diff --git a/backpack/extensions/firstorder/gradient/base.py b/backpack/extensions/firstorder/gradient/base.py index 0b0ef5f4..3ebaf893 100644 --- a/backpack/extensions/firstorder/gradient/base.py +++ b/backpack/extensions/firstorder/gradient/base.py @@ -2,27 +2,16 @@ class GradBaseModule(FirstOrderModuleExtension): - def __init__(self, derivatives, params=None): self.derivatives = derivatives super().__init__(params=params) def bias(self, ext, module, g_inp, g_out, bpQuantities): - batch = g_out[0].shape[0] - grad_out_vec = g_out[0].contiguous().view(batch, -1) - - bias_grad = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, grad_out_vec + return self.derivatives.bias_jac_t_mat_prod( + module, g_inp, g_out, g_out[0], sum_batch=True ) - return bias_grad.view(module.bias.shape) - def weight(self, ext, module, g_inp, g_out, bpQuantities): - batch = g_out[0].shape[0] - grad_out_vec = g_out[0].contiguous().view(batch, -1) - - weight_grad = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, grad_out_vec + return self.derivatives.weight_jac_t_mat_prod( + module, g_inp, g_out, g_out[0], sum_batch=True ) - - return weight_grad.view(module.weight.shape) diff --git a/backpack/extensions/firstorder/gradient/batchnorm1d.py b/backpack/extensions/firstorder/gradient/batchnorm1d.py index 5f17f6f6..5e0f3b6f 100644 --- a/backpack/extensions/firstorder/gradient/batchnorm1d.py +++ b/backpack/extensions/firstorder/gradient/batchnorm1d.py @@ -5,5 +5,6 @@ class GradBatchNorm1d(GradBaseModule): def __init__(self): - super().__init__(derivatives=BatchNorm1dDerivatives(), - params=["bias", "weight"]) + super().__init__( + derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"] + ) diff --git a/backpack/extensions/firstorder/gradient/conv2d.py b/backpack/extensions/firstorder/gradient/conv2d.py index 707e0543..a1f6eeb2 100644 --- a/backpack/extensions/firstorder/gradient/conv2d.py +++ b/backpack/extensions/firstorder/gradient/conv2d.py @@ -1,20 +1,8 @@ -from backpack.core.derivatives.conv2d import (Conv2DDerivatives, - Conv2DConcatDerivatives) +from backpack.core.derivatives.conv2d import Conv2DDerivatives from .base import GradBaseModule class GradConv2d(GradBaseModule): def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - params=["bias", "weight"] - ) - - -class GradConv2dConcat(GradBaseModule): - def __init__(self): - super().__init__( - derivatives=Conv2DConcatDerivatives(), - params=["weight"] - ) + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/gradient/linear.py b/backpack/extensions/firstorder/gradient/linear.py index 348f046c..31088c5e 100644 --- a/backpack/extensions/firstorder/gradient/linear.py +++ b/backpack/extensions/firstorder/gradient/linear.py @@ -1,19 +1,8 @@ -from backpack.core.derivatives.linear import (LinearDerivatives, - LinearConcatDerivatives) +from backpack.core.derivatives.linear import LinearDerivatives + from .base import GradBaseModule class GradLinear(GradBaseModule): def __init__(self): - super().__init__( - derivatives=LinearDerivatives(), - params=["bias", "weight"] - ) - - -class GradLinearConcat(GradBaseModule): - def __init__(self): - super().__init__( - derivatives=LinearConcatDerivatives(), - params=["weight"] - ) + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index d8c5677e..884b9a04 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -1,8 +1,8 @@ +from torch.nn import Conv2d, Linear + from backpack.extensions.backprop_extension import BackpropExtension -from backpack.core.layers import Conv2dConcat, LinearConcat -from torch.nn import Linear, Conv2d -from . import linear, conv2d +from . import conv2d, linear class SumGradSquared(BackpropExtension): @@ -10,18 +10,12 @@ class SumGradSquared(BackpropExtension): The sum of individual-gradients-squared, or second moment of the gradient. Is only meaningful is the individual functions are independent (no batchnorm). - Stores the output in :code:`sum_grad_squared`, has the same dimension as the gradient. + Stores the output in :code:`sum_grad_squared`. Same dimension as the gradient. """ def __init__(self): super().__init__( savefield="sum_grad_squared", fail_mode="WARNING", - module_exts={ - Linear: linear.SGSLinear(), - LinearConcat: linear.SGSLinearConcat(), - Conv2d: conv2d.SGSConv2d(), - Conv2dConcat: conv2d.SGSConv2dConcat(), - } + module_exts={Linear: linear.SGSLinear(), Conv2d: conv2d.SGSConv2d(),}, ) - diff --git a/backpack/extensions/firstorder/sum_grad_squared/conv2d.py b/backpack/extensions/firstorder/sum_grad_squared/conv2d.py index 7ac29e00..7288bbbf 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/conv2d.py +++ b/backpack/extensions/firstorder/sum_grad_squared/conv2d.py @@ -1,6 +1,6 @@ -from backpack.utils.utils import einsum -from backpack.utils import conv as convUtils from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.utils import conv as convUtils +from backpack.utils.ein import einsum class SGSConv2d(FirstOrderModuleExtension): @@ -8,26 +8,13 @@ def __init__(self): super().__init__(params=["bias", "weight"]) def bias(self, ext, module, g_inp, g_out, backproped): - return (g_out[0].sum(3).sum(2)**2).sum(0) + N_axis = 0 + return (einsum("nchw->nc", g_out[0]) ** 2).sum(N_axis) def weight(self, ext, module, g_inp, g_out, backproped): + N_axis = 0 X, dE_dY = convUtils.get_weight_gradient_factors( - module.input0, g_out[0], module) - d1 = einsum('bml,bkl->bmk', (dE_dY, X)) - return (d1**2).sum(0).view_as(module.weight) - - -class SGSConv2dConcat(FirstOrderModuleExtension): - def __init__(self): - super().__init__(params=["weight"]) - - def weight(self, ext, module, g_inp, g_out, backproped): - X, dE_dY = convUtils.get_weight_gradient_factors( - module.input0, g_out[0], module) - - if module.has_bias(): - X = module.append_ones(X) - - d1 = einsum('bml,bkl->bmk', (dE_dY, X)) - return (d1**2).sum(0).view_as(module.weight) - + module.input0, g_out[0], module + ) + d1 = einsum("nml,nkl->nmk", (dE_dY, X)) + return (d1 ** 2).sum(N_axis).view_as(module.weight) diff --git a/backpack/extensions/firstorder/sum_grad_squared/linear.py b/backpack/extensions/firstorder/sum_grad_squared/linear.py index 2313b312..2bcbdeef 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/linear.py +++ b/backpack/extensions/firstorder/sum_grad_squared/linear.py @@ -1,5 +1,5 @@ -from backpack.utils.utils import einsum from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.utils.ein import einsum class SGSLinear(FirstOrderModuleExtension): @@ -7,16 +7,8 @@ def __init__(self): super().__init__(params=["bias", "weight"]) def bias(self, ext, module, g_inp, g_out, backproped): - return (g_out[0] ** 2).sum(0) + N_axis = 0 + return (g_out[0] ** 2).sum(N_axis) def weight(self, ext, module, g_inp, g_out, backproped): - return einsum('bi,bj->ij', (g_out[0] ** 2, module.input0 ** 2)) - - -class SGSLinearConcat(FirstOrderModuleExtension): - def __init__(self): - super().__init__(params=["weight"]) - - def weight(self, ext, module, g_inp, g_out, backproped): - input = module.homogeneous_input() - return einsum('bi,bj->ij', (g_out[0] ** 2, input ** 2)) + return einsum("ni,nj->ij", (g_out[0] ** 2, module.input0 ** 2)) diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index 95f7f92a..8ea89215 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -1,8 +1,8 @@ +from torch.nn import Conv2d, Linear + from backpack.extensions.backprop_extension import BackpropExtension -from backpack.core.layers import Conv2dConcat, LinearConcat -from torch.nn import Linear, Conv2d -from . import linear, conv2d +from . import conv2d, linear class Variance(BackpropExtension): @@ -12,15 +12,13 @@ class Variance(BackpropExtension): Stores the output in :code:`variance`, has the same dimension as the gradient. """ + def __init__(self): super().__init__( savefield="variance", fail_mode="WARNING", module_exts={ Linear: linear.VarianceLinear(), - LinearConcat: linear.VarianceLinearConcat(), Conv2d: conv2d.VarianceConv2d(), - Conv2dConcat: conv2d.VarianceConv2dConcat(), - } + }, ) - diff --git a/backpack/extensions/firstorder/variance/conv2d.py b/backpack/extensions/firstorder/variance/conv2d.py index 1867dca9..9a4e852d 100644 --- a/backpack/extensions/firstorder/variance/conv2d.py +++ b/backpack/extensions/firstorder/variance/conv2d.py @@ -1,9 +1,6 @@ -from backpack.extensions.firstorder.gradient.conv2d import ( - GradConv2dConcat, GradConv2d -) -from backpack.extensions.firstorder.sum_grad_squared.conv2d import ( - SGSConv2dConcat, SGSConv2d -) +from backpack.extensions.firstorder.gradient.conv2d import GradConv2d +from backpack.extensions.firstorder.sum_grad_squared.conv2d import SGSConv2d + from .variance_base import VarianceBaseModule @@ -12,14 +9,5 @@ def __init__(self): super().__init__( params=["bias", "weight"], grad_extension=GradConv2d(), - sgs_extension=SGSConv2d() - ) - - -class VarianceConv2dConcat(VarianceBaseModule): - def __init__(self): - super().__init__( - params=["weight"], - grad_extension=GradConv2dConcat(), - sgs_extension=SGSConv2dConcat() + sgs_extension=SGSConv2d(), ) diff --git a/backpack/extensions/firstorder/variance/linear.py b/backpack/extensions/firstorder/variance/linear.py index 38f6dd8a..8296158c 100644 --- a/backpack/extensions/firstorder/variance/linear.py +++ b/backpack/extensions/firstorder/variance/linear.py @@ -1,9 +1,6 @@ -from backpack.extensions.firstorder.gradient.linear import ( - GradLinear, GradLinearConcat -) -from backpack.extensions.firstorder.sum_grad_squared.linear import ( - SGSLinear, SGSLinearConcat -) +from backpack.extensions.firstorder.gradient.linear import GradLinear +from backpack.extensions.firstorder.sum_grad_squared.linear import SGSLinear + from .variance_base import VarianceBaseModule @@ -12,14 +9,5 @@ def __init__(self): super().__init__( params=["bias", "weight"], grad_extension=GradLinear(), - sgs_extension=SGSLinear() - ) - - -class VarianceLinearConcat(VarianceBaseModule): - def __init__(self): - super().__init__( - params=["weight"], - grad_extension=GradLinearConcat(), - sgs_extension=SGSLinearConcat() + sgs_extension=SGSLinear(), ) diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py index 2f21bb3c..64d8c17e 100644 --- a/backpack/extensions/firstorder/variance/variance_base.py +++ b/backpack/extensions/firstorder/variance/variance_base.py @@ -18,7 +18,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): return self.variance_from( self.grad_ext.bias(ext, module, g_inp, g_out, backproped), self.sgs_ext.bias(ext, module, g_inp, g_out, backproped), - N + N, ) def weight(self, ext, module, g_inp, g_out, backproped): @@ -26,5 +26,5 @@ def weight(self, ext, module, g_inp, g_out, backproped): return self.variance_from( self.grad_ext.weight(ext, module, g_inp, g_out, backproped), self.sgs_ext.weight(ext, module, g_inp, g_out, backproped), - N + N, ) diff --git a/backpack/extensions/mat_to_mat_jac_base.py b/backpack/extensions/mat_to_mat_jac_base.py index 8910f2c4..ca9d214e 100644 --- a/backpack/extensions/mat_to_mat_jac_base.py +++ b/backpack/extensions/mat_to_mat_jac_base.py @@ -5,20 +5,16 @@ class MatToJacMat(ModuleExtension): """ Base class for backpropagating matrices by multiplying with Jacobians. """ + def __init__(self, derivatives, params=None): super().__init__(params) self.derivatives = derivatives def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - if self.derivatives is None: - return backproped - if isinstance(backproped, list): M_list = [ - self.derivatives.jac_t_mat_prod( - module, grad_inp, grad_out, M - ) + self.derivatives.jac_t_mat_prod(module, grad_inp, grad_out, M) for M in backproped ] return list(M_list) diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index e0091920..1ed811a3 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -1,7 +1,4 @@ import warnings -from backpack.context import get_from_ctx, set_in_ctx - -SAVE_BP_QUANTITIES_IN_COMPUTATION_GRAPH = True class ModuleExtension: @@ -71,40 +68,30 @@ def apply(self, ext, module, g_inp, g_out): for param in self.__params: if self.__param_exists_and_requires_grad(module, param): extFunc = getattr(self, param) - extValue = extFunc( - ext, module, g_inp, g_out, bpQuantities - ) + extValue = extFunc(ext, module, g_inp, g_out, bpQuantities) self.__save(extValue, ext, module, param) - bpQuantities = self.backpropagate( - ext, module, g_inp, g_out, bpQuantities - ) + bpQuantities = self.backpropagate(ext, module, g_inp, g_out, bpQuantities) self.__backprop_quantities(ext, inp, out, bpQuantities) @staticmethod def __backproped_quantities(ext, out): - if not SAVE_BP_QUANTITIES_IN_COMPUTATION_GRAPH: - return get_from_ctx(ext.savefield) - else: - return getattr(out, ext.savefield, None) + return getattr(out, ext.savefield, None) @staticmethod def __backprop_quantities(ext, inp, out, bpQuantities): - if not SAVE_BP_QUANTITIES_IN_COMPUTATION_GRAPH: - set_in_ctx(ext.savefield, bpQuantities) - else: - setattr(inp, ext.savefield, bpQuantities) + setattr(inp, ext.savefield, bpQuantities) - is_a_leaf = out.grad_fn is None - retain_grad_is_on = getattr(out, "retains_grad", False) - inp_is_out = id(inp) == id(out) - should_retain_grad = is_a_leaf or retain_grad_is_on or inp_is_out + is_a_leaf = out.grad_fn is None + retain_grad_is_on = getattr(out, "retains_grad", False) + inp_is_out = id(inp) == id(out) + should_retain_grad = is_a_leaf or retain_grad_is_on or inp_is_out - if not should_retain_grad: - if hasattr(out, ext.savefield): - delattr(out, ext.savefield) + if not should_retain_grad: + if hasattr(out, ext.savefield): + delattr(out, ext.savefield) @staticmethod def __param_exists_and_requires_grad(module, param): diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index 12edeae0..42aba57f 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -19,3 +19,14 @@ from .diag_ggn import DiagGGN, DiagGGNExact, DiagGGNMC from .diag_hessian import DiagHessian from .hbp import HBP, KFAC, KFLR, KFRA + +__all__ = [ + "DiagGGN", + "DiagGGNExact", + "DiagGGNMC", + "DiagHessian", + "HBP", + "KFAC", + "KFLR", + "KFRA", +] diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index ea2de8ce..e80f0bb2 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -1,50 +1,58 @@ -from torch.nn import (AvgPool2d, Conv2d, CrossEntropyLoss, Dropout, Linear, - MaxPool2d, MSELoss, ReLU, Sigmoid, Tanh, ZeroPad2d) +from torch.nn import ( + AvgPool2d, + Conv2d, + CrossEntropyLoss, + Dropout, + Flatten, + Linear, + MaxPool2d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) -from backpack.core.layers import Conv2dConcat, Flatten, LinearConcat from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.secondorder.hbp import LossHessianStrategy -from . import (activations, conv2d, dropout, flatten, linear, losses, padding, - pooling) +from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling class DiagGGN(BackpropExtension): VALID_LOSS_HESSIAN_STRATEGIES = [ - LossHessianStrategy.EXACT, LossHessianStrategy.SAMPLING + LossHessianStrategy.EXACT, + LossHessianStrategy.SAMPLING, ] - def __init__(self, - loss_hessian_strategy=LossHessianStrategy.EXACT, - savefield=None): + def __init__(self, loss_hessian_strategy=LossHessianStrategy.EXACT, savefield=None): if savefield is None: savefield = "diag_ggn" if loss_hessian_strategy not in self.VALID_LOSS_HESSIAN_STRATEGIES: raise ValueError( - "Unknown hessian strategy: {}".format(loss_hessian_strategy) + - "Valid strategies: [{}]".format( - self.VALID_LOSS_HESSIAN_STRATEGIES)) + "Unknown hessian strategy: {}".format(loss_hessian_strategy) + + "Valid strategies: [{}]".format(self.VALID_LOSS_HESSIAN_STRATEGIES) + ) self.loss_hessian_strategy = loss_hessian_strategy - super().__init__(savefield=savefield, - fail_mode="ERROR", - module_exts={ - MSELoss: losses.DiagGGNMSELoss(), - CrossEntropyLoss: - losses.DiagGGNCrossEntropyLoss(), - Linear: linear.DiagGGNLinear(), - LinearConcat: linear.DiagGGNLinearConcat(), - MaxPool2d: pooling.DiagGGNMaxPool2d(), - AvgPool2d: pooling.DiagGGNAvgPool2d(), - ZeroPad2d: padding.DiagGGNZeroPad2d(), - Conv2d: conv2d.DiagGGNConv2d(), - Conv2dConcat: conv2d.DiagGGNConv2dConcat(), - Dropout: dropout.DiagGGNDropout(), - Flatten: flatten.DiagGGNFlatten(), - ReLU: activations.DiagGGNReLU(), - Sigmoid: activations.DiagGGNSigmoid(), - Tanh: activations.DiagGGNTanh(), - }) + super().__init__( + savefield=savefield, + fail_mode="ERROR", + module_exts={ + MSELoss: losses.DiagGGNMSELoss(), + CrossEntropyLoss: losses.DiagGGNCrossEntropyLoss(), + Linear: linear.DiagGGNLinear(), + MaxPool2d: pooling.DiagGGNMaxPool2d(), + AvgPool2d: pooling.DiagGGNAvgPool2d(), + ZeroPad2d: padding.DiagGGNZeroPad2d(), + Conv2d: conv2d.DiagGGNConv2d(), + Dropout: dropout.DiagGGNDropout(), + Flatten: flatten.DiagGGNFlatten(), + ReLU: activations.DiagGGNReLU(), + Sigmoid: activations.DiagGGNSigmoid(), + Tanh: activations.DiagGGNTanh(), + }, + ) class DiagGGNExact(DiagGGN): @@ -59,9 +67,11 @@ class DiagGGNExact(DiagGGN): see :py:meth:`backpack.extensions.DiagGGNMC`. """ + def __init__(self): - super().__init__(loss_hessian_strategy=LossHessianStrategy.EXACT, - savefield="diag_ggn_exact") + super().__init__( + loss_hessian_strategy=LossHessianStrategy.EXACT, savefield="diag_ggn_exact" + ) class DiagGGNMC(DiagGGN): @@ -77,6 +87,12 @@ class DiagGGNMC(DiagGGN): see :py:meth:`backpack.extensions.DiagGGNExact`. """ - def __init__(self): - super().__init__(loss_hessian_strategy=LossHessianStrategy.SAMPLING, - savefield="diag_ggn_mc") + + def __init__(self, mc_samples=1): + self._mc_samples = mc_samples + super().__init__( + loss_hessian_strategy=LossHessianStrategy.SAMPLING, savefield="diag_ggn_mc" + ) + + def get_num_mc_samples(self): + return self._mc_samples diff --git a/backpack/extensions/secondorder/diag_ggn/activations.py b/backpack/extensions/secondorder/diag_ggn/activations.py index ccda568e..649b6348 100644 --- a/backpack/extensions/secondorder/diag_ggn/activations.py +++ b/backpack/extensions/secondorder/diag_ggn/activations.py @@ -1,7 +1,7 @@ from backpack.core.derivatives.relu import ReLUDerivatives -from backpack.core.derivatives.tanh import TanhDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives -from .diag_ggn_base import DiagGGNBaseModule +from backpack.core.derivatives.tanh import TanhDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule class DiagGGNReLU(DiagGGNBaseModule): diff --git a/backpack/extensions/secondorder/diag_ggn/conv2d.py b/backpack/extensions/secondorder/diag_ggn/conv2d.py index ecc1f0f7..9f4bd29f 100644 --- a/backpack/extensions/secondorder/diag_ggn/conv2d.py +++ b/backpack/extensions/secondorder/diag_ggn/conv2d.py @@ -1,38 +1,17 @@ +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule from backpack.utils import conv as convUtils -from backpack.utils.utils import einsum -from backpack.core.derivatives.conv2d import Conv2DDerivatives, Conv2DConcatDerivatives -from .diag_ggn_base import DiagGGNBaseModule class DiagGGNConv2d(DiagGGNBaseModule): def __init__(self): - super().__init__( - derivatives=Conv2DDerivatives(), - params=["bias", "weight"] - ) + super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) def bias(self, ext, module, grad_inp, grad_out, backproped): - sqrt_ggn = convUtils.separate_channels_and_pixels(module, backproped) - return einsum('bijc,bikc->i', (sqrt_ggn, sqrt_ggn)) + sqrt_ggn = backproped + return convUtils.extract_bias_diagonal(module, sqrt_ggn) def weight(self, ext, module, grad_inp, grad_out, backproped): X = convUtils.unfold_func(module)(module.input0) weight_diag = convUtils.extract_weight_diagonal(module, X, backproped) - return weight_diag .view_as(module.weight) - - -class DiagGGNConv2dConcat(DiagGGNBaseModule): - def __init__(self): - super().__init__( - derivatives=Conv2DConcatDerivatives(), - params=["weight"] - ) - - def weight(self, ext, module, grad_inp, grad_out, backproped): - X = convUtils.unfold_func(module)(module.input0) - if module.has_bias: - X = module.append_ones(X) - - weight_diag = convUtils.extract_weight_diagonal(module, X, backproped) - - return weight_diag.view_as(module.weight) + return weight_diag diff --git a/backpack/extensions/secondorder/diag_ggn/dropout.py b/backpack/extensions/secondorder/diag_ggn/dropout.py index a3d0121c..bd19598b 100644 --- a/backpack/extensions/secondorder/diag_ggn/dropout.py +++ b/backpack/extensions/secondorder/diag_ggn/dropout.py @@ -1,5 +1,5 @@ from backpack.core.derivatives.dropout import DropoutDerivatives -from .diag_ggn_base import DiagGGNBaseModule +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule class DiagGGNDropout(DiagGGNBaseModule): diff --git a/backpack/extensions/secondorder/diag_ggn/flatten.py b/backpack/extensions/secondorder/diag_ggn/flatten.py index a98c4ffe..60c1ca8d 100644 --- a/backpack/extensions/secondorder/diag_ggn/flatten.py +++ b/backpack/extensions/secondorder/diag_ggn/flatten.py @@ -1,9 +1,13 @@ -from .diag_ggn_base import DiagGGNBaseModule +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule class DiagGGNFlatten(DiagGGNBaseModule): def __init__(self): - super().__init__(derivatives=None) + super().__init__(derivatives=FlattenDerivatives()) def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - return backproped \ No newline at end of file + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/diag_ggn/linear.py b/backpack/extensions/secondorder/diag_ggn/linear.py index a571bd85..b5119864 100644 --- a/backpack/extensions/secondorder/diag_ggn/linear.py +++ b/backpack/extensions/secondorder/diag_ggn/linear.py @@ -1,29 +1,14 @@ -from backpack.core.derivatives.linear import LinearDerivatives, LinearConcatDerivatives -from backpack.utils.utils import einsum -from .diag_ggn_base import DiagGGNBaseModule +import backpack.utils.linear as LinUtils +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule class DiagGGNLinear(DiagGGNBaseModule): def __init__(self): - super().__init__( - derivatives=LinearDerivatives(), - params=["bias", "weight"] - ) + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) def bias(self, ext, module, grad_inp, grad_out, backproped): - return einsum('bic->i', (backproped ** 2,)) + return LinUtils.extract_bias_diagonal(module, backproped) def weight(self, ext, module, grad_inp, grad_out, backproped): - return einsum('bic,bj->ij', (backproped ** 2, module.input0 ** 2)) - - -class DiagGGNLinearConcat(DiagGGNBaseModule): - def __init__(self): - super().__init__( - derivatives=LinearConcatDerivatives(), - params=["weight"] - ) - - def weight(self, ext, module, grad_inp, grad_out, backproped): - input = module.homogeneous_input() - return einsum('bic,bj->ij', (backproped ** 2, input ** 2)) + return LinUtils.extract_weight_diagonal(module, backproped) diff --git a/backpack/extensions/secondorder/diag_ggn/losses.py b/backpack/extensions/secondorder/diag_ggn/losses.py index 51b56522..377adb52 100644 --- a/backpack/extensions/secondorder/diag_ggn/losses.py +++ b/backpack/extensions/secondorder/diag_ggn/losses.py @@ -1,24 +1,32 @@ -from backpack.core.derivatives.mseloss import MSELossDerivatives +from functools import partial + from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives +from backpack.core.derivatives.mseloss import MSELossDerivatives +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule from backpack.extensions.secondorder.hbp import LossHessianStrategy -from .diag_ggn_base import DiagGGNBaseModule - class DiagGGNLoss(DiagGGNBaseModule): def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + hess_func = self.make_loss_hessian_func(ext) + + return hess_func(module, grad_inp, grad_out) + + def make_loss_hessian_func(self, ext): + """Get function that produces the backpropagated quantity.""" + loss_hessian_strategy = ext.loss_hessian_strategy + + if loss_hessian_strategy == LossHessianStrategy.EXACT: + return self.derivatives.sqrt_hessian + elif loss_hessian_strategy == LossHessianStrategy.SAMPLING: + mc_samples = ext.get_num_mc_samples() + return partial(self.derivatives.sqrt_hessian_sampled, mc_samples=mc_samples) - if ext.loss_hessian_strategy == LossHessianStrategy.EXACT: - hess_func = self.derivatives.sqrt_hessian - elif ext.loss_hessian_strategy == LossHessianStrategy.SAMPLING: - hess_func = self.derivatives.sqrt_hessian_sampled else: raise ValueError( - "Unknown hessian strategy {}".format(ext.loss_hessian_strategy) + "Unknown hessian strategy {}".format(loss_hessian_strategy) ) - return hess_func(module, grad_inp, grad_out) - class DiagGGNMSELoss(DiagGGNLoss): def __init__(self): diff --git a/backpack/extensions/secondorder/diag_ggn/padding.py b/backpack/extensions/secondorder/diag_ggn/padding.py index 97df46a8..62af3a17 100644 --- a/backpack/extensions/secondorder/diag_ggn/padding.py +++ b/backpack/extensions/secondorder/diag_ggn/padding.py @@ -1,5 +1,5 @@ from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives -from .diag_ggn_base import DiagGGNBaseModule +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule class DiagGGNZeroPad2d(DiagGGNBaseModule): diff --git a/backpack/extensions/secondorder/diag_ggn/pooling.py b/backpack/extensions/secondorder/diag_ggn/pooling.py index cb50588d..5edd772a 100644 --- a/backpack/extensions/secondorder/diag_ggn/pooling.py +++ b/backpack/extensions/secondorder/diag_ggn/pooling.py @@ -1,6 +1,6 @@ from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives -from .diag_ggn_base import DiagGGNBaseModule +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule class DiagGGNMaxPool2d(DiagGGNBaseModule): diff --git a/backpack/extensions/secondorder/diag_hessian/__init__.py b/backpack/extensions/secondorder/diag_hessian/__init__.py index fb8f18ef..0a010a9e 100644 --- a/backpack/extensions/secondorder/diag_hessian/__init__.py +++ b/backpack/extensions/secondorder/diag_hessian/__init__.py @@ -1,7 +1,21 @@ +from torch.nn import ( + AvgPool2d, + Conv2d, + CrossEntropyLoss, + Dropout, + Flatten, + Linear, + MaxPool2d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + from backpack.extensions.backprop_extension import BackpropExtension -from backpack.core.layers import Conv2dConcat, LinearConcat, Flatten -from torch.nn import Linear, Conv2d, Dropout, MaxPool2d, Tanh, Sigmoid, ReLU, CrossEntropyLoss, MSELoss, AvgPool2d, ZeroPad2d -from . import pooling, conv2d, linear, activations, losses, padding, dropout, flatten + +from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling class DiagHessian(BackpropExtension): @@ -15,6 +29,7 @@ class DiagHessian(BackpropExtension): Very expensive on networks with non-piecewise linear activations. """ + def __init__(self): super().__init__( savefield="diag_h", @@ -23,16 +38,14 @@ def __init__(self): MSELoss: losses.DiagHMSELoss(), CrossEntropyLoss: losses.DiagHCrossEntropyLoss(), Linear: linear.DiagHLinear(), - LinearConcat: linear.DiagHLinearConcat(), MaxPool2d: pooling.DiagHMaxPool2d(), AvgPool2d: pooling.DiagHAvgPool2d(), ZeroPad2d: padding.DiagHZeroPad2d(), Conv2d: conv2d.DiagHConv2d(), - Conv2dConcat: conv2d.DiagHConv2dConcat(), Dropout: dropout.DiagHDropout(), Flatten: flatten.DiagHFlatten(), ReLU: activations.DiagHReLU(), Sigmoid: activations.DiagHSigmoid(), Tanh: activations.DiagHTanh(), - } + }, ) diff --git a/backpack/extensions/secondorder/diag_hessian/activations.py b/backpack/extensions/secondorder/diag_hessian/activations.py index caefbd64..aa476e3a 100644 --- a/backpack/extensions/secondorder/diag_hessian/activations.py +++ b/backpack/extensions/secondorder/diag_hessian/activations.py @@ -1,7 +1,7 @@ from backpack.core.derivatives.relu import ReLUDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives from backpack.core.derivatives.tanh import TanhDerivatives -from .diag_h_base import DiagHBaseModule +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHReLU(DiagHBaseModule): diff --git a/backpack/extensions/secondorder/diag_hessian/conv2d.py b/backpack/extensions/secondorder/diag_hessian/conv2d.py index c119c710..d0ecea92 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv2d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv2d.py @@ -1,9 +1,9 @@ import torch import torch.nn -from backpack.utils.utils import einsum + +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule from backpack.utils import conv as convUtils -from backpack.core.derivatives.conv2d import Conv2DDerivatives, Conv2DConcatDerivatives -from .diag_h_base import DiagHBaseModule class DiagHConv2d(DiagHBaseModule): @@ -13,42 +13,20 @@ def __init__(self): def bias(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] - h_diag = torch.zeros_like(module.bias) - for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_sqrt_view = convUtils.separate_channels_and_pixels( - module, h_sqrt) - h_diag.add_(sign * einsum('bijc,bikc->i', - (h_sqrt_view, h_sqrt_view))) - return h_diag - - def weight(self, ext, module, g_inp, g_out, backproped): - sqrt_h_outs = backproped["matrices"] - sqrt_h_outs_signs = backproped["signs"] - X = convUtils.unfold_func(module)(module.input0) - h_diag = torch.zeros_like(module.weight) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag_curr = convUtils.extract_weight_diagonal(module, X, h_sqrt) - h_diag.add_(sign * h_diag_curr.view_as(module.weight)) + h_diag_curr = convUtils.extract_bias_diagonal(module, h_sqrt) + h_diag.add_(sign * h_diag_curr) return h_diag - -class DiagHConv2dConcat(DiagHBaseModule): - def __init__(self): - super().__init__(derivatives=Conv2DConcatDerivatives(), params=["weight"]) - def weight(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] X = convUtils.unfold_func(module)(module.input0) - - if module.has_bias(): - X = module.append_ones(X) - h_diag = torch.zeros_like(module.weight) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): h_diag_curr = convUtils.extract_weight_diagonal(module, X, h_sqrt) - h_diag.add_(sign * h_diag_curr.view_as(module.weight)) + h_diag.add_(sign * h_diag_curr) return h_diag diff --git a/backpack/extensions/secondorder/diag_hessian/diag_h_base.py b/backpack/extensions/secondorder/diag_hessian/diag_h_base.py index 6b041986..db354937 100644 --- a/backpack/extensions/secondorder/diag_hessian/diag_h_base.py +++ b/backpack/extensions/secondorder/diag_hessian/diag_h_base.py @@ -1,10 +1,13 @@ -from backpack.extensions.mat_to_mat_jac_base import MatToJacMat +from numpy import prod from torch import clamp, diag_embed +from backpack.extensions.mat_to_mat_jac_base import MatToJacMat +from backpack.utils.ein import einsum + class DiagHBaseModule(MatToJacMat): - PLUS = 1. - MINUS = -1. + PLUS = 1.0 + MINUS = -1.0 def __init__(self, derivatives, params=None): super().__init__(derivatives=derivatives, params=params) @@ -15,24 +18,38 @@ def backpropagate(self, ext, module, g_inp, g_out, backproped): bp_matrices = super().backpropagate(ext, module, g_inp, g_out, bp_matrices) - for matrix, sign in self.local_curvatures(module, g_inp, g_out): + for matrix, sign in self.__local_curvatures(module, g_inp, g_out): bp_matrices.append(matrix) bp_signs.append(sign) - return { - "matrices": bp_matrices, - "signs": bp_signs - } + return {"matrices": bp_matrices, "signs": bp_signs} - def local_curvatures(self, module, g_inp, g_out): - if self.derivatives is None or self.derivatives.hessian_is_zero(): + def __local_curvatures(self, module, g_inp, g_out): + if self.derivatives.hessian_is_zero(): return [] if not self.derivatives.hessian_is_diagonal(): raise NotImplementedError - H = self.derivatives.hessian_diagonal(module, g_inp, g_out) + def positive_part(sign, H): + return clamp(sign * H, min=0) - for sign in [self.PLUS, self.MINUS]: - Hsign = clamp(sign * H, min=0, max=float('inf')).sqrt_() - yield((diag_embed(Hsign), sign)) + def diag_embed_multi_dim(H): + """Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,], + embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back + to [V, N, C_in, H_in, ..., V].""" + feature_shapes = H.shape[1:] + V, N = prod(feature_shapes), H.shape[0] + H_diag = diag_embed(H.view(N, V)) + # [V, N, C_in, H_in, ...] + shape = (V, N, *feature_shapes) + return einsum("nic->cni", H_diag).view(shape) + + def decompose_into_positive_and_negative_sqrt(H): + return [ + [diag_embed_multi_dim(positive_part(sign, H).sqrt_()), sign] + for sign in [self.PLUS, self.MINUS] + ] + + H = self.derivatives.hessian_diagonal(module, g_inp, g_out) + return decompose_into_positive_and_negative_sqrt(H) diff --git a/backpack/extensions/secondorder/diag_hessian/dropout.py b/backpack/extensions/secondorder/diag_hessian/dropout.py index 7cd9c0cf..eacab7dd 100644 --- a/backpack/extensions/secondorder/diag_hessian/dropout.py +++ b/backpack/extensions/secondorder/diag_hessian/dropout.py @@ -1,5 +1,5 @@ from backpack.core.derivatives.dropout import DropoutDerivatives -from .diag_h_base import DiagHBaseModule +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHDropout(DiagHBaseModule): diff --git a/backpack/extensions/secondorder/diag_hessian/flatten.py b/backpack/extensions/secondorder/diag_hessian/flatten.py index 60128369..d6d28b7c 100644 --- a/backpack/extensions/secondorder/diag_hessian/flatten.py +++ b/backpack/extensions/secondorder/diag_hessian/flatten.py @@ -1,9 +1,13 @@ -from .diag_h_base import DiagHBaseModule +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHFlatten(DiagHBaseModule): def __init__(self): - super().__init__(derivatives=None) + super().__init__(derivatives=FlattenDerivatives()) def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - return backproped \ No newline at end of file + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/diag_hessian/linear.py b/backpack/extensions/secondorder/diag_hessian/linear.py index 2baa9fb8..9b6a823e 100644 --- a/backpack/extensions/secondorder/diag_hessian/linear.py +++ b/backpack/extensions/secondorder/diag_hessian/linear.py @@ -1,49 +1,30 @@ import torch -import torch.nn -from backpack.utils.utils import einsum -from backpack.core.derivatives.linear import LinearDerivatives, LinearConcatDerivatives -from .diag_h_base import DiagHBaseModule + +import backpack.utils.linear as LinUtils +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHLinear(DiagHBaseModule): def __init__(self): super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) - # TODO: Reuse code in ..diaggn.linear to extract the diagonal def bias(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] - h_diag = torch.zeros_like(module.bias) - for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag.add_(sign * einsum('bic->i', (h_sqrt**2, ))) - return h_diag - - # TODO: Reuse code in ..diaggn.linear to extract the diagonal - def weight(self, ext, module, g_inp, g_out, backproped): - sqrt_h_outs = backproped["matrices"] - sqrt_h_outs_signs = backproped["signs"] - h_diag = torch.zeros_like(module.weight) for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag.add_(sign * einsum('bic,bj->ij', - (h_sqrt**2, module.input0**2))) + h_diag_curr = LinUtils.extract_bias_diagonal(module, h_sqrt) + h_diag.add_(sign * h_diag_curr) return h_diag - -class DiagHLinearConcat(DiagHBaseModule): - def __init__(self): - super().__init__(derivatives=LinearConcatDerivatives(), params=["weight"]) - - # TODO: Reuse code in ..diaggn.linear to extract the diagonal def weight(self, ext, module, g_inp, g_out, backproped): sqrt_h_outs = backproped["matrices"] sqrt_h_outs_signs = backproped["signs"] h_diag = torch.zeros_like(module.weight) - input = module.homogeneous_input() - for h_sqrt, sign in zip(sqrt_h_outs, sqrt_h_outs_signs): - h_diag.add_(sign * einsum('bic,bj->ij', (h_sqrt**2, input**2))) + h_diag_curr = LinUtils.extract_weight_diagonal(module, h_sqrt) + h_diag.add_(sign * h_diag_curr) return h_diag - diff --git a/backpack/extensions/secondorder/diag_hessian/losses.py b/backpack/extensions/secondorder/diag_hessian/losses.py index a15ec373..097730dd 100644 --- a/backpack/extensions/secondorder/diag_hessian/losses.py +++ b/backpack/extensions/secondorder/diag_hessian/losses.py @@ -1,15 +1,12 @@ -from backpack.core.derivatives.mseloss import MSELossDerivatives from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives -from .diag_h_base import DiagHBaseModule +from backpack.core.derivatives.mseloss import MSELossDerivatives +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHLoss(DiagHBaseModule): def backpropagate(self, ext, module, g_inp, g_out, backproped): sqrt_H = self.derivatives.sqrt_hessian(module, g_inp, g_out) - return { - "matrices": [sqrt_H], - "signs": [self.PLUS] - } + return {"matrices": [sqrt_H], "signs": [self.PLUS]} class DiagHMSELoss(DiagHLoss): diff --git a/backpack/extensions/secondorder/diag_hessian/padding.py b/backpack/extensions/secondorder/diag_hessian/padding.py index 6f7fa222..62db7383 100644 --- a/backpack/extensions/secondorder/diag_hessian/padding.py +++ b/backpack/extensions/secondorder/diag_hessian/padding.py @@ -1,5 +1,5 @@ from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives -from .diag_h_base import DiagHBaseModule +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHZeroPad2d(DiagHBaseModule): diff --git a/backpack/extensions/secondorder/diag_hessian/pooling.py b/backpack/extensions/secondorder/diag_hessian/pooling.py index 4cfb1bec..f9beea60 100644 --- a/backpack/extensions/secondorder/diag_hessian/pooling.py +++ b/backpack/extensions/secondorder/diag_hessian/pooling.py @@ -1,6 +1,6 @@ from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives -from .diag_h_base import DiagHBaseModule +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule class DiagHAvgPool2d(DiagHBaseModule): @@ -11,4 +11,3 @@ def __init__(self): class DiagHMaxPool2d(DiagHBaseModule): def __init__(self): super().__init__(derivatives=MaxPool2DDerivatives()) - diff --git a/backpack/extensions/secondorder/hbp/__init__.py b/backpack/extensions/secondorder/hbp/__init__.py index 33652335..5760974f 100644 --- a/backpack/extensions/secondorder/hbp/__init__.py +++ b/backpack/extensions/secondorder/hbp/__init__.py @@ -1,16 +1,38 @@ +from torch.nn import ( + AvgPool2d, + Conv2d, + CrossEntropyLoss, + Dropout, + Flatten, + Linear, + MaxPool2d, + MSELoss, + ReLU, + Sigmoid, + Tanh, + ZeroPad2d, +) + +from backpack.extensions.backprop_extension import BackpropExtension from backpack.extensions.curvature import Curvature from backpack.extensions.secondorder.hbp.hbp_options import ( - LossHessianStrategy, BackpropStrategy, ExpectationApproximation + BackpropStrategy, + ExpectationApproximation, + LossHessianStrategy, ) -from backpack.extensions.backprop_extension import BackpropExtension -from backpack.core.layers import Conv2dConcat, LinearConcat, Flatten -from torch.nn import Linear, Conv2d, Dropout, MaxPool2d, Tanh, Sigmoid, ReLU, CrossEntropyLoss, MSELoss, AvgPool2d, ZeroPad2d -from . import pooling, conv2d, linear, activations, losses, padding, dropout, flatten +from . import activations, conv2d, dropout, flatten, linear, losses, padding, pooling class HBP(BackpropExtension): - def __init__(self, curv_type, loss_hessian_strategy, backprop_strategy, ea_strategy, savefield="hbp"): + def __init__( + self, + curv_type, + loss_hessian_strategy, + backprop_strategy, + ea_strategy, + savefield="hbp", + ): self.curv_type = curv_type self.loss_hessian_strategy = loss_hessian_strategy self.backprop_strategy = backprop_strategy @@ -23,18 +45,16 @@ def __init__(self, curv_type, loss_hessian_strategy, backprop_strategy, ea_strat MSELoss: losses.HBPMSELoss(), CrossEntropyLoss: losses.HBPCrossEntropyLoss(), Linear: linear.HBPLinear(), - LinearConcat: linear.HBPLinearConcat(), MaxPool2d: pooling.HBPMaxpool2d(), AvgPool2d: pooling.HBPAvgPool2d(), ZeroPad2d: padding.HBPZeroPad2d(), Conv2d: conv2d.HBPConv2d(), - Conv2dConcat: conv2d.HBPConv2dConcat(), Dropout: dropout.HBPDropout(), Flatten: flatten.HBPFlatten(), ReLU: activations.HBPReLU(), Sigmoid: activations.HBPSigmoid(), Tanh: activations.HBPTanh(), - } + }, ) def get_curv_type(self): @@ -83,15 +103,20 @@ class KFAC(HBP): `_ by Roger Grosse and James Martens, 2016 """ - def __init__(self): + + def __init__(self, mc_samples=1): + self._mc_samples = mc_samples super().__init__( curv_type=Curvature.GGN, loss_hessian_strategy=LossHessianStrategy.SAMPLING, backprop_strategy=BackpropStrategy.SQRT, ea_strategy=ExpectationApproximation.BOTEV_MARTENS, - savefield="kfac" + savefield="kfac", ) + def get_num_mc_samples(self): + return self._mc_samples + class KFRA(HBP): """ @@ -127,10 +152,11 @@ class KFRA(HBP): `_ by Roger Grosse and James Martens, 2016 """ + def __init__(self): super().__init__( curv_type=Curvature.GGN, - loss_hessian_strategy=LossHessianStrategy.AVERAGE, + loss_hessian_strategy=LossHessianStrategy.SUM, backprop_strategy=BackpropStrategy.BATCH_AVERAGE, ea_strategy=ExpectationApproximation.BOTEV_MARTENS, savefield="kfra", diff --git a/backpack/extensions/secondorder/hbp/activations.py b/backpack/extensions/secondorder/hbp/activations.py index bb527ccf..d6d1dd10 100644 --- a/backpack/extensions/secondorder/hbp/activations.py +++ b/backpack/extensions/secondorder/hbp/activations.py @@ -1,7 +1,7 @@ from backpack.core.derivatives.relu import ReLUDerivatives -from backpack.core.derivatives.tanh import TanhDerivatives from backpack.core.derivatives.sigmoid import SigmoidDerivatives -from .hbpbase import HBPBaseModule +from backpack.core.derivatives.tanh import TanhDerivatives +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule class HBPReLU(HBPBaseModule): diff --git a/backpack/extensions/secondorder/hbp/conv2d.py b/backpack/extensions/secondorder/hbp/conv2d.py index 43b19a93..62cb5cc1 100644 --- a/backpack/extensions/secondorder/hbp/conv2d.py +++ b/backpack/extensions/secondorder/hbp/conv2d.py @@ -1,18 +1,16 @@ -import warnings - -from backpack.core.derivatives.conv2d import (Conv2DConcatDerivatives, - Conv2DDerivatives) +from backpack.core.derivatives.conv2d import Conv2DDerivatives +from backpack.extensions.secondorder.hbp.hbp_options import ( + BackpropStrategy, + ExpectationApproximation, +) +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule from backpack.utils import conv as convUtils -from backpack.utils.utils import einsum - -from .hbp_options import BackpropStrategy, ExpectationApproximation -from .hbpbase import HBPBaseModule +from backpack.utils.ein import einsum class HBPConv2d(HBPBaseModule): def __init__(self): - super().__init__(derivatives=Conv2DDerivatives(), - params=["weight", "bias"]) + super().__init__(derivatives=Conv2DDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): bp_strategy = ext.get_backprop_strategy() @@ -43,13 +41,14 @@ def _factors_from_input(self, ext, module): if ExpectationApproximation.should_average_param_jac(ea_strategy): raise NotImplementedError("Undefined") else: - yield einsum('bik,bjk->ij', (X, X)) / batch + yield einsum("bik,bjk->ij", (X, X)) / batch def _factor_from_sqrt(self, module, backproped): sqrt_ggn = backproped + sqrt_ggn = convUtils.separate_channels_and_pixels(module, sqrt_ggn) - sqrt_ggn = einsum('bijc->bic', (sqrt_ggn, )) - return einsum('bic,blc->il', (sqrt_ggn, sqrt_ggn)) + sqrt_ggn = einsum("cbij->cbi", (sqrt_ggn,)) + return einsum("cbi,cbl->il", (sqrt_ggn, sqrt_ggn)) def bias(self, ext, module, g_inp, g_out, backproped): bp_strategy = ext.get_backprop_strategy() @@ -70,46 +69,8 @@ def _factor_from_batch_average(self, module, backproped): _, out_c, out_x, out_y = module.output.size() out_pixels = out_x * out_y # sum over spatial coordinates - result = backproped.view(out_c, out_pixels, out_c, - out_pixels).sum([1, 3]) + result = backproped.view(out_c, out_pixels, out_c, out_pixels).sum([1, 3]) return result.contiguous() -class HBPConv2dConcat(HBPBaseModule): - def __init__(self): - super().__init__(derivatives=Conv2DConcatDerivatives(), - params=["weight"]) - - def weight(self, ext, module, g_inp, g_out, backproped): - bp_strategy = ext.get_backprop_strategy() - - if BackpropStrategy.is_batch_average(bp_strategy): - raise NotImplementedError - elif BackpropStrategy.is_sqrt(bp_strategy): - return self._weight_for_sqrt(ext, module, backproped) - - def _weight_for_sqrt(self, ext, module, backproped): - kron_factors = [self._factor_from_sqrt(module, backproped)] - kron_factors += self._factors_from_input(ext, module) - - return kron_factors - - def _factors_from_input(self, ext, module): - X = module.homogeneous_unfolded_input() - batch = X.size(0) - - ea_strategy = ext.get_ea_strategy() - - if ExpectationApproximation.should_average_param_jac(ea_strategy): - raise NotImplementedError - else: - yield einsum('bik,bjk->ij', (X, X)) / batch - - def _factor_from_sqrt(self, module, backproped): - sqrt_ggn = backproped - sqrt_ggn = convUtils.separate_channels_and_pixels(module, sqrt_ggn) - sqrt_ggn = einsum('bijc->bic', (sqrt_ggn, )) - return einsum('bic,blc->il', (sqrt_ggn, sqrt_ggn)) - - -EXTENSIONS = [HBPConv2d(), HBPConv2dConcat()] +EXTENSIONS = [HBPConv2d()] diff --git a/backpack/extensions/secondorder/hbp/dropout.py b/backpack/extensions/secondorder/hbp/dropout.py index cb1986a3..63131434 100644 --- a/backpack/extensions/secondorder/hbp/dropout.py +++ b/backpack/extensions/secondorder/hbp/dropout.py @@ -1,5 +1,5 @@ from backpack.core.derivatives.dropout import DropoutDerivatives -from .hbpbase import HBPBaseModule +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule class HBPDropout(HBPBaseModule): diff --git a/backpack/extensions/secondorder/hbp/flatten.py b/backpack/extensions/secondorder/hbp/flatten.py index eefcb6ba..990d0b02 100644 --- a/backpack/extensions/secondorder/hbp/flatten.py +++ b/backpack/extensions/secondorder/hbp/flatten.py @@ -1,9 +1,13 @@ -from .hbpbase import HBPBaseModule +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule class HBPFlatten(HBPBaseModule): def __init__(self): - super().__init__(derivatives=None) + super().__init__(derivatives=FlattenDerivatives()) - def backpropagate(self, ext, module, g_inp, g_out, backproped): - return backproped + def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/hbp/hbp_options.py b/backpack/extensions/secondorder/hbp/hbp_options.py index 843a8db7..9378b747 100644 --- a/backpack/extensions/secondorder/hbp/hbp_options.py +++ b/backpack/extensions/secondorder/hbp/hbp_options.py @@ -1,24 +1,24 @@ -class LossHessianStrategy(): +class LossHessianStrategy: EXACT = "exact" SAMPLING = "sampling" - AVERAGE = "average" + SUM = "sum" CHOICES = [ EXACT, SAMPLING, - AVERAGE, + SUM, ] @classmethod def check_exists(cls, strategy): - if not strategy in cls.CHOICES: + if strategy not in cls.CHOICES: raise AttributeError( - "Unknown loss Hessian strategy: {}. ".format(strategy) + - "Expecting one of {}".format(cls.CHOICES) + "Unknown loss Hessian strategy: {}. ".format(strategy) + + "Expecting one of {}".format(cls.CHOICES) ) -class BackpropStrategy(): +class BackpropStrategy: SQRT = "sqrt" BATCH_AVERAGE = "average" @@ -39,14 +39,14 @@ def is_sqrt(cls, strategy): @classmethod def check_exists(cls, strategy): - if not strategy in cls.CHOICES: + if strategy not in cls.CHOICES: raise AttributeError( - "Unknown backpropagation strategy: {}. ".format(strategy) + - "Expect {}".format(cls.CHOICES) + "Unknown backpropagation strategy: {}. ".format(strategy) + + "Expect {}".format(cls.CHOICES) ) -class ExpectationApproximation(): +class ExpectationApproximation: BOTEV_MARTENS = "E[J^T E(H) J]" CHEN = "E(J^T) E(H) E(J)" @@ -62,8 +62,8 @@ def should_average_param_jac(cls, strategy): @classmethod def check_exists(cls, strategy): - if not strategy in cls.CHOICES: + if strategy not in cls.CHOICES: raise AttributeError( - "Unknown EA strategy: {}. ".format(strategy) + - "Expect {}".format(cls.CHOICES) + "Unknown EA strategy: {}. ".format(strategy) + + "Expect {}".format(cls.CHOICES) ) diff --git a/backpack/extensions/secondorder/hbp/hbpbase.py b/backpack/extensions/secondorder/hbp/hbpbase.py index 4718c48e..6bf2647a 100644 --- a/backpack/extensions/secondorder/hbp/hbpbase.py +++ b/backpack/extensions/secondorder/hbp/hbpbase.py @@ -1,6 +1,6 @@ -from backpack.extensions.module_extension import ModuleExtension -from .hbp_options import BackpropStrategy from backpack.extensions.curvature import Curvature +from backpack.extensions.module_extension import ModuleExtension +from backpack.extensions.secondorder.hbp.hbp_options import BackpropStrategy class HBPBaseModule(ModuleExtension): @@ -17,19 +17,13 @@ def backpropagate(self, ext, module, g_inp, g_out, backproped): ) elif BackpropStrategy.is_sqrt(bp_strategy): - return self.backpropagate_sqrt( - ext, module, g_inp, g_out, backproped - ) + return self.backpropagate_sqrt(ext, module, g_inp, g_out, backproped) def backpropagate_sqrt(self, ext, module, g_inp, g_out, H): - return self.derivatives.jac_t_mat_prod( - module, g_inp, g_out, H - ) + return self.derivatives.jac_t_mat_prod(module, g_inp, g_out, H) def backpropagate_batch_average(self, ext, module, g_inp, g_out, H): - ggn = self.derivatives.ea_jac_t_mat_jac_prod( - module, g_inp, g_out, H - ) + ggn = self.derivatives.ea_jac_t_mat_jac_prod(module, g_inp, g_out, H) residual = self.second_order_module_effects(module, g_inp, g_out) residual_mod = Curvature.modify_residual(residual, ext.get_curv_type()) @@ -49,9 +43,7 @@ def second_order_module_effects(self, module, g_inp, g_out): ) else: - return self.derivatives.hessian_diagonal( - module, g_inp, g_out - ).sum(0) + return self.derivatives.hessian_diagonal(module, g_inp, g_out).sum(0) @staticmethod def add_diag_to_mat(diag, mat): diff --git a/backpack/extensions/secondorder/hbp/linear.py b/backpack/extensions/secondorder/hbp/linear.py index 249681eb..385111d7 100644 --- a/backpack/extensions/secondorder/hbp/linear.py +++ b/backpack/extensions/secondorder/hbp/linear.py @@ -1,15 +1,15 @@ -from backpack.core.derivatives.linear import LinearDerivatives, LinearConcatDerivatives -from backpack.utils.utils import einsum -from .hbpbase import HBPBaseModule -from .hbp_options import BackpropStrategy, ExpectationApproximation +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.hbp.hbp_options import ( + BackpropStrategy, + ExpectationApproximation, +) +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule +from backpack.utils.ein import einsum class HBPLinear(HBPBaseModule): def __init__(self): - super().__init__( - derivatives=LinearDerivatives(), - params=["weight", "bias"] - ) + super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) def weight(self, ext, module, g_inp, g_out, backproped): bp_strategy = ext.get_backprop_strategy() @@ -37,22 +37,18 @@ def _factors_from_input(self, ext, module): mean_input = self.__mean_input(module).unsqueeze(-1) return [mean_input, mean_input.transpose()] else: - yield self.__mean_input_outer(module) + return [self.__mean_input_outer(module)] def _factor_from_sqrt(self, backproped): - return [einsum('bic,bjc->ij', (backproped, backproped))] + return [einsum("vni,vnj->ij", (backproped, backproped))] def bias(self, ext, module, g_inp, g_out, backproped): bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): - return self._bias_for_batch_average( - backproped - ) + return self._bias_for_batch_average(backproped) elif BackpropStrategy.is_sqrt(bp_strategy): - return self._factor_from_sqrt( - backproped - ) + return self._factor_from_sqrt(backproped) def _bias_for_batch_average(self, backproped): return [backproped] @@ -63,53 +59,4 @@ def __mean_input(self, module): def __mean_input_outer(self, module): N, flat_input = self.derivatives.batch_flat(module.input0) - return einsum('bi,bj->ij', (flat_input, flat_input)) / N - - -class HBPLinearConcat(HBPBaseModule): - def __init__(self): - super().__init__( - derivatives=LinearConcatDerivatives(), - params=["weight"] - ) - - def weight(self, ext, module, g_inp, g_out, backproped): - bp_strategy = ext.get_backprop_strategy() - - if BackpropStrategy.is_batch_average(bp_strategy): - return self._weight_for_batch_average(ext, module, backproped) - elif BackpropStrategy.is_sqrt(bp_strategy): - return self._weight_for_sqrt(ext, module, backproped) - - def _weight_for_batch_average(self, ext, module, backproped): - kron_factors = self._bias_for_batch_average(backproped) - kron_factors += self._factors_from_input(ext, module) - return kron_factors - - def _weight_for_sqrt(self, ext, module, backproped): - kron_factors = self._factor_from_sqrt(backproped) - kron_factors += self._factors_from_input(ext, module) - return kron_factors - - def _factors_from_input(self, ext, module): - ea_strategy = ext.get_ea_strategy() - - if ExpectationApproximation.should_average_param_jac(ea_strategy): - mean_input = self.__mean_input(module).unsqueeze(-1) - return [mean_input, mean_input.transpose()] - else: - return [self.__mean_input_outer(module)] - - def _factor_from_sqrt(self, backproped): - return [einsum('bic,bjc->ij', (backproped, backproped))] - - def _bias_for_batch_average(self, backproped): - return [backproped] - - def __mean_input(self, module): - _, flat_input = self.derivatives.batch_flat(module.homogeneous_input()) - return flat_input.mean(0) - - def __mean_input_outer(self, module): - N, flat_input = self.derivatives.batch_flat(module.homogeneous_input()) - return einsum('bi,bj->ij', (flat_input, flat_input)) / N + return einsum("ni,nj->ij", (flat_input, flat_input)) / N diff --git a/backpack/extensions/secondorder/hbp/losses.py b/backpack/extensions/secondorder/hbp/losses.py index 2c92eb2c..9dad53b6 100644 --- a/backpack/extensions/secondorder/hbp/losses.py +++ b/backpack/extensions/secondorder/hbp/losses.py @@ -1,32 +1,37 @@ -from backpack.core.derivatives.mseloss import MSELossDerivatives +from functools import partial + from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives -from .hbp_options import LossHessianStrategy +from backpack.core.derivatives.mseloss import MSELossDerivatives from backpack.extensions.curvature import Curvature -from .hbpbase import HBPBaseModule +from backpack.extensions.secondorder.hbp.hbp_options import LossHessianStrategy +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule class HBPLoss(HBPBaseModule): - def __init__(self, derivatives, params=None): - super().__init__(derivatives=derivatives, params=params) - - self.LOSS_HESSIAN_GETTERS = { - LossHessianStrategy.EXACT: self.derivatives.sqrt_hessian, - LossHessianStrategy.SAMPLING: self.derivatives.sqrt_hessian_sampled, - LossHessianStrategy.AVERAGE: self.derivatives.sum_hessian, - } - def backpropagate(self, ext, module, g_inp, g_out, backproped): Curvature.check_loss_hessian( - self.derivatives.hessian_is_psd(), - curv_type=ext.get_curv_type() + self.derivatives.hessian_is_psd(), curv_type=ext.get_curv_type() ) - hessian_strategy = ext.get_loss_hessian_strategy() - H_func = self.LOSS_HESSIAN_GETTERS[hessian_strategy] + H_func = self.make_loss_hessian_func(ext) H_loss = H_func(module, g_inp, g_out) return H_loss + def make_loss_hessian_func(self, ext): + """Get function that produces the backpropagated quantity.""" + hessian_strategy = ext.get_loss_hessian_strategy() + + if hessian_strategy == LossHessianStrategy.EXACT: + return self.derivatives.sqrt_hessian + elif hessian_strategy == LossHessianStrategy.SAMPLING: + mc_samples = ext.get_num_mc_samples() + return partial(self.derivatives.sqrt_hessian_sampled, mc_samples=mc_samples) + elif hessian_strategy == LossHessianStrategy.SUM: + return self.derivatives.sum_hessian + else: + raise ValueError("Unknown Hessian strategy: {}".format(hessian_strategy)) + class HBPMSELoss(HBPLoss): def __init__(self): diff --git a/backpack/extensions/secondorder/hbp/padding.py b/backpack/extensions/secondorder/hbp/padding.py index 53d8a8c0..0af7a87a 100644 --- a/backpack/extensions/secondorder/hbp/padding.py +++ b/backpack/extensions/secondorder/hbp/padding.py @@ -1,5 +1,5 @@ from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives -from .hbpbase import HBPBaseModule +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule class HBPZeroPad2d(HBPBaseModule): diff --git a/backpack/extensions/secondorder/hbp/pooling.py b/backpack/extensions/secondorder/hbp/pooling.py index 085a6169..bda85833 100644 --- a/backpack/extensions/secondorder/hbp/pooling.py +++ b/backpack/extensions/secondorder/hbp/pooling.py @@ -1,6 +1,6 @@ from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives -from .hbpbase import HBPBaseModule +from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule class HBPAvgPool2d(HBPBaseModule): diff --git a/backpack/extensions/secondorder/utils.py b/backpack/extensions/secondorder/utils.py deleted file mode 100644 index dc967e22..00000000 --- a/backpack/extensions/secondorder/utils.py +++ /dev/null @@ -1,65 +0,0 @@ -from backpack.utils.utils import einsum - - -def matrix_from_kron_facs(factors): - assert all_tensors_of_order(order=2, tensors=factors) - mat = None - for factor in factors: - if mat is None: - mat = factor - else: - new_shape = (mat.shape[0] * factor.shape[0], - mat.shape[1] * factor.shape[1]) - mat = einsum('ij,kl->ikjl', - (mat, factor)).contiguous().view(new_shape) - return mat - - -def vp_from_kron_facs(factors): - """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """ - assert all_tensors_of_order(order=2, tensors=factors) - - shapes = [list(f.size()) for f in factors] - _, col_dims = zip(*shapes) - - num_factors = len(shapes) - equation = vp_einsum_equation(num_factors) - - def vp(v): - assert len(v.shape) == 1 - v_reshaped = v.view(col_dims) - return einsum(equation, v_reshaped, *factors).view(-1) - - return vp - - -def multiply_vec_with_kron_facs(factors, v): - """Return (A ⊗ B ⊗ ...) v for `factors = [A, B, ...]` - - All Kronecker factors have to be of order-2-tensors. - """ - vp = vp_from_kron_facs(factors) - return vp(v) - - -def vp_einsum_equation(num_factors): - letters = get_letters() - in_str, v_str, out_str = "", "", "" - - for _ in range(num_factors): - row_idx, col_idx = next(letters), next(letters) - - in_str += "," + row_idx + col_idx - v_str += col_idx - out_str += row_idx - - return "{}{}->{}".format(v_str, in_str, out_str) - - -def all_tensors_of_order(order, tensors): - return all([len(t.shape) == order for t in tensors]) - - -def get_letters(max_letters=26): - for i in range(max_letters): - yield chr(ord('a') + i) diff --git a/backpack/hessianfree/ggnvp.py b/backpack/hessianfree/ggnvp.py index b4ee3b7a..92c08363 100644 --- a/backpack/hessianfree/ggnvp.py +++ b/backpack/hessianfree/ggnvp.py @@ -1,6 +1,6 @@ +from .hvp import hessian_vector_product from .lop import L_op from .rop import R_op -from .hvp import hessian_vector_product def ggn_vector_product(loss, output, model, v): @@ -33,9 +33,7 @@ def ggn_vector_product(loss, output, model, v): v: [torch.Tensor] List of tensors matching the sizes of model.parameters() """ - return ggn_vector_product_from_plist( - loss, output, list(model.parameters()), v - ) + return ggn_vector_product_from_plist(loss, output, list(model.parameters()), v) def ggn_vector_product_from_plist(loss, output, plist, v): diff --git a/backpack/hessianfree/hvp.py b/backpack/hessianfree/hvp.py index 874e0df4..548013ad 100644 --- a/backpack/hessianfree/hvp.py +++ b/backpack/hessianfree/hvp.py @@ -3,7 +3,7 @@ from .rop import R_op -def hessian_vector_product(f, params, v, detach=True): +def hessian_vector_product(f, params, v, grad_params=None, detach=True): """ Multiplies the vector `v` with the Hessian, `v = H @ v` @@ -29,14 +29,24 @@ def hessian_vector_product(f, params, v, detach=True): params: torch.Tensor or [torch.Tensor] v: torch.Tensor or [torch.Tensor] Shapes must match `params` + grad_params: torch.Tensor or [torch.Tensor], optional + Gradient of `f` w.r.t. `params`. If the gradients have already + been computed elsewhere, the first of two backpropagations can + be saved. `grad_params` must have been computed with + `create_graph = True` to not destroy the computation graph for + the second backward pass. detach: Bool, optional Whether to detach the output from the computation graph (default: True) """ - df_dx = torch.autograd.grad(f, params, create_graph=True, retain_graph=True) + if grad_params is not None: + df_dx = tuple(grad_params) + else: + df_dx = torch.autograd.grad(f, params, create_graph=True, retain_graph=True) + Hv = R_op(df_dx, params, v) if detach: - return tuple([j.detach() for j in Hv]) + return tuple(j.detach() for j in Hv) else: return Hv diff --git a/backpack/hessianfree/lop.py b/backpack/hessianfree/lop.py index ea653272..60c0e044 100644 --- a/backpack/hessianfree/lop.py +++ b/backpack/hessianfree/lop.py @@ -12,10 +12,10 @@ def L_op(ys, xs, ws, retain_graph=True, detach=True): grad_outputs=ws, create_graph=True, retain_graph=retain_graph, - allow_unused=True + allow_unused=True, ) if detach: - return tuple([j.detach() for j in vJ]) + return tuple(j.detach() for j in vJ) else: return vJ diff --git a/backpack/hessianfree/rop.py b/backpack/hessianfree/rop.py index 9642a999..007b9e7e 100644 --- a/backpack/hessianfree/rop.py +++ b/backpack/hessianfree/rop.py @@ -17,20 +17,15 @@ def R_op(ys, xs, vs, retain_graph=True, detach=True): grad_outputs=ws, create_graph=True, retain_graph=retain_graph, - allow_unused=True + allow_unused=True, ) re = torch.autograd.grad( - gs, - ws, - grad_outputs=vs, - create_graph=True, - retain_graph=True, - allow_unused=True + gs, ws, grad_outputs=vs, create_graph=True, retain_graph=True, allow_unused=True ) if detach: - return tuple([j.detach() for j in re]) + return tuple(j.detach() for j in re) else: return re diff --git a/backpack/hessianfree/utils.py b/backpack/hessianfree/utils.py deleted file mode 100644 index 34020bda..00000000 --- a/backpack/hessianfree/utils.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch - - -def _check_param_device(param, old_param_device): - r"""This helper function is to check if the parameters are located - in the same device. Currently, the conversion between model parameters - and single vector form is not supported for multiple allocations, - e.g. parameters in different GPUs, or mixture of CPU/GPU. - - Arguments: - param ([Tensor]): a Tensor of a parameter of a model - old_param_device (int): the device where the first parameter of a - model is allocated. - - Returns: - old_param_device (int): report device for the first time - """ - - # Meet the first parameter - if old_param_device is None: - old_param_device = param.get_device() if param.is_cuda else -1 - else: - warn = False - if param.is_cuda: # Check if in same GPU - warn = (param.get_device() != old_param_device) - else: # Check if in CPU - warn = (old_param_device != -1) - if warn: - raise TypeError('Found two parameters on different devices, ' - 'this is currently not supported.') - return old_param_device - - -def vector_to_parameter_list(vec, parameters): - """ - Convert the vector `vec` to a parameter-list format matching `parameters`. - - Parameters: - ----------- - vec: Tensor - a single vector represents the parameters of a model - parameters: (Iterable[Tensor]) - an iterator of Tensors that are the parameters of a model. - """ - # Ensure vec of type Tensor - if not isinstance(vec, torch.Tensor): - raise TypeError('expected torch.Tensor, but got: {}'.format( - torch.typename(vec))) - # Flag for the device where the parameter is located - param_device = None - params_new = [] - # Pointer for slicing the vector for each parameter - pointer = 0 - for param in parameters: - # Ensure the parameters are located in the same device - param_device = _check_param_device(param, param_device) - - # The length of the parameter - num_param = param.numel() - # Slice the vector, reshape it, and replace the old data of the parameter - param_new = vec[pointer:pointer + num_param].view_as(param).data - params_new.append(param_new) - # Increment the pointer - pointer += num_param - - return list(params_new) diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index dd37ec74..e69de29b 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -1,9 +0,0 @@ -from torch.nn import Unfold - - -def unfold_func(module): - return Unfold( - kernel_size=module.kernel_size, - dilation=module.dilation, - padding=module.padding, - stride=module.stride) diff --git a/backpack/utils/conv.py b/backpack/utils/conv.py index 34d8b41b..c45e815c 100644 --- a/backpack/utils/conv.py +++ b/backpack/utils/conv.py @@ -1,57 +1,28 @@ from torch.nn import Unfold -from backpack.utils.utils import einsum + +from backpack.utils.ein import eingroup, einsum def unfold_func(module): - return Unfold(kernel_size=module.kernel_size, - dilation=module.dilation, - padding=module.padding, - stride=module.stride) + return Unfold( + kernel_size=module.kernel_size, + dilation=module.dilation, + padding=module.padding, + stride=module.stride, + ) + def get_weight_gradient_factors(input, grad_out, module): - batch = input.size(0) + # shape [N, C_in * K_x * K_y, H_out * W_out] X = unfold_func(module)(input) - dE_dY = grad_out.contiguous().view(batch, module.out_channels, -1) + dE_dY = eingroup("n,c,h,w->n,c,hw", grad_out) return X, dE_dY def separate_channels_and_pixels(module, tensor): - """Reshape (batch, out_features, classes) - into (batch, out_channels, pixels, classes). - """ - batch, channels, pixels, classes = ( - module.input0.size(0), - module.out_channels, - module.output_shape[2] * module.output_shape[3], - -1, - ) - return tensor.contiguous().view(batch, channels, pixels, classes) - - -def check_sizes_input_jac_t(mat, module): - batch, out_channels, out_x, out_y = module.output_shape - assert tuple(mat.size())[:2] == (batch, out_channels * out_x * out_y) - + """Reshape (V, N, C, H, W) into (V, N, C, H * W).""" + return eingroup("v,n,c,h,w->v,n,c,hw", tensor) -def check_sizes_input_jac(mat, module): - batch, in_channels, in_x, in_y = module.input0.size() - assert tuple(mat.size())[:2] == (batch, in_channels * in_x * in_y) - - -def check_sizes_output_jac_t(jtmp, module): - if tuple(jtmp.size())[1:] != tuple(module.input0.size())[1:]: - raise ValueError( - "Size after conv_transpose does not match", "Got {}, and {}.", - "Expected all dimensions to match, except for the first.".format( - jtmp.size(), module.input0.size())) - - -def check_sizes_output_jac(jmp, module): - if tuple(jmp.size())[1:] != tuple(module.output_shape)[1:]: - raise ValueError( - "Size after conv does not match", "Got {}, and {}.", - "Expected all dimensions to match, except for the first.".format( - jmp.size(), module.output_shape)) def extract_weight_diagonal(module, input, grad_output): """ @@ -59,5 +30,15 @@ def extract_weight_diagonal(module, input, grad_output): and grad_output the backpropagated gradient """ grad_output_viewed = separate_channels_and_pixels(module, grad_output) - AX = einsum('bkl,bmlc->cbkm', (input, grad_output_viewed)) - return (AX ** 2).sum([0, 1]).transpose(0, 1) + AX = einsum("nkl,vnml->vnkm", (input, grad_output_viewed)) + weight_diagonal = (AX ** 2).sum([0, 1]).transpose(0, 1) + return weight_diagonal.view_as(module.weight) + + +def extract_bias_diagonal(module, sqrt): + """ + `sqrt` must be the backpropagated quantity for DiagH or DiagGGN(MC) + """ + V_axis, N_axis = 0, 1 + bias_diagonal = (einsum("vnchw->vnc", sqrt) ** 2).sum([V_axis, N_axis]) + return bias_diagonal diff --git a/backpack/utils/convert_parameters.py b/backpack/utils/convert_parameters.py new file mode 100644 index 00000000..b3f91731 --- /dev/null +++ b/backpack/utils/convert_parameters.py @@ -0,0 +1,48 @@ +import torch + + +def vector_to_parameter_list(vec, parameters): + """ + Convert the vector `vec` to a parameter-list format matching `parameters`. + + This function is the inverse of `parameters_to_vector` from the + pytorch module `torch.nn.utils.convert_parameters`. + Contrary to `vector_to_parameters`, which replaces the value + of the parameters, this function leaves the parameters unchanged and + returns a list of parameter views of the vector. + + ``` + from torch.nn.utils import parameters_to_vector + + vector_view = parameters_to_vector(parameters) + param_list_view = vector_to_parameter_list(vec, parameters) + + for a, b in zip(parameters, param_list_view): + assert torch.all_close(a, b) + ``` + + Parameters: + ----------- + vec: Tensor + a single vector represents the parameters of a model + parameters: (Iterable[Tensor]) + an iterator of Tensors that are of the desired shapes. + """ + # Ensure vec of type Tensor + if not isinstance(vec, torch.Tensor): + raise TypeError( + "expected torch.Tensor, but got: {}".format(torch.typename(vec)) + ) + params_new = [] + # Pointer for slicing the vector for each parameter + pointer = 0 + for param in parameters: + # The length of the parameter + num_param = param.numel() + # Slice the vector, reshape it + param_new = vec[pointer : pointer + num_param].view_as(param).data + params_new.append(param_new) + # Increment the pointer + pointer += num_param + + return params_new diff --git a/backpack/utils/ein.py b/backpack/utils/ein.py new file mode 100644 index 00000000..bc9812fd --- /dev/null +++ b/backpack/utils/ein.py @@ -0,0 +1,180 @@ +""" +Einsum utility functions. + +Makes it easy to switch to opt_einsum rather than torch's einsum for tests. +""" + +import numpy as np +import opt_einsum as oe +import torch + +TORCH = "torch" +OPT_EINSUM = "opt_einsum" + +BPEXTS_EINSUM = "torch" + + +def _oe_einsum(equation, *operands): + # handle old interface, passing operands as one list + # see https://pytorch.org/docs/stable/_modules/torch/functional.html#einsum + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): + operands = operands[0] + return oe.contract(equation, *operands, backend="torch") + + +EINSUMS = { + TORCH: torch.einsum, + OPT_EINSUM: _oe_einsum, +} + + +def einsum(equation, *operands): + """`einsum` implementations used by `backpack`. + + Modify by setting `backpack.utils.utils.BPEXTS_EINSUM`. + See `backpack.utils.utils.EINSUMS` for supported implementations. + """ + return EINSUMS[BPEXTS_EINSUM](equation, *operands) + + +def eingroup(equation, operand, dim=None): + """Use einsum notation for (un-)grouping dimensions. + + Dimensions that cannot be inferred can be handed in via the + dictionary `dim`. + + Many operations in `backpack` require that certain axes of a tensor + be treated identically, and will therefore be grouped into a single + dimesion of the tensor. One way to do that is using `view`s or + `reshape`s. `eingroup` helps facilitate this process. It can be + used in the same way as `einsum`, but acts only on a single tensor at + a time (although this could be fixed with an improved syntax and + equation analysis). + + Idea: + ----- + * "a,b,c->ab,c": group dimension a and b into a single one + * "a,b,c->ba,c" to transpose, then group b and a dimension + + Raises: + ------- + `KeyError`: If information about a dimension in `dim` is missing + or can be removed. + `RuntimeError`: If the groups inferred from `equation` do not match + the number of axes of `operand` + + Example usage: + ``` + import torch + from backpack.utils.ein import einsum, eingroup + + dim_a, dim_b, dim_c, dim_d = torch.randint(low=1, high=10, size=(4,)) + tensor = torch.randn((dim_a, dim_b, dim_c, dim_d)) + + # 1) Transposition: Note the slightly different syntax for `eingroup` + tensor_trans = einsum("abcd->cbad", tensor) + tensor_trans_eingroup = eingroup("a,b,c,d->c,b,a,d", tensor) + assert torch.allclose(tensor_trans, tensor_trans_eingroup) + + # 2) Grouping axes (a,c) and (b,d) together + tensor_group = einsum("abcd->acbd", tensor).reshape((dim_a * dim_c, dim_b * dim_d)) + tensor_group_eingroup = eingroup("a,b,c,d->ac,bd", tensor) + assert torch.allclose(tensor_group, tensor_group_eingroup) + + # 3) Ungrouping a tensor whose axes where merged + tensor_merge = tensor.reshape(dim_a * dim_b, dim_c, dim_d) + tensor_unmerge = tensor.reshape(dim_a, dim_b, dim_c, dim_d) + assert torch.allclose(tensor_unmerge, tensor) + # eingroup needs to know the dimensions of the ungrouped dimension + tensor_unmerge_eingroup = eingroup( + "ab,c,d->a,b,c,d", tensor_merge, dim={"a": dim_a, "b": dim_b} + ) + assert torch.allclose(tensor_unmerge, tensor_unmerge_eingroup) + + # 4) `einsum` functionality to sum out dimensions + # sum over dim_c, group dim_a and dim_d + tensor_sum = einsum("abcd->adb", tensor).reshape(dim_a * dim_d, dim_b) + tensor_sum_eingroup = eingroup("a,b,c,d->ad,b", tensor) + assert torch.allclose(tensor_sum, tensor_sum_eingroup) + ``` + """ + + dim = {} if dim is None else dim + in_shape, out_shape, einsum_eq = _eingroup_preprocess(equation, operand, dim=dim) + + operand_in = try_view(operand, in_shape) + result = einsum(einsum_eq, operand_in) + return try_view(result, out_shape) + + +def _eingroup_preprocess(equation, operand, dim): + """Process `eingroup` equation. + + Return the `reshape`s and `einsum` equations that have to + be performed. + """ + split, sep = "->", "," + + def groups(string): + return string.split(sep) + + lhs, rhs = equation.split(split) + in_groups, out_groups = groups(lhs), groups(rhs) + + dim = __eingroup_infer(in_groups, operand, dim) + in_shape_flat, out_shape = __eingroup_shapes(in_groups, out_groups, dim) + + return in_shape_flat, out_shape, equation.replace(sep, "") + + +def __eingroup_shapes(in_groups, out_groups, dim): + """Return shape the input needs to be reshaped, and the output shape""" + + def shape(groups, dim): + return [group_dim(group, dim) for group in groups] + + def group_dim(group, dim): + try: + return np.prod([dim[g] for g in group]) + except KeyError as e: + raise KeyError("Unknown dimension for an axis {}".format(e)) + + out_shape = shape(out_groups, dim) + + in_groups_flat = [] + for group in in_groups: + for letter in group: + in_groups_flat.append(letter) + in_shape_flat = shape(in_groups_flat, dim) + + return in_shape_flat, out_shape + + +def __eingroup_infer(in_groups, operand, dim): + """Infer the size of each axis.""" + if not len(in_groups) == len(operand.shape): + raise RuntimeError( + "Got {} input groups {}, but tensor has {} axes.".format( + len(in_groups), in_groups, len(operand.shape) + ) + ) + + for group, size in zip(in_groups, operand.shape): + if len(group) == 1: + axis = group[0] + if axis in dim.keys(): + raise KeyError( + "Can infer dimension of axis {}.".format(axis), + "Remove from dim = {}.".format(dim), + ) + dim[axis] = size + + return dim + + +def try_view(tensor, shape): + """Fall back to reshape (more expensive) if viewing does not work.""" + try: + return tensor.view(shape) + except RuntimeError: + return tensor.reshape(shape) diff --git a/backpack/utils/examples.py b/backpack/utils/examples.py new file mode 100644 index 00000000..4fe4f84b --- /dev/null +++ b/backpack/utils/examples.py @@ -0,0 +1,30 @@ +"""Utility functions for examples.""" +import torch +import torchvision + + +def download_mnist(): + """Download and normalize MNIST training data.""" + mnist_dataset = torchvision.datasets.MNIST( + root="./data", + train=True, + transform=torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + download=True, + ) + return mnist_dataset + + +def load_mnist_data(batch_size=64, shuffle=True): + """Return (inputs, labels) for an MNIST mini-batch.""" + mnist_dataset = download_mnist() + mnist_loader = torch.utils.data.dataloader.DataLoader( + mnist_dataset, batch_size=batch_size, shuffle=shuffle, + ) + + X, y = next(iter(mnist_loader)) + return X, y diff --git a/backpack/utils/kroneckers.py b/backpack/utils/kroneckers.py new file mode 100644 index 00000000..b68d5495 --- /dev/null +++ b/backpack/utils/kroneckers.py @@ -0,0 +1,152 @@ +from backpack.utils.ein import einsum +from backpack.utils.unsqueeze import kfacmp_unsqueeze_if_missing_dim + + +def kfacs_to_mat(factors): + """Given [A, B, C, ...], return A ⊗ B ⊗ C ⊗ ... .""" + mat = None + for factor in factors: + if mat is None: + assert is_matrix(factor) + mat = factor + else: + mat = two_kfacs_to_mat(mat, factor) + + return mat + + +def two_kfacs_to_mat(A, B): + """Given A, B, return A ⊗ B.""" + assert is_matrix(A) + assert is_matrix(B) + + mat_shape = ( + A.shape[0] * B.shape[0], + A.shape[1] * B.shape[1], + ) + mat = einsum("ij,kl->ikjl", (A, B)).contiguous().view(mat_shape) + return mat + + +def kfac_mat_prod(factors): + """Return function v ↦ (A ⊗ B ⊗ ...)v for `factors = [A, B, ...]` """ + assert all_tensors_of_order(order=2, tensors=factors) + + shapes = [list(f.size()) for f in factors] + _, col_dims = zip(*shapes) + + num_factors = len(shapes) + equation = kfac_mat_prod_einsum_equation(num_factors) + + @kfacmp_unsqueeze_if_missing_dim(mat_dim=2) + def kfacmp(mat): + assert is_matrix(mat) + _, mat_cols = mat.shape + mat_reshaped = mat.view(*(col_dims), mat_cols) + return einsum(equation, mat_reshaped, *factors).contiguous().view(-1, mat_cols) + + return kfacmp + + +def apply_kfac_mat_prod(factors, mat): + """Return (A ⊗ B ⊗ ...) mat for `factors = [A, B, ...]` + + All Kronecker factors have to be matrices. + """ + kfacmp = kfac_mat_prod(factors) + return kfacmp(mat) + + +def inv_kfac_mat_prod(factors, shift=None): + """ Return function M ↦ [(A + 𝜆₁I)⁻¹ ⊗ (A + 𝜆₂I)⁻¹⊗ ...] M + given [A, B, ...], [𝜆₁, 𝜆₂, ...]. + """ + inv_factors = inv_kfacs(factors, shift) + return kfac_mat_prod(inv_factors) + + +def apply_inv_kfac_mat_prod(factors, mat, shift=None): + """Return [(A + 𝜆₁I)⁻¹ ⊗ (A + 𝜆₂I)⁻¹⊗ ...] M.""" + inv_mat_prod = inv_kfac_mat_prod(factors, shift) + return inv_mat_prod(mat) + + +def inv_kfacs(factors, shift=None): + """Given [A, B, ...], [𝜆₁, 𝜆₂, ...] Return [(A + 𝜆₁I)⁻¹, (A + 𝜆₂I)⁻¹, ...]. + + I denotes the identity matrix. All KFACs are assumed symmetric. + + Parameters: + ----------- + shift: list, tuple, float: + Diagonal shift of the eigenvalues. Per default, no shift is applied. + If float, the same shift is applied to all factors. + """ + + def make_shifts(): + """Turn user-specified shift into a value for each factor.""" + same = shift is None or isinstance(shift, float) + if same: + value = 0.0 if shift is None else shift + return [value for factor in factors] + else: + assert isinstance(shift, (tuple, list)) + assert len(factors) == len(shift) + return shift + + def sym_mat_inv(mat, shift, truncate=1e-8): + """Inverse of a symmetric matrix A -> (A + 𝜆I)⁻¹. + + Computed by eigenvalue decomposition. Eigenvalues with small + absolute values are truncated. + """ + eigvals, eigvecs = mat.symeig(eigenvectors=True) + eigvals.add_(shift) + inv_eigvals = 1.0 / eigvals + inv_truncate = 1.0 / truncate + inv_eigvals.clamp_(min=-inv_truncate, max=inv_truncate) + return einsum("ij,j,kj->ik", (eigvecs, inv_eigvals, eigvecs)) + + shifts = make_shifts() + return [sym_mat_inv(mat, shift) for mat, shift in zip(factors, shifts)] + + +def kfac_mat_prod_einsum_equation(num_factors): + letters = get_letters() + in_str, mat_str, out_str = "", "", "" + + for _ in range(num_factors): + row_idx, col_idx = next(letters), next(letters) + + in_str += "," + row_idx + col_idx + mat_str += col_idx + out_str += row_idx + + mat_col_idx = next(letters) + mat_str += mat_col_idx + out_str += mat_col_idx + + return "{}{}->{}".format(mat_str, in_str, out_str) + + +def all_tensors_of_order(order, tensors): + return all(is_tensor_of_order(order, t) for t in tensors) + + +def is_tensor_of_order(order, tensor): + return len(tensor.shape) == order + + +def is_matrix(tensor): + matrix_order = 2 + return is_tensor_of_order(matrix_order, tensor) + + +def is_vector(tensor): + vector_order = 1 + return is_tensor_of_order(vector_order, tensor) + + +def get_letters(max_letters=26): + for i in range(max_letters): + yield chr(ord("a") + i) diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py new file mode 100644 index 00000000..6794534a --- /dev/null +++ b/backpack/utils/linear.py @@ -0,0 +1,9 @@ +from backpack.utils.ein import einsum + + +def extract_weight_diagonal(module, backproped): + return einsum("vno,ni->oi", (backproped ** 2, module.input0 ** 2)) + + +def extract_bias_diagonal(module, backproped): + return einsum("vno->o", backproped ** 2) diff --git a/backpack/core/derivatives/utils.py b/backpack/utils/unsqueeze.py similarity index 55% rename from backpack/core/derivatives/utils.py rename to backpack/utils/unsqueeze.py index e9e5aef9..21a5a16b 100644 --- a/backpack/core/derivatives/utils.py +++ b/backpack/utils/unsqueeze.py @@ -6,12 +6,10 @@ def jmp_unsqueeze_if_missing_dim(mat_dim): def jmp_wrapper(jmp): @functools.wraps(jmp) - def wrapped_jmp_support_jvp(self, module, g_inp, g_out, mat, - **kwargs): - is_vec = (len(mat.shape) == mat_dim - 1) + def wrapped_jmp_support_jvp(self, module, g_inp, g_out, mat, **kwargs): + is_vec = len(mat.shape) == mat_dim - 1 mat_used = mat.unsqueeze(-1) if is_vec else mat - result = jmp(self, module, g_inp, g_out, mat_used, - **kwargs) + result = jmp(self, module, g_inp, g_out, mat_used, **kwargs) if is_vec: return result.squeeze(-1) else: @@ -28,7 +26,7 @@ def hmp_unsqueeze_if_missing_dim(mat_dim): def hmp_wrapper(hmp): @functools.wraps(hmp) def wrapped_hmp_support_hvp(mat): - is_vec = (len(mat.shape) == mat_dim - 1) + is_vec = len(mat.shape) == mat_dim - 1 mat_used = mat.unsqueeze(-1) if is_vec else mat result = hmp(mat_used) if is_vec: @@ -39,3 +37,24 @@ def wrapped_hmp_support_hvp(mat): return wrapped_hmp_support_hvp return hmp_wrapper + + +def kfacmp_unsqueeze_if_missing_dim(mat_dim): + """ + Allows Kronecker-factored matrix-matrix routines to do matrix-vector products. + """ + + def kfacmp_wrapper(kfacmp): + @functools.wraps(kfacmp) + def wrapped_kfacmp_support_kfacvp(mat): + is_vec = len(mat.shape) == mat_dim - 1 + mat_used = mat.unsqueeze(-1) if is_vec else mat + result = kfacmp(mat_used) + if is_vec: + return result.squeeze(-1) + else: + return result + + return wrapped_kfacmp_support_kfacvp + + return kfacmp_wrapper diff --git a/backpack/utils/utils.py b/backpack/utils/utils.py deleted file mode 100644 index 3b69a04b..00000000 --- a/backpack/utils/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Utility functions.""" - -import opt_einsum as oe -import torch - -TORCH = "torch" -OPT_EINSUM = "opt_einsum" - -BPEXTS_EINSUM = "torch" - - -def _oe_einsum(equation, *operands): - # handle old interface, passing operands as one list - # see https://pytorch.org/docs/stable/_modules/torch/functional.html#einsum - if len(operands) == 1 and isinstance(operands[0], (list, tuple)): - operands = operands[0] - return oe.contract(equation, *operands, backend='torch') - - -EINSUMS = { - TORCH: torch.einsum, - OPT_EINSUM: _oe_einsum, -} - - -def einsum(equation, *operands): - """`einsum` implementations used by `backpack`. - - Modify by setting `backpack.utils.utils.BPEXTS_EINSUM`. - See `backpack.utils.utils.EINSUMS` for supported implementations. - """ - return EINSUMS[BPEXTS_EINSUM](equation, *operands) - - -def random_psd_matrix(dim, device=None): - """Random positive semi-definite matrix on device.""" - if device is None: - device = torch.device("cpu") - - rand_mat = torch.randn(dim, dim, device=device) - rand_mat = 0.5 * (rand_mat + rand_mat.t()) - shift = dim * torch.eye(dim, device=device) - return rand_mat + shift diff --git a/black.toml b/black.toml new file mode 100644 index 00000000..3c9ec678 --- /dev/null +++ b/black.toml @@ -0,0 +1,18 @@ +[tool.black] +line-length = 88 +target-version = ['py35', 'py36', 'py37'] +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs + | \.git + | \.pytest_cache + | \.benchmarks + | docs_src + | docs + | build + | dist + )/ +) +''' diff --git a/changelog.md b/changelog.md new file mode 100644 index 00000000..2ac20681 --- /dev/null +++ b/changelog.md @@ -0,0 +1,50 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [1.1.0] - 2020-02-11 + +### Added +- Support MC sampling + [[Issue](https://github.com/f-dangel/backpack/issues/21), + [PR](https://github.com/f-dangel/backpack/pull/36)] +- Utilities to handle Kronecker factors + [[PR](https://github.com/f-dangel/backpack/pull/17)] +- Examples + [[PR](https://github.com/f-dangel/backpack/pull/34)] + +### Fixed +- Fixed documentation issue in `Batch l2` + [[PR](https://github.com/f-dangel/backpack/pull/33)] +- Added support for stride parameter in Conv2d + [[Issue](https://github.com/f-dangel/backpack/issues/30), + [PR](https://github.com/f-dangel/backpack/pull/31)] +- Pytorch `1.3.0` compatibility + [[PR](https://github.com/f-dangel/backpack/pull/8), + [PR](https://github.com/f-dangel/backpack/pull/9)] + +### Internal +- Added + continuous integration [[PR](https://github.com/f-dangel/backpack/pull/19)], + test coverage [[PR](https://github.com/f-dangel/backpack/pull/25)], + style guide enforcement [[PR](https://github.com/f-dangel/backpack/pull/27)] +- Changed internal shape conventions of backpropagated quantities for performance improvements + [[PR](https://github.com/f-dangel/backpack/pull/37)] + +## [1.0.1] - 2019-09-05 + +### Fixed +- Fixed PyPI installaton + +## [1.0.0] - 2019-10-03 + +Initial release + +[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.1.0...HEAD +[1.1.0]: https://github.com/f-dangel/backpack/compare/1.0.1...1.1.0 +[1.0.1]: https://github.com/f-dangel/backpack/compare/1.0.0...1.0.1 +[1.0.0]: https://github.com/f-dangel/backpack/releases/tag/1.0.0 diff --git a/docs_src/rtd/main-api.rst b/docs_src/rtd/main-api.rst index 21790cb0..9be34714 100644 --- a/docs_src/rtd/main-api.rst +++ b/docs_src/rtd/main-api.rst @@ -13,7 +13,7 @@ Extending the model and loss function import torch model = torch.nn.Sequential( - torch.nn.Linear(764, 64), + torch.nn.Linear(784, 64), torch.nn.ReLU(), torch.nn.Linear(64, 10) ) diff --git a/docs_src/rtd/supported-layers.rst b/docs_src/rtd/supported-layers.rst index 18de3f8f..394f0038 100644 --- a/docs_src/rtd/supported-layers.rst +++ b/docs_src/rtd/supported-layers.rst @@ -9,7 +9,7 @@ For example, .. code-block:: python model = torch.nn.Sequential( - torch.nn.Linear(764, 64), + torch.nn.Linear(784, 64), torch.nn.ReLU(), torch.nn.Linear(64, 10) ) diff --git a/docs_src/splash/_includes/code-samples.html b/docs_src/splash/_includes/code-samples.html index f98594f7..e5a98ba5 100644 --- a/docs_src/splash/_includes/code-samples.html +++ b/docs_src/splash/_includes/code-samples.html @@ -48,7 +48,7 @@ X, y = load_mnist_data() -model = Linear(764, 10) +model = Linear(784, 10) lossfunc = CrossEntropyLoss() loss = lossfunc(model(X), y) @@ -74,7 +74,7 @@ from backpack import extend, backpack, Variance X, y = load_mnist_data() -model = extend(Linear(764, 10)) +model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) @@ -83,7 +83,7 @@ for param in model.parameters(): print(param.grad) - print(param.var) + print(param.variance) @@ -95,19 +95,19 @@ """ from torch.nn import CrossEntropyLoss, Linear from utils import load_mnist_data -from backpack import extend, backpack, SecondMoment +from backpack import extend, backpack, SumGradSquared X, y = load_mnist_data() -model = extend(Linear(764, 10)) +model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) -with backpack(SecondMoment()): +with backpack(SumGradSquared()): loss.backward() for param in model.parameters(): print(param.grad) - print(param.secondMoment) + print(param.sum_grad_squared) @@ -118,19 +118,19 @@ """ from torch.nn import CrossEntropyLoss, Linear from utils import load_mnist_data -from backpack import extend, backpack, DiagGGN +from backpack import extend, backpack, DiagGGNExact X, y = load_mnist_data() -model = extend(Linear(764, 10)) +model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) -with backpack(DiagGGN()): +with backpack(DiagGGNExact()): loss.backward() for param in model.parameters(): print(param.grad) - print(param.diagggn) + print(param.diag_ggn_exact) @@ -144,7 +144,7 @@ from backpack import extend, backpack, KFAC X, y = load_mnist_data() -model = extend(Linear(764, 10)) +model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) @@ -153,7 +153,7 @@ for param in model.parameters(): print(param.grad) - print(param.kfac1, param.kfac2) + print(param.kfac) @@ -231,4 +231,4 @@ return false } } - \ No newline at end of file + diff --git a/examples/cheatsheet.pdf b/examples/cheatsheet.pdf new file mode 100644 index 00000000..3ded88b4 Binary files /dev/null and b/examples/cheatsheet.pdf differ diff --git a/examples/example_all_in_one.py b/examples/example_all_in_one.py new file mode 100644 index 00000000..64ca147a --- /dev/null +++ b/examples/example_all_in_one.py @@ -0,0 +1,71 @@ +""" +Compute the gradient with PyTorch and other quantities with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, other quantities with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack( + # individual gradients + extensions.BatchGrad(), + # gradient variance + extensions.Variance(), + # gradient 2nd moment + extensions.SumGradSquared(), + # individual gradient L2 norm + extensions.BatchL2Grad(), + # MC-sampled GGN diagonal + # number of samples optional (default: 1) + extensions.DiagGGNMC(mc_samples=1), + # Exact GGN diagonal + extensions.DiagGGNExact(), + # Exact Hessian diagonal + extensions.DiagHessian(), + # KFAC (Martens et al.) + # number of samples optional (default: 1) + extensions.KFAC(mc_samples=1), + # KFLR (Botev et al.) + extensions.KFLR(), + # KFRA (Botev et al.) + extensions.KFRA(), +): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + # individual gradients + print(".grad_batch.shape: ", param.grad_batch.shape) + # gradient variance + print(".variance.shape: ", param.variance.shape) + # gradient 2nd moment + print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape) + # individual gradient L2 norm + print(".batch_l2.shape: ", param.batch_l2.shape) + # MC-sampled GGN diagonal + print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) + # Exact GGN diagonal + print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) + # Exact Hessian diagonal + print(".diag_h.shape: ", param.diag_h.shape) + # KFAC (Martens et al.) + print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac]) + # KFLR (Botev et al.) + print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr]) + # KFRA (Botev et al.) + print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra]) diff --git a/examples/example_diag_ggn_exact.py b/examples/example_diag_ggn_exact.py new file mode 100644 index 00000000..b9db5e09 --- /dev/null +++ b/examples/example_diag_ggn_exact.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the exact GGN diagonal with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, exact GGN diagonal with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.DiagGGNExact()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) diff --git a/examples/example_diag_ggn_mc.py b/examples/example_diag_ggn_mc.py new file mode 100644 index 00000000..64688012 --- /dev/null +++ b/examples/example_diag_ggn_mc.py @@ -0,0 +1,30 @@ +""" +Compute the gradient with PyTorch and the MC-sampled GGN diagonal with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, MC-sampled GGN diagonal with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +# number of MC samples is optional, defaults to 1 +with backpack(extensions.DiagGGNMC(mc_samples=1)): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) diff --git a/examples/example_diag_ggn_optimizer.py b/examples/example_diag_ggn_optimizer.py index 24f812ca..feb59b56 100644 --- a/examples/example_diag_ggn_optimizer.py +++ b/examples/example_diag_ggn_optimizer.py @@ -18,13 +18,10 @@ """ import torch -import torchvision -# The main BackPACK functionalities + from backpack import backpack, extend -# The diagonal GGN extension from backpack.extensions import DiagGGNMC -# This layer did not exist in Pytorch 1.0 -from backpack.core.layers import Flatten +from backpack.utils.examples import download_mnist # Hyperparameters BATCH_SIZE = 64 @@ -41,20 +38,9 @@ and fit a 3-layer MLP with ReLU activations. """ - +mnist_dataset = download_mnist() mnist_loader = torch.utils.data.dataloader.DataLoader( - torchvision.datasets.MNIST( - './data', - train=True, - download=True, - transform=torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - (0.1307,), (0.3081,) - ) - ])), - batch_size=BATCH_SIZE, - shuffle=True + mnist_dataset, batch_size=BATCH_SIZE, shuffle=True ) model = torch.nn.Sequential( @@ -64,15 +50,15 @@ torch.nn.Conv2d(20, 50, 5, 1), torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), - Flatten(), - # Pytorch <1.2 doesn't have a Flatten layer - torch.nn.Linear(4*4*50, 500), + torch.nn.Flatten(), + torch.nn.Linear(4 * 4 * 50, 500), torch.nn.ReLU(), torch.nn.Linear(500, 10), ) loss_function = torch.nn.CrossEntropyLoss() + def get_accuracy(output, targets): """Helper function to print the accuracy""" predictions = output.argmax(dim=1, keepdim=True).view_as(targets) @@ -96,10 +82,7 @@ def get_accuracy(output, targets): class DiagGGNOptimizer(torch.optim.Optimizer): def __init__(self, parameters, step_size, damping): - super().__init__( - parameters, - dict(step_size=step_size, damping=damping) - ) + super().__init__(parameters, dict(step_size=step_size, damping=damping)) def step(self): for group in self.param_groups: @@ -109,7 +92,6 @@ def step(self): return loss - """ Step 3: Tell BackPACK about the model and loss function, create the optimizer, and we will be ready to go @@ -118,11 +100,7 @@ def step(self): extend(model) extend(loss_function) -optimizer = DiagGGNOptimizer( - model.parameters(), - step_size=STEP_SIZE, - damping=DAMPING -) +optimizer = DiagGGNOptimizer(model.parameters(), step_size=STEP_SIZE, damping=DAMPING) """ @@ -149,9 +127,10 @@ def step(self): optimizer.step() print( - "Iteration %3.d/%d " % (batch_idx, MAX_ITER) + - "Minibatch Loss %.3f " % (loss) + - "Accuracy %.0f" % (accuracy * 100) + "%" + "Iteration %3.d/%d " % (batch_idx, MAX_ITER) + + "Minibatch Loss %.3f " % (loss) + + "Accuracy %.0f" % (accuracy * 100) + + "%" ) if batch_idx >= MAX_ITER: diff --git a/examples/example_diag_hessian.py b/examples/example_diag_hessian.py new file mode 100644 index 00000000..981018f0 --- /dev/null +++ b/examples/example_diag_hessian.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the Hessian diagonal with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, Hessian diagonal with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.DiagHessian()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".diag_h.shape: ", param.diag_h.shape) diff --git a/examples/example_ggn_matrix.py b/examples/example_ggn_matrix.py new file mode 100644 index 00000000..c12eb4c2 --- /dev/null +++ b/examples/example_ggn_matrix.py @@ -0,0 +1,41 @@ +""" +Compute the full GGN matrix with automatic differentiation. +Use GGN-vector products for row-wise construction. +""" +import torch +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential +from torch.nn.utils.convert_parameters import parameters_to_vector + +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# GGN matrix with automatic differentiation | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +output = model(X) +loss = lossfunc(output, y) + +num_params = sum(p.numel() for p in model.parameters()) +ggn = torch.zeros(num_params, num_params) + +for i in range(num_params): + # GGN-vector product with i.th unit vector yields the i.th row + e_i = torch.zeros(num_params) + e_i[i] = 1.0 + + # convert to model parameter shapes + e_i_list = vector_to_parameter_list(e_i, model.parameters()) + ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) + + ggn_i = parameters_to_vector(ggn_i_list) + ggn[i, :] = ggn_i + +print("Model parameters: ", num_params) +print("GGN shape: ", ggn.shape) +print("GGN: ", ggn) diff --git a/examples/example_ggn_vector_product.py b/examples/example_ggn_vector_product.py new file mode 100644 index 00000000..f1101f1f --- /dev/null +++ b/examples/example_ggn_vector_product.py @@ -0,0 +1,59 @@ +""" +Compute the gradient and Hessian-vector products with PyTorch. +""" + +import torch +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential +from torch.nn.utils import parameters_to_vector + +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# GGN-vector product and gradients with PyTorch | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +print("# 1) Vector with shapes like parameters | B =", B) + +output = model(X) +loss = lossfunc(output, y) +v = [torch.randn_like(p) for p in model.parameters()] + +GGNv = ggn_vector_product(loss, output, model, v) + +# has to be called afterwards, or with create_graph=True +loss.backward() + +for (name, param), vec, GGNvec in zip(model.named_parameters(), v, GGNv): + print(name) + print(".grad.shape: ", param.grad.shape) + # vector + print("vector shape: ", vec.shape) + # Hessian-vector product + print("GGN-vector product shape: ", GGNvec.shape) + +print("# 2) Flattened vector | B =", B) + +output = model(X) +loss = lossfunc(output, y) + +num_params = sum(p.numel() for p in model.parameters()) +v_flat = torch.randn(num_params) + +v = vector_to_parameter_list(v_flat, model.parameters()) +GGNv = ggn_vector_product(loss, output, model, v) +GGNv_flat = parameters_to_vector(GGNv) + +# has to be called afterwards, or with create_graph=True +loss.backward() + +print("Model parameters: ", num_params) +# vector +print("flat vector shape: ", v_flat.shape) +# individual gradient L2 norm +print("flat GGN-vector product shape: ", GGNv_flat.shape) diff --git a/examples/example_grad.py b/examples/example_grad.py new file mode 100644 index 00000000..418e0333 --- /dev/null +++ b/examples/example_grad.py @@ -0,0 +1,20 @@ +"""Compute the gradient with PyTorch.""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +loss = lossfunc(model(X), y) +loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) diff --git a/examples/example_grad_2nd_moment.py b/examples/example_grad_2nd_moment.py new file mode 100644 index 00000000..0ea34106 --- /dev/null +++ b/examples/example_grad_2nd_moment.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the gradient 2nd moment with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, gradient 2nd moment with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.SumGradSquared()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape) diff --git a/examples/example_grad_l2.py b/examples/example_grad_l2.py new file mode 100644 index 00000000..bdcb4386 --- /dev/null +++ b/examples/example_grad_l2.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the individual gradients' L2 norms with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, individual gradients' L2 norms with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.BatchL2Grad()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".batch_l2.shape: ", param.batch_l2.shape) diff --git a/examples/example_grad_var.py b/examples/example_grad_var.py new file mode 100644 index 00000000..3e95b4c8 --- /dev/null +++ b/examples/example_grad_var.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the gradient variance with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, gradient variance with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.Variance()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".variance.shape: ", param.variance.shape) diff --git a/examples/example_hessian_matrix.py b/examples/example_hessian_matrix.py new file mode 100644 index 00000000..07cbb366 --- /dev/null +++ b/examples/example_hessian_matrix.py @@ -0,0 +1,78 @@ +""" +Compute the full Hessian matrix with automatic differentiation. +Use Hessian-vector products for row-wise construction. +""" +import time + +import torch +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential +from torch.nn.utils.convert_parameters import parameters_to_vector + +from backpack.hessianfree.hvp import hessian_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +print("# 1) Hessian matrix with automatic differentiation | B =", B) + +loss = lossfunc(model(X), y) + +num_params = sum(p.numel() for p in model.parameters()) +hessian = torch.zeros(num_params, num_params) + + +start = time.time() +for i in range(num_params): + # GGN-vector product with i.th unit vector yields the i.th row + e_i = torch.zeros(num_params) + e_i[i] = 1.0 + + # convert to model parameter shapes + e_i_list = vector_to_parameter_list(e_i, model.parameters()) + hessian_i_list = hessian_vector_product(loss, list(model.parameters()), e_i_list) + + hessian_i = parameters_to_vector(hessian_i_list) + hessian[i, :] = hessian_i +end = time.time() + +print("Model parameters: ", num_params) +print("Hessian shape: ", hessian.shape) +print("Hessian: ", hessian) +print("Time [s]: ", end - start) + +print("# 2) Hessian matrix with automatic differentiation (faster) | B =", B) +print("# Save one backpropagation for each HVP by recycling gradients") + +loss = lossfunc(model(X), y) +loss.backward(create_graph=True) + +grad_params = [p.grad for p in model.parameters()] + +num_params = sum(p.numel() for p in model.parameters()) +hessian = torch.zeros(num_params, num_params) + +start = time.time() +for i in range(num_params): + # GGN-vector product with i.th unit vector yields the i.th row + e_i = torch.zeros(num_params) + e_i[i] = 1.0 + + # convert to model parameter shapes + e_i_list = vector_to_parameter_list(e_i, model.parameters()) + hessian_i_list = hessian_vector_product( + loss, list(model.parameters()), e_i_list, grad_params=grad_params + ) + + hessian_i = parameters_to_vector(hessian_i_list) + hessian[i, :] = hessian_i +end = time.time() + +print("Model parameters: ", num_params) +print("Hessian shape: ", hessian.shape) +print("Hessian: ", hessian) +print("Time [s]: ", end - start) diff --git a/examples/example_hessian_vector_product.py b/examples/example_hessian_vector_product.py new file mode 100644 index 00000000..168c18ea --- /dev/null +++ b/examples/example_hessian_vector_product.py @@ -0,0 +1,78 @@ +""" +Compute the gradient and Hessian-vector products with PyTorch. +""" + +import torch +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential +from torch.nn.utils import parameters_to_vector + +from backpack.hessianfree.hvp import hessian_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Hessian-vector product and gradients with PyTorch | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +print("# 1) Vector with shapes like parameters | B =", B) + +loss = lossfunc(model(X), y) +v = [torch.randn_like(p) for p in model.parameters()] + +Hv = hessian_vector_product(loss, list(model.parameters()), v) + +# has to be called afterwards, or with create_graph=True +loss.backward() + +for (name, param), vec, Hvec in zip(model.named_parameters(), v, Hv): + print(name) + print(".grad.shape: ", param.grad.shape) + # vector + print("vector shape: ", vec.shape) + # Hessian-vector product + print("Hessian-vector product shape: ", Hvec.shape) + +print("# 2) Flattened vector | B =", B) + +loss = lossfunc(model(X), y) + +num_params = sum(p.numel() for p in model.parameters()) +v_flat = torch.randn(num_params) + +v = vector_to_parameter_list(v_flat, model.parameters()) +Hv = hessian_vector_product(loss, list(model.parameters()), v) +Hv_flat = parameters_to_vector(Hv) + +# has to be called afterwards, or with create_graph=True +loss.backward() + +print("Model parameters: ", num_params) +# vector +print("flat vector shape: ", v_flat.shape) +# individual gradient L2 norm +print("flat Hessian-vector product shape: ", Hv_flat.shape) + + +print("# 3) Using gradients to save one backward pass | B =", B) + +loss = lossfunc(model(X), y) +# has to be called with create_graph=True +loss.backward(create_graph=True) + +v = [torch.randn_like(p) for p in model.parameters()] +params = list(model.parameters()) +grad_params = [p.grad for p in params] +Hv = hessian_vector_product(loss, params, v, grad_params=grad_params) + + +for (name, param), vec, Hvec in zip(model.named_parameters(), v, Hv): + print(name) + print(".grad.shape: ", param.grad.shape) + # vector + print("vector shape: ", vec.shape) + # Hessian-vector product + print("Hessian-vector product shape: ", Hvec.shape) diff --git a/examples/example_indiv_grads.py b/examples/example_indiv_grads.py new file mode 100644 index 00000000..79239068 --- /dev/null +++ b/examples/example_indiv_grads.py @@ -0,0 +1,27 @@ +"""Compute the gradient with PyTorch and the variance with BackPACK.""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, individual gradients with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.BatchGrad()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".grad_batch.shape: ", param.grad_batch.shape) diff --git a/examples/example_indiv_grads_arbitrary_ops.py b/examples/example_indiv_grads_arbitrary_ops.py new file mode 100644 index 00000000..104a43d1 --- /dev/null +++ b/examples/example_indiv_grads_arbitrary_ops.py @@ -0,0 +1,15 @@ +"""Compute the gradient with PyTorch and the variance with BackPACK.""" +import torch +from torch.nn import Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions + +X = torch.randn(size=(50, 784), requires_grad=True) +model = Sequential(Flatten(), extend(Linear(784, 10)),) +loss = torch.mean(torch.sqrt(torch.abs(model(X)))) + +with backpack(extensions.BatchGrad()): + loss.backward() + +for name, param in model.named_parameters(): + print(name, param.grad_batch.shape) diff --git a/examples/example_kfac.py b/examples/example_kfac.py new file mode 100644 index 00000000..92e2c0fe --- /dev/null +++ b/examples/example_kfac.py @@ -0,0 +1,30 @@ +""" +Compute the gradient with PyTorch and the KFAC approximation with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, KFAC approximation with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +# number of MC samples is optional, defaults to 1 +with backpack(extensions.KFAC(mc_samples=1)): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac]) diff --git a/examples/example_kflr.py b/examples/example_kflr.py new file mode 100644 index 00000000..258e0240 --- /dev/null +++ b/examples/example_kflr.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the KFLR approximation with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, KFLR approximation with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.KFLR()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr]) diff --git a/examples/example_kfra.py b/examples/example_kfra.py new file mode 100644 index 00000000..3a51ef55 --- /dev/null +++ b/examples/example_kfra.py @@ -0,0 +1,29 @@ +""" +Compute the gradient with PyTorch and the KFRA approximation with BackPACK. +""" + +from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential + +from backpack import backpack, extend, extensions +from backpack.utils.examples import load_mnist_data + +B = 4 +X, y = load_mnist_data(B) + +print("# Gradient with PyTorch, KFRA approximation with BackPACK | B =", B) + +model = Sequential(Flatten(), Linear(784, 10),) +lossfunc = CrossEntropyLoss() + +model = extend(model) +lossfunc = extend(lossfunc) + +loss = lossfunc(model(X), y) + +with backpack(extensions.KFRA()): + loss.backward() + +for name, param in model.named_parameters(): + print(name) + print(".grad.shape: ", param.grad.shape) + print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra]) diff --git a/examples/run_examples.py b/examples/run_examples.py new file mode 100644 index 00000000..b51c51aa --- /dev/null +++ b/examples/run_examples.py @@ -0,0 +1,20 @@ +""" +Run all example files. +Example files are identified by the pattern 'example_*.py'. +""" +import glob +import os +import subprocess + +HERE = os.path.dirname(os.path.realpath(__file__)) +PATTERN = os.path.join(HERE, r"example_*.py") +FILES = glob.glob(PATTERN) + +for example in FILES: + print("\nRunning {}".format(example)) + + exit_code = subprocess.call(["python", example]) + crash = exit_code != 0 + + if crash: + raise RuntimeError("Error running {}".format(example)) diff --git a/makefile b/makefile new file mode 100644 index 00000000..6394b7d4 --- /dev/null +++ b/makefile @@ -0,0 +1,95 @@ +.PHONY: help +.PHONY: black black-check flake8 +.PHONY: install install-dev install-devtools install-test install-lint +.PHONY: test +.PHONY: conda-env +.PHONY: black isort format +.PHONY: black-check isort-check format-check +.PHONY: flake8 + +.DEFAULT: help +help: + @echo "test" + @echo " Run pytest on the project and report coverage" + @echo "black" + @echo " Run black on the project" + @echo "black-check" + @echo " Check if black would change files" + @echo "flake8" + @echo " Run flake8 on the project" + @echo "install" + @echo " Install backpack and dependencies" + @echo "install-dev" + @echo " Install all development tools" + @echo "install-lint" + @echo " Install only the linter tools (included in install-dev)" + @echo "install-test" + @echo " Install only the testing tools (included in install-dev)" + @echo "conda-env" + @echo " Create conda environment 'backpack' with dev setup" +### +# Test coverage +test: + @pytest -vx --cov=backpack . + +### +# Linter and autoformatter + +# Uses black.toml config instead of pyproject.toml to avoid pip issues. See +# - https://github.com/psf/black/issues/683 +# - https://github.com/pypa/pip/pull/6370 +# - https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support +black: + @black . --config=black.toml + +black-check: + @black . --config=black.toml --check + +flake8: + @flake8 . + +isort: + @isort --apply + +isort-check: + @isort --check + +format: + @make black + @make isort + @make black-check + +format-check: black-check isort-check + + +### +# Installation + +install: + @pip install -r requirements.txt + @pip install . + +install-lint: + @pip install -r requirements/lint.txt + +install-test: + @pip install -r requirements/test.txt + +install-devtools: + @echo "Install dev tools..." + @pip install -r requirements-dev.txt + +install-dev: install-devtools + @echo "Install dependencies..." + @pip install -r requirements.txt + @echo "Uninstall existing version of backpack..." + @pip uninstall backpack-for-pytorch + @echo "Install backpack in editable mode..." + @pip install -e . + @echo "Install pre-commit hooks..." + @pre-commit install + +### +# Conda environment +conda-env: + @conda env create --file .conda_env.yml diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..ab0f1785 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +-r requirements/test.txt +-r requirements/lint.txt +pre-commit diff --git a/requirements.txt b/requirements.txt index 15446d72..d4129fd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ opt-einsum >= 2.3.2, < 3.0.0 -pytest >= 4.0.1, < 5.0.0 -pytest-benchmark >= 3.2.2, < 4.0.0 -torch >= 1.1.0, < 2.0.0 +torch >= 1.3.0, < 2.0.0 torchvision >= 0.3.0, < 1.0.0 diff --git a/requirements/lint.txt b/requirements/lint.txt new file mode 100644 index 00000000..6fc5c8cd --- /dev/null +++ b/requirements/lint.txt @@ -0,0 +1,10 @@ +flake8 +mccabe +pycodestyle +pyflakes +pep8-naming +flake8-bugbear +flake8-comprehensions +flake8-tidy-imports +black +isort \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 00000000..e0e09ff8 --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,7 @@ +scipy +pytest >= 4.5.0, < 5.0.0 +pytest-benchmark >= 3.2.2, < 4.0.0 +pytest-optional-tests >= 0.1.1 +pytest-cov +coveralls + diff --git a/setup.py b/setup.py index ff7d1480..397d9c37 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,13 @@ NAME = "backpack-for-pytorch" PACKAGES = find_packages() -DESCRIPTION = r"""BACKpropagation PACKage - A backpack for PyTorch to compute quantities beyond the gradient.""" -LONG_DESCR = "https://github.com/f-dangel/backpack" +DESCRIPTION = r"""BACKpropagation PACKage""" +LONG_DESCR = """ + A backpack for PyTorch to compute quantities beyond the gradient. + https://github.com/f-dangel/backpack + """ -VERSION = "1.0.1" +VERSION = "1.1.0" URL = "https://github.com/f-dangel/backpack" LICENSE = "MIT" @@ -23,15 +26,17 @@ with open(REQUIREMENTS_FILE) as f: requirements = f.read().splitlines() -setup(author=AUTHORS, - name=NAME, - version=VERSION, - description=DESCRIPTION, - long_description=LONG_DESCR, - long_description_content_type="text/markdown", - install_requires=requirements, - url=URL, - license=LICENSE, - packages=PACKAGES, - zip_safe=False, - python_requires='>=3.5') +setup( + author=AUTHORS, + name=NAME, + version=VERSION, + description=DESCRIPTION, + long_description=LONG_DESCR, + long_description_content_type="text/markdown", + install_requires=requirements, + url=URL, + license=LICENSE, + packages=PACKAGES, + zip_safe=False, + python_requires=">=3.5", +) diff --git a/test/automated_bn_test.py b/test/automated_bn_test.py index 4028bc15..58e5e95d 100644 --- a/test/automated_bn_test.py +++ b/test/automated_bn_test.py @@ -1,9 +1,8 @@ """ TODO: Implement all features for BN, then add to automated tests. """ import pytest - import torch -from .automated_test import atol, check_sizes, check_values, rtol +from .automated_test import check_sizes, check_values from .implementation.implementation_autograd import AutogradImpl from .implementation.implementation_bpext import BpextImpl from .test_problems_bn import TEST_PROBLEMS as BN_TEST_PROBLEMS @@ -26,35 +25,31 @@ CONFIGURATION_IDS = [] for dev_name, dev in DEVICES.items(): for probname, prob in TEST_PROBLEMS.items(): - ALL_CONFIGURATIONS.append(tuple([prob, dev])) + ALL_CONFIGURATIONS.append((prob, dev)) CONFIGURATION_IDS.append(probname + "-" + dev_name) ### # Tests ### -@pytest.mark.parametrize("problem,device", - ALL_CONFIGURATIONS, - ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_batch_gradients_sum_to_grad(problem, device): problem.to(device) backpack_batch_res = BpextImpl(problem).batch_gradients() - backpack_res = list([g.sum(0) for g in backpack_batch_res]) + backpack_res = [g.sum(0) for g in backpack_batch_res] autograd_res = AutogradImpl(problem).gradient() check_sizes(autograd_res, backpack_res, list(problem.model.parameters())) check_values(autograd_res, backpack_res) -@pytest.mark.parametrize("problem,device", - ALL_CONFIGURATIONS, - ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_ggn_mp(problem, device): problem.to(device) NUM_COLS = 10 matrices = [ - torch.randn(p.numel(), NUM_COLS, device=device) + torch.randn(NUM_COLS, *p.shape, device=device) for p in problem.model.parameters() ] @@ -65,19 +60,56 @@ def test_ggn_mp(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize("problem,device", - ALL_CONFIGURATIONS, - ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_ggn_vp(problem, device): problem.to(device) - vecs = [ - torch.randn(p.numel(), device=device) - for p in problem.model.parameters() - ] + vecs = [torch.randn(*p.shape, device=device) for p in problem.model.parameters()] backpack_res = BpextImpl(problem).ggn_vp(vecs) autograd_res = AutogradImpl(problem).ggn_vp(vecs) check_sizes(autograd_res, backpack_res) check_values(autograd_res, backpack_res) + + +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +def test_hvp_is_not_implemented(problem, device): + # TODO: Rename after implementing BatchNorm R_mat_prod + problem.to(device) + + vecs = [torch.randn(*p.shape, device=device) for p in problem.model.parameters()] + + # TODO: Implement BatchNorm R_mat_prod in backpack/core/derivatives/batchnorm1d.py + try: + backpack_res = BpextImpl(problem).hvp(vecs) + except NotImplementedError: + return + + autograd_res = AutogradImpl(problem).hvp(vecs) + + check_sizes(autograd_res, backpack_res) + check_values(autograd_res, backpack_res) + + +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +def test_hmp_is_not_implemented(problem, device): + # TODO: Rename after implementing BatchNorm R_mat_prod + problem.to(device) + + NUM_COLS = 10 + matrices = [ + torch.randn(NUM_COLS, *p.shape, device=device) + for p in problem.model.parameters() + ] + + # TODO: Implement BatchNorm R_mat_prod in backpack/core/derivatives/batchnorm1d.py + try: + backpack_res = BpextImpl(problem).hmp(matrices) + except NotImplementedError: + return + + autograd_res = AutogradImpl(problem).hmp(matrices) + + check_sizes(autograd_res, backpack_res) + check_values(autograd_res, backpack_res) diff --git a/test/automated_kfac_test.py b/test/automated_kfac_test.py index 4fd42430..ace28402 100644 --- a/test/automated_kfac_test.py +++ b/test/automated_kfac_test.py @@ -1,11 +1,11 @@ -import torch import pytest -from .test_problems_kfacs import TEST_PROBLEMS as BATCH1_PROBLEMS -from .test_problems_kfacs import REGRESSION_PROBLEMS as BATCH1_REGRESSION_PROBLEMS -from .implementation.implementation_autograd import AutogradImpl -from .implementation.implementation_bpext import BpextImpl +import torch from .automated_test import check_sizes, check_values +from .implementation.implementation_autograd import AutogradImpl +from .implementation.implementation_bpext import BpextImpl +from .test_problems_kfacs import REGRESSION_PROBLEMS as BATCH1_REGRESSION_PROBLEMS +from .test_problems_kfacs import TEST_PROBLEMS as BATCH1_PROBLEMS if torch.cuda.is_available(): DEVICES = { @@ -25,7 +25,7 @@ CONFIGURATION_IDS = [] for dev_name, dev in DEVICES.items(): for probname, prob in BATCH1_TEST_PROBLEMS.items(): - BATCH1_CONFIGURATIONS.append(tuple([prob, dev])) + BATCH1_CONFIGURATIONS.append((prob, dev)) CONFIGURATION_IDS.append(probname + "-" + dev_name) BATCH1_TEST_REGRESSION_PROBLEMS = { @@ -36,15 +36,14 @@ REGRESSION_CONFIGURATION_IDS = [] for dev_name, dev in DEVICES.items(): for probname, prob in BATCH1_TEST_REGRESSION_PROBLEMS.items(): - BATCH1_REGRESSION_CONFIGURATIONS.append(tuple([prob, dev])) + BATCH1_REGRESSION_CONFIGURATIONS.append((prob, dev)) REGRESSION_CONFIGURATION_IDS.append(probname + "-" + dev_name) ### # Tests ### -@pytest.mark.parametrize( - "problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_kfra_should_equal_ggn(problem, device): problem.to(device) @@ -55,8 +54,7 @@ def test_kfra_should_equal_ggn(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_kflr_should_equal_ggn(problem, device): problem.to(device) @@ -67,8 +65,7 @@ def test_kflr_should_equal_ggn(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_hbp_ggn_mode_should_equal_ggn(problem, device): problem.to(device) @@ -79,8 +76,7 @@ def test_hbp_ggn_mode_should_equal_ggn(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_hbp_h_mode_should_equal_h(problem, device): problem.to(device) @@ -92,9 +88,8 @@ def test_hbp_h_mode_should_equal_h(problem, device): @pytest.mark.parametrize( - "problem,device", - BATCH1_REGRESSION_CONFIGURATIONS, - ids=REGRESSION_CONFIGURATION_IDS) + "problem,device", BATCH1_REGRESSION_CONFIGURATIONS, ids=REGRESSION_CONFIGURATION_IDS +) def test_kfac_regression_should_equal_ggn(problem, device): problem.to(device) @@ -103,3 +98,28 @@ def test_kfac_regression_should_equal_ggn(problem, device): check_sizes(autograd_res, backpack_res) check_values(autograd_res, backpack_res) + + +@pytest.mark.montecarlo +@pytest.mark.parametrize("problem,device", BATCH1_CONFIGURATIONS, ids=CONFIGURATION_IDS) +def test_kfac_should_approx_ggn_montecarlo(problem, device): + problem.to(device) + + torch.manual_seed(0) + autograd_res = AutogradImpl(problem).ggn_blocks() + + backpack_average_res = [] + for param_res in autograd_res: + backpack_average_res.append(torch.zeros_like(param_res)) + + mc_samples = 200 + for _ in range(mc_samples): + backpack_res = BpextImpl(problem).kfac_blocks() + for i, param_res in enumerate(backpack_res): + backpack_average_res[i] += param_res + + for i in range(len(backpack_average_res)): + backpack_average_res[i] /= mc_samples + + check_sizes(autograd_res, backpack_average_res) + check_values(autograd_res, backpack_average_res, atol=1e-1, rtol=1e-1) diff --git a/test/automated_test.py b/test/automated_test.py index e10c97b9..676db9a5 100644 --- a/test/automated_test.py +++ b/test/automated_test.py @@ -1,13 +1,14 @@ -import torch import numpy as np import pytest +import torch + +from .implementation.implementation_autograd import AutogradImpl +from .implementation.implementation_bpext import BpextImpl +from .test_problems_activations import TEST_PROBLEMS as ACT_TEST_PROBLEMS from .test_problems_convolutions import TEST_PROBLEMS as CONV_TEST_PROBLEMS from .test_problems_linear import TEST_PROBLEMS as LIN_TEST_PROBLEMS -from .test_problems_activations import TEST_PROBLEMS as ACT_TEST_PROBLEMS -from .test_problems_pooling import TEST_PROBLEMS as POOL_TEST_PROBLEMS from .test_problems_padding import TEST_PROBLEMS as PAD_TEST_PROBLEMS -from .implementation.implementation_autograd import AutogradImpl -from .implementation.implementation_bpext import BpextImpl +from .test_problems_pooling import TEST_PROBLEMS as POOL_TEST_PROBLEMS if torch.cuda.is_available(): DEVICES = { @@ -30,7 +31,7 @@ CONFIGURATION_IDS = [] for dev_name, dev in DEVICES.items(): for probname, prob in TEST_PROBLEMS.items(): - ALL_CONFIGURATIONS.append(tuple([prob, dev])) + ALL_CONFIGURATIONS.append((prob, dev)) CONFIGURATION_IDS.append(probname + "-" + dev_name) atol = 1e-5 @@ -49,7 +50,7 @@ def report_nonclose_values(x, y): where_not_close = np.argwhere(np.logical_not(close)) for idx in where_not_close: x, y = x_numpy[idx], y_numpy[idx] - print('{} versus {}. Ratio of {}'.format(x, y, y / x)) + print("{} versus {}. Ratio of {}".format(x, y, y / x)) def check_sizes(*plists): @@ -61,7 +62,7 @@ def check_sizes(*plists): assert params[i].size() == params[i + 1].size() -def check_values(list1, list2): +def check_values(list1, list2, atol=atol, rtol=rtol): for i, (g1, g2) in enumerate(zip(list1, list2)): print(i) print(g1.size()) @@ -74,8 +75,7 @@ def check_values(list1, list2): ### -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_batch_gradients(problem, device): problem.to(device) backpack_res = BpextImpl(problem).batch_gradients() @@ -85,20 +85,18 @@ def test_batch_gradients(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_batch_gradients_sum_to_grad(problem, device): problem.to(device) backpack_batch_res = BpextImpl(problem).batch_gradients() - backpack_res = list([g.sum(0) for g in backpack_batch_res]) + backpack_res = [g.sum(0) for g in backpack_batch_res] autograd_res = AutogradImpl(problem).gradient() check_sizes(autograd_res, backpack_res, list(problem.model.parameters())) check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_sgs(problem, device): problem.to(device) autograd_res = AutogradImpl(problem).sgs() @@ -108,8 +106,7 @@ def test_sgs(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_diag_ggn(problem, device): problem.to(device) @@ -120,8 +117,32 @@ def test_diag_ggn(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.montecarlo +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +def test_diag_ggn_mc_approx_ggn_montecarlo(problem, device): + problem.to(device) + + torch.manual_seed(0) + bp_diagggn = BpextImpl(problem).diag_ggn() + + bp_diagggn_mc_avg = [] + for param_res in bp_diagggn: + bp_diagggn_mc_avg.append(torch.zeros_like(param_res)) + + mc_samples = 500 + for _ in range(mc_samples): + bp_diagggn_mc = BpextImpl(problem).diag_ggn_mc() + for i, param_res in enumerate(bp_diagggn_mc): + bp_diagggn_mc_avg[i] += param_res + + for i in range(len(bp_diagggn_mc_avg)): + bp_diagggn_mc_avg[i] /= mc_samples + + check_sizes(bp_diagggn, bp_diagggn_mc_avg) + check_values(bp_diagggn, bp_diagggn_mc_avg, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_batch_l2(problem, device): problem.to(device) @@ -132,8 +153,7 @@ def test_batch_l2(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_variance(problem, device): problem.to(device) @@ -144,8 +164,7 @@ def test_variance(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_diag_h(problem, device): problem.to(device) @@ -156,14 +175,13 @@ def test_diag_h(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_hmp(problem, device): problem.to(device) NUM_COLS = 10 matrices = [ - torch.randn(p.numel(), NUM_COLS, device=device) + torch.randn(NUM_COLS, *p.shape, device=device) for p in problem.model.parameters() ] @@ -174,14 +192,13 @@ def test_hmp(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_ggn_mp(problem, device): problem.to(device) NUM_COLS = 10 matrices = [ - torch.randn(p.numel(), NUM_COLS, device=device) + torch.randn(NUM_COLS, *p.shape, device=device) for p in problem.model.parameters() ] @@ -192,15 +209,11 @@ def test_ggn_mp(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_hvp(problem, device): problem.to(device) - vecs = [ - torch.randn(p.numel(), device=device) - for p in problem.model.parameters() - ] + vecs = [torch.randn(*p.shape, device=device) for p in problem.model.parameters()] backpack_res = BpextImpl(problem).hvp(vecs) autograd_res = AutogradImpl(problem).hvp(vecs) @@ -209,15 +222,11 @@ def test_hvp(problem, device): check_values(autograd_res, backpack_res) -@pytest.mark.parametrize( - "problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,device", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_ggn_vp(problem, device): problem.to(device) - vecs = [ - torch.randn(p.numel(), device=device) - for p in problem.model.parameters() - ] + vecs = [torch.randn(*p.shape, device=device) for p in problem.model.parameters()] backpack_res = BpextImpl(problem).ggn_vp(vecs) autograd_res = AutogradImpl(problem).ggn_vp(vecs) diff --git a/test/benchmark/functionality.py b/test/benchmark/functionality.py index 8e81f203..d4a90ba9 100644 --- a/test/benchmark/functionality.py +++ b/test/benchmark/functionality.py @@ -1,9 +1,11 @@ -import torch import pytest +import torch + from backpack import extend -from ..test_problem import TestProblem + from ..implementation.implementation_autograd import AutogradImpl from ..implementation.implementation_bpext import BpextImpl +from ..test_problem import TestProblem def make_large_linear_classification_problem(): @@ -19,7 +21,7 @@ def make_large_linear_classification_problem(): ) N = 128 X = torch.randn(size=(N, Ds[0])) - Y = torch.randint(high=Ds[-1], size=(N, )) + Y = torch.randint(high=Ds[-1], size=(N,)) lossfunc = extend(torch.nn.CrossEntropyLoss()) return TestProblem(X, Y, model, lossfunc) @@ -33,7 +35,7 @@ def make_smallest_linear_classification_problem(): ) N = 16 X = torch.randn(size=(N, Ds[0])) - Y = torch.randint(high=Ds[-1], size=(N, )) + Y = torch.randint(high=Ds[-1], size=(N,)) lossfunc = extend(torch.nn.CrossEntropyLoss()) return TestProblem(X, Y, model, lossfunc) @@ -47,7 +49,7 @@ def make_small_linear_classification_problem(): ) N = 32 X = torch.randn(size=(N, Ds[0])) - Y = torch.randint(high=Ds[-1], size=(N, )) + Y = torch.randint(high=Ds[-1], size=(N,)) lossfunc = extend(torch.nn.CrossEntropyLoss()) return TestProblem(X, Y, model, lossfunc) @@ -61,41 +63,36 @@ def make_small_linear_classification_problem(): ALL_CONFIGURATIONS = [] CONFIGURATION_IDS = [] for probname, prob in reversed(list(TEST_PROBLEMS.items())): - ALL_CONFIGURATIONS.append(tuple([prob, AutogradImpl])) + ALL_CONFIGURATIONS.append((prob, AutogradImpl)) CONFIGURATION_IDS.append(probname + "-autograd") - ALL_CONFIGURATIONS.append(tuple([prob, BpextImpl])) + ALL_CONFIGURATIONS.append((prob, BpextImpl)) CONFIGURATION_IDS.append(probname + "-bpext") -@pytest.mark.parametrize( - "problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_diag_ggn(problem, impl, tmp_path, benchmark): if "large_autograd" in str(tmp_path): pytest.skip() benchmark(impl(problem).diag_ggn) -@pytest.mark.parametrize( - "problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_sgs(problem, impl, benchmark): benchmark(impl(problem).sgs) -@pytest.mark.parametrize( - "problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_batch_gradients(problem, impl, benchmark): benchmark(impl(problem).batch_gradients) @pytest.mark.skip() -@pytest.mark.parametrize( - "problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_var(problem, impl, benchmark): raise NotImplementedError -@pytest.mark.parametrize( - "problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) +@pytest.mark.parametrize("problem,impl", ALL_CONFIGURATIONS, ids=CONFIGURATION_IDS) def test_diag_h(problem, impl, tmp_path, benchmark): if "large_autograd" in str(tmp_path): pytest.skip() diff --git a/test/benchmark/jacobians.py b/test/benchmark/jacobians.py deleted file mode 100644 index 7267e4fd..00000000 --- a/test/benchmark/jacobians.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -import pytest - -from torch.nn import Linear -from torch.nn import ReLU, Sigmoid, Tanh -from torch.nn import Conv2d, MaxPool2d, AvgPool2d -from torch.nn import Dropout -from torch.nn import MSELoss, CrossEntropyLoss -from torch.nn import Sequential - -from .implementation.implementation_autograd import AutogradImpl -from .implementation.implementation_bpext import BpextImpl -from .test_problem import TestProblem - -from backpack.core.layers import Flatten -from backpack import extend as xtd - - -N, C, H, W = 100, 3, 4, 4 -D = C * H * W - - -def X_Y(input_type, output_type): - - if input_type is "IMAGE": - X = torch.randn(N, C, H, W) - elif input_type is "LINEAR": - X = torch.randn(N, D) - else: - raise NotImplementedError - - if output_type is "CE": - Y = torch.randint(high=2, size=(N, )) - else: - raise NotImplementedError - - return (X, Y) - - -models = [ - Sequential(xtd(Linear(D, 2))), - Sequential(xtd(Linear(D, 2)), xtd(ReLU())), - Sequential(xtd(Linear(D, 2)), xtd(Sigmoid())), - Sequential(xtd(Linear(D, 2)), xtd(Tanh())), - Sequential(xtd(Linear(D, 2)), xtd(Dropout())), -] -img_models = [ - Sequential(xtd(Conv2d(3, 2, 2)), Flatten(), xtd(Linear(18, 2))), - Sequential(xtd(MaxPool2d(3)), Flatten(), xtd(Linear(3, 2))), - Sequential(xtd(AvgPool2d(3)), Flatten(), xtd(Linear(3, 2))), - # Sequential(xtd(Conv2d(3, 2, 2)), xtd(MaxPool2d(3)), Flatten(), xtd(Linear(2, 2))), - # Sequential(xtd(Conv2d(3, 2, 2)), xtd(AvgPool2d(3)), Flatten(), xtd(Linear(2, 2))), - # Sequential(xtd(Conv2d(3, 2, 2)), xtd(ReLU()), Flatten(), xtd(Linear(18, 2))), - # Sequential(xtd(Conv2d(3, 2, 2)), xtd(Sigmoid()), Flatten(), xtd(Linear(18, 2))), - # Sequential(xtd(Conv2d(3, 2, 2)), xtd(Tanh()), Flatten(), xtd(Linear(18, 2))), - # Sequential(xtd(Conv2d(3, 2, 2)), xtd(Dropout()), Flatten(), xtd(Linear(18, 2))), -] - - -def all_problems(): - problems = [] - for model in models: - problems.append( - TestProblem(*X_Y("LINEAR", "CE"), model, xtd(CrossEntropyLoss()))) - for model in img_models: - problems.append( - TestProblem(*X_Y("IMAGE", "CE"), model, xtd(CrossEntropyLoss()))) - return problems - - -@pytest.mark.parametrize("problem", all_problems()) -def test_all_jacobian_ag(problem, benchmark): - benchmark(AutogradImpl(problem).gradient) - - -@pytest.mark.parametrize("problem", all_problems()) -def test_all_jacobian_bp(problem, benchmark): - benchmark(BpextImpl(problem).diag_ggn) diff --git a/test/benchmark/jvp.py b/test/benchmark/jvp.py index 9eefbac3..7e5c8925 100644 --- a/test/benchmark/jvp.py +++ b/test/benchmark/jvp.py @@ -2,18 +2,19 @@ import pytest import torch +from torch import allclose +from torch.nn import Dropout, ReLU, Sigmoid, Tanh + from backpack.core.derivatives import derivatives_for from backpack.hessianfree.lop import transposed_jacobian_vector_product from backpack.hessianfree.rop import jacobian_vector_product -from torch import allclose -from .jvp_linear import data_linear, data_linearconcat -from .jvp_conv2d import data_conv2d, data_conv2dconcat +from .jvp_activations import data as data_activation from .jvp_avgpool2d import data as data_avgpool2d +from .jvp_conv2d import data_conv2d +from .jvp_linear import data_linear from .jvp_maxpool2d import data as data_maxpool2d from .jvp_zeropad2d import data as data_zeropad2d -from .jvp_activations import data as data_activation -from torch.nn import Dropout, ReLU, Tanh, Sigmoid ATOL = 1e-3 RTOL = 1e-3 @@ -22,9 +23,7 @@ PROBLEMS = { "Linear": data_linear, - "LinearConcat": data_linearconcat, "Conv2d": data_conv2d, - "Conv2dConcat": data_conv2dconcat, "AvgPool2d": data_avgpool2d, "MaxPool2d": data_maxpool2d, "ZeroPad2d": data_zeropad2d, @@ -63,9 +62,9 @@ def skip_if_attribute_does_not_exists(module, attr): def ag_jtv_func(X, out, vin): def f(): - r = transposed_jacobian_vector_product( - out, X, vin, detach=False - )[0].contiguous() + r = transposed_jacobian_vector_product(out, X, vin, detach=False)[ + 0 + ].contiguous() if vin.is_cuda: torch.cuda.synchronize() return r @@ -75,9 +74,7 @@ def f(): def ag_jv_func(X, out, vout): def f(): - r = jacobian_vector_product( - out, X, vout, detach=False - )[0].contiguous() + r = jacobian_vector_product(out, X, vout, detach=False)[0].contiguous() if vout.is_cuda: torch.cuda.synchronize() return r @@ -87,9 +84,11 @@ def f(): def bp_jtv_func(module, vin): def f(): - r = derivatives_for[module.__class__]().jac_t_mat_prod( - module, None, None, vin - ).contiguous() + r = ( + derivatives_for[module.__class__]() + .jac_t_mat_prod(module, None, None, vin) + .contiguous() + ) if vin.is_cuda: torch.cuda.synchronize() return r @@ -99,9 +98,11 @@ def f(): def bp_jv_func(module, vout): def f(): - r = derivatives_for[module.__class__]().jac_mat_prod( - module, None, None, vout - ).contiguous() + r = ( + derivatives_for[module.__class__]() + .jac_mat_prod(module, None, None, vout) + .contiguous() + ) if vout.is_cuda: torch.cuda.synchronize() return r @@ -113,9 +114,9 @@ def ag_jtv_weight_func(module, out, vin): skip_if_attribute_does_not_exists(module, "weight") def f(): - r = transposed_jacobian_vector_product( - out, module.weight, vin, detach=False - )[0].contiguous() + r = transposed_jacobian_vector_product(out, module.weight, vin, detach=False)[ + 0 + ].contiguous() if vin.is_cuda: torch.cuda.synchronize() return r @@ -127,9 +128,11 @@ def bp_jtv_weight_func(module, vin): skip_if_attribute_does_not_exists(module, "weight") def f(): - r = derivatives_for[module.__class__]().weight_jac_t_mat_prod( - module, None, None, vin - ).contiguous() + r = ( + derivatives_for[module.__class__]() + .weight_jac_t_mat_prod(module, None, None, vin) + .contiguous() + ) if vin.is_cuda: torch.cuda.synchronize() return r @@ -141,9 +144,9 @@ def ag_jtv_bias_func(module, out, vin): skip_if_attribute_does_not_exists(module, "bias") def f(): - r = transposed_jacobian_vector_product( - out, module.bias, vin, detach=False - )[0].contiguous() + r = transposed_jacobian_vector_product(out, module.bias, vin, detach=False)[ + 0 + ].contiguous() if vin.is_cuda: torch.cuda.synchronize() return r @@ -155,9 +158,11 @@ def bp_jtv_bias_func(module, vin): skip_if_attribute_does_not_exists(module, "bias") def f(): - r = derivatives_for[module.__class__]().bias_jac_t_mat_prod( - module, None, None, vin.unsqueeze(2) - ).contiguous() + r = ( + derivatives_for[module.__class__]() + .bias_jac_t_mat_prod(module, None, None, vin.unsqueeze(2)) + .contiguous() + ) if vin.is_cuda: torch.cuda.synchronize() return r diff --git a/test/benchmark/jvp_activations.py b/test/benchmark/jvp_activations.py index 3d884e1c..a3390357 100644 --- a/test/benchmark/jvp_activations.py +++ b/test/benchmark/jvp_activations.py @@ -1,4 +1,5 @@ from torch import randn + from backpack import extend diff --git a/test/benchmark/jvp_avgpool2d.py b/test/benchmark/jvp_avgpool2d.py index 879ada89..33374340 100644 --- a/test/benchmark/jvp_avgpool2d.py +++ b/test/benchmark/jvp_avgpool2d.py @@ -1,5 +1,6 @@ from torch import randn from torch.nn import AvgPool2d + from backpack import extend diff --git a/test/benchmark/jvp_conv2d.py b/test/benchmark/jvp_conv2d.py index 2a9365b6..521f46f8 100644 --- a/test/benchmark/jvp_conv2d.py +++ b/test/benchmark/jvp_conv2d.py @@ -1,6 +1,6 @@ -from backpack.core.layers import Conv2dConcat from torch import randn from torch.nn import Conv2d + from backpack import extend @@ -26,26 +26,3 @@ def data_conv2d(device="cpu"): "vin_ag": vin, "vin_bp": vin.view(N, -1, 1), } - -def data_conv2dconcat(device="cpu"): - N, Cin, Hin, Win = 100, 10, 32, 32 - Cout, KernelH, KernelW = 25, 5, 5 - - X = randn(N, Cin, Hin, Win, requires_grad=True, device=device) - module = extend(Conv2dConcat(Cin, Cout, (KernelH, KernelW))).to(device=device) - out = module(X) - - Hout = Hin - (KernelH - 1) - Wout = Win - (KernelW - 1) - vin = randn(N, Cout, Hout, Wout, device=device) - vout = randn(N, Cin, Hin, Win, device=device) - - return { - "X": X, - "module": module, - "output": out, - "vout_ag": vout, - "vout_bp": vout.view(N, -1, 1), - "vin_ag": vin, - "vin_bp": vin.view(N, -1, 1), - } \ No newline at end of file diff --git a/test/benchmark/jvp_linear.py b/test/benchmark/jvp_linear.py index a3571875..5be34c2b 100644 --- a/test/benchmark/jvp_linear.py +++ b/test/benchmark/jvp_linear.py @@ -1,6 +1,6 @@ -from backpack.core.layers import LinearConcat from torch import randn from torch.nn import Linear + from backpack import extend @@ -23,24 +23,3 @@ def data_linear(device="cpu"): "vin_ag": vin, "vin_bp": vin.unsqueeze(2), } - - -def data_linearconcat(device="cpu"): - N, D1, D2 = 100, 64, 256 - - X = randn(N, D1, requires_grad=True, device=device) - linear = extend(LinearConcat(D1, D2)).to(device=device) - out = linear(X) - - vin = randn(N, D2, device=device) - vout = randn(N, D1, device=device) - - return { - "X": X, - "module": linear, - "output": out, - "vout_ag": vout, - "vout_bp": vout.unsqueeze(2), - "vin_ag": vin, - "vin_bp": vin.unsqueeze(2), - } diff --git a/test/benchmark/jvp_maxpool2d.py b/test/benchmark/jvp_maxpool2d.py index ff5cb312..9cdd7e48 100644 --- a/test/benchmark/jvp_maxpool2d.py +++ b/test/benchmark/jvp_maxpool2d.py @@ -1,5 +1,6 @@ from torch import randn from torch.nn import MaxPool2d + from backpack import extend diff --git a/test/benchmark/jvp_zeropad2d.py b/test/benchmark/jvp_zeropad2d.py index deb743ca..60cb035d 100644 --- a/test/benchmark/jvp_zeropad2d.py +++ b/test/benchmark/jvp_zeropad2d.py @@ -1,10 +1,12 @@ from torch import randn from torch.nn import ZeroPad2d + from backpack import extend + def data(device="cpu"): N, C, Hin, Win = 100, 10, 32, 32 - padding = [1,2,3,4] + padding = [1, 2, 3, 4] Hout = Hin + padding[2] + padding[3] Wout = Win + padding[0] + padding[1] diff --git a/test/bugfixes_test.py b/test/bugfixes_test.py new file mode 100644 index 00000000..958ec760 --- /dev/null +++ b/test/bugfixes_test.py @@ -0,0 +1,59 @@ +import itertools + +import pytest +import torch + +import backpack + + +def parameters_issue_30(): + possible_values = { + "N": [4], + "C_in": [4], + "C_out": [6], + "H": [6], + "W": [6], + "K": [3], + "S": [1, 3], + "pad": [0, 2], + "dil": [1, 2], + } + + configs = [ + dict(zip(possible_values.keys(), config_tuple)) + for config_tuple in itertools.product(*possible_values.values()) + ] + + return { + "argvalues": configs, + "ids": [str(config) for config in configs], + } + + +@pytest.mark.parametrize("params", **parameters_issue_30()) +def test_convolutions_stride_issue_30(params): + """ + https://github.com/f-dangel/backpack/issues/30 + + The gradient for the convolution is wrong when `stride` is not a multiple of + `D + 2*padding - dilation*(kernel-1) - 1`. + """ + torch.manual_seed(0) + + mod = torch.nn.Conv2d( + in_channels=params["C_in"], + out_channels=params["C_out"], + kernel_size=params["K"], + stride=params["S"], + padding=params["pad"], + dilation=params["dil"], + ) + backpack.extend(mod) + x = torch.randn(size=(params["N"], params["C_in"], params["W"], params["H"])) + + with backpack.backpack(backpack.extensions.BatchGrad()): + loss = torch.sum(mod(x)) + loss.backward() + + for p in mod.parameters(): + assert torch.allclose(p.grad, p.grad_batch.sum(0)) diff --git a/test/conv2d_test.py b/test/conv2d_test.py index 4a85de84..0421c869 100644 --- a/test/conv2d_test.py +++ b/test/conv2d_test.py @@ -5,12 +5,14 @@ Chellapilla: High Performance Convolutional Neural Networks for Document Processing (2007). """ +from random import choice, randint + import pytest -from torch import (Tensor, randn, allclose) +from torch import Tensor, allclose, randn from torch.nn import Conv2d -from random import (randint, choice) -from backpack import extend, backpack + import backpack.extensions as new_ext +from backpack import backpack, extend def ExtConv2d(*args, **kwargs): @@ -19,13 +21,15 @@ def ExtConv2d(*args, **kwargs): TEST_ATOL = 1e-4 + ### # Problem settings ### -def make_conv_params(in_channels, out_channels, kernel_size, stride, padding, - dilation, bias, kernel): +def make_conv_params( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias, kernel +): return { "in_channels": in_channels, "out_channels": out_channels, @@ -46,7 +50,8 @@ def make_conv_layer(LayerClass, conv_params): stride=conv_params["stride"], padding=conv_params["padding"], dilation=conv_params["dilation"], - bias=conv_params["bias"]) + bias=conv_params["bias"], + ) layer.weight.data = conv_params["kernel"] return layer @@ -57,8 +62,9 @@ def make_conv_layer(LayerClass, conv_params): kernel21 = [[1, 0], [0, 1]] kernel22 = [[2, 1], [2, 1]] kernel23 = [[1, 2], [2, 0]] -kernel = Tensor([[kernel11, kernel12, kernel13], - [kernel21, kernel22, kernel23]]).float() +kernel = Tensor( + [[kernel11, kernel12, kernel13], [kernel21, kernel22, kernel23]] +).float() CONV_PARAMS = make_conv_params( in_channels=3, @@ -68,7 +74,8 @@ def make_conv_layer(LayerClass, conv_params): padding=(0, 0), dilation=(1, 1), bias=False, - kernel=kernel) + kernel=kernel, +) # input (1 sample) in_feature1 = [[1, 2, 0], [1, 1, 3], [0, 2, 2]] @@ -88,7 +95,7 @@ def make_conv_layer(LayerClass, conv_params): def loss_function(tensor): """Test loss function. Sum over squared entries.""" - return ((tensor.contiguous().view(-1))**2).sum() + return ((tensor.contiguous().view(-1)) ** 2).sum() def test_forward(): @@ -103,19 +110,34 @@ def test_forward(): def random_convolutions_and_inputs( - in_channels=randint(1, 3), - out_channels=randint(1, 3), - kernel_size=(randint(1, 3), randint(1, 3)), - stride=(randint(1, 3), randint(1, 3)), - padding=(randint(0, 2), randint(0, 2)), - dilation=(randint(1, 3), randint(1, 3)), - bias=choice([True, False]), - batch_size=randint(1, 3), - in_size=(randint(8, 12), randint(8, 12))): + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + padding=None, + dilation=None, + bias=None, + batch_size=None, + in_size=None, +): """Return same torch/exts 2d conv modules and random inputs. Arguments can be fixed by handing them over. """ + + def __replace_if_none(var, by): + return by if var is None else var + + in_channels = __replace_if_none(in_channels, randint(1, 3)) + out_channels = __replace_if_none(out_channels, randint(1, 3)) + kernel_size = __replace_if_none(kernel_size, (randint(1, 3), randint(1, 3))) + stride = __replace_if_none(stride, (randint(1, 3), randint(1, 3))) + padding = __replace_if_none(padding, (randint(0, 2), randint(0, 2))) + dilation = __replace_if_none(dilation, (randint(1, 3), randint(1, 3))) + bias = __replace_if_none(bias, choice([True, False])) + batch_size = __replace_if_none(batch_size, randint(1, 3)) + in_size = __replace_if_none(in_size, (randint(8, 12), randint(8, 12))) + kernel_shape = (out_channels, in_channels) + kernel_size kernel = randn(kernel_shape) in_shape = (batch_size, in_channels) + in_size @@ -129,7 +151,8 @@ def random_convolutions_and_inputs( padding=padding, dilation=dilation, bias=bias, - kernel=kernel) + kernel=kernel, + ) conv2d = make_conv_layer(Conv2d, conv_params) g_conv2d = make_conv_layer(ExtConv2d, conv_params) @@ -155,24 +178,25 @@ def compare_grads(conv2d, g_conv2d, input): assert allclose(g_conv2d.bias.grad, conv2d.bias.grad, atol=TEST_ATOL) assert allclose(g_conv2d.weight.grad, conv2d.weight.grad, atol=TEST_ATOL) + assert allclose(g_conv2d.bias.grad_batch.sum(0), conv2d.bias.grad, atol=TEST_ATOL) assert allclose( - g_conv2d.bias.grad_batch.sum(0), conv2d.bias.grad, atol=TEST_ATOL) - assert allclose( - g_conv2d.weight.grad_batch.sum(0), conv2d.weight.grad, atol=TEST_ATOL) + g_conv2d.weight.grad_batch.sum(0), conv2d.weight.grad, atol=TEST_ATOL + ) @pytest.mark.skip("Test does not consistently fail or pass") def test_random_grad(random_runs=10): """Compare bias gradients for a single sample.""" - for i in range(random_runs): + for _ in range(random_runs): conv2d, g_conv2d, input = random_convolutions_and_inputs( - bias=True, batch_size=1) + bias=True, batch_size=1 + ) compare_grads(conv2d, g_conv2d, input) @pytest.mark.skip("Test does not consistently fail or pass") def test_random_grad_batch(random_runs=10): """Check bias gradients for a batch.""" - for i in range(random_runs): + for _ in range(random_runs): conv2d, g_conv2d, input = random_convolutions_and_inputs(bias=True) compare_grads(conv2d, g_conv2d, input) diff --git a/test/implementation/implementation.py b/test/implementation/implementation.py index 06c23ca0..14a0aff7 100644 --- a/test/implementation/implementation.py +++ b/test/implementation/implementation.py @@ -1,4 +1,4 @@ -class Implementation(): +class Implementation: def __init__(self, test_problem, device=None): self.problem = test_problem self.model = self.problem.model diff --git a/test/implementation/implementation_autograd.py b/test/implementation/implementation_autograd.py index 52d000dd..b07dd6f5 100644 --- a/test/implementation/implementation_autograd.py +++ b/test/implementation/implementation_autograd.py @@ -1,11 +1,11 @@ import torch -from .implementation import Implementation -from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist -from backpack.hessianfree.lop import L_op +from backpack.hessianfree.hvp import hessian_vector_product from backpack.hessianfree.rop import R_op -from backpack.hessianfree.utils import vector_to_parameter_list +from backpack.utils.convert_parameters import vector_to_parameter_list + +from .implementation import Implementation class AutogradImpl(Implementation): @@ -19,8 +19,7 @@ def batch_gradients(self): ] for b in range(self.N): - gradients = torch.autograd.grad( - self.loss(b), self.model.parameters()) + gradients = torch.autograd.grad(self.loss(b), self.model.parameters()) for idx, g in enumerate(gradients): batch_grads[idx][b, :] = g.detach() / self.N @@ -28,8 +27,7 @@ def batch_gradients(self): def batch_l2(self): batch_grad = self.batch_gradients() - batch_l2 = [(g**2).sum(list(range(1, len(g.shape)))) - for g in batch_grad] + batch_l2 = [(g ** 2).sum(list(range(1, len(g.shape)))) for g in batch_grad] return batch_l2 def variance(self): @@ -41,10 +39,9 @@ def sgs(self): sgs = self.plist_like(self.model.parameters()) for b in range(self.N): - gradients = torch.autograd.grad( - self.loss(b), self.model.parameters()) + gradients = torch.autograd.grad(self.loss(b), self.model.parameters()) for idx, g in enumerate(gradients): - sgs[idx] += (g.detach() / self.N)**2 + sgs[idx] += (g.detach() / self.N) ** 2 return sgs @@ -54,7 +51,7 @@ def diag_ggn(self): def extract_ith_element_of_diag_ggn(i, p): v = torch.zeros(p.numel()).to(self.device) - v[i] = 1. + v[i] = 1.0 vs = vector_to_parameter_list(v, [p]) GGN_vs = ggn_vector_product_from_plist(loss, outputs, [p], vs) GGN_v = torch.cat([g.detach().view(-1) for g in GGN_vs]) @@ -65,8 +62,7 @@ def extract_ith_element_of_diag_ggn(i, p): diag_ggn_p = torch.zeros_like(p).view(-1) for parameter_index in range(p.numel()): - diag_value = extract_ith_element_of_diag_ggn( - parameter_index, p) + diag_value = extract_ith_element_of_diag_ggn(parameter_index, p) diag_ggn_p[parameter_index] = diag_value diag_ggns.append(diag_ggn_p.view(p.size())) @@ -74,16 +70,15 @@ def extract_ith_element_of_diag_ggn(i, p): return diag_ggns def diag_h(self): - loss = self.problem.lossfunc( - self.model(self.problem.X), self.problem.Y) + loss = self.problem.lossfunc(self.model(self.problem.X), self.problem.Y) def hvp(df_dx, x, v): Hv = R_op(df_dx, x, v) - return tuple([j.detach() for j in Hv]) + return [j.detach() for j in Hv] def extract_ith_element_of_diag_h(i, p, df_dx): v = torch.zeros(p.numel()).to(self.device) - v[i] = 1. + v[i] = 1.0 vs = vector_to_parameter_list(v, [p]) Hvs = hvp(df_dx, [p], vs) @@ -95,11 +90,9 @@ def extract_ith_element_of_diag_h(i, p, df_dx): for p in list(self.model.parameters()): diag_h_p = torch.zeros_like(p).view(-1) - df_dx = torch.autograd.grad( - loss, [p], create_graph=True, retain_graph=True) + df_dx = torch.autograd.grad(loss, [p], create_graph=True, retain_graph=True) for parameter_index in range(p.numel()): - diag_value = extract_ith_element_of_diag_h( - parameter_index, p, df_dx) + diag_value = extract_ith_element_of_diag_h(parameter_index, p, df_dx) diag_h_p[parameter_index] = diag_value diag_hs.append(diag_h_p.view(p.size())) @@ -109,20 +102,26 @@ def extract_ith_element_of_diag_h(i, p, df_dx): def h_blocks(self): mat_list = [] for p in self.model.parameters(): - mat_list.append(torch.eye(p.numel(), device=p.device)) - return self.hmp(mat_list) + mat_list.append( + torch.eye(p.numel(), device=p.device).reshape(p.numel(), *p.shape) + ) + # return self.hmp(mat_list) + hmp_list = self.hmp(mat_list) + return [ + mat.reshape(p.numel(), p.numel()) + for mat, p in zip(hmp_list, self.model.parameters()) + ] def hvp(self, vec_list): - mat_list = [vec.unsqueeze(-1) for vec in vec_list] + mat_list = [vec.unsqueeze(0) for vec in vec_list] results = self.hmp(mat_list) - results_vec = [mat.squeeze(-1) for mat in results] + results_vec = [mat.squeeze(0) for mat in results] return results_vec def hmp(self, mat_list): assert len(mat_list) == len(list(self.model.parameters())) - loss = self.problem.lossfunc( - self.model(self.problem.X), self.problem.Y) + loss = self.problem.lossfunc(self.model(self.problem.X), self.problem.Y) results = [] for p, mat in zip(self.model.parameters(), mat_list): @@ -132,22 +131,29 @@ def hmp(self, mat_list): def hvp_applied_columnwise(self, f, p, mat): h_cols = [] - for i in range(mat.size(1)): - hvp_col_i = hessian_vector_product(f, [p], mat[:, i].view_as(p))[0] - h_cols.append(hvp_col_i.view(-1, 1)) + for i in range(mat.size(0)): + hvp_col_i = hessian_vector_product(f, [p], mat[i, :])[0] + h_cols.append(hvp_col_i.unsqueeze(0)) - return torch.cat(h_cols, dim=1) + return torch.cat(h_cols, dim=0) def ggn_blocks(self): mat_list = [] for p in self.model.parameters(): - mat_list.append(torch.eye(p.numel(), device=p.device)) - return self.ggn_mp(mat_list) + mat_list.append( + torch.eye(p.numel(), device=p.device).reshape(p.numel(), *p.shape) + ) + ggn_mp_list = self.ggn_mp(mat_list) + return [ + mat.reshape(p.numel(), p.numel()) + for mat, p in zip(ggn_mp_list, self.model.parameters()) + ] + # return ggn_mp_list def ggn_vp(self, vec_list): - mat_list = [vec.unsqueeze(-1) for vec in vec_list] + mat_list = [vec.unsqueeze(0) for vec in vec_list] results = self.ggn_mp(mat_list) - results_vec = [mat.squeeze(-1) for mat in results] + results_vec = [mat.squeeze(0) for mat in results] return results_vec def ggn_mp(self, mat_list): @@ -158,21 +164,18 @@ def ggn_mp(self, mat_list): results = [] for p, mat in zip(self.model.parameters(), mat_list): - results.append( - self.ggn_vp_applied_columnwise(loss, outputs, p, mat)) + results.append(self.ggn_vp_applied_columnwise(loss, outputs, p, mat)) return results def ggn_vp_applied_columnwise(self, loss, out, p, mat): ggn_cols = [] - for i in range(mat.size(1)): - col_i = vector_to_parameter_list(mat[:, i], [p]) - - GGN_col_i = ggn_vector_product_from_plist(loss, out, [p], col_i) - GGN_col_i = torch.cat([g.detach().view(-1) for g in GGN_col_i]) - ggn_cols.append(GGN_col_i.view(-1, 1)) + for i in range(mat.size(0)): + col_i = mat[i, :] + GGN_col_i = ggn_vector_product_from_plist(loss, out, [p], col_i)[0] + ggn_cols.append(GGN_col_i.unsqueeze(0)) - return torch.cat(ggn_cols, dim=1) + return torch.cat(ggn_cols, dim=0) def plist_like(self, plist): - return list([torch.zeros(*p.size()).to(self.device) for p in plist]) + return [torch.zeros(*p.size()).to(self.device) for p in plist] diff --git a/test/implementation/implementation_bpext.py b/test/implementation/implementation_bpext.py index 23ba5ebb..9ee086af 100644 --- a/test/implementation/implementation_bpext.py +++ b/test/implementation/implementation_bpext.py @@ -1,14 +1,16 @@ import torch -from .implementation import Implementation + +import backpack.extensions as new_ext from backpack import backpack from backpack.extensions.curvature import Curvature -from backpack.extensions.secondorder.utils import matrix_from_kron_facs from backpack.extensions.secondorder.hbp import ( - ExpectationApproximation, BackpropStrategy, + ExpectationApproximation, LossHessianStrategy, ) -import backpack.extensions as new_ext +from backpack.utils.kroneckers import kfacs_to_mat + +from .implementation import Implementation class BpextImpl(Implementation): @@ -40,12 +42,17 @@ def sgs(self): return sgs def diag_ggn(self): - print(self.model) with backpack(new_ext.DiagGGN()): self.loss().backward() diag_ggn = [p.diag_ggn for p in self.model.parameters()] return diag_ggn + def diag_ggn_mc(self): + with backpack(new_ext.DiagGGNMC()): + self.loss().backward() + diag_ggn = [p.diag_ggn_mc for p in self.model.parameters()] + return diag_ggn + def diag_h(self): with backpack(new_ext.DiagHessian()): self.loss().backward() @@ -82,7 +89,7 @@ def matrices_from_kronecker_curvature(self, extension_cls, savefield): self.loss().backward() for p in self.model.parameters(): factors = getattr(p, savefield) - results.append(matrix_from_kron_facs(factors)) + results.append(kfacs_to_mat(factors)) return results def kfra_blocks(self): @@ -94,23 +101,26 @@ def kflr_blocks(self): def kfac_blocks(self): return self.matrices_from_kronecker_curvature(new_ext.KFAC, "kfac") - def hbp_with_curv(self, - curv_type, - loss_hessian_strategy=LossHessianStrategy.AVERAGE, - backprop_strategy=BackpropStrategy.BATCH_AVERAGE, - ea_strategy=ExpectationApproximation.BOTEV_MARTENS): + def hbp_with_curv( + self, + curv_type, + loss_hessian_strategy=LossHessianStrategy.SUM, + backprop_strategy=BackpropStrategy.BATCH_AVERAGE, + ea_strategy=ExpectationApproximation.BOTEV_MARTENS, + ): results = [] with backpack( - new_ext.HBP( - curv_type=curv_type, - loss_hessian_strategy=loss_hessian_strategy, - backprop_strategy=backprop_strategy, - ea_strategy=ea_strategy, - )): + new_ext.HBP( + curv_type=curv_type, + loss_hessian_strategy=loss_hessian_strategy, + backprop_strategy=backprop_strategy, + ea_strategy=ea_strategy, + ) + ): self.loss().backward() for p in self.model.parameters(): factors = p.hbp - results.append(matrix_from_kron_facs(factors)) + results.append(kfacs_to_mat(factors)) return results def hbp_single_sample_ggn_blocks(self): diff --git a/test/interface_test.py b/test/interface_test.py index ea0a17e2..92f64d96 100644 --- a/test/interface_test.py +++ b/test/interface_test.py @@ -3,12 +3,10 @@ """ import pytest import torch -from torch.nn import Linear, ReLU, CrossEntropyLoss -from torch.nn import Sequential -from torch.nn import Conv2d -from backpack.core.layers import Flatten -from backpack import extend, backpack +from torch.nn import Conv2d, CrossEntropyLoss, Linear, ReLU, Sequential + import backpack.extensions as new_ext +from backpack import backpack, extend def dummy_forward_pass(): @@ -40,7 +38,7 @@ def dummy_forward_pass_conv(): Y = torch.randint(high=5, size=(N,)) conv = Conv2d(3, 2, 2) lin = Linear(18, 5) - model = extend(Sequential(conv, Flatten(), lin)) + model = extend(Sequential(conv, torch.nn.Flatten(), lin)) loss = extend(CrossEntropyLoss()) def forward(): @@ -53,10 +51,7 @@ def forward(): forward_func_conv, weights_conv, bias_conv = dummy_forward_pass_conv() -def interface_test(feature, - weight_has_attr=True, - bias_has_attr=True, - use_conv=False): +def interface_test(feature, weight_has_attr=True, bias_has_attr=True, use_conv=False): if use_conv: f, ws, bs = forward_func_conv, weights_conv, bias_conv else: diff --git a/test/layers.py b/test/layers.py index 1106a920..2b7d18ca 100644 --- a/test/layers.py +++ b/test/layers.py @@ -1,31 +1,28 @@ -from backpack.core.layers import Conv2dConcat, LinearConcat from torch import nn LINEARS = { - 'Linear': nn.Linear, - 'LinearConcat': LinearConcat, + "Linear": nn.Linear, } ACTIVATIONS = { - 'ReLU': nn.ReLU, - 'Sigmoid': nn.Sigmoid, - 'Tanh': nn.Tanh, + "ReLU": nn.ReLU, + "Sigmoid": nn.Sigmoid, + "Tanh": nn.Tanh, } CONVS = { - 'Conv2d': nn.Conv2d, - 'Conv2dConcat': Conv2dConcat, + "Conv2d": nn.Conv2d, } PADDINGS = { - 'ZeroPad2d': nn.ZeroPad2d, + "ZeroPad2d": nn.ZeroPad2d, } POOLINGS = { - 'MaxPool2d': nn.MaxPool2d, - 'AvgPool2d': nn.AvgPool2d, + "MaxPool2d": nn.MaxPool2d, + "AvgPool2d": nn.AvgPool2d, } BN = { - 'BatchNorm1d': nn.BatchNorm1d, + "BatchNorm1d": nn.BatchNorm1d, } diff --git a/test/layers_test.py b/test/layers_test.py index 2f6d3fec..e69de29b 100644 --- a/test/layers_test.py +++ b/test/layers_test.py @@ -1,142 +0,0 @@ -"""Test batch gradient computation of linear layer.""" -from backpack import extend -from torch import allclose, randn, randint, manual_seed, cat -from torch.autograd import grad -from torch.nn import Linear, Sequential, CrossEntropyLoss, Conv2d -from backpack.core.layers import LinearConcat, Conv2dConcat, Flatten - -# Linear - - -def data(): - N = 5 - Ds = [20, 10, 3] - - X = randn(N, Ds[0]) - Y = randint(high=Ds[-1], size=(N, )) - - manual_seed(0) - model1 = Sequential( - extend(Linear(Ds[0], Ds[1])), extend(Linear(Ds[1], Ds[2]))) - - manual_seed(0) - model2 = Sequential( - extend(LinearConcat(Ds[0], Ds[1])), extend(LinearConcat(Ds[1], Ds[2]))) - - loss = CrossEntropyLoss() - - return X, Y, model1, model2, loss - - -def test_LinearConcat_forward(): - X, Y, model1, model2, loss = data() - assert allclose(model1(X), model2(X)) - - -def test_LinearConcat_backward(): - X, Y, model1, model2, loss = data() - - d1 = grad(loss(model1(X), Y), model1.parameters()) - d2 = grad(loss(model2(X), Y), model2.parameters()) - - d1 = list(d1) - d2 = list(d2) - - d1_cat = list() - - # take grad of separated parameters and concat them - for i in range(len(d2)): - d1_cat.append(cat([ - d1[2 * i], - d1[2 * i + 1].unsqueeze(-1), - ], dim=1)) - - for p1, p2 in zip(d1_cat, d2): - assert allclose(p1, p2) - - -# Conv -TEST_SETTINGS = { - "in_features": (3, 4, 5), - "out_channels": 3, - "kernel_size": (3, 2), - "padding": (1, 1), - "bias": True, - "batch": 3, - "rtol": 1e-5, - "atol": 5e-4 -} - - -def convlayer(join_params): - conv_cls = Conv2dConcat if join_params else Conv2d - return extend( - conv_cls( - in_channels=TEST_SETTINGS["in_features"][0], - out_channels=TEST_SETTINGS["out_channels"], - kernel_size=TEST_SETTINGS["kernel_size"], - padding=TEST_SETTINGS["padding"], - bias=TEST_SETTINGS["bias"])) - - -def convlayer2(join_params): - conv_cls = Conv2dConcat if join_params else Conv2d - return extend( - conv_cls( - in_channels=TEST_SETTINGS["in_features"][0], - out_channels=TEST_SETTINGS["out_channels"], - kernel_size=TEST_SETTINGS["kernel_size"], - padding=TEST_SETTINGS["padding"], - bias=TEST_SETTINGS["bias"])) - - -def data_conv(): - input_size = (TEST_SETTINGS["batch"], ) + TEST_SETTINGS["in_features"] - - temp_model = Sequential(convlayer(False), convlayer2(False), Flatten()) - - X = randn(size=input_size) - Y = randint(high=X.shape[1], size=(temp_model(X).shape[0], )) - - del temp_model - - manual_seed(0) - model1 = Sequential(convlayer(False), convlayer2(False), Flatten()) - - manual_seed(0) - model2 = Sequential(convlayer(True), convlayer2(True), Flatten()) - - loss = CrossEntropyLoss() - - return X, Y, model1, model2, loss - - -def test_Conv2dConcat_forward(): - X, Y, model1, model2, loss = data_conv() - assert allclose(model1(X), model2(X)) - - -def test_Conv2dConcat_backward(): - X, Y, model1, model2, loss = data_conv() - - d1 = grad(loss(model1(X), Y), model1.parameters()) - d2 = grad(loss(model2(X), Y), model2.parameters()) - - d1 = list(d1) - d2 = list(d2) - - d1_cat = list() - - # take grad of separated parameters and concat them - for i in range(len(d2)): - d1_cat.append( - cat( - [ - # require view because concat stores kernel as 2d tensor - d1[2 * i].view(d1[2 * i].shape[0], -1), - d1[2 * i + 1].unsqueeze(-1), - ], - dim=1)) - - for p1, p2 in zip(d1_cat, d2): - assert allclose(p1, p2) diff --git a/test/linear_test.py b/test/linear_test.py index b4f52468..91bc7f7a 100644 --- a/test/linear_test.py +++ b/test/linear_test.py @@ -2,8 +2,9 @@ from torch import Tensor, allclose from torch.nn import Linear -from backpack import extend, backpack + import backpack.extensions as new_ext +from backpack import backpack, extend def ExtLinear(*args, **kwargs): @@ -33,13 +34,13 @@ def make_lin_layer(LayerClass, in_features, out_features, weight, bias): def loss_function(tensor): """Test loss function. Sum over squared entries.""" - return ((tensor.view(-1))**2).sum() + return ((tensor.view(-1)) ** 2).sum() EXAMPLE_1 = { "in": Tensor([[1, 1, 1]]).float(), "out": Tensor([[6 + 7, 15 + 8]]).float(), - "loss": 13**2 + 23**2, + "loss": 13 ** 2 + 23 ** 2, "bias_grad": Tensor([2 * 13, 2 * 23]).float(), "bias_grad_batch": Tensor([2 * 13, 2 * 23]).float(), "weight_grad": Tensor([[26, 26, 26], [46, 46, 46]]).float(), @@ -47,20 +48,15 @@ def loss_function(tensor): } EXAMPLE_2 = { - "in": - Tensor([[1, 0, 1], [0, 1, 0]]).float(), - "out": - Tensor([[4 + 7, 10 + 8], [2 + 7, 5 + 8]]).float(), - "loss": - 11**2 + 18**2 + 9**2 + 13**2, - "bias_grad": - Tensor([2 * (11 + 9), 2 * (18 + 13)]), - "bias_grad_batch": - Tensor([[2 * 11, 2 * 18], [2 * 9, 2 * 13]]).float(), - "weight_grad": - Tensor([[22, 18, 22], [36, 26, 36]]).float(), - "weight_grad_batch": - Tensor([[[22, 0, 22], [36, 0, 36]], [[0, 18, 0], [0, 26, 0]]]).float(), + "in": Tensor([[1, 0, 1], [0, 1, 0]]).float(), + "out": Tensor([[4 + 7, 10 + 8], [2 + 7, 5 + 8]]).float(), + "loss": 11 ** 2 + 18 ** 2 + 9 ** 2 + 13 ** 2, + "bias_grad": Tensor([2 * (11 + 9), 2 * (18 + 13)]), + "bias_grad_batch": Tensor([[2 * 11, 2 * 18], [2 * 9, 2 * 13]]).float(), + "weight_grad": Tensor([[22, 18, 22], [36, 26, 36]]).float(), + "weight_grad_batch": Tensor( + [[[22, 0, 22], [36, 0, 36]], [[0, 18, 0], [0, 26, 0]]] + ).float(), } EXAMPLES = [EXAMPLE_1, EXAMPLE_2] @@ -109,18 +105,22 @@ def test_grad(): def test_grad_batch(): """Test computation of bias/weight batch gradients.""" for ex in EXAMPLES: - input, b_grad_batch, w_grad_batch = ex["in"], ex[ - "bias_grad_batch"], ex["weight_grad_batch"] + input, b_grad_batch, w_grad_batch = ( + ex["in"], + ex["bias_grad_batch"], + ex["weight_grad_batch"], + ) loss = loss_function(g_lin(input)) with backpack(new_ext.BatchGrad()): loss.backward() assert allclose(g_lin.bias.grad_batch, b_grad_batch), "{} ≠ {}".format( - g_lin.bias.grad_batch, b_grad_batch) - assert allclose(g_lin.weight.grad_batch, - w_grad_batch), "{} ≠ {}".format( - g_lin.weight.grad_batch, w_grad_batch) + g_lin.bias.grad_batch, b_grad_batch + ) + assert allclose(g_lin.weight.grad_batch, w_grad_batch), "{} ≠ {}".format( + g_lin.weight.grad_batch, w_grad_batch + ) del g_lin.bias.grad del g_lin.weight.grad diff --git a/test/networks.py b/test/networks.py index 1016e003..9b4018f0 100644 --- a/test/networks.py +++ b/test/networks.py @@ -1,4 +1,4 @@ -class HiddenLayer(): +class HiddenLayer: def __init__(self, linear_cls, activation_cls=None): self.linear_cls = linear_cls self.activation_cls = activation_cls diff --git a/test/problems.py b/test/problems.py index a6af8909..bb85aa4e 100644 --- a/test/problems.py +++ b/test/problems.py @@ -1,10 +1,11 @@ import torch -from .test_problem import TestProblem + from backpack import extend -from backpack.core.layers import Flatten + +from .test_problem import TestProblem -class ProblemBase(): +class ProblemBase: def __init__(self, input_shape, network_modules): self.input_shape = input_shape self.net_modules = network_modules @@ -54,14 +55,13 @@ def get_loss_func(self): def get_modules(self): modules = self.get_network_modules() - modules.append(Flatten()) + modules.append(torch.nn.Flatten()) modules.append(self.sum_output_layer()) return modules def sum_output_layer(self): num_outputs = self.get_num_network_outputs() - return torch.nn.Linear( - in_features=num_outputs, out_features=1, bias=True) + return torch.nn.Linear(in_features=num_outputs, out_features=1, bias=True) def get_XY(self, model): X = torch.randn(size=self.input_shape) @@ -77,12 +77,12 @@ def get_loss_func(self): def get_modules(self): modules = self.get_network_modules() - modules.append(Flatten()) + modules.append(torch.nn.Flatten()) return modules def get_XY(self, model): X = torch.randn(size=self.input_shape) - Y = torch.randint(high=model(X).shape[1], size=(X.shape[0], )) + Y = torch.randint(high=model(X).shape[1], size=(X.shape[0],)) return X, Y diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 00000000..478d7ace --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +optional_tests: + montecarlo: slow tests using low-precision allclose after Monte-Carlo sampling +filterwarnings = + ignore:cannot collect test class 'TestProblem':pytest.PytestCollectionWarning: diff --git a/test/readme.md b/test/readme.md new file mode 100644 index 00000000..48c33d49 --- /dev/null +++ b/test/readme.md @@ -0,0 +1,28 @@ +# Testing +Automated testing based on [`pytest`](https://docs.pytest.org/en/latest/). +Install with `pip install pytest`, run tests with `pytest` from this directory. + +Useful options: +``` +-v verbose output +-k text select tests containing text in their name +-x stop if a test fails +--tb=no disable trace output +--help +``` + +## Optional tests +Uses [`pytest-optional-tests`](https://pypi.org/project/pytest-optional-tests) f +or optional tests. Install with `pip install pytest-optional-tests`. + +Optional test categories are defined in `pytest.ini` +and tests are marked with `@pytest.mark.OPTIONAL_TEST_CATEGORY`. + +To run the optional tests, use +`pytest --run-optional-tests=OPTIONAL_TEST_CATEGORY` + +## Run all tests for BackPACK +In working directory `tests/`, run +```bash +pytest -vx --run-optional-tests=montecarlo . +``` diff --git a/test/test_problem.py b/test/test_problem.py index b0ba71b0..b67432d5 100644 --- a/test/test_problem.py +++ b/test/test_problem.py @@ -1,10 +1,12 @@ import torch + from backpack import extend +DEVICE_CPU = torch.device("cpu") -class TestProblem(): - def __init__(self, X, Y, model, lossfunc, device=torch.device("cpu")): +class TestProblem: + def __init__(self, X, Y, model, lossfunc, device=DEVICE_CPU): """ A traditional machine learning test problem, loss(model(X), Y) @@ -48,12 +50,7 @@ def clear(self): """ Clear saved state """ - attrs = [ - "sum_grad_squared" - "grad_batch" - "grad" - "diag_ggn" - ] + attrs = ["sum_grad_squared" "grad_batch" "grad" "diag_ggn"] def safeclear(p, attr): if hasattr(p, attr): diff --git a/test/test_problems_activations.py b/test/test_problems_activations.py index 1670b92d..fd94462a 100644 --- a/test/test_problems_activations.py +++ b/test/test_problems_activations.py @@ -1,6 +1,6 @@ -from .layers import LINEARS, ACTIVATIONS -from .problems import make_regression_problem, make_classification_problem +from .layers import ACTIVATIONS, LINEARS from .networks import single_linear_layer, two_linear_layers +from .problems import make_classification_problem, make_regression_problem TEST_SETTINGS = { "in_features": 7, @@ -9,7 +9,7 @@ "bias": True, "batch": 5, "rtol": 1e-5, - "atol": 1e-5 + "atol": 1e-5, } INPUT_SHAPE = (TEST_SETTINGS["batch"], TEST_SETTINGS["in_features"]) @@ -18,20 +18,23 @@ for act_name, act_cls in ACTIVATIONS.items(): for lin_name, lin_cls in LINEARS.items(): - TEST_PROBLEMS["{}{}-regression".format( - lin_name, act_name)] = make_regression_problem( - INPUT_SHAPE, - single_linear_layer( - TEST_SETTINGS, lin_cls, activation_cls=act_cls)) + TEST_PROBLEMS[ + "{}{}-regression".format(lin_name, act_name) + ] = make_regression_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) - TEST_PROBLEMS["{}{}-classification".format( - lin_name, act_name)] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer( - TEST_SETTINGS, lin_cls, activation_cls=act_cls)) + TEST_PROBLEMS[ + "{}{}-classification".format(lin_name, act_name) + ] = make_classification_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) - TEST_PROBLEMS["{}{}-2layer-classification".format( - lin_name, act_name)] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers( - TEST_SETTINGS, lin_cls, activation_cls=act_cls)) + TEST_PROBLEMS[ + "{}{}-2layer-classification".format(lin_name, act_name) + ] = make_classification_problem( + INPUT_SHAPE, + two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) diff --git a/test/test_problems_bn.py b/test/test_problems_bn.py index 04d7a50d..a337a44a 100644 --- a/test/test_problems_bn.py +++ b/test/test_problems_bn.py @@ -1,4 +1,3 @@ -import torch from torch.nn import BatchNorm1d from .layers import LINEARS @@ -12,7 +11,7 @@ "bias": True, "batch": 5, "rtol": 1e-5, - "atol": 1e-5 + "atol": 1e-5, } @@ -30,20 +29,23 @@ def bn_layer2(): for lin_name, lin_cls in LINEARS.items(): - TEST_PROBLEMS["{}-bn-regression".format( - lin_name)] = make_regression_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + - [BatchNorm1d(TEST_SETTINGS["out_features"])]) - - TEST_PROBLEMS["{}-bn-classification".format( - lin_name)] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + - [bn_layer1()]) - - TEST_PROBLEMS["{}-bn-2layer-classification".format( - lin_name)] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + - [bn_layer2()]) + TEST_PROBLEMS["{}-bn-regression".format(lin_name)] = make_regression_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + + [BatchNorm1d(TEST_SETTINGS["out_features"])], + ) + + TEST_PROBLEMS[ + "{}-bn-classification".format(lin_name) + ] = make_classification_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + + [bn_layer1()], + ) + + TEST_PROBLEMS[ + "{}-bn-2layer-classification".format(lin_name) + ] = make_classification_problem( + INPUT_SHAPE, + two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + [bn_layer2()], + ) diff --git a/test/test_problems_convolutions.py b/test/test_problems_convolutions.py index c0af8905..2e483219 100644 --- a/test/test_problems_convolutions.py +++ b/test/test_problems_convolutions.py @@ -1,10 +1,10 @@ import numpy as np import torch -from backpack.core.layers import Flatten, Conv2dConcat + from backpack import extend + +from .layers import ACTIVATIONS, CONVS from .test_problem import TestProblem -from .layers import CONVS -from .layers import ACTIVATIONS TEST_SETTINGS = { "in_features": (3, 4, 5), @@ -14,43 +14,57 @@ "bias": True, "batch": 3, "rtol": 1e-5, - "atol": 5e-4 + "atol": 5e-4, } def convlayer(conv_cls, settings): return extend( - conv_cls(in_channels=settings["in_features"][0], - out_channels=settings["out_channels"], - kernel_size=settings["kernel_size"], - padding=settings["padding"], - bias=settings["bias"])) + conv_cls( + in_channels=settings["in_features"][0], + out_channels=settings["out_channels"], + kernel_size=settings["kernel_size"], + padding=settings["padding"], + bias=settings["bias"], + ) + ) def convlayer2(conv_cls, settings): return extend( - conv_cls(in_channels=settings["in_features"][0], - out_channels=settings["out_channels"], - kernel_size=settings["kernel_size"], - padding=settings["padding"], - bias=settings["bias"])) + conv_cls( + in_channels=settings["in_features"][0], + out_channels=settings["out_channels"], + kernel_size=settings["kernel_size"], + padding=settings["padding"], + bias=settings["bias"], + ) + ) -input_size = (TEST_SETTINGS["batch"], ) + TEST_SETTINGS["in_features"] +input_size = (TEST_SETTINGS["batch"],) + TEST_SETTINGS["in_features"] X = torch.randn(size=input_size) def convearlayer(settings): return extend( - torch.nn.Linear(in_features=np.prod( - [f - settings["padding"][0] - for f in settings["in_features"]]) * settings["out_channels"], - out_features=1)) + torch.nn.Linear( + in_features=np.prod( + [f - settings["padding"][0] for f in settings["in_features"]] + ) + * settings["out_channels"], + out_features=1, + ) + ) def make_regression_problem(conv_cls, act_cls): - model = torch.nn.Sequential(convlayer(conv_cls, TEST_SETTINGS), act_cls(), - Flatten(), convearlayer(TEST_SETTINGS)) + model = torch.nn.Sequential( + convlayer(conv_cls, TEST_SETTINGS), + act_cls(), + torch.nn.Flatten(), + convearlayer(TEST_SETTINGS), + ) Y = torch.randn(size=(model(X).shape[0], 1)) @@ -60,10 +74,11 @@ def make_regression_problem(conv_cls, act_cls): def make_classification_problem(conv_cls, act_cls): - model = torch.nn.Sequential(convlayer(conv_cls, TEST_SETTINGS), act_cls(), - Flatten()) + model = torch.nn.Sequential( + convlayer(conv_cls, TEST_SETTINGS), act_cls(), torch.nn.Flatten() + ) - Y = torch.randint(high=X.shape[1], size=(model(X).shape[0], )) + Y = torch.randint(high=X.shape[1], size=(model(X).shape[0],)) lossfunc = extend(torch.nn.CrossEntropyLoss()) @@ -71,11 +86,15 @@ def make_classification_problem(conv_cls, act_cls): def make_2layer_classification_problem(conv_cls, act_cls): - model = torch.nn.Sequential(convlayer(conv_cls, TEST_SETTINGS), act_cls(), - convlayer2(conv_cls, TEST_SETTINGS), act_cls(), - Flatten()) + model = torch.nn.Sequential( + convlayer(conv_cls, TEST_SETTINGS), + act_cls(), + convlayer2(conv_cls, TEST_SETTINGS), + act_cls(), + torch.nn.Flatten(), + ) - Y = torch.randint(high=X.shape[1], size=(model(X).shape[0], )) + Y = torch.randint(high=X.shape[1], size=(model(X).shape[0],)) lossfunc = extend(torch.nn.CrossEntropyLoss()) @@ -85,11 +104,12 @@ def make_2layer_classification_problem(conv_cls, act_cls): TEST_PROBLEMS = {} for conv_name, conv_cls in CONVS.items(): for act_name, act_cls in ACTIVATIONS.items(): - TEST_PROBLEMS["{}-{}-regression".format( - conv_name, act_name)] = make_regression_problem(conv_cls, act_cls) - TEST_PROBLEMS["{}-{}-classification".format( - conv_name, - act_name)] = make_classification_problem(conv_cls, act_cls) - TEST_PROBLEMS["{}-{}-2layer-classification".format( - conv_name, - act_name)] = make_2layer_classification_problem(conv_cls, act_cls) + TEST_PROBLEMS[ + "{}-{}-regression".format(conv_name, act_name) + ] = make_regression_problem(conv_cls, act_cls) + TEST_PROBLEMS[ + "{}-{}-classification".format(conv_name, act_name) + ] = make_classification_problem(conv_cls, act_cls) + TEST_PROBLEMS[ + "{}-{}-2layer-classification".format(conv_name, act_name) + ] = make_2layer_classification_problem(conv_cls, act_cls) diff --git a/test/test_problems_kfacs.py b/test/test_problems_kfacs.py index 426eaf75..39c9e08c 100644 --- a/test/test_problems_kfacs.py +++ b/test/test_problems_kfacs.py @@ -1,6 +1,6 @@ -from .problems import make_regression_problem, make_classification_problem -from .networks import single_linear_layer, two_linear_layers from .layers import ACTIVATIONS, LINEARS +from .networks import single_linear_layer, two_linear_layers +from .problems import make_classification_problem, make_regression_problem TEST_SETTINGS = { "in_features": 7, @@ -9,7 +9,7 @@ "bias": True, "batch": 1, "rtol": 1e-5, - "atol": 1e-5 + "atol": 1e-5, } INPUT_SHAPE = (TEST_SETTINGS["batch"], TEST_SETTINGS["in_features"]) assert TEST_SETTINGS["batch"] == 1 @@ -17,11 +17,12 @@ REGRESSION_PROBLEMS = {} for act_name, act_cls in ACTIVATIONS.items(): for lin_name, lin_cls in LINEARS.items(): - REGRESSION_PROBLEMS["{}{}-regression".format( - lin_name, act_name)] = make_regression_problem( - INPUT_SHAPE, - single_linear_layer( - TEST_SETTINGS, lin_cls, activation_cls=act_cls)) + REGRESSION_PROBLEMS[ + "{}{}-regression".format(lin_name, act_name) + ] = make_regression_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) TEST_PROBLEMS = { **REGRESSION_PROBLEMS, @@ -29,14 +30,16 @@ for act_name, act_cls in ACTIVATIONS.items(): for lin_name, lin_cls in LINEARS.items(): - TEST_PROBLEMS["{}{}-classification".format( - lin_name, act_name)] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer( - TEST_SETTINGS, lin_cls, activation_cls=act_cls)) + TEST_PROBLEMS[ + "{}{}-classification".format(lin_name, act_name) + ] = make_classification_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) - TEST_PROBLEMS["{}{}-2layer-classification".format( - lin_name, act_name)] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers( - TEST_SETTINGS, lin_cls, activation_cls=act_cls)) + TEST_PROBLEMS[ + "{}{}-2layer-classification".format(lin_name, act_name) + ] = make_classification_problem( + INPUT_SHAPE, + two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) diff --git a/test/test_problems_linear.py b/test/test_problems_linear.py index 19703ba8..94fec994 100644 --- a/test/test_problems_linear.py +++ b/test/test_problems_linear.py @@ -1,6 +1,6 @@ from .layers import LINEARS -from .problems import make_regression_problem, make_classification_problem from .networks import single_linear_layer, two_linear_layers +from .problems import make_classification_problem, make_regression_problem TEST_SETTINGS = { "in_features": 7, @@ -9,7 +9,7 @@ "bias": True, "batch": 5, "rtol": 1e-5, - "atol": 1e-5 + "atol": 1e-5, } INPUT_SHAPE = (TEST_SETTINGS["batch"], TEST_SETTINGS["in_features"]) @@ -18,15 +18,15 @@ for lin_name, lin_cls in LINEARS.items(): TEST_PROBLEMS["{}-regression".format(lin_name)] = make_regression_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None)) + INPUT_SHAPE, single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + ) - TEST_PROBLEMS["{}-classification".format( - lin_name)] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None)) + TEST_PROBLEMS["{}-classification".format(lin_name)] = make_classification_problem( + INPUT_SHAPE, single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + ) - TEST_PROBLEMS["{}-2layer-classification".format( - lin_name)] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None)) + TEST_PROBLEMS[ + "{}-2layer-classification".format(lin_name) + ] = make_classification_problem( + INPUT_SHAPE, two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + ) diff --git a/test/test_problems_padding.py b/test/test_problems_padding.py index 6055b4ef..eec06d09 100644 --- a/test/test_problems_padding.py +++ b/test/test_problems_padding.py @@ -1,8 +1,9 @@ import torch -from backpack.core.layers import Flatten + from backpack import extend -from .test_problem import TestProblem + from .layers import PADDINGS +from .test_problem import TestProblem TEST_SETTINGS = { "in_features": (3, 4, 5), @@ -22,10 +23,12 @@ def conv_no_padding_layer(): in_channels=TEST_SETTINGS["in_features"][0], out_channels=TEST_SETTINGS["out_channels"], kernel_size=TEST_SETTINGS["kernel_size"], - bias=TEST_SETTINGS["bias"])) + bias=TEST_SETTINGS["bias"], + ) + ) -input_size = (TEST_SETTINGS["batch"], ) + TEST_SETTINGS["in_features"] +input_size = (TEST_SETTINGS["batch"],) + TEST_SETTINGS["in_features"] X = torch.randn(size=input_size) @@ -35,10 +38,14 @@ def padding(padding_cls): def make_2layer_classification_problem(padding_cls): model = torch.nn.Sequential( - padding(padding_cls), conv_no_padding_layer(), padding(padding_cls), - conv_no_padding_layer(), Flatten()) + padding(padding_cls), + conv_no_padding_layer(), + padding(padding_cls), + conv_no_padding_layer(), + torch.nn.Flatten(), + ) - Y = torch.randint(high=X.shape[1], size=(model(X).shape[0], )) + Y = torch.randint(high=X.shape[1], size=(model(X).shape[0],)) lossfunc = extend(torch.nn.CrossEntropyLoss()) @@ -47,5 +54,6 @@ def make_2layer_classification_problem(padding_cls): TEST_PROBLEMS = {} for pad_name, pad_cls in PADDINGS.items(): - TEST_PROBLEMS["conv+{}-classification-2layer".format( - pad_name)] = make_2layer_classification_problem(pad_cls) + TEST_PROBLEMS[ + "conv+{}-classification-2layer".format(pad_name) + ] = make_2layer_classification_problem(pad_cls) diff --git a/test/test_problems_pooling.py b/test/test_problems_pooling.py index db9af177..23c5e469 100644 --- a/test/test_problems_pooling.py +++ b/test/test_problems_pooling.py @@ -1,9 +1,9 @@ -import numpy as np import torch -from backpack.core.layers import Flatten + from backpack import extend -from .test_problem import TestProblem + from .layers import POOLINGS +from .test_problem import TestProblem TEST_SETTINGS = { "in_features": (3, 4, 5), @@ -25,10 +25,12 @@ def convlayer(): out_channels=TEST_SETTINGS["out_channels"], kernel_size=TEST_SETTINGS["kernel_size"], padding=TEST_SETTINGS["padding"], - bias=TEST_SETTINGS["bias"])) + bias=TEST_SETTINGS["bias"], + ) + ) -input_size = (TEST_SETTINGS["batch"], ) + TEST_SETTINGS["in_features"] +input_size = (TEST_SETTINGS["batch"],) + TEST_SETTINGS["in_features"] X = torch.randn(size=input_size) @@ -41,8 +43,9 @@ def pooling(pooling_cls): def make_regression_problem(pooling_cls): - model = torch.nn.Sequential(convlayer(), pooling(pooling_cls), Flatten(), - linearlayer()) + model = torch.nn.Sequential( + convlayer(), pooling(pooling_cls), torch.nn.Flatten(), linearlayer() + ) Y = torch.randn(size=(model(X).shape[0], 1)) @@ -52,9 +55,9 @@ def make_regression_problem(pooling_cls): def make_classification_problem(pooling_cls): - model = torch.nn.Sequential(convlayer(), pooling(pooling_cls), Flatten()) + model = torch.nn.Sequential(convlayer(), pooling(pooling_cls), torch.nn.Flatten()) - Y = torch.randint(high=X.shape[1], size=(model(X).shape[0], )) + Y = torch.randint(high=X.shape[1], size=(model(X).shape[0],)) lossfunc = extend(torch.nn.CrossEntropyLoss()) @@ -62,10 +65,15 @@ def make_classification_problem(pooling_cls): def make_2layer_classification_problem(pooling_cls): - model = torch.nn.Sequential(convlayer(), pooling(pooling_cls), convlayer(), - pooling(pooling_cls), Flatten()) + model = torch.nn.Sequential( + convlayer(), + pooling(pooling_cls), + convlayer(), + pooling(pooling_cls), + torch.nn.Flatten(), + ) - Y = torch.randint(high=X.shape[1], size=(model(X).shape[0], )) + Y = torch.randint(high=X.shape[1], size=(model(X).shape[0],)) lossfunc = extend(torch.nn.CrossEntropyLoss()) @@ -74,9 +82,12 @@ def make_2layer_classification_problem(pooling_cls): TEST_PROBLEMS = {} for pool_name, pool_cls in POOLINGS.items(): - TEST_PROBLEMS["conv+{}-regression".format( - pool_name)] = make_regression_problem(pool_cls) - TEST_PROBLEMS["conv+{}-classification".format( - pool_name)] = make_classification_problem(pool_cls) - TEST_PROBLEMS["conv+{}-classification-2layer".format( - pool_name)] = make_2layer_classification_problem(pool_cls) + TEST_PROBLEMS["conv+{}-regression".format(pool_name)] = make_regression_problem( + pool_cls + ) + TEST_PROBLEMS[ + "conv+{}-classification".format(pool_name) + ] = make_classification_problem(pool_cls) + TEST_PROBLEMS[ + "conv+{}-classification-2layer".format(pool_name) + ] = make_2layer_classification_problem(pool_cls) diff --git a/test/utils_test.py b/test/utils_test.py new file mode 100644 index 00000000..7232a024 --- /dev/null +++ b/test/utils_test.py @@ -0,0 +1,191 @@ +"""Test of Kronecker utilities.""" + +import random +import unittest + +import scipy.linalg +import torch + +from backpack.utils import kroneckers as bp_utils +from backpack.utils.ein import einsum + + +class KroneckerUtilsTest(unittest.TestCase): + RUNS = 100 + + # Precision of results + ATOL = 1e-6 + RTOL = 1e-5 + + # Restriction of dimension and number of factors + MIN_DIM = 1 + MAX_DIM = 5 + MIN_FACS = 1 + MAX_FACS = 3 + + # Number of columns for KFAC-matrix products + KFACMP_COLS = 7 + + # Minimum eigenvalue of positive semi-definite + PSD_KFAC_MIN_EIGVAL = 1 + + # HELPERS + ########################################################################## + def allclose(self, tensor1, tensor2): + return torch.allclose(tensor1, tensor2, rtol=self.RTOL, atol=self.ATOL) + + def list_allclose(self, tensor_list1, tensor_list2): + assert len(tensor_list1) == len(tensor_list2) + close = [self.allclose(t1, t2) for t1, t2 in zip(tensor_list1, tensor_list2)] + print(close) + for is_close, t1, t2 in zip(close, tensor_list1, tensor_list2): + if not is_close: + print(t1) + print(t2) + return all(close) + + def make_random_kfacs(self, num_facs=None): + def random_kfac(): + def random_dim(): + return random.randint(self.MIN_DIM, self.MAX_DIM) + + shape = [random_dim(), random_dim()] + return torch.rand(shape) + + def random_num_facs(): + return random.randint(self.MIN_FACS, self.MAX_FACS) + + num_facs = num_facs if num_facs is not None else random_num_facs() + return [random_kfac() for _ in range(num_facs)] + + def make_random_psd_kfacs(self, num_facs=None): + def make_quadratic_psd(mat): + """Make matrix positive semi-definite: A -> AAᵀ.""" + mat_squared = einsum("ij,kj->ik", (mat, mat)) + shift = self.PSD_KFAC_MIN_EIGVAL * self.torch_eye_like(mat_squared) + return mat_squared + shift + + kfacs = self.make_random_kfacs(num_facs=num_facs) + return [make_quadratic_psd(fac) for fac in kfacs] + + # Torch helpers + ######################################################################### + @staticmethod + def torch_eye_like(tensor): + return torch.eye(*tensor.size(), out=torch.empty_like(tensor)) + + # SCIPY implementations + ########################################################################## + def scipy_two_kfacs_to_mat(self, A, B): + return torch.from_numpy(scipy.linalg.kron(A.numpy(), B.numpy())) + + def scipy_kfacs_to_mat(self, factors): + mat = None + for factor in factors: + if mat is None: + assert bp_utils.is_matrix(factor) + mat = factor + else: + mat = self.scipy_two_kfacs_to_mat(mat, factor) + + return mat + + def make_matrix_for_multiplication_with(self, kfac, cols=None): + cols = cols if cols is not None else self.KFACMP_COLS + assert bp_utils.is_matrix(kfac) + _, rows = kfac.shape + return torch.rand(rows, cols) + + def make_vector_for_multiplication_with(self, kfac): + vec = self.make_matrix_for_multiplication_with(kfac, cols=1).squeeze(-1) + assert bp_utils.is_vector(vec) + return vec + + def scipy_inv(self, mat, shift): + mat_shifted = (shift * self.torch_eye_like(mat) + mat).numpy() + inv = scipy.linalg.inv(mat_shifted) + return torch.from_numpy(inv) + + def scipy_inv_kfacs(self, factors, shift_list): + assert len(factors) == len(shift_list) + return [self.scipy_inv(fac, shift) for fac, shift in zip(factors, shift_list)] + + # TESTS + ########################################################################## + + def test_two_kfacs_to_mat(self): + """Check matrix from two Kronecker factors with `scipy`.""" + NUM_FACS = 2 + + for _ in range(self.RUNS): + A, B = self.make_random_kfacs(NUM_FACS) + + bp_result = bp_utils.two_kfacs_to_mat(A, B) + sp_result = self.scipy_two_kfacs_to_mat(A, B) + + assert self.allclose(bp_result, sp_result) + + def test_kfacs_to_mat(self): + """Check matrix from list of Kronecker factors with `scipy`.""" + for _ in range(self.RUNS): + factors = self.make_random_kfacs() + + bp_result = bp_utils.kfacs_to_mat(factors) + sp_result = self.scipy_kfacs_to_mat(factors) + + assert self.allclose(bp_result, sp_result) + + def test_apply_kfac_mat_prod(self): + """Check matrix multiplication from Kronecker factors with matrix.""" + make_vec = self.make_vector_for_multiplication_with + self.compare_kfac_tensor_prod(make_vec) + + def test_apply_kfac_vec_prod(self): + """Check matrix multiplication from Kronecker factors with vector.""" + make_mat = self.make_matrix_for_multiplication_with + self.compare_kfac_tensor_prod(make_mat) + + def compare_kfac_tensor_prod(self, make_tensor): + def set_up(): + factors = self.make_random_kfacs() + kfac = bp_utils.kfacs_to_mat(factors) + tensor = make_tensor(kfac) + return factors, kfac, tensor + + for _ in range(self.RUNS): + factors, kfac, tensor = set_up() + + bp_result = bp_utils.apply_kfac_mat_prod(factors, tensor) + torch_result = torch.matmul(kfac, tensor) + + assert self.allclose(bp_result, torch_result) + + def test_inv_kfacs(self): + def get_shift(): + return random.random() + + for _ in range(self.RUNS): + kfacs = self.make_random_psd_kfacs() + num_kfacs = len(kfacs) + + # None vs 0. + default_result = bp_utils.inv_kfacs(kfacs) + no_shift_result = bp_utils.inv_kfacs(kfacs, shift=0.0) + assert self.list_allclose(default_result, no_shift_result) + + # 0. vs tiny + tiny = 1e-4 + tiny_shift_result = bp_utils.inv_kfacs(kfacs, shift=tiny) + assert not self.list_allclose(no_shift_result, tiny_shift_result) + + # scalar vs. list of scalar: shift a should equal shift [a, a, ...] + shift = get_shift() + scalar_result = bp_utils.inv_kfacs(kfacs, shift=shift) + list_result = bp_utils.inv_kfacs(kfacs, shift=num_kfacs * [shift]) + assert self.list_allclose(scalar_result, list_result) + + # scipy vs. torch + shift_list = [get_shift() for _ in range(num_kfacs)] + bp_result = bp_utils.inv_kfacs(kfacs, shift=shift_list) + sp_result = self.scipy_inv_kfacs(kfacs, shift_list) + assert self.list_allclose(bp_result, sp_result)