Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Refactor for Final Release #28

Merged
merged 129 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
37e5d32
Modify gitignore
pomonam Jun 27, 2024
db7222a
Modify openwebtext dataset
pomonam Jun 27, 2024
0ee47f9
Format the code
pomonam Jun 27, 2024
58c7006
Add pipeline
pomonam Jun 27, 2024
40115e3
Fix tests
pomonam Jun 27, 2024
55cdb96
Add valid modules
pomonam Jun 27, 2024
5fce5e2
Fix minor typos
pomonam Jun 27, 2024
d545d50
Use more module partitions
pomonam Jun 27, 2024
e423e49
Change extreme reduce
pomonam Jun 27, 2024
d112dbf
Fix tests
pomonam Jun 27, 2024
2707a8b
Change pipeline
pomonam Jun 27, 2024
78e3e3f
Fix tests
pomonam Jun 27, 2024
798406e
Fix tests
pomonam Jun 27, 2024
6213c91
Put the whole model into half
pomonam Jun 27, 2024
b6eaddd
minor
pomonam Jun 27, 2024
d9df4ed
Load bfloat16
pomonam Jun 28, 2024
b2cc4e1
Load bfloat16
pomonam Jun 28, 2024
1c48e35
Revert to older version for tests
pomonam Jun 28, 2024
de8ba8e
Add post process func
pomonam Jun 28, 2024
b472ccd
Add dailymail code
pomonam Jun 29, 2024
6aeed66
Load to right device
pomonam Jun 29, 2024
60a16d5
Change to t5
pomonam Jun 29, 2024
2569fbb
Lint fix & Simplfiy pytest codes
pomonam Jun 29, 2024
db1c89f
Clean up tests
pomonam Jun 29, 2024
98c174d
Lint fix
pomonam Jun 29, 2024
8e5eae3
Use smaller models
pomonam Jun 29, 2024
d1aecdc
Lint fix
pomonam Jun 29, 2024
7fad0d7
Lint fix
pomonam Jun 29, 2024
74ad18e
Add debug points
pomonam Jun 29, 2024
2c26f58
minor
pomonam Jun 30, 2024
46df33a
Add analyze scripts
pomonam Jun 30, 2024
50961a8
Add exact modules to track
pomonam Jul 1, 2024
f403704
Reload dataset
pomonam Jul 1, 2024
312f43e
Clean up factor arguments
pomonam Jul 1, 2024
2828847
Initial commit for final planned optimization
pomonam Jul 3, 2024
427f6ce
Various optimizations done
pomonam Jul 4, 2024
a2d3318
Fix device mismatch problem
pomonam Jul 4, 2024
67e2847
minor
pomonam Jul 4, 2024
87ae3c5
Remove CPU requirements
pomonam Jul 5, 2024
8573276
Factors code cleanup
pomonam Jul 5, 2024
346d4fd
Debug code to track memory
pomonam Jul 5, 2024
3dde1de
Add cuda condition
pomonam Jul 5, 2024
27aa722
add debug code
pomonam Jul 5, 2024
57442ed
Print device
pomonam Jul 5, 2024
c56e285
Remove reset memory
pomonam Jul 5, 2024
c78c092
minor
pomonam Jul 5, 2024
c135ecd
Memory cleanup
pomonam Jul 5, 2024
288427e
m
pomonam Jul 5, 2024
d0154f1
Final covariance cleanup
pomonam Jul 6, 2024
a80595e
Finalize factor computations
pomonam Jul 6, 2024
42e539b
Add score computations
pomonam Jul 6, 2024
1c603fd
Release memory after prepare
pomonam Jul 6, 2024
743070a
Do GPU tests
pomonam Jul 7, 2024
1f26c88
minor
pomonam Jul 7, 2024
9ece91c
Set condition
pomonam Jul 7, 2024
614ccb0
Use default names
pomonam Jul 7, 2024
2df875b
Modify fsdp
pomonam Jul 7, 2024
c918176
Only initialize when necessary
pomonam Jul 7, 2024
398503c
Modify all other tests
pomonam Jul 7, 2024
133f53c
Finalize refactor
pomonam Jul 8, 2024
7ec811f
Add measurement score CPU tests
pomonam Jul 8, 2024
207457b
Remove state initialization
pomonam Jul 8, 2024
b665723
Add DDP tests
pomonam Jul 8, 2024
4b2f9bb
Disable nccl initialization
pomonam Jul 8, 2024
ea1f03d
Reduce logging level
pomonam Jul 8, 2024
7d9c794
Change contiguous tensor
pomonam Jul 8, 2024
0a4cf37
Fix DDP test
pomonam Jul 8, 2024
9119228
Add FSDP tests
pomonam Jul 8, 2024
c1e19fa
Add AMP tests
pomonam Jul 8, 2024
7c6e0a3
Add compile tests
pomonam Jul 8, 2024
e0bc934
Add debug code
pomonam Jul 8, 2024
f5aa709
Add normal lines
pomonam Jul 8, 2024
3dce62d
Add reset compiler
pomonam Jul 8, 2024
1dc7d85
Remove debug lines
pomonam Jul 8, 2024
1c42025
Add debug line for wiki
pomonam Jul 8, 2024
85f99a5
minor
pomonam Jul 8, 2024
3f83568
minor
pomonam Jul 8, 2024
7556c56
Remove debug lines
pomonam Jul 8, 2024
8f5e80f
Remove reference to score
pomonam Jul 8, 2024
37aacd8
Add debug lines
pomonam Jul 8, 2024
4d9dad1
Check mem leak
pomonam Jul 8, 2024
132ac9c
Explicitly remove
pomonam Jul 8, 2024
3cfcc48
More debug lines
pomonam Jul 8, 2024
386b335
Reduce size
pomonam Jul 8, 2024
75962ae
Explicit deletion
pomonam Jul 8, 2024
0535cf4
Remove processed lambda count
pomonam Jul 8, 2024
6645ca9
Minimize size for dot product
pomonam Jul 9, 2024
9744252
Remove self cache
pomonam Jul 9, 2024
c757145
Change to torch einsum
pomonam Jul 9, 2024
77c0e3a
Use einsum
pomonam Jul 9, 2024
a2efa47
Remove einsum
pomonam Jul 9, 2024
1919097
Fix linting
pomonam Jul 9, 2024
23eda1c
Fix linting in tests
pomonam Jul 9, 2024
62123a4
Start wikitext add
pomonam Jul 9, 2024
eb11cf0
Improve the examples
pomonam Jul 9, 2024
6de4afb
Modify score_args name
pomonam Jul 9, 2024
2e2802d
Fix depreciated commands
pomonam Jul 9, 2024
1868e86
Modify default batch_size
pomonam Jul 9, 2024
2a6e1d2
Fix incorrect paths
pomonam Jul 9, 2024
4a7ed69
Fix examples
pomonam Jul 9, 2024
94c22a9
Make two modules split
pomonam Jul 9, 2024
c15ece6
Load 32 models
pomonam Jul 9, 2024
51f64b5
Try out lambda batch size
pomonam Jul 9, 2024
6f5c2af
Test out lambda
pomonam Jul 9, 2024
5cf4864
Remove contiguous calls
pomonam Jul 10, 2024
0ad5a3c
Let minimze flops
pomonam Jul 10, 2024
8ca4e86
Add einsum
pomonam Jul 10, 2024
9544c36
Remove inspect arguments
pomonam Jul 10, 2024
afff397
Remove contract operation
pomonam Jul 10, 2024
9a5626d
Remove cpu flag
pomonam Jul 10, 2024
3aa89bc
fix tests
pomonam Jul 10, 2024
9c3f89c
Remove contract path dependency
pomonam Jul 10, 2024
395bbc1
Fix task formulation
pomonam Jul 10, 2024
7209dac
Finish dailymail example
pomonam Jul 10, 2024
7374ddd
add openwebtext
pomonam Jul 10, 2024
c6efd69
Ignore Attn computation
pomonam Jul 10, 2024
ab16f3f
Remove blank einsum calls with >3 operands
pomonam Jul 10, 2024
2ce91a6
Update documentations
pomonam Jul 10, 2024
44764e3
Finalize all examples
pomonam Jul 10, 2024
3310f19
Fix spacing issues
pomonam Jul 10, 2024
7ea5ca7
Add debug statement
pomonam Jul 10, 2024
dd69813
Remove debug statement
pomonam Jul 10, 2024
c1eed9d
Try unscale
pomonam Jul 10, 2024
496b8ed
Increase timeout
pomonam Jul 10, 2024
ad06d11
Add score computation script
pomonam Jul 11, 2024
37b37d1
Fix typo
pomonam Jul 11, 2024
041a6e3
Update commands
pomonam Jul 11, 2024
c6be439
Disable module parition
pomonam Jul 11, 2024
05d06a0
Add factors name arguments
pomonam Jul 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ jobs:
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/factors/test_eigendecompositions.py
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_matmul.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ cython_debug/

# Checkpoints and influence outputs
checkpoints/
analyses/
influence_results/
data/
cache/
*.pth
*.pt
98 changes: 51 additions & 47 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class YourTask(Task):
) -> torch.Tensor:
# TODO: Complete this method.

def tracked_modules(self) -> Optional[List[str]]:
def get_influence_tracked_modules(self) -> Optional[List[str]]:
# TODO: [Optional] Complete this method.
return None # Compute influence scores on all available modules.

Expand All @@ -89,7 +89,7 @@ model = prepare_model(model=model, task=task)
...
```

If you have specified specific module names in `Task.tracked_modules`, `TrackedModule` will only be installed for these modules.
If you have specified specific module names in `Task.get_influence_tracked_modules`, `TrackedModule` will only be installed for these modules.

**\[Optional\] Create a DDP and FSDP Module.**
After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or
Expand Down Expand Up @@ -140,7 +140,7 @@ Try rewriting the model so that it uses supported modules (as done for the `conv
Alternatively, you can create a subclass of `TrackedModule` to compute influence scores for your custom module.
If there are specific modules you would like to see supported, please submit an issue.

**How should I write task.tracked_modules?**
**How should I write task.get_influence_tracked_modules?**
We recommend using all supported modules for influence computations. However, if you would like to compute influence scores
on subset of the modules (e.g., influence computations only on MLP layers for transformer or influence computation only on the last layer),
inspect `model.named_modules()` to determine what modules to use. You can specify the list of module names you want to analyze.
Expand Down Expand Up @@ -183,7 +183,7 @@ def forward(x: torch.Tensor) -> torch.Tensor:
> [!WARNING]
> The default arguments assume the module is used only once during the forward pass.
> If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
> `shared_parameters_exist=True` in `FactorArguments`.
> `has_shared_parameters=True` in `FactorArguments`.

**Why are there so many arguments?**
Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments`
Expand All @@ -204,14 +204,13 @@ from kronfluence.arguments import FactorArguments
factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
distributed_sync_steps=1000,
amp_dtype=None,
shared_parameters_exist=False,
has_shared_parameters=False,

# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
covariance_data_partition_size=1,
covariance_module_partition_size=1,
covariance_data_partitions=1,
covariance_module_partitions=1,
activation_covariance_dtype=torch.float32,
gradient_covariance_dtype=torch.float32,

Expand All @@ -220,10 +219,10 @@ factor_args = FactorArguments(

# Settings for Lambda matrix fitting.
lambda_max_examples=100_000,
lambda_data_partition_size=1,
lambda_module_partition_size=1,
lambda_iterative_aggregate=False,
cached_activation_cpu_offload=False,
lambda_data_partitions=1,
lambda_module_partitions=1,
use_iterative_lambda_aggregation=False,
offload_activations_to_cpu=False,
per_sample_gradient_dtype=torch.float32,
lambda_dtype=torch.float32,
)
Expand All @@ -237,7 +236,7 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass.
- `has_shared_parameters`: Specifies whether the shared parameters exist in the forward pass.

### Fitting Covariance Matrices

Expand All @@ -254,13 +253,13 @@ covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_fa
This step corresponds to **Equation 16** in the paper. You can tune:
- `covariance_max_examples`: Controls the maximum number of data points for fitting covariance matrices. Setting it to `None`,
Kronfluence computes covariance matrices for all data points.
- `covariance_data_partition_size`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partition_size = 2`, the dataset is split into 2 chunks and covariance matrices
- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
- `covariance_module_partition_size`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partition_size = 2`, the module is split into 2 chunks and covariance matrices
- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
covariance matrices cannot fit into GPU memory). However, this will require multiple iterations over the dataset and can be slow.
- `activation_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16`
Expand All @@ -271,7 +270,7 @@ or `torch.float16`.
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting covariance matrices.
2. Try using lower precision for `activation_covariance_dtype` and `gradient_covariance_dtype`.
3. Try setting `covariance_module_partition_size > 1`.
3. Try setting `covariance_module_partitions > 1`.

### Performing Eigendecomposition

Expand Down Expand Up @@ -301,22 +300,22 @@ lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor")

This corresponds to **Equation 20** in the paper. You can tune:
- `lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices.
- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
- `cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
- `lambda_data_partitions`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partitions`: Number of module partitions to use for computing Lambda matrices.
- `offload_activations_to_cpu`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
You can set `offload_activations_to_cpu=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `use_iterative_lambda_aggregation`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16`
or `torch.float16`.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
or `torch.float16`. Recommended to use `torch.float32`.
or `torch.float16`.

**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.)
2. Try setting `use_iterative_lambda_aggregation=True` or `offload_activations_to_cpu=True`. (Try out `use_iterative_lambda_aggregation=True` first.)
3. Try using lower precision for `per_sample_gradient_dtype` and `lambda_dtype`.
4. Try using `lambda_module_partition_size > 1`.
4. Try using `lambda_module_partitions > 1`.

### FAQs

Expand All @@ -339,21 +338,24 @@ import torch
from kronfluence.arguments import ScoreArguments

score_args = ScoreArguments(
damping=1e-08,
cached_activation_cpu_offload=False,
distributed_sync_steps=1000,
damping_factor=1e-08,
amp_dtype=None,
offload_activations_to_cpu=False,

# More functionalities to compute influence scores.
data_partition_size=1,
module_partition_size=1,
per_module_score=False,
data_partitions=1,
module_partitions=1,
compute_per_module_scores=False,
compute_per_token_scores=False,
use_measurement_for_self_influence=False,
aggregate_query_gradients=False,
aggregate_train_gradients=False,

# Configuration for query batching.
query_gradient_rank=None,
query_gradient_low_rank=None,
use_full_svd=False,
query_gradient_svd_dtype=torch.float32,
num_query_gradient_accumulations=1,
query_gradient_accumulation_steps=1,

# Configuration for dtype.
score_dtype=torch.float32,
Expand All @@ -362,23 +364,25 @@ score_args = ScoreArguments(
)
```

- `damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
- `damping_factor`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
`(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296).
- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `data_partition_size`: Number of data partitions for computing influence scores.
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
- `offload_activations_to_cpu`: Whether to offload cached activations to CPU.
- `data_partitions`: Number of data partitions for computing influence scores.
- `module_partitions`: Number of module partitions for computing influence scores.
- `compute_per_module_scores`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.
- - `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `compute_per_token_scores`: Whether to return a per-token influence scores. Only applicable to transformer-based models.
- `aggregate_query_gradients`: Whether to use the summed query gradient instead of per-sample query gradients.
- `aggregate_train_gradients`: Whether to use the summed training gradient instead of per-sample training gradients.
- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `query_gradient_low_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
- `num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with
- `query_gradient_accumulation_steps`: Number of query gradients to accumulate over. For example, when `query_gradient_accumulation_steps=2` with
`query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
but `torch.float32` is recommended.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`.

### Computing Influence Scores

Expand Down Expand Up @@ -409,12 +413,12 @@ vector will correspond to `g_m^T ⋅ H^{-1} ⋅ g_l`, where `g_m` is the gradien

**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
2. Try setting `cached_activation_cpu_offload=True`.
2. Try setting `offload_activations_to_cpu=True`.
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
4. Try using lower precision for `precondition_dtype`.
5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
5. Try setting `query_gradient_low_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores, not self-influence scores.
6. Try setting `module_partition_size > 1`.
6. Try setting `module_partitions > 1`.

### FAQs

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Please address any reported issues before submitting your PR.
## Acknowledgements

[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

## License

Expand Down
3 changes: 2 additions & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
scikit-learn>=1.4.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
transformers>=4.41.2
transformers>=4.42.0
isort==5.13.2
pylint==3.2.3
pytest==8.2.2
Expand Down
10 changes: 5 additions & 5 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ pip install -r requirements.txt

Alternatively, navigate to each example folder and run `pip install -r requirements.txt`.


## List of Tasks

Our examples cover the following tasks:

<div align="center">

| Task | Example datasets |
| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 / ImageNet |
| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Language Modeling | WikiText-2 / OpenWebText |
| Summarization | DNN/DailyMail |
| Language Modeling | WikiText-2 & OpenWebText |

</div>

These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel).
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
Loading
Loading