Skip to content

Commit

Permalink
Merge pull request #5 from pomonam/uci
Browse files Browse the repository at this point in the history
Add uci example
  • Loading branch information
pomonam authored Mar 19, 2024
2 parents 97afac1 + b67a9be commit 3f28fba
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 28 deletions.
6 changes: 3 additions & 3 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ scores = analyzer.load_self_scores(scores_name="self")
2. Try setting `cached_activation_cpu_offload=True`.
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
4. Try setting `immediate_gradient_removal=True`.
5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`.
6Try setting `module_partition_size > 1`.

5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores.
6. Try setting `module_partition_size > 1`.

### FAQs

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pip install -e .
## Getting Started

Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md)
page for a comprehensive guide,
page for a comprehensive guide.

### Examples

Expand Down
19 changes: 15 additions & 4 deletions examples/uci/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# UCI Regression Example

This directory contains scripts designed for training a regression model and conducting influence analysis with datasets obtained from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/datasets).
This directory contains scripts designed for training a regression model and conducting influence analysis with
datasets from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/datasets). Install all necessary packages:

```bash
pip install -r requirements.txt
```

## Training

Expand All @@ -17,10 +22,16 @@ python train.py --dataset_name concrete \
--seed 1004
```

# Influence Analysis
# Computing Pairwise Influence Scores

To obtain a pairwise influence scores using EKFAC,
To obtain a pairwise influence scores using EKFAC, run the following command:
```bash
python analyze.py --dataset_name concrete \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```

# Counterfactual Evaluation

To evaluate the accuracy of influence estimates, we can perform counterfactual evaluation.
You can check the notebook `tutorial.ipynb` for running the counterfactual evaluation.
25 changes: 5 additions & 20 deletions examples/uci/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torch.nn.functional as F
from arguments import FactorArguments, ScoreArguments
from arguments import FactorArguments
from torch import nn

from examples.uci.pipeline import construct_regression_mlp, get_regression_dataset
Expand Down Expand Up @@ -94,49 +94,34 @@ def main():
raise ValueError(f"No checkpoint found at {checkpoint_path}.")
model.load_state_dict(torch.load(checkpoint_path))

print(Analyzer.get_module_summary(model))

task = RegressionTask()
model = prepare_model(model, task)

analyzer = Analyzer(
analysis_name=args.dataset_name,
model=model,
task=task,
profile=True,
cpu=True,
)
factor_args = FactorArguments(strategy=args.factor_strategy, lambda_iterative_aggregate=True)

factor_args = FactorArguments(strategy=args.factor_strategy)
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
)

score_args = ScoreArguments(query_gradient_rank=16)
analyzer.compute_pairwise_scores(
scores_name="pairwise",
factors_name=args.factor_strategy,
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=len(eval_dataset),
score_args=score_args,
# per_device_train_batch_size=8,
overwrite_output_dir=True,
)

analyzer.compute_self_scores(
scores_name="self",
factors_name=args.factor_strategy,
# query_dataset=eval_dataset,
train_dataset=train_dataset,
# per_device_query_batch_size=len(eval_dataset),
# per_device_train_batch_size=8,
overwrite_output_dir=True,
)
# # logging.info(f"Scores: {scores}")
scores = analyzer.load_pairwise_scores("pairwise")
print(scores)


if __name__ == "__main__":
Expand Down

0 comments on commit 3f28fba

Please sign in to comment.