diff --git a/examples/cifar/README.md b/examples/cifar/README.md index a6ae155..652a5b3 100644 --- a/examples/cifar/README.md +++ b/examples/cifar/README.md @@ -38,7 +38,23 @@ python analyze.py --query_batch_size 1000 \ In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors): ``` - +---------------------------------------------------------------------------------------------------------------------------------- +| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % | +---------------------------------------------------------------------------------------------------------------------------------- +| Total | - | 11 | 112.83 | 100 % | +---------------------------------------------------------------------------------------------------------------------------------- +| Compute Pairwise Score | 47.989 | 1 | 47.989 | 42.532 | +| Fit Lambda | 34.639 | 1 | 34.639 | 30.7 | +| Fit Covariance | 21.841 | 1 | 21.841 | 19.357 | +| Save Pairwise Score | 3.5998 | 1 | 3.5998 | 3.1905 | +| Perform Eigendecomposition | 2.7724 | 1 | 2.7724 | 2.4572 | +| Save Covariance | 0.85695 | 1 | 0.85695 | 0.75951 | +| Save Eigendecomposition | 0.85628 | 1 | 0.85628 | 0.75892 | +| Save Lambda | 0.12327 | 1 | 0.12327 | 0.10925 | +| Load Eigendecomposition | 0.056494 | 1 | 0.056494 | 0.05007 | +| Load All Factors | 0.048981 | 1 | 0.048981 | 0.043412 | +| Load Covariance | 0.046798 | 1 | 0.046798 | 0.041476 | +---------------------------------------------------------------------------------------------------------------------------------- ``` To use AMP when computing influence scores (in addition to half precision when computing influence factors and scores), run: diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index 5731191..d628b96 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -151,10 +151,10 @@ def main(): factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16) factors_name += "_half" analyzer.fit_all_factors( - factors_name=args.factor_strategy, + factors_name=factors_name, + factor_args=factor_args, dataset=train_dataset, per_device_batch_size=None, - factor_args=factor_args, overwrite_output_dir=False, )