Skip to content

Commit

Permalink
Add cifar analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent edbfeb8 commit 38bdb08
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 1 deletion.
142 changes: 142 additions & 0 deletions examples/cifar/analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import argparse
import logging
import os
from typing import Tuple

import torch
import torch.nn.functional as F
from arguments import FactorArguments

from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.task import Task

BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]


def parse_args():
parser = argparse.ArgumentParser(description="Influence analysis on UCI datasets.")

parser.add_argument(
"--corrupt_percentage",
type=float,
default=None,
help="Percentage of the training dataset to corrupt.",
)
parser.add_argument(
"--dataset_dir",
type=str,
default="./data",
help="A folder to download or load CIFAR-10 dataset.",
)

parser.add_argument(
"--checkpoint_dir",
type=str,
default="./checkpoints",
help="A path to store the final checkpoint.",
)

parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute preconditioning factors.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
os.makedirs(args.checkpoint_dir, exist_ok=True)

return args


class ClassificationTask(Task):

def compute_train_loss(
self,
batch: BATCH_DTYPE,
outputs: torch.Tensor,
sample: bool = False,
) -> torch.Tensor:
_, labels = batch

if not sample:
return F.cross_entropy(outputs, labels, reduction="sum")
with torch.no_grad():
probs = torch.nn.functional.softmax(outputs, dim=-1)
sampled_labels = torch.multinomial(
probs,
num_samples=1,
).flatten()
return F.cross_entropy(outputs, sampled_labels.detach(), reduction="sum")

def compute_measurement(
self,
batch: BATCH_DTYPE,
outputs: torch.Tensor,
) -> torch.Tensor:
# Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py.
_, labels = batch

bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False)
logits_correct = outputs[bindex, labels]

cloned_logits = outputs.clone()
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()


def main():
args = parse_args()
logging.basicConfig(level=logging.INFO)

train_dataset = get_cifar10_dataset(
split="eval_train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir
)
eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir)

model = construct_resnet9()
model_name = "model"
if args.corrupt_percentage is not None:
model_name += "_corrupt_" + str(args.corrupt_percentage)
checkpoint_path = os.path.join(args.checkpoint_dir, f"{model_name}.pth")
if not os.path.isfile(checkpoint_path):
raise ValueError(f"No checkpoint found at {checkpoint_path}.")
model.load_state_dict(torch.load(checkpoint_path))

task = ClassificationTask()
model = prepare_model(model, task)

analyzer = Analyzer(
analysis_name=args.dataset_name,
model=model,
task=task,
cpu=True,
)

factor_args = FactorArguments(strategy=args.factor_strategy)
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
)
analyzer.compute_pairwise_scores(
scores_name="pairwise",
factors_name=args.factor_strategy,
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=len(eval_dataset),
overwrite_output_dir=True,
)
scores = analyzer.load_pairwise_scores("pairwise")
print(scores)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/cifar/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_cifar10_dataset(
assert split in ["train", "eval_train", "valid"]

normalize = torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
if split in ["train", "eval_train"]:
if split == "train":
transform_config = torchvision.transforms.Compose(
[
torchvision.transforms.RandomCrop(32, padding=4),
Expand Down

0 comments on commit 38bdb08

Please sign in to comment.