Skip to content

Commit

Permalink
Add self-influence computation
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 5b78068 commit a998164
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
15 changes: 14 additions & 1 deletion examples/cifar/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python analyze.py --query_batch_size 1000 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 2 minutes to compute the
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the
pairwise scores (including computing EKFAC factors).

## Mislabeled Data Detection
Expand All @@ -43,3 +43,16 @@ python train.py --dataset_dir ./data \
--num_train_epochs 25 \
--seed 1004
```

Then, compute self-influence scores with the following command:
```bash
python detect_mislabeled_dataset.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--num_train_epochs 25 \
--seed 1004
```
2 changes: 1 addition & 1 deletion examples/cifar/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def parse_args():
parser = argparse.ArgumentParser(description="Influence analysis on UCI datasets.")
parser = argparse.ArgumentParser(description="Influence analysis on CIFAR-10 dataset.")

parser.add_argument(
"--corrupt_percentage",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import logging
import os
from typing import Tuple

import torch
from arguments import FactorArguments
Expand All @@ -10,9 +9,8 @@
from kronfluence.analyzer import Analyzer, prepare_model



def parse_args():
parser = argparse.ArgumentParser(description="Influence analysis on UCI datasets.")
parser = argparse.ArgumentParser(description="Detecting mislabeled CIFAR-10 dataset.")

parser.add_argument(
"--corrupt_percentage",
Expand All @@ -27,13 +25,6 @@ def parse_args():
help="A folder to download or load CIFAR-10 dataset.",
)

parser.add_argument(
"--query_batch_size",
type=int,
default=1000,
help="Batch size for computing query gradients.",
)

parser.add_argument(
"--checkpoint_dir",
type=str,
Expand Down Expand Up @@ -63,7 +54,6 @@ def main():
train_dataset = get_cifar10_dataset(
split="eval_train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir
)
eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir)

model = construct_resnet9()
model_name = "model"
Expand All @@ -90,18 +80,15 @@ def main():
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
overwrite_output_dir=False,
)
analyzer.compute_pairwise_scores(
scores_name="pairwise",
analyzer.compute_self_scores(
scores_name="self",
factors_name=args.factor_strategy,
query_dataset=eval_dataset,
query_indices=list(range(2000)),
train_dataset=train_dataset,
per_device_query_batch_size=args.query_batch_size,
overwrite_output_dir=True,
)
scores = analyzer.load_pairwise_scores("pairwise")
scores = analyzer.load_pairwise_scores("self")
print(scores)


Expand Down

0 comments on commit a998164

Please sign in to comment.