Skip to content

Commit

Permalink
Merge pull request #28 from pomonam/openwebtext
Browse files Browse the repository at this point in the history
Code Refactor for Final Release
  • Loading branch information
pomonam authored Jul 11, 2024
2 parents e81e653 + 05d06a0 commit 52a67f9
Show file tree
Hide file tree
Showing 104 changed files with 7,164 additions and 3,837 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ jobs:
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/factors/test_eigendecompositions.py
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_matmul.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ cython_debug/

# Checkpoints and influence outputs
checkpoints/
analyses/
influence_results/
data/
cache/
*.pth
*.pt
98 changes: 51 additions & 47 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class YourTask(Task):
) -> torch.Tensor:
# TODO: Complete this method.

def tracked_modules(self) -> Optional[List[str]]:
def get_influence_tracked_modules(self) -> Optional[List[str]]:
# TODO: [Optional] Complete this method.
return None # Compute influence scores on all available modules.

Expand All @@ -89,7 +89,7 @@ model = prepare_model(model=model, task=task)
...
```

If you have specified specific module names in `Task.tracked_modules`, `TrackedModule` will only be installed for these modules.
If you have specified specific module names in `Task.get_influence_tracked_modules`, `TrackedModule` will only be installed for these modules.

**\[Optional\] Create a DDP and FSDP Module.**
After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or
Expand Down Expand Up @@ -140,7 +140,7 @@ Try rewriting the model so that it uses supported modules (as done for the `conv
Alternatively, you can create a subclass of `TrackedModule` to compute influence scores for your custom module.
If there are specific modules you would like to see supported, please submit an issue.

**How should I write task.tracked_modules?**
**How should I write task.get_influence_tracked_modules?**
We recommend using all supported modules for influence computations. However, if you would like to compute influence scores
on subset of the modules (e.g., influence computations only on MLP layers for transformer or influence computation only on the last layer),
inspect `model.named_modules()` to determine what modules to use. You can specify the list of module names you want to analyze.
Expand Down Expand Up @@ -183,7 +183,7 @@ def forward(x: torch.Tensor) -> torch.Tensor:
> [!WARNING]
> The default arguments assume the module is used only once during the forward pass.
> If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
> `shared_parameters_exist=True` in `FactorArguments`.
> `has_shared_parameters=True` in `FactorArguments`.
**Why are there so many arguments?**
Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments`
Expand All @@ -204,14 +204,13 @@ from kronfluence.arguments import FactorArguments
factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
distributed_sync_steps=1000,
amp_dtype=None,
shared_parameters_exist=False,
has_shared_parameters=False,

# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
covariance_data_partition_size=1,
covariance_module_partition_size=1,
covariance_data_partitions=1,
covariance_module_partitions=1,
activation_covariance_dtype=torch.float32,
gradient_covariance_dtype=torch.float32,

Expand All @@ -220,10 +219,10 @@ factor_args = FactorArguments(

# Settings for Lambda matrix fitting.
lambda_max_examples=100_000,
lambda_data_partition_size=1,
lambda_module_partition_size=1,
lambda_iterative_aggregate=False,
cached_activation_cpu_offload=False,
lambda_data_partitions=1,
lambda_module_partitions=1,
use_iterative_lambda_aggregation=False,
offload_activations_to_cpu=False,
per_sample_gradient_dtype=torch.float32,
lambda_dtype=torch.float32,
)
Expand All @@ -237,7 +236,7 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass.
- `has_shared_parameters`: Specifies whether the shared parameters exist in the forward pass.

### Fitting Covariance Matrices

Expand All @@ -254,13 +253,13 @@ covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_fa
This step corresponds to **Equation 16** in the paper. You can tune:
- `covariance_max_examples`: Controls the maximum number of data points for fitting covariance matrices. Setting it to `None`,
Kronfluence computes covariance matrices for all data points.
- `covariance_data_partition_size`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partition_size = 2`, the dataset is split into 2 chunks and covariance matrices
- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
- `covariance_module_partition_size`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partition_size = 2`, the module is split into 2 chunks and covariance matrices
- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
covariance matrices cannot fit into GPU memory). However, this will require multiple iterations over the dataset and can be slow.
- `activation_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16`
Expand All @@ -271,7 +270,7 @@ or `torch.float16`.
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting covariance matrices.
2. Try using lower precision for `activation_covariance_dtype` and `gradient_covariance_dtype`.
3. Try setting `covariance_module_partition_size > 1`.
3. Try setting `covariance_module_partitions > 1`.

### Performing Eigendecomposition

Expand Down Expand Up @@ -301,22 +300,22 @@ lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor")

This corresponds to **Equation 20** in the paper. You can tune:
- `lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices.
- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
- `cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
- `lambda_data_partitions`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partitions`: Number of module partitions to use for computing Lambda matrices.
- `offload_activations_to_cpu`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
You can set `offload_activations_to_cpu=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `use_iterative_lambda_aggregation`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16`
or `torch.float16`.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
or `torch.float16`. Recommended to use `torch.float32`.
or `torch.float16`.

**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.)
2. Try setting `use_iterative_lambda_aggregation=True` or `offload_activations_to_cpu=True`. (Try out `use_iterative_lambda_aggregation=True` first.)
3. Try using lower precision for `per_sample_gradient_dtype` and `lambda_dtype`.
4. Try using `lambda_module_partition_size > 1`.
4. Try using `lambda_module_partitions > 1`.

### FAQs

Expand All @@ -339,21 +338,24 @@ import torch
from kronfluence.arguments import ScoreArguments

score_args = ScoreArguments(
damping=1e-08,
cached_activation_cpu_offload=False,
distributed_sync_steps=1000,
damping_factor=1e-08,
amp_dtype=None,
offload_activations_to_cpu=False,

# More functionalities to compute influence scores.
data_partition_size=1,
module_partition_size=1,
per_module_score=False,
data_partitions=1,
module_partitions=1,
compute_per_module_scores=False,
compute_per_token_scores=False,
use_measurement_for_self_influence=False,
aggregate_query_gradients=False,
aggregate_train_gradients=False,

# Configuration for query batching.
query_gradient_rank=None,
query_gradient_low_rank=None,
use_full_svd=False,
query_gradient_svd_dtype=torch.float32,
num_query_gradient_accumulations=1,
query_gradient_accumulation_steps=1,

# Configuration for dtype.
score_dtype=torch.float32,
Expand All @@ -362,23 +364,25 @@ score_args = ScoreArguments(
)
```

- `damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
- `damping_factor`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
`(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296).
- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `data_partition_size`: Number of data partitions for computing influence scores.
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
- `offload_activations_to_cpu`: Whether to offload cached activations to CPU.
- `data_partitions`: Number of data partitions for computing influence scores.
- `module_partitions`: Number of module partitions for computing influence scores.
- `compute_per_module_scores`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.
- - `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `compute_per_token_scores`: Whether to return a per-token influence scores. Only applicable to transformer-based models.
- `aggregate_query_gradients`: Whether to use the summed query gradient instead of per-sample query gradients.
- `aggregate_train_gradients`: Whether to use the summed training gradient instead of per-sample training gradients.
- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `query_gradient_low_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
- `num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with
- `query_gradient_accumulation_steps`: Number of query gradients to accumulate over. For example, when `query_gradient_accumulation_steps=2` with
`query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
but `torch.float32` is recommended.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`.

### Computing Influence Scores

Expand Down Expand Up @@ -409,12 +413,12 @@ vector will correspond to `g_m^T ⋅ H^{-1} ⋅ g_l`, where `g_m` is the gradien

**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
2. Try setting `cached_activation_cpu_offload=True`.
2. Try setting `offload_activations_to_cpu=True`.
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
4. Try using lower precision for `precondition_dtype`.
5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
5. Try setting `query_gradient_low_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores, not self-influence scores.
6. Try setting `module_partition_size > 1`.
6. Try setting `module_partitions > 1`.

### FAQs

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Please address any reported issues before submitting your PR.
## Acknowledgements

[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

## License

Expand Down
3 changes: 2 additions & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
scikit-learn>=1.4.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
transformers>=4.41.2
transformers>=4.42.0
isort==5.13.2
pylint==3.2.3
pytest==8.2.2
Expand Down
10 changes: 5 additions & 5 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ pip install -r requirements.txt

Alternatively, navigate to each example folder and run `pip install -r requirements.txt`.


## List of Tasks

Our examples cover the following tasks:

<div align="center">

| Task | Example datasets |
| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 / ImageNet |
| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Language Modeling | WikiText-2 / OpenWebText |
| Summarization | DNN/DailyMail |
| Language Modeling | WikiText-2 & OpenWebText |

</div>

These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel).
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
Loading

0 comments on commit 52a67f9

Please sign in to comment.