diff --git a/examples/cifar/README.md b/examples/cifar/README.md index a86283a..66ac654 100644 --- a/examples/cifar/README.md +++ b/examples/cifar/README.md @@ -26,7 +26,7 @@ python analyze.py --query_batch_size 1000 \ --checkpoint_dir ./checkpoints \ --factor_strategy ekfac ``` -You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 2 minutes to compute the +You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the pairwise scores (including computing EKFAC factors). ## Mislabeled Data Detection @@ -43,3 +43,16 @@ python train.py --dataset_dir ./data \ --num_train_epochs 25 \ --seed 1004 ``` + +Then, compute self-influence scores with the following command: +```bash +python detect_mislabeled_dataset.py --dataset_dir ./data \ + --corrupt_percentage 0.1 \ + --checkpoint_dir ./checkpoints \ + --train_batch_size 512 \ + --eval_batch_size 1024 \ + --learning_rate 0.4 \ + --weight_decay 0.0001 \ + --num_train_epochs 25 \ + --seed 1004 +``` \ No newline at end of file diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index 6f7556f..292b102 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -17,7 +17,7 @@ def parse_args(): - parser = argparse.ArgumentParser(description="Influence analysis on UCI datasets.") + parser = argparse.ArgumentParser(description="Influence analysis on CIFAR-10 dataset.") parser.add_argument( "--corrupt_percentage", diff --git a/examples/cifar/detect_mislabled_dataset.py b/examples/cifar/detect_mislabeled_dataset.py similarity index 78% rename from examples/cifar/detect_mislabled_dataset.py rename to examples/cifar/detect_mislabeled_dataset.py index 266ee1c..bccbd77 100644 --- a/examples/cifar/detect_mislabled_dataset.py +++ b/examples/cifar/detect_mislabeled_dataset.py @@ -1,7 +1,6 @@ import argparse import logging import os -from typing import Tuple import torch from arguments import FactorArguments @@ -10,9 +9,8 @@ from kronfluence.analyzer import Analyzer, prepare_model - def parse_args(): - parser = argparse.ArgumentParser(description="Influence analysis on UCI datasets.") + parser = argparse.ArgumentParser(description="Detecting mislabeled CIFAR-10 dataset.") parser.add_argument( "--corrupt_percentage", @@ -27,13 +25,6 @@ def parse_args(): help="A folder to download or load CIFAR-10 dataset.", ) - parser.add_argument( - "--query_batch_size", - type=int, - default=1000, - help="Batch size for computing query gradients.", - ) - parser.add_argument( "--checkpoint_dir", type=str, @@ -63,7 +54,6 @@ def main(): train_dataset = get_cifar10_dataset( split="eval_train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir ) - eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir) model = construct_resnet9() model_name = "model" @@ -90,18 +80,15 @@ def main(): dataset=train_dataset, per_device_batch_size=None, factor_args=factor_args, - overwrite_output_dir=True, + overwrite_output_dir=False, ) - analyzer.compute_pairwise_scores( - scores_name="pairwise", + analyzer.compute_self_scores( + scores_name="self", factors_name=args.factor_strategy, - query_dataset=eval_dataset, - query_indices=list(range(2000)), train_dataset=train_dataset, - per_device_query_batch_size=args.query_batch_size, overwrite_output_dir=True, ) - scores = analyzer.load_pairwise_scores("pairwise") + scores = analyzer.load_pairwise_scores("self") print(scores)