Skip to content

Commit

Permalink
Polish examples
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent fd6b045 commit 0a6cf5b
Show file tree
Hide file tree
Showing 16 changed files with 412 additions and 140 deletions.
9 changes: 6 additions & 3 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# CIFAR-10 & ResNet-9 Example

This directory contains scripts designed for training ResNet-9 on CIFAR-10. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb).
This directory contains scripts for training ResNet-9 on CIFAR-10. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). Please begin by installing necessary packages.
```bash
pip install -r requirements.txt
```

## Training

To train the model on the CIFAR-10 dataset, run the following command:
To train ResNet-9 on CIFAR-10 dataset, run the following command:
```bash
python train.py --dataset_dir ./data \
--checkpoint_dir ./checkpoints \
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch
import torch.nn.functional as F
from kronfluence.arguments import FactorArguments
from torch import nn

from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments
from kronfluence.task import Task
from kronfluence.utils.dataset import DataLoaderKwargs

Expand Down
6 changes: 3 additions & 3 deletions examples/cifar/detect_mislabeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import os

import torch
from kronfluence.arguments import FactorArguments

from examples.cifar.analyze import ClassificationTask
from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model

from kronfluence.arguments import FactorArguments
from kronfluence.utils.dataset import DataLoaderKwargs


Expand Down Expand Up @@ -102,7 +102,7 @@ def main():
accuracies = []
for interval in intervals:
interval = interval.item()
predicted_indices = torch.argsort(scores, descending=True)[:int(interval * len(train_dataset))]
predicted_indices = torch.argsort(scores, descending=True)[: int(interval * len(train_dataset))]
predicted_indices = list(predicted_indices.numpy())
accuracies.append(len(set(predicted_indices) & set(corrupted_indices)) / total_corrupt_size)

Expand Down
34 changes: 34 additions & 0 deletions examples/glue/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# GLUE & BERT Example

This directory contains scripts for fine-tuning BERT on GLUE benchmark. The pipeline is motivated from [HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification).
Please begin by installing necessary packages.
```bash
pip install -r requirements.txt
```

## Training

To fine-tune BERT on some specific dataset, run the following command (we are using `SST2` dataset):
```bash
python train.py --dataset_name sst2 \
--checkpoint_dir ./checkpoints \
--train_batch_size 32 \
--eval_batch_size 32 \
--learning_rate 3e-05 \
--weight_decay 0.01 \
--num_train_epochs 3 \
--seed 1004
```

## Computing Pairwise Influence Scores

To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command:
```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the
pairwise scores (including computing EKFAC factors).

168 changes: 168 additions & 0 deletions examples/glue/analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import logging
import os
from typing import Tuple, Dict

import torch
import torch.nn.functional as F
from torch import nn

from examples.glue.pipeline import construct_bert, get_glue_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments
from kronfluence.task import Task
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Dict[str, torch.Tensor]


def parse_args():
parser = argparse.ArgumentParser(description="Influence analysis on CIFAR-10 dataset.")

parser.add_argument(
"--dataset_name",
type=str,
default="sst2",
help="A name of GLUE dataset.",
)

parser.add_argument(
"--query_batch_size",
type=int,
default=32,
help="Batch size for computing query gradients.",
)

parser.add_argument(
"--checkpoint_dir",
type=str,
default="./checkpoints",
help="A path to store the final checkpoint.",
)

parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute preconditioning factors.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
os.makedirs(args.checkpoint_dir, exist_ok=True)
args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset_name)
os.makedirs(args.checkpoint_dir, exist_ok=True)

return args


class TextClassificationTask(Task):
def compute_train_loss(
self,
batch: BATCH_TYPE,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"],
).logits

if not sample:
return F.cross_entropy(
logits, batch["labels"], reduction="sum"
)
with torch.no_grad():
probs = torch.nn.functional.softmax(logits, dim=-1)
sampled_labels = torch.multinomial(
probs, num_samples=1,
).flatten()
return F.cross_entropy(
logits, sampled_labels.detach(), reduction="sum"
)

def compute_measurement(
self,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
# Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py.
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"],
).logits

labels = batch["labels"]
bindex = torch.arange(logits.shape[0]).to(
device=logits.device, non_blocking=False
)
logits_correct = logits[bindex, labels]

cloned_logits = logits.clone()
cloned_logits[bindex, labels] = torch.tensor(
-torch.inf, device=logits.device, dtype=logits.dtype
)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()


def main():
args = parse_args()
logging.basicConfig(level=logging.INFO)

train_dataset = get_glue_dataset(
data_name=args.data_name, split="eval_train",
)
eval_dataset = get_glue_dataset(
data_name=args.data_name, split="valid",
)

model = construct_bert()
model_name = "model"
if args.corrupt_percentage is not None:
model_name += "_corrupt_" + str(args.corrupt_percentage)
checkpoint_path = os.path.join(args.checkpoint_dir, f"{model_name}.pth")
if not os.path.isfile(checkpoint_path):
raise ValueError(f"No checkpoint found at {checkpoint_path}.")
model.load_state_dict(torch.load(checkpoint_path))

task = TextClassificationTask()
model = prepare_model(model, task)

analyzer = Analyzer(
analysis_name="cifar10",
model=model,
task=task,
cpu=False,
)

dataloader_kwargs = DataLoaderKwargs(num_workers=4)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

factor_args = FactorArguments(strategy=args.factor_strategy)
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
)
analyzer.compute_pairwise_scores(
scores_name="pairwise",
factors_name=args.factor_strategy,
query_dataset=eval_dataset,
query_indices=list(range(2000)),
train_dataset=train_dataset,
per_device_query_batch_size=args.query_batch_size,
overwrite_output_dir=True,
)
scores = analyzer.load_pairwise_scores("pairwise")
print(scores)


if __name__ == "__main__":
main()
12 changes: 7 additions & 5 deletions examples/glue/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
from typing import List

import torch
import torch.nn as nn
import torchvision
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
Expand Down Expand Up @@ -41,14 +38,12 @@ def get_glue_dataset(
data_name: str,
split: str,
indices: List[int] = None,
dataset_dir: str = "data/",
) -> Dataset:
assert split in ["train", "eval_train", "valid"]

raw_datasets = load_dataset(
path="glue",
name=data_name,
# data_dir=dataset_dir,
)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
Expand Down Expand Up @@ -86,3 +81,10 @@ def preprocess_function(examples):
ds = ds.select(indices)

return ds


if __name__ == "__main__":
from kronfluence import Analyzer

model = construct_bert()
print(Analyzer.get_module_summary(model))
1 change: 1 addition & 0 deletions examples/glue/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
evaluate
15 changes: 5 additions & 10 deletions examples/glue/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def parse_args():
default="sst2",
help="A name of GLUE dataset.",
)
parser.add_argument(
"--dataset_dir",
type=str,
default="./data",
help="A folder to download or load GLUE dataset.",
)

parser.add_argument(
"--train_batch_size",
Expand Down Expand Up @@ -103,6 +97,7 @@ def train(
drop_last=True,
collate_fn=default_data_collator,
)

model = construct_bert().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

Expand Down Expand Up @@ -140,7 +135,7 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) ->
batch["input_ids"].to(device=DEVICE),
batch["token_type_ids"].to(device=DEVICE),
batch["attention_mask"].to(device=DEVICE),
)
).logits
labels = batch["labels"].to(device=DEVICE)
total_loss += F.cross_entropy(outputs, labels, reduction="sum").detach().item()
predictions = outputs.argmax(dim=-1)
Expand All @@ -160,7 +155,7 @@ def main():
if args.seed is not None:
set_seed(args.seed)

train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir)
train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train")
model = train(
dataset=train_dataset,
batch_size=args.train_batch_size,
Expand All @@ -169,11 +164,11 @@ def main():
weight_decay=args.weight_decay,
)

eval_train_dataset = get_glue_dataset(data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir)
eval_train_dataset = get_glue_dataset(data_name=args.dataset_name, split="eval_train")
train_loss, train_acc = evaluate_model(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size)
logger.info(f"Train loss: {train_loss}, Train Accuracy: {train_acc}")

eval_dataset = get_glue_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir)
eval_dataset = get_glue_dataset(data_name=args.dataset_name, split="valid")
eval_loss, eval_acc = evaluate_model(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size)
logger.info(f"Evaluation loss: {eval_loss}, Evaluation Accuracy: {eval_acc}")

Expand Down
Loading

0 comments on commit 0a6cf5b

Please sign in to comment.