Skip to content

Commit

Permalink
Update RTE examples
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 26, 2024
1 parent bd19cf0 commit 674869d
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 33 deletions.
103 changes: 83 additions & 20 deletions examples/glue/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# GLUE & BERT Example

This directory contains scripts for fine-tuning BERT on GLUE benchmark. The pipeline is motivated from [HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification).
Please begin by installing necessary packages.
This directory contains scripts for fine-tuning BERT and computing influence scores on GLUE benchmark. The pipeline is motivated from [this HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification).
To get started, please install the necessary packages:

```bash
pip install -r requirements.txt
```

## Training

To fine-tune BERT on some specific dataset, run the following command (we are using `SST2` dataset):
To fine-tune BERT on a specific dataset, run the following command (we are using the `SST2` dataset in this example):

```bash
python train.py --dataset_name sst2 \
--checkpoint_dir ./checkpoints \
Expand All @@ -22,39 +24,100 @@ python train.py --dataset_name sst2 \

## Computing Pairwise Influence Scores

To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command:
To obtain pairwise influence scores on a maximum of 2000 query data points using `ekfac`, run the following command:

```bash
python analyze.py --dataset_name sst2 \
--query_batch_size 175 \
--train_batch_size 128 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
On A100 (80GB), it takes roughly 80 minutes to compute the pairwise scores for SST2 with around 900 query data points
(including computing EKFAC factors).

We can also use query batching (low-rank approximation to the query gradient; see Section 3.2.2 from the [paper](https://arxiv.org/pdf/2308.03296.pdf)) to compute influence scores with a
larger query batch size.
On an A100 (80GB), it takes roughly 95 minutes to compute the pairwise scores for SST2 with around 900 query data points (including computing EKFAC factors):

```
```

For more efficient computation, use half precision:

```bash
python analyze.py --dataset_name sst2 \
--query_gradient_rank 32 \
--query_batch_size 436 \
--train_batch_size 256 \
--query_batch_size 175 \
--train_batch_size 128 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
--factor_strategy ekfac \
--use_half_precision
```
Note that query batching is slower in this case (140 minutes in total), as the number of training data points is small and the cost of performing SVD dominates the overall cost.
Assuming that you ran above two commands, `query_batching_analysis.py` contains code to compute the correlations between the full rank and low-rank scores.

<p align="center">
<a href="#"><img width="380" img src="figure/query_batching.png" alt="Counterfactual"/></a>
</p>
The averaged correlations between the low-rank and full rank scores for 100 data points is 0.98.
This reduces computation time to about 20 minutes on an A100 (80GB) GPU:

```
```

## Counterfactual Evaluation

We plan to add a simple demo for counterfactual evaluation on the RTE dataset soon.
Can we remove top positively influential training examples to make some queries misclassify? Subset removal counterfactual evaluation
selects correctly classified query data point, removes top-k positively influential training samples, and retrain the network with the modified dataset to see if that query
data point gets misclassified.

We first need to compute pairwise influence scores for the `RTE` dataset:

```bash
python train.py --dataset_name rte \
--checkpoint_dir ./checkpoints \
--train_batch_size 32 \
--eval_batch_size 32 \
--learning_rate 2e-05 \
--weight_decay 0.01 \
--num_train_epochs 3 \
--seed 1004

python analyze.py --dataset_name rte \
--query_batch_size 175 \
--train_batch_size 128 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac

python analyze.py --dataset_name rte \
--query_batch_size 175 \
--train_batch_size 128 \
--checkpoint_dir ./checkpoints \
--factor_strategy identity
```

`run_counterfactual.py` contains the script to run the counterfactual experiment.

<p align="center">
<a href="#"><img width="380" img src="figure/counterfactual.png" alt="Counterfactual"/></a>
</p>
</p>

## Evaluating Linear Datamodeling Score

The `evaluate_lds.py` script computes the [linear datamodeling score (LDS)](https://arxiv.org/abs/2303.14186). It measures the LDS obtained by
retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). By running `evaludate_lds.py`, we obtain `xx` LDS (we get `xx` LDS with the half precision).

The script also includes functionality to print out top influential sequences for a given query.

```
Query Example:
= Homarus gammarus =
Homarus gammarus, known as the European lobster or common lobster, is a species of clawed lobster from the eastern Atlantic Ocean, Mediterranean Sea and parts of the Black Sea. It is closely related to the American lobster, H. americanus. It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ), and bears a conspicuous pair of claws. In life, the lobsters are blue, only becoming " lobster red " on cooking. Mating occurs in the summer, producing eggs which are carried by the females for up to a year before hatching into planktonic larvae. Homarus gammarus is a highly esteemed food, and is widely caught using lobster pots, mostly around the British Isles.
= = Description = =
Homarus gammarus is a large crustacean, with a body length up to 60 centimetres ( 24 in ) and weighing up to 5 – 6 kilograms ( 11 – 13 lb ), although the lobsters caught in lobster pots are usually 23 – 38 cm ( 9 – 15 in ) long and weigh 0 @.@ 7 – 2 @.@ 2 kg ( 1 @.@ 5 – 4 @.@ 9 lb ). Like other crustaceans, lobsters have a hard exoskeleton which they must shed in order to grow, in a process called ecdysis ( moulting ). This may occur several times a year for young lobsters, but decreases to once every 1 – 2 years for larger animals.
The first pair of pereiopods is armed with a large, asymmetrical pair of claws. The larger one is the " crusher ", and has rounded nodules used for crushing prey ; the other is the " cutter ", which has sharp inner edges, and is used for holding or tearing the prey. Usually, the left claw is the crusher, and the right is the cutter.
The exoskeleton is generally blue above, with spots that coalesce, and yellow below. The red colour associated with lobsters only appears after cooking. This occurs because, in life, the red pigment astaxanthin is bound to a protein complex, but the complex is broken up by the heat of cooking, releasing the red pigment.
The closest relative of H. gammarus is the American lobster, Homarus americanus. The two species are very similar, and can be crossed artificially
Top Influential Example:
Sector Headquarters, Port Moresby
= Cape lobster =
The Cape lobster, Homarinus capensis, is a species of small lobster that lives off the coast of South Africa, from Dassen Island to Haga Haga. Only a few dozen specimens are known, mostly regurgitated by reef @-@ dwelling fish. It lives in rocky reefs, and is thought to lay large eggs that have a short larval phase, or that hatch directly as a juvenile. The species grows to a total length of 10 cm ( 3 @.@ 9 in ), and resembles a small European or American lobster ; it was previously included in the same genus, Homarus, although it is not very closely related to those species, and is now considered to form a separate, monotypic genus – Homarinus. Its closest relatives are the genera Thymops and Thymopides.
= = Distribution and ecology = =
The Cape lobster is endemic to South Africa. It occurs from Dassen Island, Western Cape in the west to Haga Haga, Eastern Cape in the east, a range of 900 kilometres ( 560 mi ). Most of the known specimens were regurgitated by fish caught on reefs at depths of 20 – 40 metres ( 66 – 131 ft ). This suggests that the Cape lobster inhabits rocky substrates, and may explain its apparent rarity, since such areas are not amenable to dredging or trawling, and the species may be too small to be retained by lobster traps.
= = Description = =
Homarinus capensis is considerably smaller than the large northern lobsters of the Atlantic Ocean, Homarus gammarus ( the European lobster ) and Homarus americanus ( the American lobster ), at 8 – 10 centimetres ( 3 @.@ 1 – 3 @.@ 9 in ) total length, or 4 – 5 cm ( 1 @.@ 6 – 2 @.@ 0 in ) carapace length. Accounts of the colouration of H. capensis are very variable, from tawny, red or yellow to " a rather dark olive ", similar to Homarus gammarus.
Homarinus and Homarus are considered to be the most plesiomorphic genera in the family Nephropidae. Nonetheless, the Cape lobster differs from Homarus in a number of characters. The rostrum of the Cape lobster is flattened, while that of Homarus is rounded in section
```
45 changes: 34 additions & 11 deletions examples/glue/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Dict[str, torch.Tensor]
Expand All @@ -33,12 +35,24 @@ def parse_args():
help="A path that is storing the final checkpoint of the model.",
)

parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute influence factors.",
)
parser.add_argument(
"--query_gradient_rank",
type=int,
default=-1,
help="Rank for the low-rank query gradient approximation.",
)
parser.add_argument(
"--use_half_precision",
action="store_true",
default=False,
help="Whether to use half precision for computing factors and scores.",
)
parser.add_argument(
"--query_batch_size",
type=int,
Expand All @@ -52,12 +66,11 @@ def parse_args():
help="Batch size for computing training gradients.",
)
parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute influence factors.",
"--profile",
action="store_true",
default=False,
help="Boolean flag to profile computations.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
Expand Down Expand Up @@ -146,38 +159,48 @@ def main():
analysis_name=args.dataset_name,
model=model,
task=task,
cpu=False,
profile=args.profile,
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

# Compute influence factors.
factors_name = args.factor_strategy
factor_args = FactorArguments(strategy=args.factor_strategy)
if args.use_half_precision:
factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16)
factors_name += "_half"
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
factors_name=factors_name,
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
initial_per_device_batch_size_attempt=512,
)

# Compute pairwise scores.
score_args = ScoreArguments()
scores_name = factor_args.strategy
if args.use_half_precision:
score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
scores_name += "_half"
rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
score_args = ScoreArguments(query_gradient_rank=rank, query_gradient_svd_dtype=torch.float32)
scores_name = args.factor_strategy
if rank is not None:
score_args.query_gradient_rank = rank
score_args.num_query_gradient_accumulations = 10
scores_name += f"_qlr{rank}"
analyzer.compute_pairwise_scores(
score_args=score_args,
scores_name=scores_name,
factors_name=args.factor_strategy,
factors_name=factors_name,
query_dataset=eval_dataset,
query_indices=list(range(min([len(eval_dataset), 2000]))),
train_dataset=train_dataset,
per_device_query_batch_size=args.query_batch_size,
per_device_train_batch_size=args.train_batch_size,
overwrite_output_dir=True,
overwrite_output_dir=False,
)
scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
logging.info(f"Scores shape: {scores.shape}")
Expand Down
68 changes: 68 additions & 0 deletions examples/glue/evaluate_lds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging

import numpy as np
import torch
import tqdm
from scipy.stats import spearmanr
from transformers import AutoTokenizer

from examples.wikitext.pipeline import get_wikitext_dataset
from kronfluence.analyzer import Analyzer


def evaluate_correlations(data_name: str, scores: torch.Tensor) -> float:
margins = torch.from_numpy(torch.load(open(f"files/{data_name}/margins.pt", "rb")))
masks = torch.from_numpy(torch.load(open(f"files/{data_name}/masks.pt", "rb"))).float()

val_indices = np.arange(481)
preds = -masks @ scores.T

rs = []
ps = []
for j in tqdm.tqdm(val_indices):
r, p = spearmanr(preds[:, j], margins[:, j])
rs.append(r)
ps.append(p)
rs, ps = np.array(rs), np.array(ps)
return rs.mean()


def main():
logging.basicConfig(level=logging.INFO)

margins = torch.from_numpy(torch.load(open(f"files/margins.pt", "rb")))
masks = torch.from_numpy(torch.load(open(f"files/masks.pt", "rb"))).float()

# You might need to change the path.
scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac/pairwise_scores.safetensors")[
"all_modules"
].to(dtype=torch.float32)
# scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half/pairwise_scores.safetensors")[
# "all_modules"
# ].to(dtype=torch.float32)
# scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half_compile/pairwise_scores.safetensors")[
# "all_modules"
# ].to(dtype=torch.float32)

corr_mean = evaluate_correlations(scores)
logging.info(f"LDS: {np.mean(corr_mean)}")

# We can also visualize the top influential sequences.
eval_idx = 0
train_dataset = get_wikitext_dataset(
split="eval_train",
)
eval_dataset = get_wikitext_dataset(
split="valid",
)
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True, trust_remote_code=True)
print("Query Data Example:")
print(tokenizer.decode(eval_dataset[eval_idx]["input_ids"]))

top_idx = int(torch.argsort(scores[eval_idx], descending=True)[0])
print("Top Influential Example:")
print(tokenizer.decode(train_dataset[top_idx]["input_ids"]))


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions examples/glue/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def preprocess_function(examples):
if split in ["train", "eval_train"]:
train_dataset = raw_datasets["train"]
ds = train_dataset
if data_name == "rte":
ds = ds.select(range(2432))
else:
eval_dataset = raw_datasets["validation"]
ds = eval_dataset
Expand Down
Loading

0 comments on commit 674869d

Please sign in to comment.