Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge dev branch #30

Merged
merged 251 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
251 commits
Select commit Hold shift + click to select a range
35568b9
Update paper links
pomonam Jun 19, 2024
0601611
Fix spacing issue
pomonam Jun 19, 2024
ccfd2ad
Merge pull request #21 from pomonam/documentation
pomonam Jun 19, 2024
3c0b52f
Update documentation
pomonam Jun 22, 2024
d43a06c
Merge pull request #23 from pomonam/docu
pomonam Jun 22, 2024
b7ad379
Fix typos & improve readability
pomonam Jun 22, 2024
883cf28
Update requirements
pomonam Jun 22, 2024
80513bc
Cleanup factor computation codes
pomonam Jun 22, 2024
b79206b
Change output directory name
pomonam Jun 22, 2024
9e4e7b3
More comprehensive factor tests
pomonam Jun 22, 2024
293e3d2
Global refactor
pomonam Jun 23, 2024
2b7dbd3
Merge pull request #24 from pomonam/cleanup_factor
pomonam Jun 23, 2024
d4eccbc
Improve test coverage
pomonam Jun 23, 2024
5f6f0a0
Reformat gpu tests
pomonam Jun 23, 2024
1100ace
Add tests to check identical results with cpu offload
pomonam Jun 23, 2024
8f8dbf8
Reduce tol for amp tests
pomonam Jun 23, 2024
378f95a
Make a collection of common factor arguments
pomonam Jun 23, 2024
2c0472a
Unscale at addmm
pomonam Jun 23, 2024
e2197ea
Clone output gradients
pomonam Jun 23, 2024
d469691
Only clone for covariance backward
pomonam Jun 23, 2024
2c249ff
Avoid in-place for normalization
pomonam Jun 23, 2024
fa67826
Use test factor arguments
pomonam Jun 24, 2024
1a7ecfc
Add allclose for compile tests
pomonam Jun 24, 2024
2e2f06a
Revert changes
pomonam Jun 24, 2024
c5ad5d1
Add per module score for DDP tests
pomonam Jun 24, 2024
594cbed
Move to cuda before synchronize
pomonam Jun 24, 2024
fce33d9
Move to cuda before synchronize
pomonam Jun 24, 2024
f5ab479
Add common score arguments
pomonam Jun 24, 2024
99c9b36
Modify tests to use common arugments
pomonam Jun 24, 2024
acd6529
Use common score arguments
pomonam Jun 24, 2024
180c2ec
Merge pull request #26 from pomonam/finalize_documents
pomonam Jun 24, 2024
769f56f
Improve UCI examples
pomonam Jun 24, 2024
c7700a5
Clean up fine-tuning script
pomonam Jun 24, 2024
d13effe
Clean up code for influence scores
pomonam Jun 24, 2024
322076a
Fix perp computations
pomonam Jun 24, 2024
d2ddd63
Lint fix
pomonam Jun 24, 2024
163327f
Disable max-tune
pomonam Jun 24, 2024
a8ff151
Reset initial batch size
pomonam Jun 24, 2024
d093fed
Add procedure for torch.compile
pomonam Jun 24, 2024
bd42b5c
Finalize wiki
pomonam Jun 24, 2024
1d50c0f
Change num workers for CIFAR-10
pomonam Jun 25, 2024
0f5f8b2
minor
pomonam Jun 25, 2024
5cd802d
Add half precision support
pomonam Jun 25, 2024
d742bb4
Add profile option
pomonam Jun 25, 2024
94f4e60
Add profile option
pomonam Jun 25, 2024
81277eb
Add figure for mislabled
pomonam Jun 25, 2024
490f5f1
Fix typos
pomonam Jun 25, 2024
5872feb
Lint fix
pomonam Jun 25, 2024
ec88148
Disable overwrite
pomonam Jun 25, 2024
08616db
Add imagenet example
pomonam Jun 25, 2024
189bf00
Add swag model
pomonam Jun 25, 2024
b9d8b0c
Rerun notebook
pomonam Jun 26, 2024
4bcc81d
Add swag example
pomonam Jun 26, 2024
7eabaf6
Finalize UCI example
pomonam Jun 26, 2024
38bb08a
Finalize wikitext example
pomonam Jun 26, 2024
2b736bb
Finalize readme file
pomonam Jun 26, 2024
bd19cf0
Push LDS results pickle files
pomonam Jun 26, 2024
674869d
Update RTE examples
pomonam Jun 26, 2024
05ce8eb
minor
pomonam Jun 26, 2024
548149e
minor
pomonam Jun 26, 2024
fb0ecfa
Add LDS files
pomonam Jun 26, 2024
f1e341f
Finalize imagenet
pomonam Jun 26, 2024
9d8c065
Update interval
pomonam Jun 26, 2024
7af7c64
Modify base repeat
pomonam Jun 26, 2024
0575225
Add swag and openwebtext examples
pomonam Jun 26, 2024
00e1112
minor
pomonam Jun 26, 2024
b6c8be2
minor
pomonam Jun 26, 2024
b874348
minor
pomonam Jun 26, 2024
dcecdf4
Set proper target size
pomonam Jun 26, 2024
3971c6c
minor
pomonam Jun 26, 2024
979b0c8
Modify tasks
pomonam Jun 26, 2024
b15d820
Test out prompt
pomonam Jun 26, 2024
c8f7782
minor
pomonam Jun 26, 2024
9b81123
minor
pomonam Jun 26, 2024
7de1d9a
minor
pomonam Jun 26, 2024
d33eabe
minor
pomonam Jun 26, 2024
ab99c93
minor
pomonam Jun 26, 2024
49311d3
m
pomonam Jun 26, 2024
8d4c8ac
minor
pomonam Jun 26, 2024
eec7887
minor
pomonam Jun 26, 2024
a6fe4fd
Add iterative update
pomonam Jun 26, 2024
1bafa6d
minor
pomonam Jun 26, 2024
6c3109f
minor
pomonam Jun 26, 2024
24c6de9
minor
pomonam Jun 26, 2024
3e57935
Example
pomonam Jun 26, 2024
30d7e93
More examples
pomonam Jun 26, 2024
02ddf54
minor
pomonam Jun 26, 2024
e54953b
Add joke
pomonam Jun 26, 2024
8f4e0dd
Add proper requirements
pomonam Jun 26, 2024
9ec6bb6
Lint fix
pomonam Jun 26, 2024
11ddd4f
Finalize GLUE description
pomonam Jun 26, 2024
8fc8d12
Add LDS
pomonam Jun 26, 2024
dd8e25f
Enable auto batch
pomonam Jun 26, 2024
f7b4559
Update openwebtext pipeline
pomonam Jun 26, 2024
1339777
Don't aggregate
pomonam Jun 26, 2024
a16c8fa
Add factor batch size
pomonam Jun 26, 2024
d35e015
Add template for openwebtext
pomonam Jun 27, 2024
eb1de8a
Add data
pomonam Jun 27, 2024
ad02134
minor
pomonam Jun 27, 2024
bec08c9
minor
pomonam Jun 27, 2024
7278052
Lint fix
pomonam Jun 27, 2024
e81e653
Merge pull request #27 from pomonam/document_example
pomonam Jun 27, 2024
37e5d32
Modify gitignore
pomonam Jun 27, 2024
db7222a
Modify openwebtext dataset
pomonam Jun 27, 2024
0ee47f9
Format the code
pomonam Jun 27, 2024
58c7006
Add pipeline
pomonam Jun 27, 2024
40115e3
Fix tests
pomonam Jun 27, 2024
55cdb96
Add valid modules
pomonam Jun 27, 2024
5fce5e2
Fix minor typos
pomonam Jun 27, 2024
d545d50
Use more module partitions
pomonam Jun 27, 2024
e423e49
Change extreme reduce
pomonam Jun 27, 2024
d112dbf
Fix tests
pomonam Jun 27, 2024
2707a8b
Change pipeline
pomonam Jun 27, 2024
78e3e3f
Fix tests
pomonam Jun 27, 2024
798406e
Fix tests
pomonam Jun 27, 2024
6213c91
Put the whole model into half
pomonam Jun 27, 2024
b6eaddd
minor
pomonam Jun 27, 2024
d9df4ed
Load bfloat16
pomonam Jun 28, 2024
b2cc4e1
Load bfloat16
pomonam Jun 28, 2024
1c48e35
Revert to older version for tests
pomonam Jun 28, 2024
de8ba8e
Add post process func
pomonam Jun 28, 2024
b472ccd
Add dailymail code
pomonam Jun 29, 2024
6aeed66
Load to right device
pomonam Jun 29, 2024
60a16d5
Change to t5
pomonam Jun 29, 2024
2569fbb
Lint fix & Simplfiy pytest codes
pomonam Jun 29, 2024
db1c89f
Clean up tests
pomonam Jun 29, 2024
98c174d
Lint fix
pomonam Jun 29, 2024
8e5eae3
Use smaller models
pomonam Jun 29, 2024
d1aecdc
Lint fix
pomonam Jun 29, 2024
7fad0d7
Lint fix
pomonam Jun 29, 2024
74ad18e
Add debug points
pomonam Jun 29, 2024
2c26f58
minor
pomonam Jun 30, 2024
46df33a
Add analyze scripts
pomonam Jun 30, 2024
50961a8
Add exact modules to track
pomonam Jul 1, 2024
f403704
Reload dataset
pomonam Jul 1, 2024
312f43e
Clean up factor arguments
pomonam Jul 1, 2024
2828847
Initial commit for final planned optimization
pomonam Jul 3, 2024
427f6ce
Various optimizations done
pomonam Jul 4, 2024
a2d3318
Fix device mismatch problem
pomonam Jul 4, 2024
67e2847
minor
pomonam Jul 4, 2024
87ae3c5
Remove CPU requirements
pomonam Jul 5, 2024
8573276
Factors code cleanup
pomonam Jul 5, 2024
346d4fd
Debug code to track memory
pomonam Jul 5, 2024
3dde1de
Add cuda condition
pomonam Jul 5, 2024
27aa722
add debug code
pomonam Jul 5, 2024
57442ed
Print device
pomonam Jul 5, 2024
c56e285
Remove reset memory
pomonam Jul 5, 2024
c78c092
minor
pomonam Jul 5, 2024
c135ecd
Memory cleanup
pomonam Jul 5, 2024
288427e
m
pomonam Jul 5, 2024
d0154f1
Final covariance cleanup
pomonam Jul 6, 2024
a80595e
Finalize factor computations
pomonam Jul 6, 2024
42e539b
Add score computations
pomonam Jul 6, 2024
1c603fd
Release memory after prepare
pomonam Jul 6, 2024
743070a
Do GPU tests
pomonam Jul 7, 2024
1f26c88
minor
pomonam Jul 7, 2024
9ece91c
Set condition
pomonam Jul 7, 2024
614ccb0
Use default names
pomonam Jul 7, 2024
2df875b
Modify fsdp
pomonam Jul 7, 2024
c918176
Only initialize when necessary
pomonam Jul 7, 2024
398503c
Modify all other tests
pomonam Jul 7, 2024
133f53c
Finalize refactor
pomonam Jul 8, 2024
7ec811f
Add measurement score CPU tests
pomonam Jul 8, 2024
207457b
Remove state initialization
pomonam Jul 8, 2024
b665723
Add DDP tests
pomonam Jul 8, 2024
4b2f9bb
Disable nccl initialization
pomonam Jul 8, 2024
ea1f03d
Reduce logging level
pomonam Jul 8, 2024
7d9c794
Change contiguous tensor
pomonam Jul 8, 2024
0a4cf37
Fix DDP test
pomonam Jul 8, 2024
9119228
Add FSDP tests
pomonam Jul 8, 2024
c1e19fa
Add AMP tests
pomonam Jul 8, 2024
7c6e0a3
Add compile tests
pomonam Jul 8, 2024
e0bc934
Add debug code
pomonam Jul 8, 2024
f5aa709
Add normal lines
pomonam Jul 8, 2024
3dce62d
Add reset compiler
pomonam Jul 8, 2024
1dc7d85
Remove debug lines
pomonam Jul 8, 2024
1c42025
Add debug line for wiki
pomonam Jul 8, 2024
85f99a5
minor
pomonam Jul 8, 2024
3f83568
minor
pomonam Jul 8, 2024
7556c56
Remove debug lines
pomonam Jul 8, 2024
8f5e80f
Remove reference to score
pomonam Jul 8, 2024
37aacd8
Add debug lines
pomonam Jul 8, 2024
4d9dad1
Check mem leak
pomonam Jul 8, 2024
132ac9c
Explicitly remove
pomonam Jul 8, 2024
3cfcc48
More debug lines
pomonam Jul 8, 2024
386b335
Reduce size
pomonam Jul 8, 2024
75962ae
Explicit deletion
pomonam Jul 8, 2024
0535cf4
Remove processed lambda count
pomonam Jul 8, 2024
6645ca9
Minimize size for dot product
pomonam Jul 9, 2024
9744252
Remove self cache
pomonam Jul 9, 2024
c757145
Change to torch einsum
pomonam Jul 9, 2024
77c0e3a
Use einsum
pomonam Jul 9, 2024
a2efa47
Remove einsum
pomonam Jul 9, 2024
1919097
Fix linting
pomonam Jul 9, 2024
23eda1c
Fix linting in tests
pomonam Jul 9, 2024
62123a4
Start wikitext add
pomonam Jul 9, 2024
eb11cf0
Improve the examples
pomonam Jul 9, 2024
6de4afb
Modify score_args name
pomonam Jul 9, 2024
2e2802d
Fix depreciated commands
pomonam Jul 9, 2024
1868e86
Modify default batch_size
pomonam Jul 9, 2024
2a6e1d2
Fix incorrect paths
pomonam Jul 9, 2024
4a7ed69
Fix examples
pomonam Jul 9, 2024
94c22a9
Make two modules split
pomonam Jul 9, 2024
c15ece6
Load 32 models
pomonam Jul 9, 2024
51f64b5
Try out lambda batch size
pomonam Jul 9, 2024
6f5c2af
Test out lambda
pomonam Jul 9, 2024
5cf4864
Remove contiguous calls
pomonam Jul 10, 2024
0ad5a3c
Let minimze flops
pomonam Jul 10, 2024
8ca4e86
Add einsum
pomonam Jul 10, 2024
9544c36
Remove inspect arguments
pomonam Jul 10, 2024
afff397
Remove contract operation
pomonam Jul 10, 2024
9a5626d
Remove cpu flag
pomonam Jul 10, 2024
3aa89bc
fix tests
pomonam Jul 10, 2024
9c3f89c
Remove contract path dependency
pomonam Jul 10, 2024
395bbc1
Fix task formulation
pomonam Jul 10, 2024
7209dac
Finish dailymail example
pomonam Jul 10, 2024
7374ddd
add openwebtext
pomonam Jul 10, 2024
c6efd69
Ignore Attn computation
pomonam Jul 10, 2024
ab16f3f
Remove blank einsum calls with >3 operands
pomonam Jul 10, 2024
2ce91a6
Update documentations
pomonam Jul 10, 2024
44764e3
Finalize all examples
pomonam Jul 10, 2024
3310f19
Fix spacing issues
pomonam Jul 10, 2024
7ea5ca7
Add debug statement
pomonam Jul 10, 2024
dd69813
Remove debug statement
pomonam Jul 10, 2024
c1eed9d
Try unscale
pomonam Jul 10, 2024
496b8ed
Increase timeout
pomonam Jul 10, 2024
ad06d11
Add score computation script
pomonam Jul 11, 2024
37b37d1
Fix typo
pomonam Jul 11, 2024
041a6e3
Update commands
pomonam Jul 11, 2024
c6be439
Disable module parition
pomonam Jul 11, 2024
05d06a0
Add factors name arguments
pomonam Jul 11, 2024
52a67f9
Merge pull request #28 from pomonam/openwebtext
pomonam Jul 11, 2024
c7198ad
Fix AMP logic + disable gradscaler for bfloat16
pomonam Jul 11, 2024
1b58d30
Disable double model for AMP tests
pomonam Jul 11, 2024
7ca5af4
float16 for amp
pomonam Jul 11, 2024
7a37594
All to float16
pomonam Jul 11, 2024
849c90a
Increase amp scale
pomonam Jul 11, 2024
ead0fb4
Add debug line
pomonam Jul 11, 2024
3f397cf
Remove debug statements
pomonam Jul 11, 2024
d7e35bc
Merge pull request #29 from pomonam/amp_logic
pomonam Jul 11, 2024
7220ecc
Lint fix
pomonam Jul 11, 2024
4adc1b7
backward amp_scale
pomonam Jul 11, 2024
728ea1b
Remove warning sign
pomonam Jul 11, 2024
7a19b74
Add flag factor_name
pomonam Jul 12, 2024
4c099ba
Fix typo
pomonam Jul 12, 2024
28e682b
Fix typo
pomonam Jul 12, 2024
9240217
More comments
pomonam Jul 12, 2024
6e5f812
Add Openwebtext results
pomonam Jul 12, 2024
37293dd
Merge pull request #31 from pomonam/llama
pomonam Jul 12, 2024
ad2d082
Fix typos
pomonam Jul 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ jobs:
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/factors/test_eigendecompositions.py
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_matmul.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ cython_debug/

# Checkpoints and influence outputs
checkpoints/
analyses/
influence_results/
data/
cache/
*.pth
*.pt
211 changes: 125 additions & 86 deletions DOCUMENTATION.md

Large diffs are not rendered by default.

65 changes: 55 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@

---

> **Kronfluence** is a research repository designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.
> **Kronfluence** is a PyTorch package designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296), *Studying Large Language Model Generalization with Influence Functions*.

---

> [!WARNING]
> This repository is under active development and has not reached its first stable release.

## Installation

> [!IMPORTANT]
Expand All @@ -53,11 +50,9 @@ pip install -e .

## Getting Started

Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page for a comprehensive guide.

### Learn More
Kronfluence supports influence computations on [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) and [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) modules.
See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page for a comprehensive guide.

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence. More examples will be added in the future.
**TL;DR** You need to prepare a trained model and datasets, and pass them into the `Analyzer` class.

```python
Expand Down Expand Up @@ -115,6 +110,30 @@ analyzer.compute_pairwise_scores(
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
```

Kronfluence supports various PyTorch features; the following table summarizes the supported features:

<div align="center">

| Feature | Supported |
|-----------------------------------------------------------------------------------------------------------------------------|:---------:|
| [Distributed Data Parallel (DDP)](https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html) | ✅ |
| [Automatic Mixed Precision (AMP)](https://pytorch.org/docs/stable/amp.html) | ✅ |
| [Torch Compile](https://pytorch.org/docs/stable/generated/torch.compile.html) | ✅ |
| [Gradient Checkpointing](https://pytorch.org/docs/stable/checkpoint.html) | ✅ |
| [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/docs/stable/fsdp.html) | ✅ |

</div>

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence.

## LogIX

While Kronfluence supports influence function computations on large-scale models like `Meta-Llama-3-8B-Instruct`, for those
interested in running influence analysis on even larger models or with a large number of query data points, our
project [LogIX](https://github.com/logix-project/logix) may be worth exploring. It integrates with frameworks like
[HuggingFace Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)
and is also compatible with many PyTorch features (DDP & FSDP & [DeepSpeed](https://github.com/microsoft/DeepSpeed)).

## Contributing

Contributions are welcome! To get started, please review our [Code of Conduct](https://github.com/pomonam/kronfluence/blob/main/CODE_OF_CONDUCT.md). For bug fixes, please submit a pull request.
Expand All @@ -131,10 +150,36 @@ cd kronfluence
pip install -e ."[dev]"
```

### Style Testing

To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a
pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will
modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.

Sort import orderings using [isort](https://pycqa.github.io/isort/):

```bash
isort kronfluence
```

Format code using [ruff](https://docs.astral.sh/ruff/):

```bash
ruff format kronfluence
```

To view all [pylint](https://www.pylint.org/) complaints, run the following command:

```bash
pylint kronfluence
```

Please address any reported issues before submitting your PR.

## Acknowledgements

[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

## License

Expand Down
27 changes: 14 additions & 13 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
torch
torchvision
accelerate
einops
einconv
opt_einsum
safetensors
tqdm
datasets
transformers
torch>=2.1.0
torchvision>=0.16.0
accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
scikit-learn>=1.4.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
transformers>=4.42.0
isort==5.13.2
pylint==3.0.3
pytest==8.0.0
ruff==0.3.0
pylint==3.2.3
pytest==8.2.2
ruff==0.4.0
33 changes: 33 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Kronfluence: Examples

For detailed technical documentation of Kronfluence, please refer to the [Technical Documentation](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page.

## Getting Started

To run all examples, install the necessary packages:

```bash
pip install -r requirements.txt
```

Alternatively, navigate to each example folder and run `pip install -r requirements.txt`.

## List of Tasks

Our examples cover the following tasks:

<div align="center">

| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Summarization | DNN/DailyMail |
| Language Modeling | WikiText-2 & OpenWebText |

</div>

These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel).
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
123 changes: 110 additions & 13 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,154 @@
# CIFAR-10 & ResNet-9 Example

This directory contains scripts for training ResNet-9 on CIFAR-10. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb).
This directory contains scripts for training ResNet-9 and computing influence scores on CIFAR-10 dataset. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). To get started, please install the necessary packages by running the following command:

```bash
pip install -r requirements.txt
```

## Training

To train ResNet-9 on CIFAR-10 dataset, run the following command:
To train ResNet-9 on the CIFAR-10 dataset, run the following command:

```bash
python train.py --dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--weight_decay 0.001 \
--num_train_epochs 25 \
--seed 1004
```

This will train the model using the specified hyperparameters and save the trained checkpoint in the `./checkpoints` directory.

## Computing Pairwise Influence Scores

To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command:
To compute pairwise influence scores on 2000 query data points using the `ekfac` strategy, run the following command:

```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the
pairwise scores (including computing EKFAC factors).

In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors):

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 106.38 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Pairwise Score | 46.745 | 1 | 46.745 | 43.941 |
| Fit Lambda | 34.885 | 1 | 34.885 | 32.793 |
| Fit Covariance | 22.538 | 1 | 22.538 | 21.187 |
| Perform Eigendecomposition | 0.91424 | 1 | 0.91424 | 0.85941 |
| Save Pairwise Score | 0.81219 | 1 | 0.81219 | 0.76348 |
| Save Covariance | 0.22351 | 1 | 0.22351 | 0.21011 |
| Save Eigendecomposition | 0.21617 | 1 | 0.21617 | 0.20321 |
| Save Lambda | 0.031038 | 1 | 0.031038 | 0.029177 |
| Load Eigendecomposition | 0.010442 | 1 | 0.010442 | 0.0098156 |
| Load All Factors | 0.0026517 | 1 | 0.0026517 | 0.0024927 |
| Load Covariance | 0.0016419 | 1 | 0.0016419 | 0.0015435 |
----------------------------------------------------------------------------------------------------------------------------------
```

To use AMP when computing influence scores, run:

```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac \
--use_half_precision
```

This reduces computation time to about 40 seconds on an A100 (80GB) GPU:

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 35.965 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Pairwise Score | 18.012 | 1 | 18.012 | 50.082 |
| Fit Lambda | 9.2271 | 1 | 9.2271 | 25.656 |
| Fit Covariance | 7.134 | 1 | 7.134 | 19.836 |
| Perform Eigendecomposition | 0.87962 | 1 | 0.87962 | 2.4457 |
| Save Pairwise Score | 0.45432 | 1 | 0.45432 | 1.2632 |
| Save Covariance | 0.12861 | 1 | 0.12861 | 0.35759 |
| Save Eigendecomposition | 0.11296 | 1 | 0.11296 | 0.31407 |
| Save Lambda | 0.010712 | 1 | 0.010712 | 0.029784 |
| Load All Factors | 0.002736 | 1 | 0.002736 | 0.0076074 |
| Load Covariance | 0.0016696 | 1 | 0.0016696 | 0.0046421 |
| Load Eigendecomposition | 0.0014892 | 1 | 0.0014892 | 0.0041406 |
----------------------------------------------------------------------------------------------------------------------------------
```

You can run `half_precision_analysis.py` to verify that the scores computed with AMP have high correlations with those of the default configuration.

<p align="center">
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Half Precision"/></a>
</p>

## Visualizing Influential Training Images

[This Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing) provides a tutorial on visualizing the top influential training images.

## Mislabeled Data Detection

We can use self-influence scores (see Section 5.4 for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of training examples mislabeled by running the following command:
We can use self-influence scores (see **Section 5.4** for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of the training examples mislabeled by running:

```bash
python train.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--weight_decay 0.001 \
--num_train_epochs 25 \
--seed 1004
```

Then, compute self-influence scores with the following command:
Then, compute the self-influence scores with:

```bash
python detect_mislabeled_dataset.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```

On A100 (80GB), it takes roughly 1.5 minutes to compute the self-influence scores.
We can detect around 82% of mislabeled data points by inspecting 10% of the dataset (96% by inspecting 20%).
On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence scores:

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 121.85 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Self-Influence Score | 62.778 | 1 | 62.778 | 51.519 |
| Fit Lambda | 35.174 | 1 | 35.174 | 28.866 |
| Fit Covariance | 22.582 | 1 | 22.582 | 18.532 |
| Perform Eigendecomposition | 0.82656 | 1 | 0.82656 | 0.67832 |
| Save Covariance | 0.2478 | 1 | 0.2478 | 0.20336 |
| Save Eigendecomposition | 0.22042 | 1 | 0.22042 | 0.18088 |
| Save Lambda | 0.018463 | 1 | 0.018463 | 0.015152 |
| Load All Factors | 0.0027554 | 1 | 0.0027554 | 0.0022612 |
| Load Covariance | 0.0016607 | 1 | 0.0016607 | 0.0013628 |
| Load Eigendecomposition | 0.0015408 | 1 | 0.0015408 | 0.0012645 |
| Save Self-Influence Score | 0.0010841 | 1 | 0.0010841 | 0.00088966 |
----------------------------------------------------------------------------------------------------------------------------------
```

Around 80% of mislabeled data points can be detected by inspecting 10% of the dataset (97% by inspecting 20%).

<p align="center">
<a href="#"><img width="380" img src="figure/mislabel.png" alt="Mislabeled Data Detection"/></a>
</p>
Loading
Loading