Skip to content

Commit

Permalink
Merge pull request #32 from pomonam/low_precision
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
pomonam authored Jul 16, 2024
2 parents 1bb64f5 + 8541589 commit d6414c5
Show file tree
Hide file tree
Showing 51 changed files with 39,714 additions and 6,183 deletions.
3 changes: 1 addition & 2 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ Kronfluence computes covariance matrices for all data points.
- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
covariance matrices will be saved in disk. It is also helpful when using low precision.
- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
Expand Down
16 changes: 8 additions & 8 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ Our examples cover the following tasks:

<div align="center">

| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Summarization | CNN/DailyMail |
| Language Modeling | WikiText-2 & OpenWebText |
| Task | Example Datasets |
|----------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| Regression | [UCI](https://github.com/pomonam/kronfluence/tree/main/examples/uci) |
| Image Classification | [CIFAR-10](https://github.com/pomonam/kronfluence/tree/main/examples/cifar) & [ImageNet](https://github.com/pomonam/kronfluence/tree/main/examples/imagenet) |
| Text Classification | [GLUE](https://github.com/pomonam/kronfluence/tree/main/examples/glue) |
| Multiple-Choice | [SWAG](https://github.com/pomonam/kronfluence/tree/main/examples/swag) |
| Summarization | [CNN/DailyMail](https://github.com/pomonam/kronfluence/tree/main/examples/dailymail) |
| Language Modeling | [WikiText-2](https://github.com/pomonam/kronfluence/tree/main/examples/wikitext) & [OpenWebText](https://github.com/pomonam/kronfluence/tree/main/examples/openwebtext) |

</div>

Expand Down
21 changes: 11 additions & 10 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# CIFAR-10 & ResNet-9 Example

This directory contains scripts for training ResNet-9 and computing influence scores on CIFAR-10 dataset. The pipeline is motivated from
This directory contains scripts for training ResNet-9 and computing influence scores on the CIFAR-10 dataset. The pipeline is motivated from the
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). To get started, please install the necessary packages by running the following command:

```bash
Expand All @@ -9,7 +9,7 @@ pip install -r requirements.txt

## Training

To train ResNet-9 on the CIFAR-10 dataset, run the following command:
To train ResNet-9 on CIFAR-10, execute:

```bash
python train.py --dataset_dir ./data \
Expand All @@ -35,7 +35,8 @@ python analyze.py --query_batch_size 1000 \
--factor_strategy ekfac
```

In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors):
In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`.
On an A100 (80GB) GPU, computation takes approximately 2 minutes, including EKFAC factor calculation:

```
----------------------------------------------------------------------------------------------------------------------------------
Expand All @@ -57,7 +58,7 @@ In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as t
----------------------------------------------------------------------------------------------------------------------------------
```

To use AMP when computing influence scores, run:
To use AMP for faster computation, add the `--use_half_precision` flag:

```bash
python analyze.py --query_batch_size 1000 \
Expand Down Expand Up @@ -89,20 +90,20 @@ This reduces computation time to about 40 seconds on an A100 (80GB) GPU:
----------------------------------------------------------------------------------------------------------------------------------
```

You can run `half_precision_analysis.py` to verify that the scores computed with AMP have high correlations with those of the default configuration.
Run `half_precision_analysis.py` to verify that AMP-computed scores maintain high correlations with default configuration scores.

<p align="center">
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Half Precision"/></a>
</p>

## Visualizing Influential Training Images

[This Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing) provides a tutorial on visualizing the top influential training images.
For a tutorial on visualizing top influential training images, refer to [this Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing)

## Mislabeled Data Detection

We can use self-influence scores (see **Section 5.4** for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of the training examples mislabeled by running:
First, train the model with 10% of the training examples mislabeled:

```bash
python train.py --dataset_dir ./data \
Expand All @@ -116,7 +117,7 @@ python train.py --dataset_dir ./data \
--seed 1004
```

Then, compute the self-influence scores with:
Then compute self-influence scores:

```bash
python detect_mislabeled_dataset.py --dataset_dir ./data \
Expand All @@ -125,7 +126,7 @@ python detect_mislabeled_dataset.py --dataset_dir ./data \
--factor_strategy ekfac
```

On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence scores:
On an A100 (80GB) GPU, this takes approximately 2 minutes:

```
----------------------------------------------------------------------------------------------------------------------------------
Expand All @@ -147,7 +148,7 @@ On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence
----------------------------------------------------------------------------------------------------------------------------------
```

Around 80% of mislabeled data points can be detected by inspecting 10% of the dataset (97% by inspecting 20%).
By inspecting just 10% of the dataset, about 80% of mislabeled data points can be detected (97% by inspecting 20%).

<p align="center">
<a href="#"><img width="380" img src="figure/mislabel.png" alt="Mislabeled Data Detection"/></a>
Expand Down
Binary file modified examples/cifar/figure/half_precision.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions examples/cifar/half_precision_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
plt.rcParams["axes.axisbelow"] = True

# Only plot first 3000 points to avoid clutter.
idx = 79
idx = 0
plt.scatter(half_scores[idx][:3000], scores[idx][:3000], edgecolor="k")
plt.grid()
plt.xlabel("bfloat16")
Expand All @@ -36,9 +36,11 @@ def main():

# Compute the averaged spearman correlation.
all_corr = []
for i in range(100):
for i in range(2000):
all_corr.append(spearmanr(scores[i], half_scores[i])[0])
logging.info(f"Averaged Spearman Correlation: {np.array(all_corr).mean()}")
logging.info(f"Lowest Spearman Correlation: {np.array(all_corr).min()}")
logging.info(f"Highest Spearman Correlation: {np.array(all_corr).max()}")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit d6414c5

Please sign in to comment.