Skip to content

Commit

Permalink
Linting fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 426c983 commit 5b78068
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 20 deletions.
20 changes: 16 additions & 4 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ python train.py --dataset_dir ./data \
--seed 1004
```

# Computing Pairwise Influence Scores
## Computing Pairwise Influence Scores

To obtain a pairwise influence scores on 2000 query data points using `ekfac`, run the following command:
```bash
Expand All @@ -26,8 +26,20 @@ python analyze.py --query_batch_size 1000 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 2 minutes to compute the pairwise scores.
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 2 minutes to compute the
pairwise scores (including computing EKFAC factors).

# Counterfactual Evaluation
## Mislabeled Data Detection

You can check the notebook `tutorial.ipynb` for running the counterfactual evaluation.
First, train the model with 10% of training dataset mislabeled by running the following command:
```bash
python train.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--num_train_epochs 25 \
--seed 1004
```
7 changes: 6 additions & 1 deletion examples/cifar/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import torch
import torch.nn.functional as F
from arguments import FactorArguments
from kronfluence.arguments import FactorArguments
from torch import nn

from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.task import Task
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -125,6 +127,9 @@ def main():
cpu=False,
)

dataloader_kwargs = DataLoaderKwargs(num_workers=4)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

factor_args = FactorArguments(strategy=args.factor_strategy)
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
Expand Down
109 changes: 109 additions & 0 deletions examples/cifar/detect_mislabled_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import logging
import os
from typing import Tuple

import torch
from arguments import FactorArguments
from examples.cifar.analyze import ClassificationTask
from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model



def parse_args():
parser = argparse.ArgumentParser(description="Influence analysis on UCI datasets.")

parser.add_argument(
"--corrupt_percentage",
type=float,
default=0.1,
help="Percentage of the training dataset to corrupt.",
)
parser.add_argument(
"--dataset_dir",
type=str,
default="./data",
help="A folder to download or load CIFAR-10 dataset.",
)

parser.add_argument(
"--query_batch_size",
type=int,
default=1000,
help="Batch size for computing query gradients.",
)

parser.add_argument(
"--checkpoint_dir",
type=str,
default="./checkpoints",
help="A path to store the final checkpoint.",
)

parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute preconditioning factors.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
os.makedirs(args.checkpoint_dir, exist_ok=True)

return args


def main():
args = parse_args()
logging.basicConfig(level=logging.INFO)

train_dataset = get_cifar10_dataset(
split="eval_train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir
)
eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir)

model = construct_resnet9()
model_name = "model"
if args.corrupt_percentage is not None:
model_name += "_corrupt_" + str(args.corrupt_percentage)
checkpoint_path = os.path.join(args.checkpoint_dir, f"{model_name}.pth")
if not os.path.isfile(checkpoint_path):
raise ValueError(f"No checkpoint found at {checkpoint_path}.")
model.load_state_dict(torch.load(checkpoint_path))

task = ClassificationTask()
model = prepare_model(model, task)

analyzer = Analyzer(
analysis_name="cifar10",
model=model,
task=task,
cpu=False,
)

factor_args = FactorArguments(strategy=args.factor_strategy)
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
)
analyzer.compute_pairwise_scores(
scores_name="pairwise",
factors_name=args.factor_strategy,
query_dataset=eval_dataset,
query_indices=list(range(2000)),
train_dataset=train_dataset,
per_device_query_batch_size=args.query_batch_size,
overwrite_output_dir=True,
)
scores = analyzer.load_pairwise_scores("pairwise")
print(scores)


if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions examples/uci/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# UCI Regression Example

This directory contains scripts designed for training a regression model and conducting influence analysis with
datasets from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/datasets). Install all necessary packages:
datasets from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/datasets). Please begin by installing necessary packages.

```bash
pip install -r requirements.txt
Expand All @@ -22,7 +22,7 @@ python train.py --dataset_name concrete \
--seed 1004
```

# Computing Pairwise Influence Scores
## Computing Pairwise Influence Scores

To obtain a pairwise influence scores using `ekfac`, run the following command:
```bash
Expand All @@ -33,6 +33,6 @@ python analyze.py --dataset_name concrete \
```
You can also use `identity`, `diagonal`, and `kfac`.

# Counterfactual Evaluation
## Counterfactual Evaluation

You can check the notebook `tutorial.ipynb` for running the counterfactual evaluation.
You can check the notebook `tutorial.ipynb` to run the counterfactual evaluation.
6 changes: 3 additions & 3 deletions examples/uci/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.task import Task

BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]
BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]


def parse_args():
Expand Down Expand Up @@ -58,7 +58,7 @@ def parse_args():
class RegressionTask(Task):
def compute_train_loss(
self,
batch: BATCH_DTYPE,
batch: BATCH_TYPE,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
Expand All @@ -72,7 +72,7 @@ def compute_train_loss(

def compute_measurement(
self,
batch: BATCH_DTYPE,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
# The measurement function is set as a training loss.
Expand Down
3 changes: 1 addition & 2 deletions kronfluence/computer/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from kronfluence.utils.exceptions import (
FactorsNotFoundError,
TrackedModuleNotFoundError,
UnsupportableModuleError,
)
from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger
from kronfluence.utils.save import (
Expand Down Expand Up @@ -82,7 +81,7 @@ def __init__(
f"Analyzer."
)
self.logger.error(error_msg)
raise UnsupportableModuleError(error_msg)
raise TrackedModuleNotFoundError(error_msg)
self.logger.info(f"Tracking modules with names: {tracked_module_names}.")

if self.state.use_distributed and not isinstance(model, (DDP, FSDP)):
Expand Down
6 changes: 3 additions & 3 deletions tests/gpu_tests/ddp_variation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.task import Task
from tests.gpu_tests.ddp_test import OLD_FACTOR_NAME
from tests.gpu_tests.pipeline import BATCH_DTYPE, construct_test_mlp, get_mnist_dataset
from tests.gpu_tests.pipeline import BATCH_TYPE, construct_test_mlp, get_mnist_dataset

LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_RANK = int(os.environ["RANK"])
Expand All @@ -27,7 +27,7 @@
class GpuVariationTask(Task):
def compute_train_loss(
self,
batch: BATCH_DTYPE,
batch: BATCH_TYPE,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
Expand All @@ -45,7 +45,7 @@ def compute_train_loss(

def compute_measurement(
self,
batch: BATCH_DTYPE,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
inputs, labels = batch
Expand Down
6 changes: 3 additions & 3 deletions tests/gpu_tests/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

from kronfluence.task import Task

BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]
BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]


class GpuTestTask(Task):
def compute_train_loss(
self,
batch: BATCH_DTYPE,
batch: BATCH_TYPE,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
Expand All @@ -34,7 +34,7 @@ def compute_train_loss(

def compute_measurement(
self,
batch: BATCH_DTYPE,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
inputs, labels = batch
Expand Down

0 comments on commit 5b78068

Please sign in to comment.