From 0a6cf5bc6e5862bc821fa2a0dfc6c8028548b9b2 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 20 Mar 2024 02:00:48 -0400 Subject: [PATCH] Polish examples --- examples/cifar/README.md | 9 +- examples/cifar/analyze.py | 2 +- examples/cifar/detect_mislabeled_dataset.py | 6 +- examples/glue/README.md | 34 ++++ examples/glue/analyze.py | 168 ++++++++++++++++++++ examples/glue/pipeline.py | 12 +- examples/glue/requirements.txt | 1 + examples/glue/train.py | 15 +- examples/imagenet/README.md | 40 +---- examples/imagenet/analyze.py | 6 +- examples/imagenet/ddp_analyze.py | 112 ++++--------- examples/uci/README.md | 3 +- examples/wikitext/README.md | 34 ++++ examples/wikitext/analysis.py | 0 examples/wikitext/pipeline.py | 110 +++++++++++++ examples/wikitext/train.py | 0 16 files changed, 412 insertions(+), 140 deletions(-) create mode 100644 examples/glue/README.md create mode 100644 examples/glue/analyze.py create mode 100644 examples/glue/requirements.txt create mode 100644 examples/wikitext/README.md create mode 100644 examples/wikitext/analysis.py create mode 100644 examples/wikitext/pipeline.py create mode 100644 examples/wikitext/train.py diff --git a/examples/cifar/README.md b/examples/cifar/README.md index bc79945..96f6e95 100644 --- a/examples/cifar/README.md +++ b/examples/cifar/README.md @@ -1,11 +1,14 @@ # CIFAR-10 & ResNet-9 Example -This directory contains scripts designed for training ResNet-9 on CIFAR-10. The pipeline is motivated from -[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). +This directory contains scripts for training ResNet-9 on CIFAR-10. The pipeline is motivated from +[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). Please begin by installing necessary packages. +```bash +pip install -r requirements.txt +``` ## Training -To train the model on the CIFAR-10 dataset, run the following command: +To train ResNet-9 on CIFAR-10 dataset, run the following command: ```bash python train.py --dataset_dir ./data \ --checkpoint_dir ./checkpoints \ diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index 292b102..62eea03 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -5,11 +5,11 @@ import torch import torch.nn.functional as F -from kronfluence.arguments import FactorArguments from torch import nn from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.arguments import FactorArguments from kronfluence.task import Task from kronfluence.utils.dataset import DataLoaderKwargs diff --git a/examples/cifar/detect_mislabeled_dataset.py b/examples/cifar/detect_mislabeled_dataset.py index e496f1c..d48b4b0 100644 --- a/examples/cifar/detect_mislabeled_dataset.py +++ b/examples/cifar/detect_mislabeled_dataset.py @@ -3,11 +3,11 @@ import os import torch -from kronfluence.arguments import FactorArguments + from examples.cifar.analyze import ClassificationTask from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset from kronfluence.analyzer import Analyzer, prepare_model - +from kronfluence.arguments import FactorArguments from kronfluence.utils.dataset import DataLoaderKwargs @@ -102,7 +102,7 @@ def main(): accuracies = [] for interval in intervals: interval = interval.item() - predicted_indices = torch.argsort(scores, descending=True)[:int(interval * len(train_dataset))] + predicted_indices = torch.argsort(scores, descending=True)[: int(interval * len(train_dataset))] predicted_indices = list(predicted_indices.numpy()) accuracies.append(len(set(predicted_indices) & set(corrupted_indices)) / total_corrupt_size) diff --git a/examples/glue/README.md b/examples/glue/README.md new file mode 100644 index 0000000..a7a8ea4 --- /dev/null +++ b/examples/glue/README.md @@ -0,0 +1,34 @@ +# GLUE & BERT Example + +This directory contains scripts for fine-tuning BERT on GLUE benchmark. The pipeline is motivated from [HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification). +Please begin by installing necessary packages. +```bash +pip install -r requirements.txt +``` + +## Training + +To fine-tune BERT on some specific dataset, run the following command (we are using `SST2` dataset): +```bash +python train.py --dataset_name sst2 \ + --checkpoint_dir ./checkpoints \ + --train_batch_size 32 \ + --eval_batch_size 32 \ + --learning_rate 3e-05 \ + --weight_decay 0.01 \ + --num_train_epochs 3 \ + --seed 1004 +``` + +## Computing Pairwise Influence Scores + +To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command: +```bash +python analyze.py --query_batch_size 1000 \ + --dataset_dir ./data \ + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac +``` +You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the +pairwise scores (including computing EKFAC factors). + diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py new file mode 100644 index 0000000..49d5a0b --- /dev/null +++ b/examples/glue/analyze.py @@ -0,0 +1,168 @@ +import argparse +import logging +import os +from typing import Tuple, Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from examples.glue.pipeline import construct_bert, get_glue_dataset +from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.arguments import FactorArguments +from kronfluence.task import Task +from kronfluence.utils.dataset import DataLoaderKwargs + +BATCH_TYPE = Dict[str, torch.Tensor] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Influence analysis on CIFAR-10 dataset.") + + parser.add_argument( + "--dataset_name", + type=str, + default="sst2", + help="A name of GLUE dataset.", + ) + + parser.add_argument( + "--query_batch_size", + type=int, + default=32, + help="Batch size for computing query gradients.", + ) + + parser.add_argument( + "--checkpoint_dir", + type=str, + default="./checkpoints", + help="A path to store the final checkpoint.", + ) + + parser.add_argument( + "--factor_strategy", + type=str, + default="ekfac", + help="Strategy to compute preconditioning factors.", + ) + + args = parser.parse_args() + + if args.checkpoint_dir is not None: + os.makedirs(args.checkpoint_dir, exist_ok=True) + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset_name) + os.makedirs(args.checkpoint_dir, exist_ok=True) + + return args + + +class TextClassificationTask(Task): + def compute_train_loss( + self, + batch: BATCH_TYPE, + model: nn.Module, + sample: bool = False, + ) -> torch.Tensor: + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + ).logits + + if not sample: + return F.cross_entropy( + logits, batch["labels"], reduction="sum" + ) + with torch.no_grad(): + probs = torch.nn.functional.softmax(logits, dim=-1) + sampled_labels = torch.multinomial( + probs, num_samples=1, + ).flatten() + return F.cross_entropy( + logits, sampled_labels.detach(), reduction="sum" + ) + + def compute_measurement( + self, + batch: BATCH_TYPE, + model: nn.Module, + ) -> torch.Tensor: + # Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. + logits = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + ).logits + + labels = batch["labels"] + bindex = torch.arange(logits.shape[0]).to( + device=logits.device, non_blocking=False + ) + logits_correct = logits[bindex, labels] + + cloned_logits = logits.clone() + cloned_logits[bindex, labels] = torch.tensor( + -torch.inf, device=logits.device, dtype=logits.dtype + ) + + margins = logits_correct - cloned_logits.logsumexp(dim=-1) + return -margins.sum() + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.INFO) + + train_dataset = get_glue_dataset( + data_name=args.data_name, split="eval_train", + ) + eval_dataset = get_glue_dataset( + data_name=args.data_name, split="valid", + ) + + model = construct_bert() + model_name = "model" + if args.corrupt_percentage is not None: + model_name += "_corrupt_" + str(args.corrupt_percentage) + checkpoint_path = os.path.join(args.checkpoint_dir, f"{model_name}.pth") + if not os.path.isfile(checkpoint_path): + raise ValueError(f"No checkpoint found at {checkpoint_path}.") + model.load_state_dict(torch.load(checkpoint_path)) + + task = TextClassificationTask() + model = prepare_model(model, task) + + analyzer = Analyzer( + analysis_name="cifar10", + model=model, + task=task, + cpu=False, + ) + + dataloader_kwargs = DataLoaderKwargs(num_workers=4) + analyzer.set_dataloader_kwargs(dataloader_kwargs) + + factor_args = FactorArguments(strategy=args.factor_strategy) + analyzer.fit_all_factors( + factors_name=args.factor_strategy, + dataset=train_dataset, + per_device_batch_size=None, + factor_args=factor_args, + overwrite_output_dir=True, + ) + analyzer.compute_pairwise_scores( + scores_name="pairwise", + factors_name=args.factor_strategy, + query_dataset=eval_dataset, + query_indices=list(range(2000)), + train_dataset=train_dataset, + per_device_query_batch_size=args.query_batch_size, + overwrite_output_dir=True, + ) + scores = analyzer.load_pairwise_scores("pairwise") + print(scores) + + +if __name__ == "__main__": + main() diff --git a/examples/glue/pipeline.py b/examples/glue/pipeline.py index d38d44e..d92bc67 100644 --- a/examples/glue/pipeline.py +++ b/examples/glue/pipeline.py @@ -1,9 +1,6 @@ -import os from typing import List -import torch import torch.nn as nn -import torchvision from datasets import load_dataset from torch.utils.data import Dataset from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer @@ -41,14 +38,12 @@ def get_glue_dataset( data_name: str, split: str, indices: List[int] = None, - dataset_dir: str = "data/", ) -> Dataset: assert split in ["train", "eval_train", "valid"] raw_datasets = load_dataset( path="glue", name=data_name, - # data_dir=dataset_dir, ) label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) @@ -86,3 +81,10 @@ def preprocess_function(examples): ds = ds.select(indices) return ds + + +if __name__ == "__main__": + from kronfluence import Analyzer + + model = construct_bert() + print(Analyzer.get_module_summary(model)) diff --git a/examples/glue/requirements.txt b/examples/glue/requirements.txt new file mode 100644 index 0000000..c30592a --- /dev/null +++ b/examples/glue/requirements.txt @@ -0,0 +1 @@ +evaluate \ No newline at end of file diff --git a/examples/glue/train.py b/examples/glue/train.py index 675233f..0cb626c 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -26,12 +26,6 @@ def parse_args(): default="sst2", help="A name of GLUE dataset.", ) - parser.add_argument( - "--dataset_dir", - type=str, - default="./data", - help="A folder to download or load GLUE dataset.", - ) parser.add_argument( "--train_batch_size", @@ -103,6 +97,7 @@ def train( drop_last=True, collate_fn=default_data_collator, ) + model = construct_bert().to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) @@ -140,7 +135,7 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> batch["input_ids"].to(device=DEVICE), batch["token_type_ids"].to(device=DEVICE), batch["attention_mask"].to(device=DEVICE), - ) + ).logits labels = batch["labels"].to(device=DEVICE) total_loss += F.cross_entropy(outputs, labels, reduction="sum").detach().item() predictions = outputs.argmax(dim=-1) @@ -160,7 +155,7 @@ def main(): if args.seed is not None: set_seed(args.seed) - train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir) + train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train") model = train( dataset=train_dataset, batch_size=args.train_batch_size, @@ -169,11 +164,11 @@ def main(): weight_decay=args.weight_decay, ) - eval_train_dataset = get_glue_dataset(data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir) + eval_train_dataset = get_glue_dataset(data_name=args.dataset_name, split="eval_train") train_loss, train_acc = evaluate_model(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size) logger.info(f"Train loss: {train_loss}, Train Accuracy: {train_acc}") - eval_dataset = get_glue_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir) + eval_dataset = get_glue_dataset(data_name=args.dataset_name, split="valid") eval_loss, eval_acc = evaluate_model(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size) logger.info(f"Evaluation loss: {eval_loss}, Evaluation Accuracy: {eval_acc}") diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md index bbacd1c..ce4110a 100644 --- a/examples/imagenet/README.md +++ b/examples/imagenet/README.md @@ -1,8 +1,3 @@ -```bash -torchrun --standalone --nnodes=1 --nproc-per-node=4 ddp_analyze.py -``` - - # ImageNet & ResNet-50 Example This directory contains scripts for training ResNet-50 on ImageNet. @@ -13,40 +8,19 @@ We will use the pre-trained dataset from `torchvision.models.resnet50`. ## Computing Pairwise Influence Scores -To obtain a pairwise influence scores on 2000 query data points using `ekfac`, run the following command: +To obtain a pairwise influence scores on 1000 query data points using `ekfac`, run the following command: ```bash python analyze.py --dataset_dir /mfs1/datasets/imagenet_pytorch/ \ - --query_gradient_rank None \ + --query_gradient_rank -1 \ --query_batch_size 100 \ - --train_batch_size 128 \ + --train_batch_size 256 \ --factor_strategy ekfac ``` -You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the -pairwise scores (including computing EKFAC factors). +On A100 (80GB), it takes roughly 1.5 minutes to compute the pairwise scores (including computing EKFAC factors). -## Mislabeled Data Detection -First, train the model with 10% of training dataset mislabeled by running the following command: -```bash -python train.py --dataset_dir ./data \ - --corrupt_percentage 0.1 \ - --checkpoint_dir ./checkpoints \ - --train_batch_size 512 \ - --eval_batch_size 1024 \ - --learning_rate 0.4 \ - --weight_decay 0.0001 \ - --num_train_epochs 25 \ - --seed 1004 -``` +## Computing Pairwise Influence Scores with DDP -Then, compute self-influence scores with the following command: ```bash -python detect_mislabeled_dataset.py --dataset_dir ./data \ - --corrupt_percentage 0.1 \ - --checkpoint_dir ./checkpoints \ - --factor_strategy ekfac -``` - -On A100 (80GB), it takes roughly 1.5 minutes to compute the self-influence scores (including computing EKFAC factors). -We can detect around 82% of mislabeled data points by inspecting 10% of the dataset using self-influence scores -(96% by inspecting 20%). \ No newline at end of file +torchrun --standalone --nnodes=1 --nproc-per-node=4 ddp_analyze.py +``` \ No newline at end of file diff --git a/examples/imagenet/analyze.py b/examples/imagenet/analyze.py index fd9e365..1fce92a 100644 --- a/examples/imagenet/analyze.py +++ b/examples/imagenet/analyze.py @@ -8,9 +8,10 @@ from arguments import FactorArguments, ScoreArguments from task import Task from torch import nn -from examples.imagenet.pipeline import construct_resnet50, get_imagenet_dataset from utils.dataset import DataLoaderKwargs +from examples.imagenet.pipeline import construct_resnet50, get_imagenet_dataset + BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor] @@ -93,14 +94,12 @@ def compute_measurement( def main(): args = parse_args() - logging.basicConfig(level=logging.INFO) train_dataset = get_imagenet_dataset(split="eval_train", dataset_dir=args.dataset_dir) eval_dataset = get_imagenet_dataset(split="valid", dataset_dir=args.dataset_dir) model = construct_resnet50() - task = ClassificationTask() model = prepare_model(model, task) @@ -109,7 +108,6 @@ def main(): model=model, task=task, ) - dataloader_kwargs = DataLoaderKwargs( num_workers=4, ) diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index 0e3d0fd..dbbafef 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -5,14 +5,12 @@ import torch import torch.distributed as dist -import torch.nn.functional as F -from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel +from examples.imagenet.analyze import ClassificationTask from examples.imagenet.pipeline import construct_resnet50, get_imagenet_dataset from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments -from kronfluence.task import Task +from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.utils.dataset import DataLoaderKwargs torch.backends.cudnn.benchmark = True @@ -33,28 +31,28 @@ def parse_args(): ) parser.add_argument( - "--factor_strategy", - type=str, - default="ekfac", - help="Strategy to compute preconditioning factors.", + "--query_gradient_rank", + type=int, + default=-1, + help="Rank for the low-rank query gradient approximation.", ) parser.add_argument( "--covariance_batch_size", type=int, default=512, - help="Batch size for computing covariance matrices.", + help="Batch size for computing query gradients.", ) parser.add_argument( "--lambda_batch_size", type=int, - default=256, - help="Batch size for computing Lambda matrices.", + default=512, + help="Batch size for computing query gradients.", ) parser.add_argument( "--query_batch_size", type=int, - default=64, - help="Batch size for computing query gradient.", + default=100, + help="Batch size for computing query gradients.", ) parser.add_argument( "--train_batch_size", @@ -62,50 +60,17 @@ def parse_args(): default=128, help="Batch size for computing training gradient.", ) + parser.add_argument( + "--factor_strategy", + type=str, + default="ekfac", + help="Strategy to compute preconditioning factors.", + ) args = parser.parse_args() return args -class ClassificationTask(Task): - def compute_train_loss( - self, - batch: BATCH_DTYPE, - model: nn.Module, - sample: bool = False, - ) -> torch.Tensor: - inputs, labels = batch - logits = model(inputs) - - if not sample: - return F.cross_entropy(logits, labels, reduction="sum") - with torch.no_grad(): - probs = torch.nn.functional.softmax(logits, dim=-1) - sampled_labels = torch.multinomial( - probs, - num_samples=1, - ).flatten() - return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum") - - def compute_measurement( - self, - batch: BATCH_DTYPE, - model: nn.Module, - ) -> torch.Tensor: - # Copied from https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. - inputs, labels = batch - logits = model(inputs) - - bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) - logits_correct = logits[bindex, labels] - - cloned_logits = logits.clone() - cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) - - margins = logits_correct - cloned_logits.logsumexp(dim=-1) - return -margins.sum() - - def main(): args = parse_args() logging.basicConfig(level=logging.INFO) @@ -124,53 +89,42 @@ def main(): model = DistributedDataParallel(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) analyzer = Analyzer( - analysis_name="ddp", + analysis_name="imagenet_ddp", model=model, task=task, - profile=True, - disable_model_save=True, ) dataloader_kwargs = DataLoaderKwargs( - num_workers=2, - pin_memory=True, - prefetch_factor=2, + num_workers=4, ) + analyzer.set_dataloader_kwargs(dataloader_kwargs) factor_args = FactorArguments( strategy=args.factor_strategy, ) - analyzer.fit_covariance_matrices( + analyzer.fit_all_factors( factors_name=args.factor_strategy, dataset=train_dataset, + per_device_batch_size=None, factor_args=factor_args, - per_device_batch_size=args.covariance_batch_size, - dataloader_kwargs=dataloader_kwargs, overwrite_output_dir=False, ) - analyzer.perform_eigendecomposition( - factors_name=args.factor_strategy, - factor_args=factor_args, - overwrite_output_dir=False, - ) - analyzer.fit_lambda_matrices( - factors_name=args.factor_strategy, - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=args.lambda_batch_size, - dataloader_kwargs=dataloader_kwargs, - overwrite_output_dir=False, - ) - scores = analyzer.compute_pairwise_scores( - scores_name="pairwise", + + rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None + score_args = ScoreArguments(query_gradient_rank=rank) + scores_name = "pairwise" + if rank is not None: + scores_name += f"_qlr{rank}" + analyzer.compute_pairwise_scores( + score_args=score_args, + scores_name=scores_name, factors_name=args.factor_strategy, query_dataset=eval_dataset, + query_indices=list(range(1000)), train_dataset=train_dataset, - per_device_train_batch_size=args.train_batch_size, per_device_query_batch_size=args.query_batch_size, - query_indices=list(range(1000)), + per_device_train_batch_size=args.train_batch_size, overwrite_output_dir=False, ) - logging.info(f"Scores: {scores}") if __name__ == "__main__": diff --git a/examples/uci/README.md b/examples/uci/README.md index d673507..27cf2c4 100644 --- a/examples/uci/README.md +++ b/examples/uci/README.md @@ -1,8 +1,7 @@ # UCI Regression Example -This directory contains scripts designed for training a regression model and conducting influence analysis with +This directory contains scripts for training a regression model and conducting influence analysis with datasets from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/datasets). Please begin by installing necessary packages. - ```bash pip install -r requirements.txt ``` diff --git a/examples/wikitext/README.md b/examples/wikitext/README.md new file mode 100644 index 0000000..a7a8ea4 --- /dev/null +++ b/examples/wikitext/README.md @@ -0,0 +1,34 @@ +# GLUE & BERT Example + +This directory contains scripts for fine-tuning BERT on GLUE benchmark. The pipeline is motivated from [HuggingFace Example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification). +Please begin by installing necessary packages. +```bash +pip install -r requirements.txt +``` + +## Training + +To fine-tune BERT on some specific dataset, run the following command (we are using `SST2` dataset): +```bash +python train.py --dataset_name sst2 \ + --checkpoint_dir ./checkpoints \ + --train_batch_size 32 \ + --eval_batch_size 32 \ + --learning_rate 3e-05 \ + --weight_decay 0.01 \ + --num_train_epochs 3 \ + --seed 1004 +``` + +## Computing Pairwise Influence Scores + +To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command: +```bash +python analyze.py --query_batch_size 1000 \ + --dataset_dir ./data \ + --checkpoint_dir ./checkpoints \ + --factor_strategy ekfac +``` +You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the +pairwise scores (including computing EKFAC factors). + diff --git a/examples/wikitext/analysis.py b/examples/wikitext/analysis.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/wikitext/pipeline.py b/examples/wikitext/pipeline.py new file mode 100644 index 0000000..b8341ec --- /dev/null +++ b/examples/wikitext/pipeline.py @@ -0,0 +1,110 @@ +from itertools import chain +from typing import List + +import torch +import torch.nn as nn +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, +) +from transformers.pytorch_utils import Conv1D + + +def replace_conv1d_modules(model: nn.Module) -> None: + # GPT-2 is defined in terms of Conv1D. However, this does not work for Kronfluence. + # Here, we convert these Conv1D modules to linear modules recursively. + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_conv1d_modules(module) + + if isinstance(module, Conv1D): + new_module = nn.Linear( + in_features=module.weight.shape[0], out_features=module.weight.shape[1] + ) + new_module.weight.data.copy_(module.weight.data.t()) + new_module.bias.data.copy_(module.bias.data) + setattr(model, name, new_module) + + +def construct_gpt2() -> nn.Module: + config = AutoConfig.from_pretrained( + "gpt2", + trust_remote_code=True, + ) + model = AutoModelForCausalLM.from_pretrained( + "gpt2", + from_tf=False, + config=config, + ignore_mismatched_sizes=False, + trust_remote_code=True, + ) + replace_conv1d_modules(model) + return model + + +def get_wikitext_dataset( + split: str, + indices: List[int] = None, +) -> torch.utils.data.DataLoader: + assert split in ["train", "eval_train", "valid"] + + raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1") + tokenizer = AutoTokenizer.from_pretrained( + "gpt2", use_fast=True, trust_remote_code=True + ) + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=None, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on dataset", + ) + block_size = 512 + + def group_texts(examples): + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=None, + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {block_size}", + ) + + if split == "train" or split == "eval_train": + train_dataset = lm_datasets["train"] + ds = train_dataset + else: + eval_dataset = lm_datasets["validation"] + ds = eval_dataset + + if indices is not None: + ds = ds.select(indices) + + return ds + + +if __name__ == "__main__": + from kronfluence import Analyzer + + model = construct_gpt2() + print(Analyzer.get_module_summary(model)) diff --git a/examples/wikitext/train.py b/examples/wikitext/train.py new file mode 100644 index 0000000..e69de29