diff --git a/examples/cifar/README.md b/examples/cifar/README.md index e25a65e..a6ae155 100644 --- a/examples/cifar/README.md +++ b/examples/cifar/README.md @@ -26,15 +26,30 @@ This will train the model using the specified hyperparameters and save the train ## Computing Pairwise Influence Scores -To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command: +To compute pairwise influence scores on 2000 query data points using the `ekfac` factorization strategy, run the following command: + ```bash python analyze.py --query_batch_size 1000 \ --dataset_dir ./data \ --checkpoint_dir ./checkpoints \ --factor_strategy ekfac ``` -You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the -pairwise scores (including computing EKFAC factors). + +In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors): + +``` + +``` + +To use AMP when computing influence scores (in addition to half precision when computing influence factors and scores), run: + +```bash +python analyze.py --query_batch_size 1000 \ + --dataset_dir ./data \ + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac \ + --use_half_precision +``` ## Mislabeled Data Detection diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index 65f34c3..5731191 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -138,6 +138,7 @@ def main(): analysis_name="cifar10", model=model, task=task, + profile=args.profile, ) # Configure parameters for DataLoader. dataloader_kwargs = DataLoaderKwargs(num_workers=4)