Skip to content

Commit

Permalink
Add: visual example output
Browse files Browse the repository at this point in the history
  • Loading branch information
nimarb committed Jan 10, 2020
1 parent a6f34f6 commit 417d1af
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@ This is a PyTorch reimplementation of Influence Functions from the ICML2017 best
[Understanding Black-box Predictions via Influence Functions](https://arxiv.org/abs/1703.04730) by Pang Wei Koh and Percy Liang.
The reference implementation can be found here: [link](https://github.com/kohpangwei/influence-release).

- [Why Use Influence Functions?](#why-use-influence-functions)
- [Requirements](#requirements)
- [Installation](#installation)
- [Usage](#usage)
- [Background and Documentation](#background-and-documentation)
- [config](#config)
- [Misc parameters](#misc-parameters)
- [Calculation parameters](#calculation-parameters)
- [s_test](#stest)
- [Modes of computation](#modes-of-computation)
- [Output variables](#output-variables)
- [Influences](#influences)
- [Harmful](#harmful)
- [Helpful](#helpful)
- [Roadmap](#roadmap)
- [v0.2](#v02)
- [v0.3](#v03)
- [v0.4](#v04)

## Why Use Influence Functions?

Influence functions help you to debug the results of your deep learning model
Expand Down Expand Up @@ -111,6 +130,10 @@ affecting everything else.
Greater recursion depth improves precision.
* `r`: Default = 1, number of `s_test` calculations to take the average of.
Greater r averaging improves precision.
* Combined, the original paper suggests that `recursion_depth * r` should equal
the training dataset size, thus the above values of `r = 10` and
`recursion_depth = 5000` are valid for CIFAR-10 with a training dataset size
of 50000 items.
* `damp`: Default = 0.01, damping factor during `s_test` calculation.
* `scale`: Default = 25, scaling factor during `s_test` calculation.

Expand All @@ -129,13 +152,37 @@ can take significant amounts of disk space (100s of GBs) but with a fast SSD
can speed up the calculation significantly as no duplicate calculations take
place. This is the case because `grad_z` has to be calculated twice, once for
the first approximation in `s_test` and once to combine with the `s_test`
vector to calculate the influence. The paper has a ton of more detail on that.
vector to calculate the influence. Most importantnly however, `s_test` is only
dependent on the test sample(s). While one `grad_z` is used to estimate the
initial value of the Hessian during the `s_test` calculation, this is
insignificant. `grad_z` on the other hand is only dependent on the training
sample. Thus, in the `calc_img_wise` mode, we throw away all `grad_z`
calculations even if we could reuse them for all subsequent `s_test`
calculations, which could potentially be 10s of thousands. However, as stated
above, keeping the `grad_z`s only makes sense if they can be loaded faster/
kept in RAM than calculating them on-the-fly.

**TL;DR**: The recommended way is using `calc_img_wise` unless you have a crazy
fast SSD and lots of free storage space.
fast SSD, lots of free storage space, and want to calculate the influences on
the prediction outcomes of an entire dataset or even >1000 test samples.

### Output variables

Visualised, the output can look like this:

![influences for ship on cifar10-resnet](figs/inf_resnet_basic_110_ship_1.png)

The test image on the top left is test image for which the influences were
calculated. To get the correct test outcome of _ship_, the Helpful images from
the training dataset were the most helpful, whereas the Harmful images were the
most harmful. Here, we used CIFAR-10 as dataset. The model was ResNet-110. The
numbers above the images show the actual influence value which was calculated.

The next figure shows the same but for a different model, DenseNet-100/12.
Thus, we can see that different models learn more from different images.

![influences for ship on cifar10-densenet](figs/inf_densenet_BC_100_12_ship_1.png)

#### Influences

Is a dict/json containting the influences calculated of all training data
Expand Down Expand Up @@ -200,6 +247,9 @@ prediction outcome of the processed test samples.
* [ ] ability to disable shell output eg for `display_progress` from the config
* [ ] add proper result plotting support
* [ ] add a dataloader for training on the most influential samples only
* [x] add some visualisation of the outcome
* [ ] add recreation of some graphs of the original paper to verify
implementation

### v0.3

Expand Down
Binary file added figs/inf_densenet_BC_100_12_ship_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/inf_resnet_basic_110_cat_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/inf_resnet_basic_110_ship_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 417d1af

Please sign in to comment.