From 674869dc2346a7173208507c380b8bbd886e8135 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 26 Jun 2024 02:46:47 -0400 Subject: [PATCH] Update RTE examples --- examples/glue/README.md | 103 +++++++++++--- examples/glue/analyze.py | 45 ++++-- examples/glue/evaluate_lds.py | 68 +++++++++ examples/glue/pipeline.py | 2 + examples/glue/run_counterfactual.py | 212 ++++++++++++++++++++++++++++ examples/glue/train.py | 3 +- 6 files changed, 400 insertions(+), 33 deletions(-) diff --git a/examples/glue/README.md b/examples/glue/README.md index 12398a1..9b4671a 100644 --- a/examples/glue/README.md +++ b/examples/glue/README.md @@ -1,14 +1,16 @@ # GLUE & BERT Example -This directory contains scripts for fine-tuning BERT on GLUE benchmark. The pipeline is motivated from [HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification). -Please begin by installing necessary packages. +This directory contains scripts for fine-tuning BERT and computing influence scores on GLUE benchmark. The pipeline is motivated from [this HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification). +To get started, please install the necessary packages: + ```bash pip install -r requirements.txt ``` ## Training -To fine-tune BERT on some specific dataset, run the following command (we are using `SST2` dataset): +To fine-tune BERT on a specific dataset, run the following command (we are using the `SST2` dataset in this example): + ```bash python train.py --dataset_name sst2 \ --checkpoint_dir ./checkpoints \ @@ -22,7 +24,8 @@ python train.py --dataset_name sst2 \ ## Computing Pairwise Influence Scores -To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command: +To obtain pairwise influence scores on a maximum of 2000 query data points using `ekfac`, run the following command: + ```bash python analyze.py --dataset_name sst2 \ --query_batch_size 175 \ @@ -30,31 +33,91 @@ python analyze.py --dataset_name sst2 \ --checkpoint_dir ./checkpoints \ --factor_strategy ekfac ``` -On A100 (80GB), it takes roughly 80 minutes to compute the pairwise scores for SST2 with around 900 query data points -(including computing EKFAC factors). -We can also use query batching (low-rank approximation to the query gradient; see Section 3.2.2 from the [paper](https://arxiv.org/pdf/2308.03296.pdf)) to compute influence scores with a -larger query batch size. +On an A100 (80GB), it takes roughly 95 minutes to compute the pairwise scores for SST2 with around 900 query data points (including computing EKFAC factors): + +``` + +``` + +For more efficient computation, use half precision: + ```bash python analyze.py --dataset_name sst2 \ - --query_gradient_rank 32 \ - --query_batch_size 436 \ - --train_batch_size 256 \ + --query_batch_size 175 \ + --train_batch_size 128 \ --checkpoint_dir ./checkpoints \ - --factor_strategy ekfac + --factor_strategy ekfac \ + --use_half_precision ``` -Note that query batching is slower in this case (140 minutes in total), as the number of training data points is small and the cost of performing SVD dominates the overall cost. -Assuming that you ran above two commands, `query_batching_analysis.py` contains code to compute the correlations between the full rank and low-rank scores. -

-Counterfactual -

-The averaged correlations between the low-rank and full rank scores for 100 data points is 0.98. +This reduces computation time to about 20 minutes on an A100 (80GB) GPU: + +``` + +``` ## Counterfactual Evaluation -We plan to add a simple demo for counterfactual evaluation on the RTE dataset soon. +Can we remove top positively influential training examples to make some queries misclassify? Subset removal counterfactual evaluation +selects correctly classified query data point, removes top-k positively influential training samples, and retrain the network with the modified dataset to see if that query +data point gets misclassified. + +We first need to compute pairwise influence scores for the `RTE` dataset: + +```bash +python train.py --dataset_name rte \ + --checkpoint_dir ./checkpoints \ + --train_batch_size 32 \ + --eval_batch_size 32 \ + --learning_rate 2e-05 \ + --weight_decay 0.01 \ + --num_train_epochs 3 \ + --seed 1004 + +python analyze.py --dataset_name rte \ + --query_batch_size 175 \ + --train_batch_size 128 \ + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac + +python analyze.py --dataset_name rte \ + --query_batch_size 175 \ + --train_batch_size 128 \ + --checkpoint_dir ./checkpoints \ + --factor_strategy identity +``` + +`run_counterfactual.py` contains the script to run the counterfactual experiment.

Counterfactual -

\ No newline at end of file +

+ +## Evaluating Linear Datamodeling Score + +The `evaluate_lds.py` script computes the [linear datamodeling score (LDS)](https://arxiv.org/abs/2303.14186). It measures the LDS obtained by +retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). By running `evaludate_lds.py`, we obtain `xx` LDS (we get `xx` LDS with the half precision). + +The script also includes functionality to print out top influential sequences for a given query. + +``` +Query Example: + = Homarus gammarus = + Homarus gammarus, known as the European lobster or common lobster, is a species of clawed lobster from the eastern Atlantic Ocean, Mediterranean Sea and parts of the Black Sea. It is closely related to the American lobster, H. americanus. It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ), and bears a conspicuous pair of claws. In life, the lobsters are blue, only becoming " lobster red " on cooking. Mating occurs in the summer, producing eggs which are carried by the females for up to a year before hatching into planktonic larvae. Homarus gammarus is a highly esteemed food, and is widely caught using lobster pots, mostly around the British Isles. + = = Description = = + Homarus gammarus is a large crustacean, with a body length up to 60 centimetres ( 24 in ) and weighing up to 5 – 6 kilograms ( 11 – 13 lb ), although the lobsters caught in lobster pots are usually 23 – 38 cm ( 9 – 15 in ) long and weigh 0 @.@ 7 – 2 @.@ 2 kg ( 1 @.@ 5 – 4 @.@ 9 lb ). Like other crustaceans, lobsters have a hard exoskeleton which they must shed in order to grow, in a process called ecdysis ( moulting ). This may occur several times a year for young lobsters, but decreases to once every 1 – 2 years for larger animals. + The first pair of pereiopods is armed with a large, asymmetrical pair of claws. The larger one is the " crusher ", and has rounded nodules used for crushing prey ; the other is the " cutter ", which has sharp inner edges, and is used for holding or tearing the prey. Usually, the left claw is the crusher, and the right is the cutter. + The exoskeleton is generally blue above, with spots that coalesce, and yellow below. The red colour associated with lobsters only appears after cooking. This occurs because, in life, the red pigment astaxanthin is bound to a protein complex, but the complex is broken up by the heat of cooking, releasing the red pigment. + The closest relative of H. gammarus is the American lobster, Homarus americanus. The two species are very similar, and can be crossed artificially + +Top Influential Example: + Sector Headquarters, Port Moresby + = Cape lobster = + The Cape lobster, Homarinus capensis, is a species of small lobster that lives off the coast of South Africa, from Dassen Island to Haga Haga. Only a few dozen specimens are known, mostly regurgitated by reef @-@ dwelling fish. It lives in rocky reefs, and is thought to lay large eggs that have a short larval phase, or that hatch directly as a juvenile. The species grows to a total length of 10 cm ( 3 @.@ 9 in ), and resembles a small European or American lobster ; it was previously included in the same genus, Homarus, although it is not very closely related to those species, and is now considered to form a separate, monotypic genus – Homarinus. Its closest relatives are the genera Thymops and Thymopides. + = = Distribution and ecology = = + The Cape lobster is endemic to South Africa. It occurs from Dassen Island, Western Cape in the west to Haga Haga, Eastern Cape in the east, a range of 900 kilometres ( 560 mi ). Most of the known specimens were regurgitated by fish caught on reefs at depths of 20 – 40 metres ( 66 – 131 ft ). This suggests that the Cape lobster inhabits rocky substrates, and may explain its apparent rarity, since such areas are not amenable to dredging or trawling, and the species may be too small to be retained by lobster traps. + = = Description = = + Homarinus capensis is considerably smaller than the large northern lobsters of the Atlantic Ocean, Homarus gammarus ( the European lobster ) and Homarus americanus ( the American lobster ), at 8 – 10 centimetres ( 3 @.@ 1 – 3 @.@ 9 in ) total length, or 4 – 5 cm ( 1 @.@ 6 – 2 @.@ 0 in ) carapace length. Accounts of the colouration of H. capensis are very variable, from tawny, red or yellow to " a rather dark olive ", similar to Homarus gammarus. + Homarinus and Homarus are considered to be the most plesiomorphic genera in the family Nephropidae. Nonetheless, the Cape lobster differs from Homarus in a number of characters. The rostrum of the Cape lobster is flattened, while that of Homarus is rounded in section +``` \ No newline at end of file diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py index 022dfd9..df610f5 100644 --- a/examples/glue/analyze.py +++ b/examples/glue/analyze.py @@ -12,6 +12,8 @@ from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.task import Task +from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments +from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments from kronfluence.utils.dataset import DataLoaderKwargs BATCH_TYPE = Dict[str, torch.Tensor] @@ -33,12 +35,24 @@ def parse_args(): help="A path that is storing the final checkpoint of the model.", ) + parser.add_argument( + "--factor_strategy", + type=str, + default="ekfac", + help="Strategy to compute influence factors.", + ) parser.add_argument( "--query_gradient_rank", type=int, default=-1, help="Rank for the low-rank query gradient approximation.", ) + parser.add_argument( + "--use_half_precision", + action="store_true", + default=False, + help="Whether to use half precision for computing factors and scores.", + ) parser.add_argument( "--query_batch_size", type=int, @@ -52,12 +66,11 @@ def parse_args(): help="Batch size for computing training gradients.", ) parser.add_argument( - "--factor_strategy", - type=str, - default="ekfac", - help="Strategy to compute influence factors.", + "--profile", + action="store_true", + default=False, + help="Boolean flag to profile computations.", ) - args = parser.parse_args() if args.checkpoint_dir is not None: @@ -146,38 +159,48 @@ def main(): analysis_name=args.dataset_name, model=model, task=task, - cpu=False, + profile=args.profile, ) # Configure parameters for DataLoader. dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator) analyzer.set_dataloader_kwargs(dataloader_kwargs) # Compute influence factors. + factors_name = args.factor_strategy factor_args = FactorArguments(strategy=args.factor_strategy) + if args.use_half_precision: + factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16) + factors_name += "_half" analyzer.fit_all_factors( - factors_name=args.factor_strategy, + factors_name=factors_name, dataset=train_dataset, per_device_batch_size=None, factor_args=factor_args, overwrite_output_dir=True, initial_per_device_batch_size_attempt=512, ) + # Compute pairwise scores. + score_args = ScoreArguments() + scores_name = factor_args.strategy + if args.use_half_precision: + score_args = all_low_precision_score_arguments(dtype=torch.bfloat16) + scores_name += "_half" rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None - score_args = ScoreArguments(query_gradient_rank=rank, query_gradient_svd_dtype=torch.float32) - scores_name = args.factor_strategy if rank is not None: + score_args.query_gradient_rank = rank + score_args.num_query_gradient_accumulations = 10 scores_name += f"_qlr{rank}" analyzer.compute_pairwise_scores( score_args=score_args, scores_name=scores_name, - factors_name=args.factor_strategy, + factors_name=factors_name, query_dataset=eval_dataset, query_indices=list(range(min([len(eval_dataset), 2000]))), train_dataset=train_dataset, per_device_query_batch_size=args.query_batch_size, per_device_train_batch_size=args.train_batch_size, - overwrite_output_dir=True, + overwrite_output_dir=False, ) scores = analyzer.load_pairwise_scores(scores_name)["all_modules"] logging.info(f"Scores shape: {scores.shape}") diff --git a/examples/glue/evaluate_lds.py b/examples/glue/evaluate_lds.py index e69de29..3f86722 100644 --- a/examples/glue/evaluate_lds.py +++ b/examples/glue/evaluate_lds.py @@ -0,0 +1,68 @@ +import logging + +import numpy as np +import torch +import tqdm +from scipy.stats import spearmanr +from transformers import AutoTokenizer + +from examples.wikitext.pipeline import get_wikitext_dataset +from kronfluence.analyzer import Analyzer + + +def evaluate_correlations(data_name: str, scores: torch.Tensor) -> float: + margins = torch.from_numpy(torch.load(open(f"files/{data_name}/margins.pt", "rb"))) + masks = torch.from_numpy(torch.load(open(f"files/{data_name}/masks.pt", "rb"))).float() + + val_indices = np.arange(481) + preds = -masks @ scores.T + + rs = [] + ps = [] + for j in tqdm.tqdm(val_indices): + r, p = spearmanr(preds[:, j], margins[:, j]) + rs.append(r) + ps.append(p) + rs, ps = np.array(rs), np.array(ps) + return rs.mean() + + +def main(): + logging.basicConfig(level=logging.INFO) + + margins = torch.from_numpy(torch.load(open(f"files/margins.pt", "rb"))) + masks = torch.from_numpy(torch.load(open(f"files/masks.pt", "rb"))).float() + + # You might need to change the path. + scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac/pairwise_scores.safetensors")[ + "all_modules" + ].to(dtype=torch.float32) + # scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half/pairwise_scores.safetensors")[ + # "all_modules" + # ].to(dtype=torch.float32) + # scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac_half_compile/pairwise_scores.safetensors")[ + # "all_modules" + # ].to(dtype=torch.float32) + + corr_mean = evaluate_correlations(scores) + logging.info(f"LDS: {np.mean(corr_mean)}") + + # We can also visualize the top influential sequences. + eval_idx = 0 + train_dataset = get_wikitext_dataset( + split="eval_train", + ) + eval_dataset = get_wikitext_dataset( + split="valid", + ) + tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True, trust_remote_code=True) + print("Query Data Example:") + print(tokenizer.decode(eval_dataset[eval_idx]["input_ids"])) + + top_idx = int(torch.argsort(scores[eval_idx], descending=True)[0]) + print("Top Influential Example:") + print(tokenizer.decode(train_dataset[top_idx]["input_ids"])) + + +if __name__ == "__main__": + main() diff --git a/examples/glue/pipeline.py b/examples/glue/pipeline.py index 3db7e55..1506a09 100644 --- a/examples/glue/pipeline.py +++ b/examples/glue/pipeline.py @@ -74,6 +74,8 @@ def preprocess_function(examples): if split in ["train", "eval_train"]: train_dataset = raw_datasets["train"] ds = train_dataset + if data_name == "rte": + ds = ds.select(range(2432)) else: eval_dataset = raw_datasets["validation"] ds = eval_dataset diff --git a/examples/glue/run_counterfactual.py b/examples/glue/run_counterfactual.py index e69de29..30a293d 100644 --- a/examples/glue/run_counterfactual.py +++ b/examples/glue/run_counterfactual.py @@ -0,0 +1,212 @@ +import time +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from accelerate.utils import set_seed +from torch.utils import data +from transformers import default_data_collator + +from examples.glue.pipeline import get_glue_dataset +from examples.glue.train import train +from kronfluence import Analyzer + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def get_accuracy(model: nn.Module, dataset: data.Dataset) -> torch.Tensor: + dataloader = data.DataLoader( + dataset=dataset, batch_size=32, shuffle=False, drop_last=False, collate_fn=default_data_collator + ) + + model.eval() + with torch.no_grad(): + acc_lst = [] + for batch in dataloader: + outputs = model( + input_ids=batch["input_ids"].to(device=DEVICE), + token_type_ids=batch["token_type_ids"].to(device=DEVICE), + attention_mask=batch["attention_mask"].to(device=DEVICE), + ).logits + labels = batch["labels"].to(device=DEVICE) + accs = (outputs.argmax(-1) == labels).float().cpu() + acc_lst.append(accs) + all_accs = torch.cat(acc_lst) + return all_accs + + +def train_with_indices(dataset: data.Dataset, seed: int, indices_to_keep: Optional[List[int]] = None) -> nn.Module: + if indices_to_keep is not None: + dataset = dataset.select(indices_to_keep) + + set_seed(seed) + model = train(dataset=dataset, batch_size=16, num_train_epochs=3, learning_rate=2e-05, weight_decay=0.01) + return model + + +def train_with_configurations( + dataset: data.Dataset, + valid_dataset: data.Dataset, + top_indices: List[int], + interval: int, + seed_ids: List[int], +) -> List[torch.Tensor]: + num_train = len(dataset) + indices_to_remove = top_indices[:interval] + indices_to_keep = list(set(range(num_train)) - set(indices_to_remove)) + assert len(indices_to_keep) + len(indices_to_remove) == num_train + + valid_acc_lst = [] + for seed in seed_ids: + model = train_with_indices(dataset=dataset, indices_to_keep=indices_to_keep, seed=seed + 2008) + valid_results = get_accuracy(model, valid_dataset) + valid_acc_lst.append(valid_results) + return valid_acc_lst + + +def main(): + train_dataset = get_glue_dataset( + data_name="rte", + split="eval_train", + ) + eval_dataset = get_glue_dataset( + data_name="rte", + split="valid", + ) + num_target = 100 + assert num_target <= len(eval_dataset) + + remove_intervals = [10, 20, 30, 40, 50, 60] + num_base_repeat = 10 + num_repeat = 3 + + large_seed_ids = list(range(num_base_repeat)) + seed_ids = list(range(num_repeat)) + + valid_acc_lst = [] + for seed in large_seed_ids: + model = train_with_indices(dataset=train_dataset, seed=seed + 79, indices_to_keep=None) + valid_results = get_accuracy(model, eval_dataset) + valid_acc_lst.append(valid_results) + + # Selects validation data points that get correctly classified on all seeds. + mask = np.array(valid_acc_lst).mean(0) >= 1.0 + + # Get random baseline. + start_time = time.time() + random_results = [] + for valid_idx in range(num_target): + print(f"{valid_idx}th validation data point.") + if mask[valid_idx]: + # Selects training data points with the same label. + correct_label = eval_dataset[valid_idx]["label"] + random_indices = list( + np.random.permutation( + [ + i + for i, x in enumerate([x["label"] for x in eval_dataset]) + if x == correct_label and i < num_target + ] + ) + ) + + success_lst = [] + for interval in remove_intervals: + results = train_with_configurations( + dataset=train_dataset, + top_indices=random_indices, + valid_dataset=eval_dataset.select([valid_idx]), + interval=interval, + seed_ids=seed_ids, + ) + if np.array(results).mean() < 0.5: + success_lst.append(1) + break + else: + success_lst.append(0) + + while len(success_lst) < len(remove_intervals): + success_lst.append(1) + + random_results.append(success_lst) + + end_time = time.time() + print(f"Took {end_time - start_time} seconds for the random baseline.") + print(f"Results: {random_results}") + + # Get EKFAC baseline. + start_time = time.time() + scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac/pairwise_scores.safetensors")[ + "all_modules" + ].to(dtype=torch.float32) + ekfac_results = [] + for valid_idx in range(num_target): + print(f"{valid_idx}th validation data point.") + if mask[valid_idx]: + top_indices = torch.argsort(scores[valid_idx], descending=True) + top_indices = [idx.item() for idx in top_indices] + + success_lst = [] + for interval in remove_intervals: + results = train_with_configurations( + dataset=train_dataset, + top_indices=top_indices, + valid_dataset=eval_dataset.select([valid_idx]), + interval=interval, + seed_ids=seed_ids, + ) + if np.array(results).mean() < 0.5: + success_lst.append(1) + break + else: + success_lst.append(0) + + while len(success_lst) < len(remove_intervals): + success_lst.append(1) + + ekfac_results.append(success_lst) + + end_time = time.time() + print(f"Took {end_time - start_time} seconds for the EKFAC baseline.") + print(f"Results: {ekfac_results}") + + # Get Identity baseline. + start_time = time.time() + scores = Analyzer.load_file("influence_results/wikitext/scores_ekfac/pairwise_scores.safetensors")[ + "all_modules" + ].to(dtype=torch.float32) + identity_results = [] + for valid_idx in range(num_target): + print(f"{valid_idx}th validation data point.") + if mask[valid_idx]: + top_indices = torch.argsort(scores[valid_idx], descending=True) + top_indices = [idx.item() for idx in top_indices] + + success_lst = [] + for interval in remove_intervals: + results = train_with_configurations( + dataset=train_dataset, + top_indices=top_indices, + valid_dataset=eval_dataset.select([valid_idx]), + interval=interval, + seed_ids=seed_ids, + ) + if np.array(results).mean() < 0.5: + success_lst.append(1) + break + else: + success_lst.append(0) + + while len(success_lst) < len(remove_intervals): + success_lst.append(1) + + identity_results.append(success_lst) + + end_time = time.time() + print(f"Took {end_time - start_time} seconds for the identity baseline.") + print(f"Results: {identity_results}") + + +if __name__ == "__main__": + main() diff --git a/examples/glue/train.py b/examples/glue/train.py index b6dacf6..f2f6950 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -71,7 +71,6 @@ def parse_args(): default="./checkpoints", help="A path to store the final checkpoint.", ) - args = parser.parse_args() if args.checkpoint_dir is not None: @@ -105,13 +104,13 @@ def train( for epoch in range(num_train_epochs): total_loss = 0.0 for batch in train_dataloader: + optimizer.zero_grad(set_to_none=True) loss = model( input_ids=batch["input_ids"].to(device=DEVICE), attention_mask=batch["attention_mask"].to(device=DEVICE), token_type_ids=batch["token_type_ids"].to(device=DEVICE), labels=batch["labels"].to(device=DEVICE), ).loss - optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.detach().float()