-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
412 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# GLUE & BERT Example | ||
|
||
This directory contains scripts for fine-tuning BERT on GLUE benchmark. The pipeline is motivated from [HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification). | ||
Please begin by installing necessary packages. | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Training | ||
|
||
To fine-tune BERT on some specific dataset, run the following command (we are using `SST2` dataset): | ||
```bash | ||
python train.py --dataset_name sst2 \ | ||
--checkpoint_dir ./checkpoints \ | ||
--train_batch_size 32 \ | ||
--eval_batch_size 32 \ | ||
--learning_rate 3e-05 \ | ||
--weight_decay 0.01 \ | ||
--num_train_epochs 3 \ | ||
--seed 1004 | ||
``` | ||
|
||
## Computing Pairwise Influence Scores | ||
|
||
To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command: | ||
```bash | ||
python analyze.py --query_batch_size 1000 \ | ||
--dataset_dir ./data \ | ||
--checkpoint_dir ./checkpoints \ | ||
--factor_strategy ekfac | ||
``` | ||
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the | ||
pairwise scores (including computing EKFAC factors). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import argparse | ||
import logging | ||
import os | ||
from typing import Tuple, Dict | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from examples.glue.pipeline import construct_bert, get_glue_dataset | ||
from kronfluence.analyzer import Analyzer, prepare_model | ||
from kronfluence.arguments import FactorArguments | ||
from kronfluence.task import Task | ||
from kronfluence.utils.dataset import DataLoaderKwargs | ||
|
||
BATCH_TYPE = Dict[str, torch.Tensor] | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Influence analysis on CIFAR-10 dataset.") | ||
|
||
parser.add_argument( | ||
"--dataset_name", | ||
type=str, | ||
default="sst2", | ||
help="A name of GLUE dataset.", | ||
) | ||
|
||
parser.add_argument( | ||
"--query_batch_size", | ||
type=int, | ||
default=32, | ||
help="Batch size for computing query gradients.", | ||
) | ||
|
||
parser.add_argument( | ||
"--checkpoint_dir", | ||
type=str, | ||
default="./checkpoints", | ||
help="A path to store the final checkpoint.", | ||
) | ||
|
||
parser.add_argument( | ||
"--factor_strategy", | ||
type=str, | ||
default="ekfac", | ||
help="Strategy to compute preconditioning factors.", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
if args.checkpoint_dir is not None: | ||
os.makedirs(args.checkpoint_dir, exist_ok=True) | ||
args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset_name) | ||
os.makedirs(args.checkpoint_dir, exist_ok=True) | ||
|
||
return args | ||
|
||
|
||
class TextClassificationTask(Task): | ||
def compute_train_loss( | ||
self, | ||
batch: BATCH_TYPE, | ||
model: nn.Module, | ||
sample: bool = False, | ||
) -> torch.Tensor: | ||
logits = model( | ||
input_ids=batch["input_ids"], | ||
attention_mask=batch["attention_mask"], | ||
token_type_ids=batch["token_type_ids"], | ||
).logits | ||
|
||
if not sample: | ||
return F.cross_entropy( | ||
logits, batch["labels"], reduction="sum" | ||
) | ||
with torch.no_grad(): | ||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
sampled_labels = torch.multinomial( | ||
probs, num_samples=1, | ||
).flatten() | ||
return F.cross_entropy( | ||
logits, sampled_labels.detach(), reduction="sum" | ||
) | ||
|
||
def compute_measurement( | ||
self, | ||
batch: BATCH_TYPE, | ||
model: nn.Module, | ||
) -> torch.Tensor: | ||
# Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. | ||
logits = model( | ||
input_ids=batch["input_ids"], | ||
attention_mask=batch["attention_mask"], | ||
token_type_ids=batch["token_type_ids"], | ||
).logits | ||
|
||
labels = batch["labels"] | ||
bindex = torch.arange(logits.shape[0]).to( | ||
device=logits.device, non_blocking=False | ||
) | ||
logits_correct = logits[bindex, labels] | ||
|
||
cloned_logits = logits.clone() | ||
cloned_logits[bindex, labels] = torch.tensor( | ||
-torch.inf, device=logits.device, dtype=logits.dtype | ||
) | ||
|
||
margins = logits_correct - cloned_logits.logsumexp(dim=-1) | ||
return -margins.sum() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
train_dataset = get_glue_dataset( | ||
data_name=args.data_name, split="eval_train", | ||
) | ||
eval_dataset = get_glue_dataset( | ||
data_name=args.data_name, split="valid", | ||
) | ||
|
||
model = construct_bert() | ||
model_name = "model" | ||
if args.corrupt_percentage is not None: | ||
model_name += "_corrupt_" + str(args.corrupt_percentage) | ||
checkpoint_path = os.path.join(args.checkpoint_dir, f"{model_name}.pth") | ||
if not os.path.isfile(checkpoint_path): | ||
raise ValueError(f"No checkpoint found at {checkpoint_path}.") | ||
model.load_state_dict(torch.load(checkpoint_path)) | ||
|
||
task = TextClassificationTask() | ||
model = prepare_model(model, task) | ||
|
||
analyzer = Analyzer( | ||
analysis_name="cifar10", | ||
model=model, | ||
task=task, | ||
cpu=False, | ||
) | ||
|
||
dataloader_kwargs = DataLoaderKwargs(num_workers=4) | ||
analyzer.set_dataloader_kwargs(dataloader_kwargs) | ||
|
||
factor_args = FactorArguments(strategy=args.factor_strategy) | ||
analyzer.fit_all_factors( | ||
factors_name=args.factor_strategy, | ||
dataset=train_dataset, | ||
per_device_batch_size=None, | ||
factor_args=factor_args, | ||
overwrite_output_dir=True, | ||
) | ||
analyzer.compute_pairwise_scores( | ||
scores_name="pairwise", | ||
factors_name=args.factor_strategy, | ||
query_dataset=eval_dataset, | ||
query_indices=list(range(2000)), | ||
train_dataset=train_dataset, | ||
per_device_query_batch_size=args.query_batch_size, | ||
overwrite_output_dir=True, | ||
) | ||
scores = analyzer.load_pairwise_scores("pairwise") | ||
print(scores) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
evaluate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.