From d863fbd6892e078fb3b3176e3f1b45790253dc31 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Sun, 17 Mar 2024 03:16:21 -0400 Subject: [PATCH] Clean up the code to avoid repeats --- examples/_test_requirements.txt | 3 +- examples/cifar/pipeline.py | 19 +- examples/cifar/train.py | 190 ++++++ examples/glue/pipeline.py | 7 +- examples/glue/train.py | 132 +++- examples/imagenet/ddp_analyze.py | 7 +- examples/uci/analyze.py | 27 +- examples/uci/train.py | 12 +- kronfluence/analyzer.py | 6 +- kronfluence/arguments.py | 30 + kronfluence/computer/computer.py | 184 ++++-- kronfluence/computer/covariance_computer.py | 336 ---------- kronfluence/computer/eigen_computer.py | 454 -------------- kronfluence/computer/factor_computer.py | 653 ++++++++++++++++++++ kronfluence/computer/score_computer.py | 442 +++++++++++++ kronfluence/factor/config.py | 35 +- kronfluence/factor/covariance.py | 66 +- kronfluence/factor/eigen.py | 110 ++-- kronfluence/module/conv2d.py | 23 +- kronfluence/module/linear.py | 20 +- kronfluence/module/tracked_module.py | 201 +++--- kronfluence/module/utils.py | 23 +- kronfluence/score/pairwise.py | 88 ++- kronfluence/score/self.py | 2 +- kronfluence/task.py | 2 +- kronfluence/utils/dataset.py | 20 +- kronfluence/utils/logger.py | 168 +++-- kronfluence/utils/save.py | 4 +- kronfluence/utils/state.py | 24 +- requirements.txt | 4 +- tests/factors/test_covariances.py | 6 + tests/factors/test_eigens.py | 8 +- tests/test_dataset_utils.py | 68 ++ tests/test_module_utils.py | 0 tests/testable_tasks/language_modeling.py | 2 +- 35 files changed, 2073 insertions(+), 1303 deletions(-) create mode 100644 examples/cifar/train.py delete mode 100644 kronfluence/computer/covariance_computer.py delete mode 100644 kronfluence/computer/eigen_computer.py create mode 100644 kronfluence/computer/factor_computer.py create mode 100644 kronfluence/computer/score_computer.py create mode 100644 tests/test_dataset_utils.py create mode 100644 tests/test_module_utils.py diff --git a/examples/_test_requirements.txt b/examples/_test_requirements.txt index 4a67121..ec96d22 100644 --- a/examples/_test_requirements.txt +++ b/examples/_test_requirements.txt @@ -1,2 +1,3 @@ scikit-learn -jupyter \ No newline at end of file +jupyter +evaluate \ No newline at end of file diff --git a/examples/cifar/pipeline.py b/examples/cifar/pipeline.py index 3aae10d..420d1fd 100644 --- a/examples/cifar/pipeline.py +++ b/examples/cifar/pipeline.py @@ -1,6 +1,6 @@ import copy import math -from typing import Dict, List, Optional, Tuple +from typing import List, Optional import datasets import numpy as np @@ -33,15 +33,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def construct_resnet9() -> nn.Module: + # ResNet-9 architecture from: https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb. def conv_bn( channels_in: int, channels_out: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, - groups=1, + groups: int = 1, ) -> nn.Module: - assert groups == 1 return torch.nn.Sequential( torch.nn.Conv2d( channels_in, @@ -74,9 +74,9 @@ def conv_bn( def get_cifar10_dataset( split: str, - do_corrupt: bool, indices: List[int] = None, - data_dir: str = "data/", + corrupt_percentage: Optional[float] = None, + dataset_dir: str = "data/", ) -> datasets.Dataset: assert split in ["train", "eval_train", "valid"] @@ -99,16 +99,17 @@ def get_cifar10_dataset( ) dataset = torchvision.datasets.CIFAR10( - root=data_dir, + root=dataset_dir, download=True, - train=split in ["train", "eval_train", "eval_train_with_aug"], + train=split in ["train", "eval_train"], transform=transform_config, ) - if do_corrupt: + if corrupt_percentage is not None: if split == "valid": raise NotImplementedError("Performing corruption on the validation dataset is not supported.") - num_corrupt = math.ceil(len(dataset) * 0.1) + assert 0.0 < corrupt_percentage <= 1.0 + num_corrupt = math.ceil(len(dataset) * corrupt_percentage) original_targets = np.array(copy.deepcopy(dataset.targets[:num_corrupt])) new_targets = torch.randint( 0, diff --git a/examples/cifar/train.py b/examples/cifar/train.py new file mode 100644 index 0000000..fdae844 --- /dev/null +++ b/examples/cifar/train.py @@ -0,0 +1,190 @@ +import argparse +import logging +import os +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from accelerate.utils import set_seed +from torch import nn +from torch.optim import lr_scheduler +from torch.utils import data +from tqdm import tqdm + +from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ResNet-9 model on CIFAR-10 dataset.") + + parser.add_argument( + "--corrupt_percentage", + type=float, + default=None, + help="Percentage of the training dataset to corrupt.", + ) + parser.add_argument( + "--dataset_dir", + type=str, + default="./data", + help="A folder to download or load CIFAR-10 dataset.", + ) + + parser.add_argument( + "--train_batch_size", + type=int, + default=512, + help="Batch size for the training dataloader.", + ) + parser.add_argument( + "--eval_batch_size", + type=int, + default=1024, + help="Batch size for the evaluation dataloader.", + ) + + parser.add_argument( + "--learning_rate", + type=float, + default=0.4, + help="Initial learning rate to train the model.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.001, + help="Weight decay to train the model.", + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=25, + help="Total number of epochs to train the model.", + ) + + parser.add_argument( + "--seed", + type=int, + default=1004, + help="A seed for reproducible training pipeline.", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default="./checkpoints", + help="A path to store the final checkpoint.", + ) + + args = parser.parse_args() + + if args.checkpoint_dir is not None: + os.makedirs(args.checkpoint_dir, exist_ok=True) + + return args + + +def train( + dataset: data.Dataset, + batch_size: int, + num_train_epochs: int, + learning_rate: float, + weight_decay: float, + disable_tqdm: bool = False, +) -> nn.Module: + train_dataloader = data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + ) + + model = construct_resnet9().to(DEVICE) + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + iters_per_epoch = len(train_dataloader) + lr_peak_epoch = num_train_epochs // 4 + lr_schedule = np.interp( + np.arange((num_train_epochs + 1) * iters_per_epoch), + [0, lr_peak_epoch * iters_per_epoch, num_train_epochs * iters_per_epoch], + [0, 1, 0], + ) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__) + + model.train() + for epoch in range(num_train_epochs): + total_loss = 0.0 + with tqdm(train_dataloader, unit="batch", disable=disable_tqdm) as tepoch: + for batch in tepoch: + tepoch.set_description(f"Epoch {epoch}") + model.zero_grad() + inputs, labels = batch + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + outputs = model(inputs) + loss = F.cross_entropy(outputs, labels) + loss.backward() + optimizer.step() + scheduler.step() + total_loss += loss.detach().float() + tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader)) + return model + + +def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> Tuple[float, float]: + dataloader = data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + ) + + model.eval() + total_loss, total_correct = 0.0, 0 + for batch in dataloader: + with torch.no_grad(): + inputs, labels = batch + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + outputs = model(inputs) + loss = F.cross_entropy(outputs, labels, reduction="sum") + total_loss += loss.detach().float() + total_correct += outputs.detach().argmax(1).eq(labels).sum() + + return total_loss.item() / len(dataloader.dataset), total_correct.item() / len(dataloader.dataset) + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger() + + if args.seed is not None: + set_seed(args.seed) + + train_dataset = get_cifar10_dataset(split="train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir) + model = train( + dataset=train_dataset, + batch_size=args.train_batch_size, + num_train_epochs=args.num_train_epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + ) + + eval_train_dataset = get_cifar10_dataset(split="eval_train", dataset_dir=args.dataset_dir) + train_loss, train_acc = evaluate(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size) + logger.info(f"Train loss: {train_loss}, Train Accuracy: {train_acc}") + + eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir) + eval_loss, eval_acc = evaluate(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size) + logger.info(f"Evaluation loss: {eval_loss}, Evaluation Accuracy: {eval_acc}") + + if args.checkpoint_dir is not None: + model_name = "model" + if args.corrupt_percentage is not None: + model_name += "_corrupt_" + str(args.corrupt_percentage) + torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f"{model_name}.pth")) + + +if __name__ == "__main__": + main() diff --git a/examples/glue/pipeline.py b/examples/glue/pipeline.py index 31d23ae..d38d44e 100644 --- a/examples/glue/pipeline.py +++ b/examples/glue/pipeline.py @@ -21,14 +21,13 @@ } -def construct_bert(data_name) -> nn.Module: +def construct_bert(data_name: str = "sst2") -> nn.Module: config = AutoConfig.from_pretrained( "bert-base-cased", num_labels=2, finetuning_task=data_name, trust_remote_code=True, ) - return AutoModelForSequenceClassification.from_pretrained( "bert-base-cased", from_tf=False, @@ -42,14 +41,14 @@ def get_glue_dataset( data_name: str, split: str, indices: List[int] = None, - data_path: str = "data/", + dataset_dir: str = "data/", ) -> Dataset: assert split in ["train", "eval_train", "valid"] raw_datasets = load_dataset( path="glue", name=data_name, - data_dir=data_path, + # data_dir=dataset_dir, ) label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) diff --git a/examples/glue/train.py b/examples/glue/train.py index 54c6c2a..d03e59b 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -1,57 +1,65 @@ import argparse import logging import os +from typing import Tuple + +import evaluate import torch import torch.nn.functional as F +from transformers import default_data_collator + +import torch from accelerate.utils import set_seed -from torch.utils.data import DataLoader +from torch import nn +from torch.utils import data from tqdm import tqdm -from transformers import default_data_collator + from examples.glue.pipeline import construct_bert, get_glue_dataset +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def parse_args(): parser = argparse.ArgumentParser(description="Train text classification models on GLUE datasets.") parser.add_argument( "--dataset_name", type=str, - default="sst", - help="A folder containing the MNIST dataset.", + default="sst2", + help="A name of GLUE dataset.", ) - parser.add_argument( "--dataset_dir", type=str, default="./data", - help="A folder containing the MNIST dataset.", + help="A folder to download or load GLUE dataset.", ) parser.add_argument( "--train_batch_size", type=int, - default=128, + default=32, help="Batch size for the training dataloader.", ) parser.add_argument( "--eval_batch_size", type=int, - default=512, + default=32, help="Batch size for the evaluation dataloader.", ) parser.add_argument( "--learning_rate", type=float, - default=0.03, + default=3e-05, help="Fixed learning rate to train the model.", ) parser.add_argument( "--weight_decay", type=float, - default=1e-4, + default=0.01, help="Weight decay to train the model.", ) parser.add_argument( @@ -78,27 +86,111 @@ def parse_args(): if args.checkpoint_dir is not None: os.makedirs(args.checkpoint_dir, exist_ok=True) + args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset_name) + os.makedirs(args.checkpoint_dir, exist_ok=True) return args +def train( + dataset: data.Dataset, + batch_size: int, + num_train_epochs: int, + learning_rate: float, + weight_decay: float, + disable_tqdm: bool = False, +) -> nn.Module: + train_dataloader = data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + collate_fn=default_data_collator, + ) + model = construct_bert().to(DEVICE) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + model.train() + for epoch in range(num_train_epochs): + total_loss = 0.0 + with tqdm(train_dataloader, unit="batch", disable=disable_tqdm) as tepoch: + for batch in tepoch: + tepoch.set_description(f"Epoch {epoch}") + model.zero_grad() + outputs = model( + input_ids=batch["input_ids"].to(device=DEVICE), + attention_mask=batch["attention_mask"].to(device=DEVICE), + token_type_ids=batch["token_type_ids"].to(device=DEVICE), + ).logits + loss = F.cross_entropy(outputs, batch["labels"].to(device=DEVICE)) + total_loss += loss.detach().float() + loss.backward() + optimizer.step() + tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader)) + return model + + +def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> Tuple[float, float]: + dataloader = data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=default_data_collator + ) + + model.eval() + metric = evaluate.load("glue", "sst2") + total_loss = 0.0 + for batch in dataloader: + with torch.no_grad(): + outputs = model( + batch["input_ids"].to(device=DEVICE), + batch["token_type_ids"].to(device=DEVICE), + batch["attention_mask"].to(device=DEVICE), + ) + labels = batch["labels"].to(device=DEVICE) + total_loss += F.cross_entropy(outputs, labels, reduction="sum").detach().item() + predictions = outputs.argmax(dim=-1) + metric.add_batch( + predictions=predictions, + references=labels, + ) + eval_metric = metric.compute() + return total_loss.item() / len(dataloader.dataset), eval_metric["accuracy"] + + def main(): args = parse_args() - logging.basicConfig(level=logging.INFO) + logger = logging.getLogger() if args.seed is not None: set_seed(args.seed) - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) - # train_dataloader = DataLoader( - # dataset=train_dataset, - # batch_size=args.train_batch_size, - # shuffle=True, - # collate_fn=default_data_collator, - # drop_last=True, - # ) + train_dataset = get_glue_dataset(data_name=args.dataset_name, + split="train", + dataset_dir=args.dataset_dir) + model = train( + dataset=train_dataset, + batch_size=args.train_batch_size, + num_train_epochs=args.num_train_epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + ) + + eval_train_dataset = get_glue_dataset( + data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir + ) + train_loss, train_acc = evaluate_model(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size) + logger.info(f"Train loss: {train_loss}, Train Accuracy: {train_acc}") + + eval_dataset = get_glue_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir) + eval_loss, eval_acc = evaluate_model(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size) + logger.info(f"Evaluation loss: {eval_loss}, Evaluation Accuracy: {eval_acc}") + + if args.checkpoint_dir is not None: + torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth")) if __name__ == "__main__": diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index 8fba262..0e3d0fd 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -6,13 +6,13 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments -from kronfluence.task import Task from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel from examples.imagenet.pipeline import construct_resnet50, get_imagenet_dataset +from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.arguments import FactorArguments +from kronfluence.task import Task from kronfluence.utils.dataset import DataLoaderKwargs torch.backends.cudnn.benchmark = True @@ -68,7 +68,6 @@ def parse_args(): class ClassificationTask(Task): - def compute_train_loss( self, batch: BATCH_DTYPE, diff --git a/examples/uci/analyze.py b/examples/uci/analyze.py index cc12ebd..744ead1 100644 --- a/examples/uci/analyze.py +++ b/examples/uci/analyze.py @@ -83,7 +83,9 @@ def main(): args = parse_args() logging.basicConfig(level=logging.INFO) - train_dataset = get_regression_dataset(data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir) + train_dataset = get_regression_dataset( + data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir + ) eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir) model = construct_regression_mlp() @@ -102,9 +104,7 @@ def main(): profile=True, cpu=True, ) - factor_args = FactorArguments( - strategy=args.factor_strategy, - ) + factor_args = FactorArguments(strategy=args.factor_strategy, lambda_iterative_aggregate=True) analyzer.fit_all_factors( factors_name=args.factor_strategy, dataset=train_dataset, @@ -113,15 +113,16 @@ def main(): overwrite_output_dir=True, ) - scores = analyzer.compute_pairwise_scores( - scores_name="pairwise", - factors_name=args.factor_strategy, - query_dataset=eval_dataset, - train_dataset=train_dataset, - per_device_query_batch_size=len(eval_dataset), - overwrite_output_dir=True, - ) - logging.info(f"Scores: {scores}") + # scores = analyzer.compute_pairwise_scores( + # scores_name="pairwise", + # factors_name=args.factor_strategy, + # query_dataset=eval_dataset, + # train_dataset=train_dataset, + # per_device_query_batch_size=len(eval_dataset), + # per_device_train_batch_size=8, + # overwrite_output_dir=True, + # ) + # logging.info(f"Scores: {scores}") if __name__ == "__main__": diff --git a/examples/uci/train.py b/examples/uci/train.py index 2ae0fdb..8a70077 100644 --- a/examples/uci/train.py +++ b/examples/uci/train.py @@ -89,6 +89,7 @@ def train( num_train_epochs: int, learning_rate: float, weight_decay: float, + disable_tqdm: bool = False, ) -> nn.Module: train_dataloader = data.DataLoader( dataset=dataset, @@ -96,23 +97,22 @@ def train( shuffle=True, drop_last=True, ) - model = construct_regression_mlp() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) model.train() for epoch in range(num_train_epochs): - total_loss = 0 - with tqdm(train_dataloader, unit="batch") as tepoch: + total_loss = 0.0 + with tqdm(train_dataloader, unit="batch", disable=disable_tqdm) as tepoch: for batch in tepoch: tepoch.set_description(f"Epoch {epoch}") + model.zero_grad() inputs, targets = batch outputs = model(inputs) loss = F.mse_loss(outputs, targets) total_loss += loss.detach().float() loss.backward() optimizer.step() - optimizer.zero_grad() tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader)) return model @@ -126,7 +126,7 @@ def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> float: ) model.eval() - total_loss = 0 + total_loss = 0.0 for batch in dataloader: with torch.no_grad(): inputs, targets = batch @@ -139,7 +139,6 @@ def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> float: def main(): args = parse_args() - logging.basicConfig(level=logging.INFO) logger = logging.getLogger() @@ -147,7 +146,6 @@ def main(): set_seed(args.seed) train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir) - model = train( dataset=train_dataset, batch_size=args.train_batch_size, diff --git a/kronfluence/analyzer.py b/kronfluence/analyzer.py index 98da37f..eec6512 100644 --- a/kronfluence/analyzer.py +++ b/kronfluence/analyzer.py @@ -1,13 +1,13 @@ from typing import Optional from accelerate.utils import extract_model_from_parallel +from computer.factor_computer import FactorComputer +from computer.score_computer import ScoreComputer from safetensors.torch import save_file from torch import nn from torch.utils import data from kronfluence.arguments import FactorArguments -from kronfluence.computer.covariance_computer import CovarianceComputer -from kronfluence.computer.eigen_computer import EigenComputer from kronfluence.computer.pairwise_score_computer import PairwiseScoreComputer from kronfluence.computer.self_score_computer import SelfScoreComputer from kronfluence.module.utils import wrap_tracked_modules @@ -40,7 +40,7 @@ def prepare_model( return model -class Analyzer(CovarianceComputer, EigenComputer, PairwiseScoreComputer, SelfScoreComputer): +class Analyzer(FactorComputer, ScoreComputer): """ Handles the computation of all factors (e.g., covariance and Lambda matrices for EKFAC) and influence scores for a given PyTorch model. diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index 9c9055f..488a647 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -18,6 +18,15 @@ def to_dict(self) -> Dict[str, Any]: config[key] = str(value) return config + def to_str_dict(self) -> Dict[str, str]: + """Converts the arguments to a dictionary, where all values are converted to strings.""" + config = copy.deepcopy(self.__dict__) + + for key, value in config.items(): + config[key] = str(value) + + return config + @dataclass class FactorArguments(Arguments): @@ -42,6 +51,20 @@ class FactorArguments(Arguments): default=False, metadata={"help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook."}, ) + ignore_bias: bool = field( + default=False, + metadata={ + "help": "Whether to use empirical fisher (using labels from batch) instead of " + "true Fisher (using sampled labels)." + }, + ) + distributed_sync_steps: int = field( + default=1_000, + metadata={ + "help": "Whether to use empirical fisher (using labels from batch) instead of " + "true Fisher (using sampled labels)." + }, + ) # Configuration for fitting covariance matrices. # covariance_max_examples: Optional[int] = field( @@ -155,6 +178,13 @@ class ScoreArguments(Arguments): default=False, metadata={"help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook."}, ) + ddp_sync_steps: int = field( + default=1_000, + metadata={ + "help": "Whether to use empirical fisher (using labels from batch) instead of " + "true Fisher (using sampled labels)." + }, + ) data_partition_size: int = field( default=1, diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index d1e8b20..30c3eb5 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch +from factor.config import FactorConfig from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import DataParallel @@ -12,7 +13,7 @@ from torch.utils import data from torch.utils.data import DistributedSampler, SequentialSampler -from kronfluence.arguments import Arguments, ScoreArguments +from kronfluence.arguments import Arguments, FactorArguments, ScoreArguments from kronfluence.factor.covariance import ( covariance_matrices_exist, load_covariance_matrices, @@ -34,12 +35,17 @@ from kronfluence.score.self import load_self_scores, self_scores_exist from kronfluence.task import Task from kronfluence.utils.dataset import ( + DataLoaderKwargs, DistributedEvalSampler, DistributedSamplerWithStack, find_executable_batch_size, make_indices_partition, ) -from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError, TrackedModuleNotFoundError +from kronfluence.utils.exceptions import ( + FactorsNotFoundError, + TrackedModuleNotFoundError, + UnsupportableModuleError, +) from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger, get_time from kronfluence.utils.save import ( FACTOR_ARGUMENTS_NAME, @@ -114,13 +120,15 @@ def __init__( self.model.to(self.state.device) # Create and configure profiler. - self.profiler = Profiler() if profile else PassThroughProfiler() - self.profiler.set_local_rank(local_rank=self.state.local_process_index) + self.profiler = Profiler(state=self.state) if profile else PassThroughProfiler(state=self.state) # Create and configure output directory. self.output_dir = Path(output_dir).joinpath(name).resolve() os.makedirs(name=self.output_dir, exist_ok=True) + # DataLoader parameters. + self._dataloader_params = DataLoaderKwargs() + def _save_arguments( self, arguments_name: str, @@ -235,6 +243,19 @@ def _get_dataloader( } | dataloader_params return data.DataLoader(dataset=dataset, **dataloader_params) + def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None: + """Sets the default DataLoader parameters to use for all experiments.""" + self._dataloader_params = dataloader_kwargs + + def _configure_dataloader(self, dataloader_kwargs: DataLoaderKwargs) -> Dict[str, Any]: + """Configures the DataLoader, logging appropriate messages.""" + if dataloader_kwargs is None: + dataloader_kwargs = self._dataloader_params + self.logger.info(f"DataLoader arguments not provided. Using the configuration: {dataloader_kwargs}.") + else: + self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") + return dataloader_kwargs.to_dict() + def _get_data_partition( self, total_data_examples: int, @@ -316,56 +337,14 @@ def factors_output_dir(self, factors_name: str) -> Path: return (self.output_dir / (FACTOR_SAVE_PREFIX + factors_name)).resolve() def scores_output_dir(self, scores_name: str) -> Path: - """Generates an output directory for storing all influence scores.""" + """Generates an output directory for storing all scores.""" return (self.output_dir / (SCORE_SAVE_PREFIX + scores_name)).resolve() - def _find_executable_factors_batch_size( - self, - func: Callable, - func_kwargs: Dict[str, Any], - dataset: data.Dataset, - dataloader_params: Dict[str, Any], - start_batch_size: int, - ) -> int: - """Automatically finds executable batch size for performing `func`.""" - self.logger.info("Automatically determining executable batch size.") - - def executable_batch_size_func(batch_size: int) -> None: - self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") - set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) - self.model.zero_grad(set_to_none=True) - release_memory() - total_batch_size = batch_size * self.state.num_processes - loader = self._get_dataloader( - dataset=dataset, - per_device_batch_size=batch_size, - # Only runs for a single step. - indices=list(range(total_batch_size)), - dataloader_params=dataloader_params, - ) - func(loader=loader, **func_kwargs) - - per_device_batch_size = find_executable_batch_size( - func=executable_batch_size_func, - start_batch_size=start_batch_size, - ) - self.logger.info(f"Executable batch size determined: {per_device_batch_size}.") - return per_device_batch_size - - @torch.no_grad() - def _aggregate_factors(self, aggregated_factors: FACTOR_TYPE, loaded_factors: FACTOR_TYPE) -> FACTOR_TYPE: - """Aggregates factors from the current loaded factors.""" - for factor_name, factors in loaded_factors.items(): - if factor_name not in aggregated_factors: - aggregated_factors[factor_name]: Dict[str, torch.Tensor] = {} - - for module_name in factors: - if module_name not in aggregated_factors[factor_name]: - aggregated_factors[factor_name][module_name] = (factors[module_name]).to(device=self.state.device) - else: - # Aggregate the factors from `loaded_factors` to `aggregated_factors`. - aggregated_factors[factor_name][module_name].add_(factors[module_name].to(device=self.state.device)) - return aggregated_factors + def _log_profile_summary(self) -> None: + """Log the summary of the profiling results.""" + profile_summary = self.profiler.summary() + if profile_summary != "": + self.logger.info(self.profiler.summary()) def load_factor_args(self, factors_name: str) -> Optional[Dict[str, Any]]: """Loads factor arguments with the given factor name.""" @@ -418,6 +397,41 @@ def load_self_scores(self, scores_name: str) -> Optional[SCORE_TYPE]: return load_self_scores(output_dir=scores_output_dir) return None + def _configure_score_args(self, score_args: ScoreArguments) -> ScoreArguments: + """Configures the ScoreArguments, logging appropriate messages.""" + if score_args is None: + score_args = ScoreArguments() + self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.") + else: + self.logger.info(f"Using the provided configuration: {score_args}.") + return score_args + + def _load_and_configure_factor_args(self, factors_name: str) -> Tuple[FactorArguments, FactorConfig]: + """Loads factor arguments and its configuration from disk.""" + factor_args = self.load_factor_args(factors_name=factors_name) + factors_output_dir = self.factors_output_dir(factors_name=factors_name) + if factor_args is None: + error_msg = f"Factors with name `{factors_name}` was not found at {factors_output_dir}." + self.logger.error(error_msg) + raise FactorsNotFoundError(error_msg) + factor_args = FactorArguments(**factor_args) + self.logger.info(f"Loaded FactorArguments with configuration: {factor_args}.") + strategy = factor_args.strategy + factor_config = FactorConfig.CONFIGS[strategy] + return factor_args, factor_config + + def _load_and_configure_score_args(self, scores_name: str) -> ScoreArguments: + """Loads score arguments from disk.""" + score_args = self.load_score_args(scores_name=scores_name) + scores_output_dir = self.scores_output_dir(scores_name=scores_name) + if score_args is None: + error_msg = f"Scores with name `{scores_name}` was not found at {scores_output_dir}." + self.logger.error(error_msg) + raise FactorsNotFoundError(error_msg) + score_args = ScoreArguments(**score_args) + self.logger.info(f"Loaded ScoreArguments with configuration: {score_args}.") + return score_args + def _load_all_required_factors(self, factors_name: str, strategy: str, factor_config: Any) -> FACTOR_TYPE: loaded_factors: FACTOR_TYPE = {} if factor_config.requires_covariance_matrices_for_precondition: @@ -454,6 +468,62 @@ def _load_all_required_factors(self, factors_name: str, strategy: str, factor_co loaded_factors.update(lambda_factors) return loaded_factors + @torch.no_grad() + def _aggregate_factors( + self, + factors_name: str, + data_partition_size: int, + module_partition_size: int, + exists_fnc: Callable, + load_fnc: Callable, + save_fnc: Callable, + ) -> Optional[FACTOR_TYPE]: + """Aggregates all factors computed for all data and module partitions.""" + factors_output_dir = self.factors_output_dir(factors_name=factors_name) + if not factors_output_dir.exists(): + error_msg = ( + f"Factors output directory {factors_output_dir} is not found when trying to " + f"aggregate partitioned factors." + ) + self.logger.error(error_msg) + raise FileNotFoundError(error_msg) + + all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] + all_partition_exists = all( + exists_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions + ) + if not all_partition_exists: + self.logger.warning("Factors are not aggregated as factors for some partitions are not yet computed.") + return + + start_time = get_time(state=self.state) + if self.state.is_main_process: + aggregated_factors: FACTOR_TYPE = {} + for data_partition in range(data_partition_size): + for module_partition in range(module_partition_size): + loaded_factors = load_fnc( + output_dir=factors_output_dir, + partition=(data_partition, module_partition), + ) + for factor_name, factors in loaded_factors.items(): + if factor_name not in aggregated_factors: + aggregated_factors[factor_name]: Dict[str, torch.Tensor] = {} + + for module_name in factors: + if module_name not in aggregated_factors[factor_name]: + aggregated_factors[factor_name][module_name] = factors[module_name] + else: + aggregated_factors[factor_name][module_name].add_(factors[module_name]) + del loaded_factors + save_fnc( + output_dir=factors_output_dir, + factors=aggregated_factors, + ) + self.state.wait_for_everyone() + end_time = get_time(state=self.state) + elapsed_time = end_time - start_time + self.logger.info(f"Aggregated all partitioned factors in {elapsed_time:.2f} seconds.") + @torch.no_grad() def _aggregate_scores( self, @@ -479,13 +549,11 @@ def _aggregate_scores( all_required_partitions = [ (i, j) for i in range(score_args.data_partition_size) for j in range(score_args.module_partition_size) ] - all_partition_exists = [ - exists_fnc(output_dir=scores_output_dir, partition=partition) for partition in all_required_partitions - ] + all_partition_exists = all( + [exists_fnc(output_dir=scores_output_dir, partition=partition) for partition in all_required_partitions] + ) if not all_partition_exists: - self.logger.info( - "Influence scores are not aggregated as scores for some partitions " "are not yet computed." - ) + self.logger.info("Influence scores are not aggregated as scores for some partitions are not yet computed.") return start_time = get_time(state=self.state) diff --git a/kronfluence/computer/covariance_computer.py b/kronfluence/computer/covariance_computer.py deleted file mode 100644 index 8eeda6b..0000000 --- a/kronfluence/computer/covariance_computer.py +++ /dev/null @@ -1,336 +0,0 @@ -import os -from typing import Any, Dict, List, Optional, Sequence - -import torch -from torch.utils import data - -from kronfluence.arguments import FactorArguments -from kronfluence.computer.computer import Computer -from kronfluence.factor.config import FactorConfig -from kronfluence.factor.covariance import ( - covariance_matrices_exist, - fit_covariance_matrices_with_loader, - load_covariance_matrices, - save_covariance_matrices, -) -from kronfluence.module.constants import FACTOR_TYPE -from kronfluence.utils.dataset import DataLoaderKwargs -from kronfluence.utils.logger import get_time -from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME -from kronfluence.utils.state import release_memory - - -class CovarianceComputer(Computer): - """Handles the computation of all covariance matrices for a given PyTorch model.""" - - def _find_executable_covariance_factors_batch_size( - self, - total_data_examples: int, - dataset: data.Dataset, - dataloader_params: Dict[str, Any], - factor_args: FactorArguments, - tracked_module_names: Optional[List[str]], - ) -> int: - """Automatically finds executable batch size for computing covariance matrices.""" - if self.state.num_processes > 1: - error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " - "Please manually configure the batch size." - ) - self.logger.error(error_msg) - raise NotImplementedError(error_msg) - - kwargs = { - "model": self.model, - "state": self.state, - "task": self.task, - "factor_args": factor_args, - "tracked_module_names": tracked_module_names, - } - start_batch_size = min( - [ - factor_args.initial_per_device_batch_size_attempt, - total_data_examples, - ] - ) - return self._find_executable_factors_batch_size( - func=fit_covariance_matrices_with_loader, - func_kwargs=kwargs, - dataset=dataset, - dataloader_params=dataloader_params, - start_batch_size=start_batch_size, - ) - - def _fit_partitioned_covariance_matrices( - self, - dataset: data.Dataset, - per_device_batch_size: int, - dataloader_params: Dict[str, Any], - factor_args: FactorArguments, - indices: Optional[List[int]] = None, - tracked_module_names: Optional[List[str]] = None, - ) -> FACTOR_TYPE: - """Fits all covariance matrices for the given data and module partition.""" - release_memory() - start_time = get_time(state=self.state) - with self.profiler.profile("Fit Covariance"): - loader = self._get_dataloader( - dataset=dataset, - per_device_batch_size=per_device_batch_size, - dataloader_params=dataloader_params, - indices=indices, - ) - num_data_processed, covariance_factors = fit_covariance_matrices_with_loader( - model=self.model, - state=self.state, - task=self.task, - loader=loader, - factor_args=factor_args, - tracked_module_names=tracked_module_names, - ) - end_time = get_time(state=self.state) - elapsed_time = end_time - start_time - self.logger.info( - f"Fitted covariance matrices on {num_data_processed.item()} data points in " f"{elapsed_time:.2f} seconds." - ) - return covariance_factors - - def fit_covariance_matrices( - self, - factors_name: str, - dataset: data.Dataset, - per_device_batch_size: Optional[int] = None, - dataloader_kwargs: Optional[DataLoaderKwargs] = None, - factor_args: Optional[FactorArguments] = None, - target_data_partitions: Optional[Sequence[int]] = None, - target_module_partitions: Optional[Sequence[int]] = None, - overwrite_output_dir: bool = False, - ) -> Optional[FACTOR_TYPE]: - """Computes covariance matrices for all available modules. See `fit_all_factors` for - the complete docstring with detailed description of each parameter.""" - self.logger.debug(f"Fitting covariance matrices with parameters: {locals()}") - - factors_output_dir = self.factors_output_dir(factors_name=factors_name) - os.makedirs(factors_output_dir, exist_ok=True) - if covariance_matrices_exist(output_dir=factors_output_dir) and not overwrite_output_dir: - self.logger.info(f"Found existing covariance matrices at {factors_output_dir}. Skipping.") - return self.load_covariance_matrices(factors_name=factors_name) - - if factor_args is None: - factor_args = FactorArguments() - self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") - else: - self.logger.info(f"Using the provided configuration: {factor_args}.") - - if self.state.is_main_process: - self._save_arguments( - arguments_name=FACTOR_ARGUMENTS_NAME, - arguments=factor_args, - output_dir=factors_output_dir, - overwrite_output_dir=overwrite_output_dir, - ) - - if not FactorConfig.CONFIGS[factor_args.strategy].requires_covariance_matrices: - self.logger.info( - f"Strategy `{factor_args.strategy}` does not require fitting covariance matrices. " f"Skipping." - ) - return - - if self.state.is_main_process: - self._save_dataset_metadata( - dataset_name="covariance", - dataset=dataset, - output_dir=factors_output_dir, - overwrite_output_dir=overwrite_output_dir, - ) - - if dataloader_kwargs is None: - dataloader_kwargs = DataLoaderKwargs() - self.logger.info( - f"DataLoader arguments not provided. Using the default configuration: {dataloader_kwargs}." - ) - else: - self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") - dataloader_params = dataloader_kwargs.to_dict() - - total_data_examples = min([factor_args.covariance_max_examples, len(dataset)]) - self.logger.info(f"Total data examples to fit covariance matrices: {total_data_examples}.") - - no_partition = ( - factor_args.covariance_data_partition_size == 1 and factor_args.covariance_module_partition_size == 1 - ) - partition_provided = target_data_partitions is not None or target_module_partitions is not None - if no_partition and partition_provided: - error_msg = ( - "`target_data_partitions` or `target_module_partitions` were specified, while" - "the `FactorArguments` did not expect any partitions when computing covariance matrices." - ) - self.logger.error(error_msg) - raise ValueError(error_msg) - - if no_partition: - if total_data_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples." - self.logger.error(error_msg) - raise ValueError(error_msg) - if per_device_batch_size is None: - per_device_batch_size = self._find_executable_covariance_factors_batch_size( - dataloader_params=dataloader_params, - dataset=dataset, - total_data_examples=total_data_examples, - factor_args=factor_args, - tracked_module_names=None, - ) - covariance_factors = self._fit_partitioned_covariance_matrices( - dataset=dataset, - per_device_batch_size=per_device_batch_size, - dataloader_params=dataloader_params, - factor_args=factor_args, - indices=list(range(total_data_examples)), - tracked_module_names=None, - ) - with self.profiler.profile("Save Covariance"): - if self.state.is_main_process: - save_covariance_matrices( - output_dir=factors_output_dir, - covariance_factors=covariance_factors, - ) - self.state.wait_for_everyone() - self.logger.info(f"Saved covariance matrices at {factors_output_dir}.") - - profile_summary = self.profiler.summary() - if profile_summary != "": - self.logger.info(self.profiler.summary()) - return covariance_factors - - data_partition_indices, target_data_partitions = self._get_data_partition( - total_data_examples=total_data_examples, - data_partition_size=factor_args.covariance_data_partition_size, - target_data_partitions=target_data_partitions, - ) - module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=factor_args.covariance_module_partition_size, - target_module_partitions=target_module_partitions, - ) - - all_start_time = get_time(state=self.state) - for data_partition in target_data_partitions: - for module_partition in target_module_partitions: - if ( - covariance_matrices_exist( - output_dir=factors_output_dir, - partition=(data_partition, module_partition), - ) - and not overwrite_output_dir - ): - self.logger.info( - f"Found existing covariance matrices for data partition {data_partition} " - f"and module partition {module_partition} at {factors_output_dir}. Skipping." - ) - continue - - start_index, end_index = data_partition_indices[data_partition] - self.logger.info( - f"Fitting covariance matrices for data partition with data indices ({start_index}, " - f"{end_index}) and modules {module_partition_names[module_partition]}." - ) - - max_total_examples = total_data_examples // factor_args.covariance_data_partition_size - if max_total_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples." - self.logger.error(error_msg) - raise ValueError(error_msg) - if per_device_batch_size is None: - per_device_batch_size = self._find_executable_covariance_factors_batch_size( - dataloader_params=dataloader_params, - dataset=dataset, - factor_args=factor_args, - total_data_examples=max_total_examples, - tracked_module_names=module_partition_names[0], - ) - covariance_factors = self._fit_partitioned_covariance_matrices( - dataset=dataset, - per_device_batch_size=per_device_batch_size, - dataloader_params=dataloader_params, - factor_args=factor_args, - indices=list(range(start_index, end_index)), - tracked_module_names=module_partition_names[module_partition], - ) - with self.profiler.profile("Save Covariance"): - if self.state.is_main_process: - save_covariance_matrices( - output_dir=factors_output_dir, - covariance_factors=covariance_factors, - partition=(data_partition, module_partition), - ) - self.state.wait_for_everyone() - del covariance_factors - self.logger.info(f"Saved partitioned covariance matrices at {factors_output_dir}.") - - all_end_time = get_time(state=self.state) - elapsed_time = all_end_time - all_start_time - self.logger.info(f"Fitted all partitioned covariance matrices in {elapsed_time:.2f} seconds.") - aggregated_covariance_factors = self.aggregate_covariance_matrices( - factors_name=factors_name, factor_args=factor_args - ) - - profile_summary = self.profiler.summary() - if profile_summary != "": - self.logger.info(self.profiler.summary()) - return aggregated_covariance_factors - - @torch.no_grad() - def aggregate_covariance_matrices( - self, - factors_name: str, - factor_args: FactorArguments, - ) -> Optional[FACTOR_TYPE]: - """Aggregates covariance matrices computed for all data and module partitions.""" - factors_output_dir = self.factors_output_dir(factors_name=factors_name) - if not factors_output_dir.exists(): - error_msg = ( - f"Factors output directory {factors_output_dir} is not found " - f"when trying to aggregate partitioned covariance matrices." - ) - self.logger.error(error_msg) - raise FileNotFoundError(error_msg) - - data_partition_size = factor_args.covariance_data_partition_size - module_partition_size = factor_args.covariance_module_partition_size - all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] - all_partition_exists = all( - covariance_matrices_exist(output_dir=factors_output_dir, partition=partition) - for partition in all_required_partitions - ) - if not all_partition_exists: - self.logger.info( - "Covariance matrices are not aggregated as covariance matrices for some partitions " - "are not yet computed." - ) - return - - start_time = get_time(state=self.state) - with self.profiler.profile("Aggregate Covariance"): - if self.state.is_main_process: - aggregated_covariance_factors: FACTOR_TYPE = {} - for data_partition in range(data_partition_size): - for module_partition in range(module_partition_size): - loaded_covariance_factors = load_covariance_matrices( - output_dir=factors_output_dir, - partition=(data_partition, module_partition), - ) - aggregated_covariance_factors = self._aggregate_factors( - aggregated_factors=aggregated_covariance_factors, - loaded_factors=loaded_covariance_factors, - ) - del loaded_covariance_factors - with self.profiler.profile("Save Covariance"): - save_covariance_matrices( - output_dir=factors_output_dir, - covariance_factors=aggregated_covariance_factors, - ) - self.state.wait_for_everyone() - end_time = get_time(state=self.state) - elapsed_time = end_time - start_time - self.logger.info(f"Aggregated all partitioned covariance matrices in {elapsed_time:.2f} seconds.") - return aggregated_covariance_factors diff --git a/kronfluence/computer/eigen_computer.py b/kronfluence/computer/eigen_computer.py deleted file mode 100644 index ce6505f..0000000 --- a/kronfluence/computer/eigen_computer.py +++ /dev/null @@ -1,454 +0,0 @@ -import os -import time -from typing import Any, Dict, List, Optional, Sequence - -import torch -from torch.utils import data - -from kronfluence.arguments import FactorArguments -from kronfluence.computer.computer import Computer -from kronfluence.factor.config import FactorConfig -from kronfluence.factor.covariance import ( - covariance_matrices_exist, - load_covariance_matrices, -) -from kronfluence.factor.eigen import ( - eigendecomposition_exist, - fit_lambda_matrices_with_loader, - lambda_matrices_exist, - load_eigendecomposition, - load_lambda_matrices, - perform_eigendecomposition, - save_eigendecomposition, - save_lambda_matrices, -) -from kronfluence.module.constants import FACTOR_TYPE -from kronfluence.utils.dataset import DataLoaderKwargs -from kronfluence.utils.exceptions import FactorsNotFoundError -from kronfluence.utils.logger import get_time -from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME -from kronfluence.utils.state import release_memory - - -class EigenComputer(Computer): - """Handles the computation of Eigendecomposition and Lambda matrices for a given PyTorch model.""" - - def perform_eigendecomposition( - self, - factors_name: str, - factor_args: Optional[FactorArguments] = None, - overwrite_output_dir: bool = False, - load_from_factors_name: Optional[str] = None, - ) -> Optional[FACTOR_TYPE]: - """Performs Eigendecomposition for all available covariance matrices. See `fit_all_factors` for - the complete docstring with detailed description of each parameter.""" - self.logger.debug(f"Performing Eigendecomposition with parameters: {locals()}") - - factors_output_dir = self.factors_output_dir(factors_name=factors_name) - os.makedirs(factors_output_dir, exist_ok=True) - if eigendecomposition_exist(output_dir=factors_output_dir) and not overwrite_output_dir: - self.logger.info(f"Found existing Eigendecomposition results at {factors_output_dir}. Skipping.") - return self.load_eigendecomposition(factors_name=factors_name) - - if factor_args is None: - factor_args = FactorArguments() - self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") - else: - self.logger.info(f"Using the provided configuration: {factor_args}.") - - if self.state.is_main_process: - self._save_arguments( - arguments_name=FACTOR_ARGUMENTS_NAME, - arguments=factor_args, - output_dir=factors_output_dir, - overwrite_output_dir=overwrite_output_dir, - ) - - if not FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition: - self.logger.info( - f"Strategy `{factor_args.strategy}` does not require performing Eigendecomposition. Skipping." - ) - return None - - if load_from_factors_name is not None: - self.logger.info(f"Loading covariance matrices from factors with name `{load_from_factors_name}`.") - load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) - else: - load_factors_output_dir = factors_output_dir - - if not covariance_matrices_exist(output_dir=load_factors_output_dir): - error_msg = ( - f"Aggregated covariance matrices not found at {load_factors_output_dir}. " - f"To perform Eigendecomposition, covariance matrices need to be first computed." - ) - self.logger.error(error_msg) - raise FactorsNotFoundError(error_msg) - - with self.profiler.profile("Load Covariance"): - covariance_factors = load_covariance_matrices(output_dir=load_factors_output_dir) - - eigen_factors = None - if self.state.is_main_process: - release_memory() - start_time = time.time() - with self.profiler.profile("Perform Eigendecomposition"): - eigen_factors = perform_eigendecomposition( - covariance_factors=covariance_factors, - model=self.model, - state=self.state, - factor_args=factor_args, - ) - end_time = time.time() - elapsed_time = end_time - start_time - self.logger.info(f"Performed Eigendecomposition in {elapsed_time:.2f} seconds.") - with self.profiler.profile("Save Eigendecomposition"): - save_eigendecomposition( - output_dir=factors_output_dir, - eigen_factors=eigen_factors, - ) - self.logger.info(f"Saved Eigendecomposition results at {factors_output_dir}.") - self.state.wait_for_everyone() - - profile_summary = self.profiler.summary() - if profile_summary != "": - self.logger.info(self.profiler.summary()) - return eigen_factors - - def _find_executable_lambda_factors_batch_size( - self, - eigen_factors: FACTOR_TYPE, - total_data_examples: int, - dataset: data.Dataset, - dataloader_params: Dict[str, Any], - factor_args: FactorArguments, - tracked_module_names: Optional[List[str]], - ) -> int: - """Automatically finds executable batch size for computing Lambda matrices.""" - if self.state.num_processes > 1: - error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " - "Please manually configure the batch size." - ) - self.logger.error(error_msg) - raise NotImplementedError(error_msg) - - kwargs = { - "eigen_factors": eigen_factors, - "model": self.model, - "state": self.state, - "task": self.task, - "factor_args": factor_args, - "tracked_module_names": tracked_module_names, - } - start_batch_size = min( - [ - factor_args.initial_per_device_batch_size_attempt, - total_data_examples, - ] - ) - return self._find_executable_factors_batch_size( - func=fit_lambda_matrices_with_loader, - func_kwargs=kwargs, - dataset=dataset, - dataloader_params=dataloader_params, - start_batch_size=start_batch_size, - ) - - def _fit_partitioned_lambda_matrices( - self, - eigen_factors: Optional[FACTOR_TYPE], - dataset: data.Dataset, - per_device_batch_size: int, - dataloader_params: Dict[str, Any], - factor_args: FactorArguments, - indices: Optional[List[int]] = None, - tracked_module_names: Optional[List[str]] = None, - ) -> FACTOR_TYPE: - """Fits all Lambda matrices for the given data and module partition.""" - release_memory() - start_time = get_time(state=self.state) - with self.profiler.profile("Fit Lambda"): - loader = self._get_dataloader( - dataset=dataset, - per_device_batch_size=per_device_batch_size, - indices=indices, - dataloader_params=dataloader_params, - ) - num_data_processed, lambda_factors = fit_lambda_matrices_with_loader( - model=self.model, - eigen_factors=eigen_factors, - state=self.state, - task=self.task, - loader=loader, - factor_args=factor_args, - tracked_module_names=tracked_module_names, - ) - end_time = get_time(state=self.state) - elapsed_time = end_time - start_time - self.logger.info( - f"Fitted Lambda matrices on {num_data_processed.item()} data points in " f"{elapsed_time:.2f} seconds." - ) - return lambda_factors - - def fit_lambda_matrices( - self, - factors_name: str, - dataset: data.Dataset, - per_device_batch_size: Optional[int] = None, - dataloader_kwargs: Optional[DataLoaderKwargs] = None, - factor_args: Optional[FactorArguments] = None, - target_data_partitions: Optional[Sequence[int]] = None, - target_module_partitions: Optional[Sequence[int]] = None, - overwrite_output_dir: bool = False, - load_from_factors_name: Optional[str] = None, - ) -> Optional[FACTOR_TYPE]: - """Computes Lambda matrices for all `TrackedModule`. See `fit_all_factors` for - the complete docstring with detailed description of each parameter.""" - self.logger.debug(f"Fitting Lambda matrices with parameters: {locals()}") - - factors_output_dir = self.factors_output_dir(factors_name=factors_name) - os.makedirs(factors_output_dir, exist_ok=True) - if lambda_matrices_exist(output_dir=factors_output_dir) and not overwrite_output_dir: - self.logger.info(f"Found existing Lambda matrices at {factors_output_dir}. Skipping.") - return self.load_lambda_matrices(factors_name=factors_name) - - if factor_args is None: - factor_args = FactorArguments() - self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") - else: - self.logger.info(f"Using the provided configuration: {factor_args}.") - - if self.state.is_main_process: - self._save_arguments( - arguments_name=FACTOR_ARGUMENTS_NAME, - arguments=factor_args, - output_dir=factors_output_dir, - overwrite_output_dir=overwrite_output_dir, - ) - - if not FactorConfig.CONFIGS[factor_args.strategy].requires_lambda_matrices: - self.logger.info( - f"Strategy `{factor_args.strategy}` does not require fitting Lambda matrices. " f"Skipping." - ) - return None - - if self.state.is_main_process: - self._save_dataset_metadata( - dataset_name="lambda", - dataset=dataset, - output_dir=factors_output_dir, - overwrite_output_dir=overwrite_output_dir, - ) - - if load_from_factors_name is not None: - self.logger.info( - f"Will be loading Eigendecomposition results from factors with name `{load_from_factors_name}`." - ) - load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) - else: - load_factors_output_dir = factors_output_dir - - if ( - not eigendecomposition_exist(output_dir=load_factors_output_dir) - and FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition_for_lambda - ): - error_msg = ( - f"Eigendecomposition results not found at {load_factors_output_dir}. " - f"To fit Lambda matrices for {factor_args.strategy}, Eigendecomposition must be " - f"performed before computing Lambda matrices." - ) - self.logger.error(error_msg) - raise FactorsNotFoundError(error_msg) - - if dataloader_kwargs is None: - dataloader_kwargs = DataLoaderKwargs() - self.logger.info( - f"DataLoader arguments not provided. Using the default configuration: {dataloader_kwargs}." - ) - else: - self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") - dataloader_params = dataloader_kwargs.to_dict() - - eigen_factors = None - if FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition_for_lambda: - with self.profiler.profile("Load Eigendecomposition"): - eigen_factors = load_eigendecomposition(output_dir=load_factors_output_dir) - - total_data_examples = min([factor_args.lambda_max_examples, len(dataset)]) - self.logger.info(f"Total data examples to fit Lambda matrices: {total_data_examples}.") - - no_partition = factor_args.lambda_data_partition_size == 1 and factor_args.lambda_module_partition_size == 1 - partition_provided = target_data_partitions is not None or target_module_partitions is not None - if no_partition and partition_provided: - error_msg = ( - "`target_data_partitions` or `target_module_partitions` were specified, while" - "the `FactorArguments` did not expect any partitions for computing Lambda matrices." - ) - self.logger.error(error_msg) - raise ValueError(error_msg) - - if no_partition: - if total_data_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples." - self.logger.error(error_msg) - raise ValueError(error_msg) - if per_device_batch_size is None: - per_device_batch_size = self._find_executable_lambda_factors_batch_size( - eigen_factors=eigen_factors, - dataloader_params=dataloader_params, - dataset=dataset, - total_data_examples=total_data_examples, - factor_args=factor_args, - tracked_module_names=None, - ) - lambda_factors = self._fit_partitioned_lambda_matrices( - eigen_factors=eigen_factors, - dataset=dataset, - per_device_batch_size=per_device_batch_size, - dataloader_params=dataloader_params, - factor_args=factor_args, - indices=list(range(total_data_examples)), - tracked_module_names=None, - ) - with self.profiler.profile("Save Lambda"): - if self.state.is_main_process: - save_lambda_matrices(output_dir=factors_output_dir, lambda_factors=lambda_factors) - self.state.wait_for_everyone() - self.logger.info(f"Saved Lambda matrices at {factors_output_dir}.") - - profile_summary = self.profiler.summary() - if profile_summary != "": - self.logger.info(self.profiler.summary()) - return lambda_factors - - data_partition_indices, target_data_partitions = self._get_data_partition( - total_data_examples=total_data_examples, - data_partition_size=factor_args.lambda_data_partition_size, - target_data_partitions=target_data_partitions, - ) - module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=factor_args.lambda_module_partition_size, - target_module_partitions=target_module_partitions, - ) - - all_start_time = get_time(state=self.state) - for data_partition in target_data_partitions: - for module_partition in target_module_partitions: - if ( - lambda_matrices_exist( - output_dir=factors_output_dir, - partition=(data_partition, module_partition), - ) - and not overwrite_output_dir - ): - self.logger.info( - f"Found existing Lambda matrices for data partition {data_partition} " - f"and module partition {module_partition} at {factors_output_dir}. Skipping." - ) - continue - - start_index, end_index = data_partition_indices[data_partition] - self.logger.info( - f"Fitting Lambda matrices for data partition with data indices ({start_index}, " - f"{end_index}) and modules {module_partition_names[module_partition]}." - ) - - max_total_examples = total_data_examples // factor_args.lambda_data_partition_size - if max_total_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples." - self.logger.error(error_msg) - raise ValueError(error_msg) - if per_device_batch_size is None: - per_device_batch_size = self._find_executable_lambda_factors_batch_size( - eigen_factors=eigen_factors, - dataloader_params=dataloader_params, - dataset=dataset, - factor_args=factor_args, - total_data_examples=max_total_examples, - tracked_module_names=module_partition_names[0], - ) - lambda_factors = self._fit_partitioned_lambda_matrices( - eigen_factors=eigen_factors, - dataset=dataset, - per_device_batch_size=per_device_batch_size, - dataloader_params=dataloader_params, - factor_args=factor_args, - indices=list(range(start_index, end_index)), - tracked_module_names=module_partition_names[module_partition], - ) - with self.profiler.profile("Save Lambda"): - if self.state.is_main_process: - save_lambda_matrices( - output_dir=factors_output_dir, - lambda_factors=lambda_factors, - partition=(data_partition, module_partition), - ) - self.state.wait_for_everyone() - del lambda_factors - self.logger.info(f"Saved partitioned Lambda matrices at {factors_output_dir}.") - - all_end_time = get_time(state=self.state) - elapsed_time = all_end_time - all_start_time - self.logger.info(f"Fitted all partitioned Lambda matrices in {elapsed_time:.2f} seconds.") - aggregated_lambda_factors = self.aggregate_lambda_matrices(factors_name=factors_name, factor_args=factor_args) - - profile_summary = self.profiler.summary() - if profile_summary != "": - self.logger.info(self.profiler.summary()) - return aggregated_lambda_factors - - @torch.no_grad() - def aggregate_lambda_matrices( - self, - factors_name: str, - factor_args: FactorArguments, - ) -> Optional[FACTOR_TYPE]: - """Aggregates Lambda matrices computed for all data and module partitions.""" - factors_output_dir = self.factors_output_dir(factors_name=factors_name) - - if not factors_output_dir.exists(): - error_msg = ( - f"Factors output directory {factors_output_dir} is not found " - f"when trying to aggregate partitioned Lambda matrices." - ) - self.logger.error(error_msg) - raise FileNotFoundError(error_msg) - - data_partition_size = factor_args.lambda_data_partition_size - module_partition_size = factor_args.lambda_module_partition_size - all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] - all_partition_exists = [ - lambda_matrices_exist(output_dir=factors_output_dir, partition=partition) - for partition in all_required_partitions - ] - if not all_partition_exists: - self.logger.info( - "Lambda matrices are not aggregated as Lambda matrices for some partitions are not yet computed." - ) - return - - start_time = get_time(state=self.state) - with self.profiler.profile("Aggregate Lambda"): - if self.state.is_main_process: - aggregated_lambda_factors: FACTOR_TYPE = {} - for data_partition in range(data_partition_size): - for module_partition in range(module_partition_size): - loaded_lambda_factors = load_lambda_matrices( - output_dir=factors_output_dir, - partition=(data_partition, module_partition), - ) - aggregated_lambda_factors = self._aggregate_factors( - aggregated_factors=aggregated_lambda_factors, - loaded_factors=loaded_lambda_factors, - ) - del loaded_lambda_factors - with self.profiler.profile("Save Lambda"): - save_lambda_matrices( - output_dir=factors_output_dir, - lambda_factors=aggregated_lambda_factors, - ) - self.state.wait_for_everyone() - end_time = get_time(state=self.state) - elapsed_time = end_time - start_time - self.logger.info(f"Aggregated all partitioned Lambda matrices in {elapsed_time:.2f} seconds.") - return aggregated_lambda_factors diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py new file mode 100644 index 0000000..da16436 --- /dev/null +++ b/kronfluence/computer/factor_computer.py @@ -0,0 +1,653 @@ +import os +import time +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Sequence + +import torch +from torch.utils import data + +from kronfluence.arguments import FactorArguments +from kronfluence.computer.computer import Computer +from kronfluence.factor.config import FactorConfig +from kronfluence.factor.covariance import ( + covariance_matrices_exist, + fit_covariance_matrices_with_loader, + load_covariance_matrices, + save_covariance_matrices, +) +from kronfluence.factor.eigen import ( + eigendecomposition_exist, + fit_lambda_matrices_with_loader, + lambda_matrices_exist, + load_eigendecomposition, + load_lambda_matrices, + perform_eigendecomposition, + save_eigendecomposition, + save_lambda_matrices, +) +from kronfluence.module.tracked_module import ModuleMode +from kronfluence.module.utils import set_mode +from kronfluence.utils.dataset import DataLoaderKwargs, find_executable_batch_size +from kronfluence.utils.exceptions import FactorsNotFoundError +from kronfluence.utils.logger import get_time +from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME +from kronfluence.utils.state import release_memory + + +class FactorComputer(Computer): + """Handles the computation of all factors for a given PyTorch model.""" + + def _configure_and_save_factor_args( + self, factor_args: Optional[FactorArguments], factors_output_dir: Path, overwrite_output_dir: bool + ) -> FactorArguments: + """Configure the provided factor arguments and save it in disk.""" + if factor_args is None: + factor_args = FactorArguments() + self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") + else: + self.logger.info(f"Using the provided configuration: {factor_args}.") + + if self.state.is_main_process: + self._save_arguments( + arguments_name=FACTOR_ARGUMENTS_NAME, + arguments=factor_args, + output_dir=factors_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + self.state.wait_for_everyone() + return factor_args + + def _find_executable_factors_batch_size( + self, + func: Callable, + func_kwargs: Dict[str, Any], + factor_args: FactorArguments, + dataset: data.Dataset, + dataloader_params: Dict[str, Any], + total_data_examples: Optional[int] = None, + ) -> int: + """Automatically finds executable batch size for performing `func`.""" + if self.state.use_distributed: + error_msg = ( + "Automatic batch size search is currently not supported for multi-GPU training. " + "Please manually configure the batch size by passing in `per_device_batch_size`." + ) + self.logger.error(error_msg) + raise NotImplementedError(error_msg) + + self.logger.info("Automatically determining executable batch size.") + if total_data_examples is None: + total_data_examples = len(dataset) + start_batch_size = min( + [ + factor_args.initial_per_device_batch_size_attempt, + total_data_examples, + ] + ) + + def executable_batch_size_func(batch_size: int) -> None: + self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") + # Release all memory that could be caused by the previous OOM. + set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) + self.model.zero_grad(set_to_none=True) + release_memory() + total_batch_size = batch_size * self.state.num_processes + loader = self._get_dataloader( + dataset=dataset, + per_device_batch_size=batch_size, + indices=list(range(total_batch_size)), + dataloader_params=dataloader_params, + allow_duplicates=True, + ) + func(loader=loader, **func_kwargs) + + per_device_batch_size = find_executable_batch_size( + func=executable_batch_size_func, + start_batch_size=start_batch_size, + ) + self.logger.info(f"Executable batch size determined: {per_device_batch_size}.") + return per_device_batch_size + + def fit_covariance_matrices( + self, + factors_name: str, + dataset: data.Dataset, + per_device_batch_size: Optional[int] = None, + dataloader_kwargs: Optional[DataLoaderKwargs] = None, + factor_args: Optional[FactorArguments] = None, + target_data_partitions: Optional[Sequence[int]] = None, + target_module_partitions: Optional[Sequence[int]] = None, + overwrite_output_dir: bool = False, + ) -> None: + """Computes activation and pseudo-covariance matrices with the given dataset. + + Args: + factors_name (str): + The unique identifier for the factor, used to organize and retrieve the results. + dataset (data.Dataset): + The dataset that will be used to fit covariance matrices. + per_device_batch_size (int, optional): + The per-device batch size used to fit the factors. If not specified, executable + batch size is automatically determined. + dataloader_kwargs (DataLoaderKwargs, optional): + Controls additional arguments for PyTorch's DataLoader. + factor_args (FactorArguments, optional): + Arguments related to computing the factors. If not specified, the default values of + `FactorArguments` will be used. + target_data_partitions(Sequence[int], optional): + The list of data partition to fit covariance matrices. By default, covariance + matrices will be computed for all partitions. + target_module_partitions(Sequence[int], optional): + The list of module partition to fit covariance matrices. By default, covariance + matrices will be computed for all partitions. + overwrite_output_dir (bool, optional): + If True, the existing factors with the same `factors_name` will be overwritten. + """ + self.logger.debug(f"Fitting covariance matrices with parameters: {locals()}") + + factors_output_dir = self.factors_output_dir(factors_name=factors_name) + os.makedirs(factors_output_dir, exist_ok=True) + if covariance_matrices_exist(output_dir=factors_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing covariance matrices at `{factors_output_dir}`. Skipping.") + return + + factor_args = self._configure_and_save_factor_args( + factor_args=factor_args, factors_output_dir=factors_output_dir, overwrite_output_dir=overwrite_output_dir + ) + + if not FactorConfig.CONFIGS[factor_args.strategy].requires_covariance_matrices: + self.logger.info( + f"Strategy `{factor_args.strategy}` does not require fitting covariance matrices. Skipping." + ) + return + + dataloader_params = self._configure_dataloader(dataloader_kwargs) + if self.state.is_main_process: + self._save_dataset_metadata( + dataset_name="covariance", + dataset=dataset, + output_dir=factors_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + + if factor_args.covariance_max_examples is None: + total_data_examples = len(dataset) + else: + total_data_examples = min([factor_args.covariance_max_examples, len(dataset)]) + self.logger.info(f"Total data examples to fit covariance matrices: {total_data_examples}.") + + no_partition = ( + factor_args.covariance_data_partition_size == 1 and factor_args.covariance_module_partition_size == 1 + ) + partition_provided = target_data_partitions is not None or target_module_partitions is not None + if no_partition and partition_provided: + error_msg = ( + "`target_data_partitions` or `target_module_partitions` were specified, while" + "the `FactorArguments` did not expect any data and module partition to compute covariance matrices." + ) + self.logger.error(error_msg) + raise ValueError(error_msg) + + data_partition_indices, target_data_partitions = self._get_data_partition( + total_data_examples=total_data_examples, + data_partition_size=factor_args.covariance_data_partition_size, + target_data_partitions=target_data_partitions, + ) + max_partition_examples = total_data_examples // factor_args.covariance_data_partition_size + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=factor_args.covariance_module_partition_size, + target_module_partitions=target_module_partitions, + ) + + if max_partition_examples < self.state.num_processes: + error_msg = "The number of processes are more than the data examples. Try reducing the number of processes." + self.logger.error(error_msg) + raise ValueError(error_msg) + + all_start_time = get_time(state=self.state) + for data_partition in target_data_partitions: + for module_partition in target_module_partitions: + if no_partition: + partition = None + else: + partition = (data_partition, module_partition) + + if ( + covariance_matrices_exist( + output_dir=factors_output_dir, + partition=partition, + ) + and not overwrite_output_dir + ): + self.logger.info( + f"Found existing covariance matrices for data partition {data_partition} " + f"and module partition {module_partition} at {factors_output_dir}. Skipping." + ) + continue + + start_index, end_index = data_partition_indices[data_partition] + self.logger.info( + f"Fitting covariance matrices with data indices ({start_index}, {end_index}) and " + f"modules {module_partition_names[module_partition]}." + ) + + if per_device_batch_size is None: + kwargs = { + "model": self.model, + "state": self.state, + "task": self.task, + "factor_args": factor_args, + "tracked_module_names": module_partition_names[module_partition], + } + per_device_batch_size = self._find_executable_factors_batch_size( + func=fit_covariance_matrices_with_loader, + func_kwargs=kwargs, + dataset=dataset, + factor_args=factor_args, + dataloader_params=dataloader_params, + total_data_examples=max_partition_examples, + ) + + release_memory() + start_time = get_time(state=self.state) + with self.profiler.profile("Fit Covariance"): + loader = self._get_dataloader( + dataset=dataset, + per_device_batch_size=per_device_batch_size, + dataloader_params=dataloader_params, + indices=list(range(start_index, end_index)), + allow_duplicates=False, + ) + num_data_processed, covariance_factors = fit_covariance_matrices_with_loader( + model=self.model, + state=self.state, + task=self.task, + loader=loader, + factor_args=factor_args, + tracked_module_names=module_partition_names[module_partition], + ) + end_time = get_time(state=self.state) + elapsed_time = end_time - start_time + self.logger.info( + f"Fitted covariance matrices with {num_data_processed.item()} data points in " + f"{elapsed_time:.2f} seconds." + ) + + with self.profiler.profile("Save Covariance"): + if self.state.is_main_process: + save_covariance_matrices( + output_dir=factors_output_dir, + factors=covariance_factors, + partition=partition, + metadata=factor_args.to_str_dict(), + ) + self.state.wait_for_everyone() + del covariance_factors, loader + self.logger.info(f"Saved covariance matrices at `{factors_output_dir}`.") + + all_end_time = get_time(state=self.state) + elapsed_time = all_end_time - all_start_time + if not no_partition: + self.logger.info(f"Fitted all partitioned covariance matrices in {elapsed_time:.2f} seconds.") + self.aggregate_covariance_matrices(factors_name=factors_name) + self.logger.info(f"Saved aggregated covariance matrices at `{factors_output_dir}`.") + self._log_profile_summary() + + @torch.no_grad() + def aggregate_covariance_matrices( + self, + factors_name: str, + ) -> None: + """Aggregates all partitioned covariance matrices. The factors will not be aggregated if covariance matrices + for some data or module partitions are missing. + + Args: + factors_name (str): + The unique identifier for the factor, used to organize and retrieve the results. + """ + factor_args, _ = self._load_and_configure_factor_args(factors_name=factors_name) + with self.profiler.profile("Aggregate Covariance"): + self._aggregate_factors( + factors_name=factors_name, + data_partition_size=factor_args.covariance_data_partition_size, + module_partition_size=factor_args.covariance_module_partition_size, + exists_fnc=covariance_matrices_exist, + load_fnc=load_covariance_matrices, + save_fnc=save_covariance_matrices, + ) + + def perform_eigendecomposition( + self, + factors_name: str, + factor_args: Optional[FactorArguments] = None, + overwrite_output_dir: bool = False, + load_from_factors_name: Optional[str] = None, + ) -> None: + """Performs Eigendecomposition on all covariance matrices. + + Args: + factors_name (str): + The unique identifier for the factor, used to organize and retrieve the results. + factor_args (FactorArguments, optional): + Arguments related to computing the factors. If not specified, the default values of + `FactorArguments` will be used. + overwrite_output_dir (bool, optional): + If True, the existing factors with the same `factors_name` will be overwritten. + load_from_factors_name (str, optional): + The `factor_name` to load covariance matrices from. By default, covariance matrices with + the same `factor_name` will be used. + """ + self.logger.debug(f"Performing Eigendecomposition with parameters: {locals()}") + + factors_output_dir = self.factors_output_dir(factors_name=factors_name) + os.makedirs(factors_output_dir, exist_ok=True) + if eigendecomposition_exist(output_dir=factors_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing Eigendecomposition results at `{factors_output_dir}`. Skipping.") + return + + factor_args = self._configure_and_save_factor_args( + factor_args=factor_args, factors_output_dir=factors_output_dir, overwrite_output_dir=overwrite_output_dir + ) + + if not FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition: + self.logger.info( + f"Strategy `{factor_args.strategy}` does not require performing Eigendecomposition. Skipping." + ) + return None + + load_factors_output_dir = factors_output_dir + if load_from_factors_name is not None: + self.logger.info(f"Will be loading covariance matrices from factors with name `{load_from_factors_name}`.") + load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) + + if not covariance_matrices_exist(output_dir=load_factors_output_dir): + error_msg = ( + f"Covariance matrices not found at `{load_factors_output_dir}`. " + f"To perform Eigendecomposition, covariance matrices need to be first computed." + ) + self.logger.error(error_msg) + raise FactorsNotFoundError(error_msg) + + with self.profiler.profile("Load Covariance"): + covariance_factors = load_covariance_matrices(output_dir=load_factors_output_dir) + + if load_from_factors_name is not None and self.state.is_main_process: + # Save the loaded covariances to the current factor output directory. + with self.profiler.profile("Save Covariance"): + save_covariance_matrices(output_dir=factors_output_dir, factors=covariance_factors) + loaded_factor_args, _ = self._load_and_configure_factor_args(factors_name=load_from_factors_name) + self._save_arguments( + arguments_name=FACTOR_ARGUMENTS_NAME + "_loaded_covariance", + arguments=loaded_factor_args, + output_dir=factors_output_dir, + overwrite_output_dir=True, + ) + + eigen_factors = None + if self.state.is_main_process: + release_memory() + start_time = time.time() + with self.profiler.profile("Perform Eigendecomposition"): + eigen_factors = perform_eigendecomposition( + covariance_factors=covariance_factors, + model=self.model, + state=self.state, + factor_args=factor_args, + ) + end_time = time.time() + elapsed_time = end_time - start_time + self.logger.info(f"Performed Eigendecomposition in {elapsed_time:.2f} seconds.") + + with self.profiler.profile("Save Eigendecomposition"): + save_eigendecomposition( + output_dir=factors_output_dir, factors=eigen_factors, metadata=factor_args.to_str_dict() + ) + self.logger.info(f"Saved Eigendecomposition results at `{factors_output_dir}`.") + self.state.wait_for_everyone() + self._log_profile_summary() + + def fit_lambda_matrices( + self, + factors_name: str, + dataset: data.Dataset, + per_device_batch_size: Optional[int] = None, + dataloader_kwargs: Optional[DataLoaderKwargs] = None, + factor_args: Optional[FactorArguments] = None, + target_data_partitions: Optional[Sequence[int]] = None, + target_module_partitions: Optional[Sequence[int]] = None, + overwrite_output_dir: bool = False, + load_from_factors_name: Optional[str] = None, + ) -> None: + """Computes Lambda (corrected-eigenvalues) matrices with the given dataset. + + Args: + factors_name (str): + The unique identifier for the factor, used to organize and retrieve the results. + dataset (data.Dataset): + The dataset that will be used to fit Lambda matrices. + per_device_batch_size (int, optional): + The per-device batch size used to fit the factors. If not specified, executable + batch size is automatically determined. + dataloader_kwargs (DataLoaderKwargs, optional): + Controls additional arguments for PyTorch's DataLoader. + factor_args (FactorArguments, optional): + Arguments related to computing the factors. If not specified, the default values of + `FactorArguments` will be used. + target_data_partitions(Sequence[int], optional): + The list of data partition to fit Lambda matrices. By default, Lambda + matrices will be computed for all partitions. + target_module_partitions(Sequence[int], optional): + The list of module partition to fit Lambda matrices. By default, Lambda + matrices will be computed for all partitions. + overwrite_output_dir (bool, optional): + If True, the existing factors with the same `factors_name` will be overwritten. + load_from_factors_name (str, optional): + The `factor_name` to load Eigendecomposition results from. By default, Eigendecomposition + results with the same `factor_name` will be used. + """ + self.logger.debug(f"Fitting Lambda matrices with parameters: {locals()}") + + factors_output_dir = self.factors_output_dir(factors_name=factors_name) + os.makedirs(factors_output_dir, exist_ok=True) + if lambda_matrices_exist(output_dir=factors_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing Lambda matrices at `{factors_output_dir}`. Skipping.") + return + + factor_args = self._configure_and_save_factor_args( + factor_args=factor_args, factors_output_dir=factors_output_dir, overwrite_output_dir=overwrite_output_dir + ) + + if not FactorConfig.CONFIGS[factor_args.strategy].requires_lambda_matrices: + self.logger.info(f"Strategy `{factor_args.strategy}` does not require fitting Lambda matrices. Skipping.") + return + + dataloader_params = self._configure_dataloader(dataloader_kwargs) + if self.state.is_main_process: + self._save_dataset_metadata( + dataset_name="lambda", + dataset=dataset, + output_dir=factors_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + + if load_from_factors_name is not None: + self.logger.info( + f"Will be loading Eigendecomposition results from factors with name `{load_from_factors_name}`." + ) + load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) + else: + load_factors_output_dir = factors_output_dir + + if ( + not eigendecomposition_exist(output_dir=load_factors_output_dir) + and FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition_for_lambda + ): + error_msg = ( + f"Eigendecomposition results not found at `{load_factors_output_dir}`. " + f"To fit Lambda matrices for `{factor_args.strategy}`, Eigendecomposition must be " + f"performed before computing Lambda matrices." + ) + self.logger.error(error_msg) + raise FactorsNotFoundError(error_msg) + + eigen_factors = None + if FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition_for_lambda: + with self.profiler.profile("Load Eigendecomposition"): + eigen_factors = load_eigendecomposition(output_dir=load_factors_output_dir) + if load_from_factors_name is not None and self.state.is_main_process: + with self.profiler.profile("Save Eigendecomposition"): + save_eigendecomposition(output_dir=factors_output_dir, factors=eigen_factors) + loaded_factor_args, _ = self._load_and_configure_factor_args(factors_name=load_from_factors_name) + self._save_arguments( + arguments_name=FACTOR_ARGUMENTS_NAME + "_loaded_eigendecomposition", + arguments=loaded_factor_args, + output_dir=factors_output_dir, + overwrite_output_dir=True, + ) + self.state.wait_for_everyone() + + if factor_args.lambda_max_examples is None: + total_data_examples = len(dataset) + else: + total_data_examples = min([factor_args.lambda_max_examples, len(dataset)]) + self.logger.info(f"Total data examples to fit Lambda matrices: {total_data_examples}.") + + no_partition = factor_args.lambda_data_partition_size == 1 and factor_args.lambda_module_partition_size == 1 + partition_provided = target_data_partitions is not None or target_module_partitions is not None + if no_partition and partition_provided: + error_msg = ( + "`target_data_partitions` or `target_module_partitions` were specified, while" + "the `FactorArguments` did not expect any data and module partition to compute Lambda matrices." + ) + self.logger.error(error_msg) + raise ValueError(error_msg) + + data_partition_indices, target_data_partitions = self._get_data_partition( + total_data_examples=total_data_examples, + data_partition_size=factor_args.lambda_data_partition_size, + target_data_partitions=target_data_partitions, + ) + max_partition_examples = total_data_examples // factor_args.lambda_data_partition_size + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=factor_args.lambda_module_partition_size, + target_module_partitions=target_module_partitions, + ) + + if max_partition_examples < self.state.num_processes: + error_msg = "The number of processes are more than the data examples. Try reducing the number of processes." + self.logger.error(error_msg) + raise ValueError(error_msg) + + all_start_time = get_time(state=self.state) + for data_partition in target_data_partitions: + for module_partition in target_module_partitions: + if no_partition: + partition = None + else: + partition = (data_partition, module_partition) + + if ( + lambda_matrices_exist( + output_dir=factors_output_dir, + partition=partition, + ) + and not overwrite_output_dir + ): + self.logger.info( + f"Found existing Lambda matrices for data partition {data_partition} " + f"and module partition {module_partition} at {factors_output_dir}. Skipping." + ) + continue + + start_index, end_index = data_partition_indices[data_partition] + self.logger.info( + f"Fitting Lambda matrices with data indices ({start_index}, {end_index}) and " + f"modules {module_partition_names[module_partition]}." + ) + + if per_device_batch_size is None: + kwargs = { + "eigen_factors": eigen_factors, + "model": self.model, + "state": self.state, + "task": self.task, + "factor_args": factor_args, + "tracked_module_names": module_partition_names[module_partition], + } + per_device_batch_size = self._find_executable_factors_batch_size( + func=fit_lambda_matrices_with_loader, + func_kwargs=kwargs, + dataset=dataset, + factor_args=factor_args, + dataloader_params=dataloader_params, + total_data_examples=max_partition_examples, + ) + + release_memory() + start_time = get_time(state=self.state) + with self.profiler.profile("Fit Lambda"): + loader = self._get_dataloader( + dataset=dataset, + per_device_batch_size=per_device_batch_size, + dataloader_params=dataloader_params, + indices=list(range(start_index, end_index)), + allow_duplicates=False, + ) + num_data_processed, lambda_factors = fit_lambda_matrices_with_loader( + eigen_factors=eigen_factors, + model=self.model, + state=self.state, + task=self.task, + loader=loader, + factor_args=factor_args, + tracked_module_names=module_partition_names[module_partition], + ) + end_time = get_time(state=self.state) + elapsed_time = end_time - start_time + self.logger.info( + f"Fitted Lambda matrices with {num_data_processed.item()} data points in " + f"{elapsed_time:.2f} seconds." + ) + + with self.profiler.profile("Save Lambda"): + if self.state.is_main_process: + save_lambda_matrices( + output_dir=factors_output_dir, + factors=lambda_factors, + partition=partition, + metadata=factor_args.to_str_dict(), + ) + self.state.wait_for_everyone() + del lambda_factors, loader + self.logger.info(f"Saved Lambda matrices at `{factors_output_dir}`.") + + all_end_time = get_time(state=self.state) + elapsed_time = all_end_time - all_start_time + if not no_partition: + self.logger.info(f"Fitted all partitioned Lambda matrices in {elapsed_time:.2f} seconds.") + self.aggregate_lambda_matrices(factors_name=factors_name) + self.logger.info(f"Saved aggregated Lambda matrices at `{factors_output_dir}`.") + self._log_profile_summary() + + @torch.no_grad() + def aggregate_lambda_matrices( + self, + factors_name: str, + ) -> None: + """Aggregates all partitioned Lambda matrices. The factors will not be aggregated if Lambda matrices + for some data or module partitions are missing. + + Args: + factors_name (str): + The unique identifier for the factor, used to organize and retrieve the results. + """ + factor_args, _ = self._load_and_configure_factor_args(factors_name=factors_name) + with self.profiler.profile("Aggregate Lambda"): + self._aggregate_factors( + factors_name=factors_name, + data_partition_size=factor_args.lambda_data_partition_size, + module_partition_size=factor_args.lambda_module_partition_size, + exists_fnc=lambda_matrices_exist, + load_fnc=load_lambda_matrices, + save_fnc=save_lambda_matrices, + ) diff --git a/kronfluence/computer/score_computer.py b/kronfluence/computer/score_computer.py new file mode 100644 index 0000000..674f374 --- /dev/null +++ b/kronfluence/computer/score_computer.py @@ -0,0 +1,442 @@ +import os +from typing import Callable, Optional, Sequence + +import torch +from score.self import ( + compute_self_scores_with_loaders, + load_self_scores, + save_self_scores, + self_scores_exist, +) +from torch.utils import data + +from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.computer.computer import Computer +from kronfluence.module.constants import FACTOR_TYPE, SCORE_TYPE +from kronfluence.module.tracked_module import ModuleMode +from kronfluence.module.utils import set_mode +from kronfluence.score.pairwise import ( + compute_pairwise_scores_with_loaders, + load_pairwise_scores, + pairwise_scores_exist, + save_pairwise_scores, +) +from kronfluence.utils.dataset import DataLoaderKwargs, find_executable_batch_size +from kronfluence.utils.exceptions import FactorsNotFoundError +from kronfluence.utils.logger import get_time +from kronfluence.utils.save import FACTOR_ARGUMENTS_NAME, SCORE_ARGUMENTS_NAME +from kronfluence.utils.state import release_memory + + +class ScoreComputer(Computer): + """Handles the computation of pairwise influence scores for a given PyTorch model.""" + + def _find_executable_scores_batch_size( + self, + func: Callable, + factor_args: FactorArguments, + loaded_factors, + query_dataset: data.Dataset, + dataloader_params, + per_device_query_batch_size, + train_dataset: data.Dataset, + score_args, + tracked_modules_name, + total_data_examples: Optional[int] = None, + ) -> int: + """Automatically finds executable batch size for performing `func`.""" + if self.state.num_processes > 1: + error_msg = ( + "Automatic batch size search is currently not supported for multi-GPU training. " + "Please manually configure the batch size." + ) + self.logger.error(error_msg) + raise NotImplementedError(error_msg) + + self.logger.info("Automatically determining executable batch size.") + + if total_data_examples is None: + total_data_examples = len(train_dataset) + start_batch_size = min( + [ + factor_args.initial_per_device_batch_size_attempt, + total_data_examples, + ] + ) + + total_query_batch_size = per_device_query_batch_size * self.state.num_processes + query_dataset = data.Subset(dataset=query_dataset, indices=list(range(total_query_batch_size))) + + def executable_batch_size_func(batch_size: int) -> None: + self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") + set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) + self.model.zero_grad(set_to_none=True) + release_memory() + total_batch_size = batch_size * self.state.num_processes + func( + loaded_factors=loaded_factors, + query_dataset=query_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=total_query_batch_size, + per_device_train_batch_size=batch_size, + dataloader_params=dataloader_params, + score_args=score_args, + factor_args=factor_args, + indices=list(range(total_batch_size)), + tracked_module_names=tracked_modules_name, + ) + + per_device_batch_size = find_executable_batch_size( + func=executable_batch_size_func, + start_batch_size=start_batch_size, + ) + self.logger.info(f"Executable batch size determined: {per_device_batch_size}.") + return per_device_batch_size + + def compute_pairwise_scores( + self, + scores_name: str, + factors_name: str, + query_dataset: data.Dataset, + train_dataset: data.Dataset, + per_device_query_batch_size: int, + per_device_train_batch_size: Optional[int] = None, + query_indices: Optional[Sequence[int]] = None, + train_indices: Optional[Sequence[int]] = None, + dataloader_kwargs: Optional[DataLoaderKwargs] = None, + score_args: Optional[ScoreArguments] = None, + target_data_partitions: Optional[Sequence[int]] = None, + target_module_partitions: Optional[Sequence[int]] = None, + overwrite_output_dir: bool = False, + ) -> None: + """Fits all pairwise scores for the given data and module partition.""" + self.logger.debug(f"Computing pairwise scores with parameters: {locals()}") + + def compute_fnc( + loaded_factors, + query_dataset, + per_device_query_batch_size, + train_dataset, + per_device_train_batch_size, + dataloader_params, + score_args, + factor_args, + indices, + tracked_module_names, + ) -> SCORE_TYPE: + query_loader = self._get_dataloader( + dataset=query_dataset, + per_device_batch_size=per_device_query_batch_size, + allow_duplicates=True, + dataloader_params=dataloader_params, + ) + train_loader = self._get_dataloader( + dataset=train_dataset, + per_device_batch_size=per_device_train_batch_size, + indices=indices, + allow_duplicates=True, + stack=True, + dataloader_params=dataloader_params, + ) + scores = compute_pairwise_scores_with_loaders( + model=self.model, + state=self.state, + task=self.task, + loaded_factors=loaded_factors, + query_loader=query_loader, + train_loader=train_loader, + per_device_query_batch_size=per_device_query_batch_size, + score_args=score_args, + factor_args=factor_args, + tracked_module_names=tracked_module_names, + ) + return scores + + self._compute_scores( + scores_name=scores_name, + factors_name=factors_name, + query_dataset=query_dataset, + train_dataset=train_dataset, + exist_fnc=pairwise_scores_exist, + compute_fnc=compute_fnc, + save_fnc=save_pairwise_scores, + per_device_train_batch_size=per_device_train_batch_size, + per_device_query_batch_size=per_device_query_batch_size, + query_indices=query_indices, + train_indices=train_indices, + dataloader_kwargs=dataloader_kwargs, + score_args=score_args, + target_data_partitions=target_data_partitions, + target_module_partitions=target_module_partitions, + overwrite_output_dir=overwrite_output_dir, + ) + + self.aggregate_pairwise_scores(scores_name) + self._log_profile_summary() + + @torch.no_grad() + def aggregate_pairwise_scores(self, scores_name: str) -> None: + """Aggregates pairwise scores computed for all data and module partitions.""" + score_args = self._load_and_configure_score_args(scores_name=scores_name) + no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + if not no_partition: + self._aggregate_scores( + scores_name=scores_name, + score_args=score_args, + exists_fnc=pairwise_scores_exist, + load_fnc=load_pairwise_scores, + save_fnc=save_pairwise_scores, + dim=1, + ) + + def compute_self_scores( + self, + scores_name: str, + factors_name: str, + train_dataset: data.Dataset, + per_device_train_batch_size: Optional[int] = None, + train_indices: Optional[Sequence[int]] = None, + dataloader_kwargs: Optional[DataLoaderKwargs] = None, + score_args: Optional[ScoreArguments] = None, + target_data_partitions: Optional[Sequence[int]] = None, + target_module_partitions: Optional[Sequence[int]] = None, + overwrite_output_dir: bool = False, + ) -> None: + """Fits all pairwise scores for the given data and module partition.""" + self.logger.debug(f"Computing pairwise scores with parameters: {locals()}") + + def compute_fnc( + loaded_factors, + query_dataset, + per_device_query_batch_size, + train_dataset, + per_device_train_batch_size, + dataloader_params, + score_args, + factor_args, + indices, + tracked_module_names, + ) -> SCORE_TYPE: + del query_dataset, per_device_query_batch_size + train_loader = self._get_dataloader( + dataset=train_dataset, + per_device_batch_size=per_device_train_batch_size, + indices=indices, + allow_duplicates=True, + stack=True, + dataloader_params=dataloader_params, + ) + scores = compute_self_scores_with_loaders( + model=self.model, + state=self.state, + task=self.task, + loaded_factors=loaded_factors, + train_loader=train_loader, + score_args=score_args, + factor_args=factor_args, + tracked_module_names=tracked_module_names, + ) + return scores + + self._compute_scores( + scores_name=scores_name, + factors_name=factors_name, + query_dataset=None, + train_dataset=train_dataset, + exist_fnc=self_scores_exist, + compute_fnc=compute_fnc, + save_fnc=save_self_scores, + per_device_train_batch_size=per_device_train_batch_size, + per_device_query_batch_size=None, + query_indices=None, + train_indices=train_indices, + dataloader_kwargs=dataloader_kwargs, + score_args=score_args, + target_data_partitions=target_data_partitions, + target_module_partitions=target_module_partitions, + overwrite_output_dir=overwrite_output_dir, + ) + + self.aggregate_self_scores(scores_name) + self._log_profile_summary() + + @torch.no_grad() + def aggregate_self_scores(self, scores_name: str) -> None: + """Aggregates pairwise scores computed for all data and module partitions.""" + score_args = self._load_and_configure_score_args(scores_name=scores_name) + no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + if not no_partition: + self._aggregate_scores( + scores_name=scores_name, + score_args=score_args, + exists_fnc=self_scores_exist, + load_fnc=load_self_scores, + save_fnc=save_self_scores, + dim=0, + ) + + def _compute_scores( + self, + scores_name: str, + factors_name: str, + query_dataset: Optional[data.Dataset], + train_dataset: data.Dataset, + exist_fnc, + compute_fnc, + save_fnc, + per_device_query_batch_size: Optional[int], + per_device_train_batch_size: Optional[int] = None, + query_indices: Optional[Sequence[int]] = None, + train_indices: Optional[Sequence[int]] = None, + dataloader_kwargs: Optional[DataLoaderKwargs] = None, + score_args: Optional[ScoreArguments] = None, + target_data_partitions: Optional[Sequence[int]] = None, + target_module_partitions: Optional[Sequence[int]] = None, + overwrite_output_dir: bool = False, + ) -> Optional[SCORE_TYPE]: + scores_output_dir = self.scores_output_dir(scores_name=scores_name) + os.makedirs(scores_output_dir, exist_ok=True) + if exist_fnc(output_dir=scores_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing scores at {scores_output_dir}. Skipping.") + return + + score_args = self._configure_score_args(score_args) + factor_args, factor_config = self._load_and_configure_factor_args(factors_name=factors_name) + + if self.state.is_main_process: + if query_dataset is not None: + self._save_dataset_metadata( + dataset_name="query", + dataset=query_dataset, + indices=query_indices, + output_dir=scores_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + self._save_dataset_metadata( + dataset_name="train", + dataset=train_dataset, + indices=train_indices, + output_dir=scores_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + self._save_arguments( + arguments_name=SCORE_ARGUMENTS_NAME, + arguments=score_args, + output_dir=scores_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + self._save_arguments( + arguments_name=FACTOR_ARGUMENTS_NAME, + arguments=factor_args, + output_dir=scores_output_dir, + overwrite_output_dir=overwrite_output_dir, + ) + + dataloader_params = self._configure_dataloader(dataloader_kwargs) + if query_indices is not None: + query_dataset = data.Subset(dataset=query_dataset, indices=query_indices) + if train_indices is not None: + train_dataset = data.Subset(dataset=train_dataset, indices=train_indices) + + with self.profiler.profile("Load All Factors"): + loaded_factors = self._load_all_required_factors( + factors_name=factors_name, + strategy=factor_args.strategy, + factor_config=factor_config, + ) + + no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + partition_provided = target_data_partitions is not None or target_module_partitions is not None + if no_partition and partition_provided: + error_msg = ( + "`target_data_partitions` or `target_module_partitions` were specified, while" + "the `ScoreArguments` did not expect any data or module partition to compute scores." + ) + self.logger.error(error_msg) + raise ValueError(error_msg) + + data_partition_indices, target_data_partitions = self._get_data_partition( + total_data_examples=len(train_dataset), + data_partition_size=score_args.data_partition_size, + target_data_partitions=target_data_partitions, + ) + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=score_args.module_partition_size, + target_module_partitions=target_module_partitions, + ) + + all_start_time = get_time(state=self.state) + for data_partition in target_data_partitions: + for module_partition in target_module_partitions: + if no_partition: + partition = None + else: + partition = (data_partition, module_partition) + + if ( + exist_fnc( + output_dir=scores_output_dir, + partition=partition, + ) + and not overwrite_output_dir + ): + self.logger.info( + f"Found existing pairwise scores for data partition {data_partition} " + f"and module partition {module_partition} at {scores_output_dir}. Skipping." + ) + continue + + start_index, end_index = data_partition_indices[data_partition] + self.logger.info( + f"Computing pairwise scores for data partition with data indices ({start_index}, " + f"{end_index}) and modules {module_partition_names[module_partition]}..." + ) + + if per_device_train_batch_size is None: + per_device_train_batch_size = self._find_executable_scores_batch_size( + loaded_factors=loaded_factors, + func=compute_fnc, + query_dataset=query_dataset, + per_device_query_batch_size=per_device_query_batch_size, + train_dataset=train_dataset, + dataloader_params=dataloader_params, + total_data_examples=len(train_dataset) // score_args.data_partition_size, + score_args=score_args, + factor_args=factor_args, + tracked_modules_name=module_partition_names[0], + ) + + release_memory() + start_time = get_time(state=self.state) + with self.profiler.profile("Compute Score"): + scores = compute_fnc( + loaded_factors=loaded_factors, + query_dataset=query_dataset, + per_device_query_batch_size=per_device_query_batch_size, + train_dataset=train_dataset, + per_device_train_batch_size=per_device_train_batch_size, + dataloader_params=dataloader_params, + score_args=score_args, + factor_args=factor_args, + indices=list(range(start_index, end_index)), + tracked_module_names=module_partition_names[module_partition], + ) + end_time = get_time(state=self.state) + elapsed_time = end_time - start_time + self.logger.info(f"Computed pairwise scores in {elapsed_time:.2f} seconds.") + + with self.profiler.profile("Save Score"): + if self.state.is_main_process: + save_fnc( + output_dir=scores_output_dir, + scores=scores, + partition=partition, + ) + self.state.wait_for_everyone() + del scores + self.logger.info(f"Saved partitioned pairwise scores at {scores_output_dir}.") + + all_end_time = get_time(state=self.state) + elapsed_time = all_end_time - all_start_time + if not no_partition: + self.logger.info(f"Computed all scores in {elapsed_time:.2f} seconds.") diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index 8ee66c0..89f73d0 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -241,23 +241,13 @@ def precondition_gradient( gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) - rotated_gradient = torch.einsum( - "ij,bjl,lk->bik", - ( - gradient_eigenvectors.t(), - gradient, - activation_eigenvectors, - ), - ) + gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) if damping is None: damping = 0.1 * torch.mean(lambda_matrix) - rotated_gradient.div_(lambda_matrix + damping) - return torch.einsum( - "ij,bjl,lk->bik", - (gradient_eigenvectors, rotated_gradient, activation_eigenvectors.t()), - ) + gradient.div_(lambda_matrix + damping) + return torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) class Ekfac(FactorConfig, factor_strategy=FactorStrategy.EKFAC): @@ -306,27 +296,12 @@ def precondition_gradient( lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) - gradient = torch.matmul( - gradient_eigenvectors.t(), - torch.matmul(gradient, activation_eigenvectors) - ) - - # rotated_gradient = torch.einsum( - # "ij,bjl,lk->bik", - # ( - # gradient_eigenvectors.t(), - # gradient, - # activation_eigenvectors, - # ), - # ) + gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) if damping is None: damping = 0.1 * torch.mean(lambda_matrix) gradient.div_(lambda_matrix + damping) - gradient = torch.matmul( - gradient_eigenvectors, - torch.matmul(gradient, activation_eigenvectors.t()) - ) + gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) gradient.mul_(num_lambda_processed) return gradient diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 3af6d1e..d5e24e4 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -31,34 +31,34 @@ def covariance_matrices_save_path( output_dir: Path, - covariance_factor_name: str, + factor_name: str, partition: Optional[PARTITION_TYPE] = None, ) -> Path: """Generates the path for saving/loading covariance matrices.""" - assert covariance_factor_name in COVARIANCE_FACTOR_NAMES + assert factor_name in COVARIANCE_FACTOR_NAMES if partition is not None: data_partition, module_partition = partition return output_dir / ( - f"{covariance_factor_name}_covariance_data_partition{data_partition}" - f"_module_partition{module_partition}.safetensors" + f"{factor_name}_data_partition{data_partition}_module_partition{module_partition}.safetensors" ) - return output_dir / f"{covariance_factor_name}_covariance.safetensors" + return output_dir / f"{factor_name}.safetensors" def save_covariance_matrices( output_dir: Path, - covariance_factors: Dict[str, Dict[str, torch.Tensor]], + factors: FACTOR_TYPE, partition: Optional[PARTITION_TYPE] = None, + metadata: Optional[Dict[str, str]] = None, ) -> None: """Saves covariance matrices to disk.""" - assert set(covariance_factors.keys()) == set(COVARIANCE_FACTOR_NAMES) - for name in covariance_factors: + assert set(factors.keys()) == set(COVARIANCE_FACTOR_NAMES) + for factor_name in factors: save_path = covariance_matrices_save_path( output_dir=output_dir, - covariance_factor_name=name, + factor_name=factor_name, partition=partition, ) - save_file(tensors=covariance_factors[name], filename=save_path) + save_file(tensors=factors[factor_name], filename=save_path, metadata=metadata) def load_covariance_matrices( @@ -67,13 +67,13 @@ def load_covariance_matrices( ) -> FACTOR_TYPE: """Loads covariance matrices from disk.""" covariance_factors = {} - for name in COVARIANCE_FACTOR_NAMES: + for factor_name in COVARIANCE_FACTOR_NAMES: save_path = covariance_matrices_save_path( output_dir=output_dir, - covariance_factor_name=name, + factor_name=factor_name, partition=partition, ) - covariance_factors[name] = load_file(filename=save_path) + covariance_factors[factor_name] = load_file(filename=save_path) return covariance_factors @@ -82,10 +82,10 @@ def covariance_matrices_exist( partition: Optional[PARTITION_TYPE] = None, ) -> bool: """Checks if covariance matrices exist at specified directory.""" - for name in COVARIANCE_FACTOR_NAMES: + for factor_name in COVARIANCE_FACTOR_NAMES: save_path = covariance_matrices_save_path( output_dir=output_dir, - covariance_factor_name=name, + factor_name=factor_name, partition=partition, ) if not save_path.exists(): @@ -127,13 +127,15 @@ def fit_covariance_matrices_with_loader( """ with torch.no_grad(): update_factor_args(model=model, factor_args=factor_args) + remove_attention_mask(model=model) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) set_mode( model=model, tracked_module_names=tracked_module_names, mode=ModuleMode.COVARIANCE, ) - num_data_processed = torch.zeros((1,), dtype=torch.int64, device=state.device, requires_grad=False) + total_steps = 0 + num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) with tqdm( total=len(loader), @@ -141,11 +143,12 @@ def fit_covariance_matrices_with_loader( bar_format=TQDM_BAR_FORMAT, disable=not state.is_main_process, ) as pbar: - for batch in loader: + for index, batch in enumerate(loader): batch = send_to_device(batch, device=state.device) with torch.no_grad(): attention_mask = task.get_attention_mask(batch=batch) - set_attention_mask(model=model, attention_mask=attention_mask) + if attention_mask is not None: + set_attention_mask(model=model, attention_mask=attention_mask) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) @@ -155,20 +158,31 @@ def fit_covariance_matrices_with_loader( sample=not factor_args.use_empirical_fisher, ) loss.backward() - num_data_processed += find_batch_size(batch) + num_data_processed += find_batch_size(data=batch) + total_steps += 1 + + if ( + state.use_distributed + and total_steps % factor_args.distributed_sync_steps == 0 + and index not in [len(loader) - 1, len(loader) - 2] + ): + # Periodically synchronize all processes to avoid timeout at the final covariance synchronization. + state.wait_for_everyone() + pbar.update(1) with torch.no_grad(): remove_attention_mask(model=model) - if state.use_distributed: - # Aggregate covariance matrices across multiple devices or nodes. - synchronize_covariance_matrices(model=model) - dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + if state.use_distributed: + # Aggregate covariance matrices across multiple devices or nodes. + synchronize_covariance_matrices(model=model) + num_data_processed = num_data_processed.to(device=state.device) + dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) - with torch.no_grad(): saved_factors: FACTOR_TYPE = {} - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - saved_factors[covariance_factor_name] = load_factors(model=model, factor_name=covariance_factor_name) + for factor_name in COVARIANCE_FACTOR_NAMES: + saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name) + state.wait_for_everyone() set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) return num_data_processed, saved_factors diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index ab13b07..dde2012 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -27,6 +27,7 @@ from kronfluence.module.utils import ( get_tracked_module_names, load_factors, + remove_attention_mask, set_factors, set_mode, synchronize_lambda_matrices, @@ -39,25 +40,22 @@ def eigendecomposition_save_path( output_dir: Path, - eigen_factor_name: str, + factor_name: str, ) -> Path: """Generates the path for saving/loading Eigendecomposition results.""" - assert eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES - return output_dir / f"{eigen_factor_name}_eigendecomposition.safetensors" + assert factor_name in EIGENDECOMPOSITION_FACTOR_NAMES + return output_dir / f"{factor_name}.safetensors" -def save_eigendecomposition( - output_dir: Path, - eigen_factors: FACTOR_TYPE, -) -> None: +def save_eigendecomposition(output_dir: Path, factors: FACTOR_TYPE, metadata: Optional[Dict[str, str]] = None) -> None: """Saves Eigendecomposition results to disk.""" - assert set(eigen_factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) - for name in eigen_factors: + assert set(factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) + for factor_name in factors: save_path = eigendecomposition_save_path( output_dir=output_dir, - eigen_factor_name=name, + factor_name=factor_name, ) - save_file(tensors=eigen_factors[name], filename=save_path) + save_file(tensors=factors[factor_name], filename=save_path, metadata=metadata) def load_eigendecomposition( @@ -65,12 +63,12 @@ def load_eigendecomposition( ) -> FACTOR_TYPE: """Loads Eigendecomposition results from disk.""" eigen_factors = {} - for name in EIGENDECOMPOSITION_FACTOR_NAMES: + for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: save_path = eigendecomposition_save_path( output_dir=output_dir, - eigen_factor_name=name, + factor_name=factor_name, ) - eigen_factors[name] = load_file(filename=save_path) + eigen_factors[factor_name] = load_file(filename=save_path) return eigen_factors @@ -78,10 +76,10 @@ def eigendecomposition_exist( output_dir: Path, ) -> bool: """Checks if Eigendecomposition results exist at specified path.""" - for name in EIGENDECOMPOSITION_FACTOR_NAMES: + for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: save_path = eigendecomposition_save_path( output_dir=output_dir, - eigen_factor_name=name, + factor_name=factor_name, ) if not save_path.exists(): return False @@ -138,17 +136,18 @@ def perform_eigendecomposition( ), ]: original_dtype = covariance_factors[covariance_name][module_name].dtype - covariance_factors[covariance_name][module_name].div_( - covariance_factors[NUM_COVARIANCE_PROCESSED][module_name] - ) covariance_matrix = covariance_factors[covariance_name][module_name].to( device=state.device, dtype=factor_args.eigendecomposition_dtype, ) - # Deal with cases where covariance matrices are not symmetric due to numerical issues. + # Normalize covariance matrices. + covariance_matrix.div_( + covariance_factors[NUM_COVARIANCE_PROCESSED][module_name].to(device=state.device) + ) + # In cases where covariance matrices are not exactly symmetric due to numerical issues. covariance_matrix = 0.5 * (covariance_matrix + covariance_matrix.t()) eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix) - eigen_factors[eigenvalues_name][module_name] = eigenvalues.to(dtype=original_dtype).cpu() + eigen_factors[eigenvalues_name][module_name] = eigenvalues.to(dtype=original_dtype).contiguous().cpu() eigen_factors[eigenvectors_name][module_name] = eigenvectors.to(dtype=original_dtype).contiguous().cpu() del eigenvalues, eigenvectors pbar.update(1) @@ -157,34 +156,34 @@ def perform_eigendecomposition( def lambda_matrices_save_path( output_dir: Path, - lambda_factor_name: str, + factor_name: str, partition: Optional[PARTITION_TYPE] = None, ) -> Path: """Generates the path for saving/loading Lambda matrices.""" - assert lambda_factor_name in LAMBDA_FACTOR_NAMES + assert factor_name in LAMBDA_FACTOR_NAMES if partition is not None: data_partition, module_partition = partition return output_dir / ( - f"{lambda_factor_name}_lambda_data_partition{data_partition}" - f"_module_partition{module_partition}.safetensors" + f"{factor_name}_data_partition{data_partition}_module_partition{module_partition}.safetensors" ) - return output_dir / f"{lambda_factor_name}_lambda.safetensors" + return output_dir / f"{factor_name}.safetensors" def save_lambda_matrices( output_dir: Path, - lambda_factors: FACTOR_TYPE, + factors: FACTOR_TYPE, partition: Optional[PARTITION_TYPE] = None, + metadata: Optional[Dict[str, str]] = None, ) -> None: """Saves Lambda matrices to disk.""" - assert set(lambda_factors.keys()) == set(LAMBDA_FACTOR_NAMES) - for name in lambda_factors: + assert set(factors.keys()) == set(LAMBDA_FACTOR_NAMES) + for factor_name in factors: save_path = lambda_matrices_save_path( output_dir=output_dir, - lambda_factor_name=name, + factor_name=factor_name, partition=partition, ) - save_file(tensors=lambda_factors[name], filename=save_path) + save_file(tensors=factors[factor_name], filename=save_path, metadata=metadata) def load_lambda_matrices( @@ -193,13 +192,13 @@ def load_lambda_matrices( ) -> FACTOR_TYPE: """Loads Lambda matrices from disk.""" lambda_factors = {} - for name in LAMBDA_FACTOR_NAMES: + for factor_name in LAMBDA_FACTOR_NAMES: save_path = lambda_matrices_save_path( output_dir=output_dir, - lambda_factor_name=name, + factor_name=factor_name, partition=partition, ) - lambda_factors[name] = load_file(filename=save_path) + lambda_factors[factor_name] = load_file(filename=save_path) return lambda_factors @@ -208,10 +207,10 @@ def lambda_matrices_exist( partition: Optional[PARTITION_TYPE] = None, ) -> bool: """Check if Lambda matrices exist at specified path.""" - for name in LAMBDA_FACTOR_NAMES: + for factor_name in LAMBDA_FACTOR_NAMES: save_path = lambda_matrices_save_path( output_dir=output_dir, - lambda_factor_name=name, + factor_name=factor_name, partition=partition, ) if not save_path.exists(): @@ -255,6 +254,7 @@ def fit_lambda_matrices_with_loader( """ with torch.no_grad(): update_factor_args(model=model, factor_args=factor_args) + remove_attention_mask(model=model) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) set_mode( model=model, @@ -264,7 +264,8 @@ def fit_lambda_matrices_with_loader( if eigen_factors is not None: for name in eigen_factors: set_factors(model=model, factor_name=name, factors=eigen_factors[name]) - num_data_processed = torch.zeros((1,), dtype=torch.int64, device=state.device, requires_grad=False) + total_steps = 0 + num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) with tqdm( total=len(loader), @@ -272,8 +273,8 @@ def fit_lambda_matrices_with_loader( bar_format=TQDM_BAR_FORMAT, disable=not state.is_main_process, ) as pbar: - for batch in loader: - batch = send_to_device(batch, device=state.device) + for index, batch in enumerate(loader): + batch = send_to_device(tensor=batch, device=state.device) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) @@ -283,17 +284,30 @@ def fit_lambda_matrices_with_loader( sample=not factor_args.use_empirical_fisher, ) loss.backward() - num_data_processed += find_batch_size(batch) - pbar.update(1) + num_data_processed += find_batch_size(data=batch) + total_steps += 1 - if state.use_distributed: - # Aggregate Lambda matrices across multiple devices or nodes. - synchronize_lambda_matrices(model=model) - dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + if ( + state.use_distributed + and total_steps % factor_args.distributed_sync_steps == 0 + and index not in [len(loader) - 1, len(loader) - 2] + ): + # Periodically synchronize all processes to avoid timeout at the final Lambda synchronization. + state.wait_for_everyone() + + pbar.update(1) with torch.no_grad(): + if state.use_distributed: + # Aggregate Lambda matrices across multiple devices or nodes. + synchronize_lambda_matrices(model=model) + num_data_processed = num_data_processed.to(device=state.device) + dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + saved_factors: FACTOR_TYPE = {} - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - saved_factors[lambda_factor_name] = load_factors(model=model, factor_name=lambda_factor_name) + if state.is_main_process: + for factor_name in LAMBDA_FACTOR_NAMES: + saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name) + state.wait_for_everyone() set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) return num_data_processed, saved_factors diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 8ab5bf7..db5aaf0 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from einconv.utils import get_conv_paddings from einops import rearrange, reduce +from opt_einsum import contract from torch import nn from torch.nn.modules.utils import _pair @@ -26,7 +27,7 @@ def extract_patches( inputs (torch.Tensor): The inputs tensor to the `nn.Conv2d` module. kernel_size (tuple, int): - Size of the convolving kernel. + Size of the convolutional kernel. stride (tuple, int): Stride of the convolution. padding (int, tuple, str): @@ -53,14 +54,14 @@ def extract_patches( inputs = rearrange(tensor=inputs, pattern="b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups) inputs = reduce(tensor=inputs, pattern="b g c_in i1 i2 -> b c_in i1 i2", reduction="mean") - inputs = F.unfold( + inputs_unfold = F.unfold( input=inputs, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride, ) - return rearrange(tensor=inputs, pattern="b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2") + return rearrange(tensor=inputs_unfold, pattern="b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2") class TrackedConv2d(TrackedModule, module_type=nn.Conv2d): @@ -93,11 +94,11 @@ def _get_flattened_activation( pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) - if self.original_module.bias is not None: + if self.original_module.bias is not None and not self.factor_args.ignore_bias: input_activation = torch.cat( [ input_activation, - input_activation.new_ones(input_activation.shape[0], 1), + input_activation.new_ones((input_activation.size(0), 1), requires_grad=False), ], dim=-1, ) @@ -120,7 +121,9 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor return rearrange(output_gradient, "b c o1 o2 -> (b o1 o2) c") def _compute_per_sample_gradient( - self, input_activation: torch.Tensor, output_gradient: torch.Tensor + self, + input_activation: torch.Tensor, + output_gradient: torch.Tensor, ) -> torch.Tensor: """Returns the flattened per-sample-gradient tensor. @@ -134,7 +137,7 @@ def _compute_per_sample_gradient( Returns: torch.Tensor: The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x input_dim x gradient_dim`. An additional dimension is added + with dimension `batch_size x gradient_dim x activation_dim`. An additional dimension is added when the bias term is used. """ input_activation = extract_patches( @@ -150,14 +153,14 @@ def _compute_per_sample_gradient( pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) - if self.original_module.bias is not None: + if self.original_module.bias is not None and not self.factor_args.ignore_bias: input_activation = torch.cat( [ input_activation, - input_activation.new_ones(input_activation.shape[0], 1), + input_activation.new_ones((input_activation.size(0), 1), requires_grad=False), ], dim=-1, ) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - return torch.einsum("abm,abn->amn", (output_gradient, input_activation)) + return contract("abm,abn->amn", output_gradient, input_activation) diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 285c899..e1d0050 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -2,6 +2,7 @@ import torch from einops import rearrange +from opt_einsum import contract from torch import nn from kronfluence.module.tracked_module import TrackedModule @@ -30,12 +31,14 @@ def _get_flattened_activation( if self._attention_mask is not None and flattened_activation.size(0) == self._attention_mask.numel(): # If the binary attention mask is provided, zero-out appropriate activations. flattened_attention_mask = rearrange(tensor=self._attention_mask, pattern="b ... -> (b ...) 1") - flattened_activation = flattened_activation * flattened_attention_mask + # Make sure in-place operation does not change the activation during the forward pass. + flattened_activation = flattened_activation.clone() + flattened_activation.mul_(flattened_attention_mask) - if self.original_module.bias is not None: - append_term = flattened_activation.new_ones(flattened_activation.shape[0], 1) + if self.original_module.bias is not None and not self.factor_args.ignore_bias: + append_term = flattened_activation.new_ones((flattened_activation.size(0), 1), requires_grad=False) if flattened_attention_mask is not None: - append_term = append_term * flattened_attention_mask + append_term.mul_(flattened_attention_mask) flattened_activation = torch.cat([flattened_activation, append_term], dim=-1) count = flattened_activation.size(0) if flattened_attention_mask is None else flattened_attention_mask.sum() @@ -71,12 +74,11 @@ def _compute_per_sample_gradient( Returns: torch.Tensor: The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x gradient_dim x input_dim`. An additional dimension is added + with dimension `batch_size x gradient_dim x activation_dim`. An additional dimension is added when the bias term is used. """ - if self.original_module.bias is not None: - shape = list(input_activation.shape[:-1]) + [1] + if self.original_module.bias is not None and not self.factor_args.ignore_bias: + shape = list(input_activation.size()[:-1]) + [1] append_term = input_activation.new_ones(shape, requires_grad=False) input_activation = torch.cat([input_activation, append_term], dim=-1) - - return torch.einsum("b...i,b...o->boi", (input_activation, output_gradient)) + return contract("b...i,b...o->bio", output_gradient, input_activation) diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 123ee6d..aaad62c 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist from accelerate.utils.dataclasses import BaseEnum +from opt_einsum import contract from torch import nn from torch.utils.hooks import RemovableHandle @@ -47,7 +48,7 @@ def full_backward_gradient_removal_hook( ) -> None: """Removes all saved `.grad` computed by Autograd from model's parameters.""" del grad_inputs, grad_outputs - for parameter in module.original_module.parameters(): + for parameter in module.parameters(): parameter.grad = None @@ -100,6 +101,7 @@ def __init__( self._registered_hooks: List[RemovableHandle] = [] self._cached_hooks: List[RemovableHandle] = [] self._storage: Dict[str, Optional[Any]] = {} + self._storge_at_current_device: bool = False # Storage for activation and pseudo-gradient covariance matrices. # for covariance_factor_name in COVARIANCE_FACTOR_NAMES: @@ -239,8 +241,8 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) - input_activation (torch.Tensor): The input tensor to the module, provided by the PyTorch's forward hook. """ + input_activation = input_activation.to(dtype=self.factor_args.activation_covariance_dtype) flattened_activation, count = self._get_flattened_activation(input_activation) - flattened_activation = flattened_activation.to(dtype=self.factor_args.activation_covariance_dtype) if self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] is None: dimension = flattened_activation.size(1) @@ -254,10 +256,14 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) if self._storage[NUM_COVARIANCE_PROCESSED] is None: + device = None + if isinstance(count, torch.Tensor): + # When using attention masks, `count` can be tensor. + device = count.device self._storage[NUM_COVARIANCE_PROCESSED] = torch.zeros( size=(1,), dtype=torch.int64, - device=flattened_activation.device, + device=device, requires_grad=False, ) # Keep track of total number of elements used to aggregate covariance matrices. @@ -288,8 +294,8 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. """ + output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) flattened_gradient = self._get_flattened_gradient(output_gradient) - flattened_gradient = flattened_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None: # Initialize pseudo-gradient covariance matrix if it does not exist. @@ -308,8 +314,9 @@ def _register_covariance_hooks(self) -> None: def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module - # Compute and update activation covariance matrix in the forward pass. - self._update_activation_covariance_matrix(inputs[0].detach()) + with torch.no_grad(): + # Compute and update activation covariance matrix in the forward pass. + self._update_activation_covariance_matrix(inputs[0].detach()) # Register backward hook to obtain gradient with respect to the output. self._cached_hooks.append(outputs.register_hook(backward_hook)) @@ -340,12 +347,14 @@ def _covariance_matrices_available(self) -> bool: return False return True + @torch.no_grad() def synchronize_covariance_matrices(self) -> None: """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized(): + if dist.is_initialized() and torch.cuda.is_available(): + # Note that only the main process holds the aggregated covariance matrix. for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - torch.distributed.all_reduce( - tensor=self._storage[covariance_factor_name], + dist.reduce( + tensor=self._storage[covariance_factor_name].cuda(), op=dist.ReduceOp.SUM, ) @@ -382,7 +391,7 @@ def _compute_per_sample_gradient( Returns: torch.Tensor: The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x input_dim x gradient_dim`. An additional dimension is added + with dimension `batch_size x gradient_dim x activation_dim`. An additional dimension is added when the bias term is used. """ raise NotImplementedError("Subclasses must implement the `_compute_per_sample_gradient` method.") @@ -442,16 +451,17 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: for i in range(batch_size): sqrt_lambda = torch.matmul( self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - per_sample_gradient[i, :, :], + per_sample_gradient[i], ) self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_()) else: per_sample_gradient = torch.matmul( self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - torch.matmul(per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME]) + torch.matmul(per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME]), ) self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) else: + # Assume that the eigenbasis is identity. self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) self._storage[NUM_LAMBDA_PROCESSED].add_(batch_size) @@ -461,11 +471,12 @@ def _register_lambda_hooks(self) -> None: def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module - cached_activation = inputs[0].detach() - if self.factor_args.cached_activation_cpu_offload: - self._cached_activations.append(cached_activation.cpu()) - else: - self._cached_activations.append(cached_activation) + with torch.no_grad(): + cached_activation = inputs[0].detach().to(dtype=self.factor_args.lambda_dtype) + if self.factor_args.cached_activation_cpu_offload: + self._cached_activations.append(cached_activation.cpu()) + else: + self._cached_activations.append(cached_activation) # Register backward hook to obtain gradient with respect to the output. self._cached_hooks.append(outputs.register_hook(backward_hook)) @@ -477,18 +488,20 @@ def backward_hook(output_gradient: torch.Tensor) -> None: if self.factor_args.cached_activation_cpu_offload: cached_activation = cached_activation.to(device=output_gradient.device) per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(dtype=self.factor_args.lambda_dtype), + input_activation=cached_activation, output_gradient=output_gradient.detach().to(dtype=self.factor_args.lambda_dtype), ) - del cached_activation, output_gradient + del cached_activation if self._cached_per_sample_gradient is None: self._cached_per_sample_gradient = per_sample_gradient else: self._cached_per_sample_gradient.add_(per_sample_gradient) + del per_sample_gradient if len(self._cached_activations) == 0: self._update_lambda_matrix(per_sample_gradient=self._cached_per_sample_gradient) + del self._cached_per_sample_gradient self._cached_per_sample_gradient = None self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) @@ -513,18 +526,21 @@ def _lambda_matrix_available(self) -> bool: return False return True + @torch.no_grad() def synchronize_lambda_matrices(self) -> None: """Aggregates Lambda matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized(): + if dist.is_initialized() and torch.cuda.is_available(): + # Note that only the main process holds the aggregated Lambda matrix. for lambda_factor_name in LAMBDA_FACTOR_NAMES: - torch.distributed.all_reduce( - tensor=self._storage[lambda_factor_name], + torch.distributed.reduce( + tensor=self._storage[lambda_factor_name].cuda(), op=dist.ReduceOp.SUM, ) ################################################## # Methods for computing preconditioned gradient. # ################################################## + @torch.no_grad() def _compute_low_rank_preconditioned_gradient( self, preconditioned_gradient: torch.Tensor, @@ -551,7 +567,8 @@ def _compute_low_rank_preconditioned_gradient( ) U_k = U[:, :, :rank] S_k = S[:, :rank] - V_k = V[:, :, :rank] + # Avoid holding the full memory of the original tensor before indexing. + V_k = V[:, :, :rank].clone() return [ torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), torch.transpose(V_k, 1, 2).contiguous(), @@ -562,11 +579,12 @@ def _register_precondition_gradient_hooks(self) -> None: def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module - cached_activation = inputs[0].detach() - if self.score_args.cached_activation_cpu_offload: - self._cached_activations.append(cached_activation.cpu()) - else: - self._cached_activations.append(cached_activation) + with torch.no_grad(): + cached_activation = inputs[0].detach().to(dtype=self.score_args.per_sample_gradient_dtype) + if self.score_args.cached_activation_cpu_offload: + self._cached_activations.append(cached_activation.cpu()) + else: + self._cached_activations.append(cached_activation) # Register backward hook to obtain gradient with respect to the output. self._cached_hooks.append(outputs.register_hook(backward_hook)) @@ -578,10 +596,10 @@ def backward_hook(output_gradient: torch.Tensor) -> None: if self.score_args.cached_activation_cpu_offload: cached_activation = cached_activation.to(device=output_gradient.device) per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(dtype=self.score_args.per_sample_gradient_dtype), + input_activation=cached_activation, output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), ) - del cached_activation, output_gradient + del cached_activation if self._cached_per_sample_gradient is None: self._cached_per_sample_gradient = per_sample_gradient @@ -590,13 +608,14 @@ def backward_hook(output_gradient: torch.Tensor) -> None: del per_sample_gradient if len(self._cached_activations) == 0: - preconditioned_gradient = FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( - gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype), - storage=self._storage, - damping=self.score_args.damping, - ) + preconditioned_gradient = ( + FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( + gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype), + storage=self._storage, + damping=self.score_args.damping, + ) + ).to(dtype=self.score_args.score_dtype) self._cached_per_sample_gradient = None - preconditioned_gradient = preconditioned_gradient.to(dtype=self.score_args.score_dtype) if ( self.score_args.query_gradient_rank is not None @@ -631,22 +650,26 @@ def get_preconditioned_gradient_batch_size(self) -> Optional[int]: return self._storage[PRECONDITIONED_GRADIENT_NAME].size(0) return None + @torch.no_grad() def truncate_preconditioned_gradient(self, keep_size: int) -> None: """Truncates and keeps only the first keep_size dimension for the preconditioned gradient.""" if self._storage[PRECONDITIONED_GRADIENT_NAME] is not None: if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): self._storage[PRECONDITIONED_GRADIENT_NAME] = [ - self._storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size], - self._storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size], + self._storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size].clone(), + self._storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size].clone(), ] else: - self._storage[PRECONDITIONED_GRADIENT_NAME] = self._storage[PRECONDITIONED_GRADIENT_NAME][:keep_size] + self._storage[PRECONDITIONED_GRADIENT_NAME] = self._storage[PRECONDITIONED_GRADIENT_NAME][ + :keep_size + ].clone() + @torch.no_grad() def synchronize_preconditioned_gradient(self, num_processes: int) -> None: """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized(): + if dist.is_initialized() and torch.cuda.is_available(): if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - for i in range(2): + for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])): size = self._storage[PRECONDITIONED_GRADIENT_NAME][i].size() stacked_matrix = torch.empty( size=(num_processes, size[0], size[1], size[2]), @@ -657,10 +680,9 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: output_tensor=stacked_matrix, input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), ) - self._storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).reshape( - num_processes * size[0], size[1], size[2] + self._storage[PRECONDITIONED_GRADIENT_NAME][i] = ( + stacked_matrix.transpose(0, 1).reshape(num_processes * size[0], size[1], size[2]).contiguous() ) - else: size = self._storage[PRECONDITIONED_GRADIENT_NAME].size() stacked_preconditioned_gradient = torch.empty( @@ -672,8 +694,10 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: output_tensor=stacked_preconditioned_gradient, input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose(0, 1).reshape( - num_processes * size[0], size[1], size[2] + self._storage[PRECONDITIONED_GRADIENT_NAME] = ( + stacked_preconditioned_gradient.transpose(0, 1) + .reshape(num_processes * size[0], size[1], size[2]) + .contiguous() ) ########################################### @@ -684,11 +708,12 @@ def _register_pairwise_score_hooks(self) -> None: def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module - cached_activation = inputs[0].detach() - if self.score_args.cached_activation_cpu_offload: - self._cached_activations.append(cached_activation.cpu()) - else: - self._cached_activations.append(cached_activation) + with torch.no_grad(): + cached_activation = inputs[0].detach().to(dtype=self.score_args.per_sample_gradient_dtype) + if self.score_args.cached_activation_cpu_offload: + self._cached_activations.append(cached_activation.cpu()) + else: + self._cached_activations.append(cached_activation) # Register backward hook to obtain gradient with respect to the output. self._cached_hooks.append(outputs.register_hook(backward_hook)) @@ -698,10 +723,10 @@ def backward_hook(output_gradient: torch.Tensor) -> None: handle.remove() cached_activation = self._cached_activations.pop() per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(dtype=self.score_args.per_sample_gradient_dtype), + input_activation=cached_activation, output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), - ) - del cached_activation, output_gradient + ).to(dtype=self.score_args.score_dtype) + del cached_activation if self._cached_per_sample_gradient is None: self._cached_per_sample_gradient = per_sample_gradient @@ -712,30 +737,21 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # If the module was used multiple times throughout the forward pass, # only compute scores after aggregating all per-sample-gradients. if len(self._cached_activations) == 0: - self._cached_per_sample_gradient = self._cached_per_sample_gradient.to( - dtype=self.score_args.score_dtype - ) if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): # The preconditioned gradient is stored as a low-rank approximation. left_mat, right_mat = self._storage[PRECONDITIONED_GRADIENT_NAME] - scores = torch.einsum( + self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( "qki,toi,qok->qt", right_mat, self._cached_per_sample_gradient, left_mat, ) else: - query_batch_size = self._storage[PRECONDITIONED_GRADIENT_NAME].size(0) - train_batch_size = self._cached_per_sample_gradient.size(0) - # scores = torch.einsum( - # "qio,tio->qt", - # self._storage[PRECONDITIONED_GRADIENT_NAME], - # self._cached_per_sample_gradient, - # ) - scores = torch.matmul(self._storage[PRECONDITIONED_GRADIENT_NAME].view(query_batch_size, -1), - self._cached_per_sample_gradient.view(train_batch_size, -1).t() - ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = scores + self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( + "qio,tio->qt", + self._storage[PRECONDITIONED_GRADIENT_NAME], + self._cached_per_sample_gradient, + ) self._cached_per_sample_gradient = None self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) @@ -750,31 +766,36 @@ def _register_self_score_hooks(self) -> None: def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module - cached_activation = inputs[0].detach() - if self.score_args.cached_activation_cpu_offload: - self._cached_activations.append(cached_activation.cpu()) - else: - self._cached_activations.append(cached_activation) + with torch.no_grad(): + cached_activation = inputs[0].detach().to(dtype=self.score_args.per_sample_gradient_dtype) + if self.score_args.cached_activation_cpu_offload: + self._cached_activations.append(cached_activation.cpu()) + else: + self._cached_activations.append(cached_activation) # Register backward hook to obtain gradient with respect to the output. - outputs.register_hook(backward_hook) + self._cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: + handle = self._cached_hooks.pop() + handle.remove() cached_activation = self._cached_activations.pop() per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(dtype=self.score_args.per_sample_gradient_dtype), + input_activation=cached_activation, output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), ) - del cached_activation, output_gradient + del cached_activation # The preconditioning factors need to be loaded to appropriate device as they will be # used at each iteration. - for name, factor in self._storage.items(): - if factor is not None and isinstance(factor, torch.Tensor): - self._storage[name] = factor.to( - device=per_sample_gradient.device, - dtype=self.score_args.precondition_dtype, - ) + if not self._storge_at_current_device: + for name, factor in self._storage.items(): + if factor is not None and isinstance(factor, torch.Tensor): + self._storage[name] = factor.to( + device=per_sample_gradient.device, + dtype=self.score_args.precondition_dtype, + ) + self._storge_at_current_device = True if self._cached_per_sample_gradient is None: self._cached_per_sample_gradient = per_sample_gradient @@ -785,12 +806,15 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # If the module was used multiple times throughout the forward pass, # only compute scores after aggregating all per-sample-gradients. if len(self._cached_activations) == 0: - preconditioned_gradient = FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( - gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype), - storage=self._storage, - damping=self.score_args.damping, + preconditioned_gradient = ( + FactorConfig.CONFIGS[self.factor_args.strategy] + .precondition_gradient( + gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype), + storage=self._storage, + damping=self.score_args.damping, + ) + .to(dtype=self.score_args.score_dtype) ) - preconditioned_gradient = preconditioned_gradient.to(dtype=self.score_args.score_dtype) self._cached_per_sample_gradient = self._cached_per_sample_gradient.to( dtype=self.score_args.score_dtype ) @@ -813,3 +837,4 @@ def release_scores(self) -> None: self._storage[SELF_SCORE_VECTOR_NAME] = None self._cached_activations = [] self._cached_per_sample_gradient = None + self._storge_at_current_device = False diff --git a/kronfluence/module/utils.py b/kronfluence/module/utils.py index adb179f..1394e24 100644 --- a/kronfluence/module/utils.py +++ b/kronfluence/module/utils.py @@ -29,8 +29,7 @@ def wrap_tracked_modules( factor_args: Optional[FactorArguments] = None, score_args: Optional[ScoreArguments] = None, ) -> nn.Module: - """Inspects all modules within the model and if supported modules for factor & influence - computations are found, wraps them with `TrackedModule`. + """Inspects all modules within the model and, if supported modules are found, wraps them with `TrackedModule`. Args: model (nn.Module): @@ -43,7 +42,8 @@ def wrap_tracked_modules( Arguments related to computing the influence scores. Returns: - nn.Module: The wrapped model with `TrackedModule`. + nn.Module: + The wrapped model with `TrackedModule`. """ if isinstance(model, (DP, DDP, FSDP)): raise ValueError( @@ -52,7 +52,11 @@ def wrap_tracked_modules( ) tracked_module_count = 0 - tracked_module_names = task.influence_modules() if task is not None else None + tracked_module_names = task.tracked_modules() if task is not None else None + tracked_module_exists_dict = None + if tracked_module_names is not None: + tracked_module_exists_dict = {name: False for name in tracked_module_names} + named_modules = model.named_modules() for module_name, module in named_modules: if len(list(module.children())) > 0: @@ -70,11 +74,21 @@ def wrap_tracked_modules( factor_args=factor_args, score_args=score_args, ) + # We need backward hooks for these modules to be activated. tracked_module.requires_grad_(True) parent, target_name = _get_submodules(model=model, key=module_name) setattr(parent, target_name, tracked_module) tracked_module_count += 1 + if tracked_module_exists_dict is not None: + tracked_module_exists_dict[module_name] = True + + if tracked_module_exists_dict is not None and not all(list(tracked_module_exists_dict.values())): + error_msg = ( + f"Some provided tracked modules were not found. The current mapping: " f"{tracked_module_exists_dict}." + ) + raise IllegalTaskConfigurationError(error_msg) + if tracked_module_count == 0: supported_modules_names = [module.__name__ for module in TrackedModule.SUPPORTED_MODULES] error_msg = ( @@ -85,6 +99,7 @@ def wrap_tracked_modules( ) error_msg += f"\n{model}" raise IllegalTaskConfigurationError(error_msg) + return model diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index 5b30e47..aa853b9 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from accelerate.utils import find_batch_size, send_to_device +from accelerate.utils import send_to_device from safetensors.torch import load_file, save_file from torch import nn from torch.utils import data @@ -18,7 +18,6 @@ ) from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( - get_preconditioned_gradient_batch_size, get_tracked_module_names, release_scores, set_factors, @@ -41,7 +40,7 @@ def pairwise_scores_save_path( if partition is not None: data_partition, module_partition = partition return output_dir / ( - f"pairwise_scores_data_partition{data_partition}" f"_module_partition{module_partition}.safetensors" + f"pairwise_scores_data_partition{data_partition}_module_partition{module_partition}.safetensors" ) return output_dir / "pairwise_scores.safetensors" @@ -104,7 +103,6 @@ def _compute_pairwise_dot_products_with_loader( score_chunks[ALL_MODULE_NAME] = [] with torch.no_grad(): - total_query_batch_size = get_preconditioned_gradient_batch_size(model=model) set_mode( model=model, mode=ModuleMode.PAIRWISE_SCORE, @@ -119,7 +117,7 @@ def _compute_pairwise_dot_products_with_loader( disable=not state.is_main_process, ) as pbar: for batch in train_loader: - batch = send_to_device(batch, device=state.device) + batch = send_to_device(tensor=batch, device=state.device) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) @@ -139,16 +137,14 @@ def _compute_pairwise_dot_products_with_loader( ) else: # Aggregate the pairwise scores across all modules. - batch_size = find_batch_size(batch) - pairwise_scores = torch.zeros( - size=(total_query_batch_size, batch_size), - dtype=score_args.score_dtype, - device=state.device, - requires_grad=False, - ) + pairwise_scores = None for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + if pairwise_scores is None: + pairwise_scores = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME) + else: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + # `.cpu()` synchronizes the CUDA stream. score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) release_scores(model=model) @@ -166,28 +162,13 @@ def _compute_pairwise_dot_products_with_loader( for module_name, chunks in score_chunks.items(): total_scores[module_name] = torch.cat(chunks, dim=1) if state.use_distributed: - size = total_scores[module_name].size() - total_scores[module_name] = total_scores[module_name].to( - device=state.device, - non_blocking=False, - ) - torch.cuda.synchronize(state.device) - total_scores[module_name] = total_scores[module_name].t().contiguous() - release_memory() - stacked_scores = torch.empty( - size=(size[1] * state.num_processes, size[0]), - dtype=total_scores[module_name].dtype, - device=state.device, - requires_grad=False, - ) - torch.distributed.all_gather_into_tensor( - output_tensor=stacked_scores, - input_tensor=total_scores[module_name], - ) - stacked_scores = stacked_scores.t().contiguous() - stacked_scores = stacked_scores[:, :dataset_size] - total_scores[module_name] = stacked_scores.cpu() - del stacked_scores + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + torch.distributed.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=1)[:, :dataset_size].cpu() return total_scores @@ -272,7 +253,7 @@ def compute_pairwise_scores_with_loaders( ) as pbar: for query_index, query_batch in enumerate(query_loader): query_batch = send_to_device( - query_batch, + tensor=query_batch, device=state.device, ) @@ -281,13 +262,13 @@ def compute_pairwise_scores_with_loaders( measurement = task.compute_measurement(batch=query_batch, model=model) measurement.backward() - if state.use_distributed: - # Stack preconditioned query gradient across multiple devices or nodes. - synchronize_preconditioned_gradient(model=model, num_processes=state.num_processes) - if query_index == len(query_loader) - 1 and query_remainder > 0: - # Remove duplicate data points if the dataset is not exactly divisible - # by the current batch size. - truncate_preconditioned_gradient(model=model, keep_size=query_remainder) + if state.use_distributed: + # Stack preconditioned query gradient across multiple devices or nodes. + synchronize_preconditioned_gradient(model=model, num_processes=state.num_processes) + if query_index == len(query_loader) - 1 and query_remainder > 0: + # Remove duplicate data points if the dataset is not exactly divisible + # by the current batch size. + truncate_preconditioned_gradient(model=model, keep_size=query_remainder) # Compute the dot product between preconditioning query gradient and all training gradients. release_memory() @@ -299,18 +280,19 @@ def compute_pairwise_scores_with_loaders( score_args=score_args, tracked_module_names=tracked_module_names, ) - - with torch.no_grad(): - for module_name, current_scores in scores.items(): - if module_name not in total_scores_chunks: - total_scores_chunks[module_name] = [] - total_scores_chunks[module_name].append(current_scores) + if state.is_main_process: + with torch.no_grad(): + for module_name, current_scores in scores.items(): + if module_name not in total_scores_chunks: + total_scores_chunks[module_name] = [] + total_scores_chunks[module_name].append(current_scores) + state.wait_for_everyone() pbar.update(1) with torch.no_grad(): set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) - - for module_name in total_scores_chunks: - total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) - + if state.is_main_process: + for module_name in total_scores_chunks: + total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) + state.wait_for_everyone() return total_scores_chunks diff --git a/kronfluence/score/self.py b/kronfluence/score/self.py index 1c19741..85a7f51 100644 --- a/kronfluence/score/self.py +++ b/kronfluence/score/self.py @@ -38,7 +38,7 @@ def self_scores_save_path( if partition is not None: data_partition, module_partition = partition return output_dir / ( - f"self_scores_data_partition{data_partition}" f"_module_partition{module_partition}.safetensors" + f"self_scores_data_partition{data_partition}_module_partition{module_partition}.safetensors" ) return output_dir / "self_scores.safetensors" diff --git a/kronfluence/task.py b/kronfluence/task.py index 90e6947..6bf0eba 100644 --- a/kronfluence/task.py +++ b/kronfluence/task.py @@ -58,7 +58,7 @@ def compute_measurement( """ raise NotImplementedError("Subclasses must implement the `compute_measurement` method.") - def influence_modules(self) -> Optional[List[str]]: + def tracked_modules(self) -> Optional[List[str]]: """Specifies modules for preconditioning factors and influence scores computation. Returns None by default, applying computations to all supported modules (e.g., nn.Linear, nn.Conv2d). diff --git a/kronfluence/utils/dataset.py b/kronfluence/utils/dataset.py index 6557af6..29632c9 100644 --- a/kronfluence/utils/dataset.py +++ b/kronfluence/utils/dataset.py @@ -1,15 +1,18 @@ import math import multiprocessing from dataclasses import dataclass -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, TypeVar import numpy as np import torch import torch.distributed as dist from accelerate.utils import KwargsHandler from accelerate.utils.memory import should_reduce_batch_size +from torch.utils import data from torch.utils.data import Sampler +T_co = TypeVar("T_co", covariant=True) + @dataclass class DataLoaderKwargs(KwargsHandler): @@ -31,9 +34,12 @@ class DataLoaderKwargs(KwargsHandler): def make_indices_partition(total_data_examples: int, partition_size: int) -> List[Tuple[int, int]]: """Returns partitioned indices from the total data examples.""" + if total_data_examples < partition_size: + raise ValueError("The total data examples must be equal or greater than the partition size.") + # See https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length. bins = list(map(len, np.array_split(range(total_data_examples), partition_size))) - indices_bin = [] start_idx = 0 + indices_bin = [] for i in range(partition_size): indices_bin.append((start_idx, start_idx + bins[i])) start_idx += bins[i] @@ -41,7 +47,7 @@ def make_indices_partition(total_data_examples: int, partition_size: int) -> Lis def find_executable_batch_size(func: Callable, start_batch_size: int) -> int: - """Finds executable batch size for calling the function that does not have OOM error. The code is motivated + """Finds executable batch size for calling the function that does not encounter OOM error. The code is motivated from https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/memory.py#L83. """ batch_size = start_batch_size @@ -61,8 +67,8 @@ def find_executable_batch_size(func: Callable, start_batch_size: int) -> int: return batch_size -class DistributedEvalSampler(Sampler): - """DistributedEvalSampler is different from DistributedSampler. It does NOT add extra samples to make +class DistributedEvalSampler(Sampler[T_co]): + """DistributedEvalSampler is different from DistributedSampler: ut does not add extra samples to make it evenly divisible. DistributedEvalSampler should not be used for training. The distributed processes could hang forever. See this issue for details: https://github.com/pytorch/pytorch/issues/22584. @@ -106,13 +112,13 @@ def __len__(self) -> int: return self.num_samples -class DistributedSamplerWithStack(Sampler): +class DistributedSamplerWithStack(Sampler[T_co]): """DistributedSampleWithStack is different from DistributedSampler. Instead of subsampling, it stacks the dataset.""" def __init__( # pylint: disable=super-init-not-called self, - dataset: torch.utils.data.Dataset, + dataset: data.Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, diff --git a/kronfluence/utils/logger.py b/kronfluence/utils/logger.py index 8a5ccd9..96b1caf 100644 --- a/kronfluence/utils/logger.py +++ b/kronfluence/utils/logger.py @@ -1,12 +1,10 @@ -import functools import logging import os import time from collections import defaultdict from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Tuple -import numpy as np import torch import torch.distributed as dist @@ -16,6 +14,9 @@ "{desc} [{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} " "[time left: {remaining}, time spent: {elapsed}]" ) +_TABLE_ROW = Tuple[str, float, int, float, float] +_TABLE_DATA = List[_TABLE_ROW] + class MultiProcessAdapter(logging.LoggerAdapter): """An adapter to assist with logging in multiprocess. @@ -24,21 +25,15 @@ class MultiProcessAdapter(logging.LoggerAdapter): minor modifications. """ - def log(self, level, msg, *args, **kwargs): + def log(self, level: int, msg: str, *args, **kwargs) -> None: """Delegates logger call after checking if we should log.""" if self.isEnabledFor(level) and not self.extra["disable_log"]: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def warning_once(self, *args, **kwargs): - """This method is identical to `logger.warning()`, but will emit the warning with the same - message only once.""" - self.warning(*args, **kwargs) - -def get_logger(name: str, disable_log: bool = False, log_level: int = None): - """Returns the logger with an option to disable.""" +def get_logger(name: str, disable_log: bool = False, log_level: int = None) -> MultiProcessAdapter: + """Returns the logger with an option to disable logging.""" logger = logging.getLogger(name) if log_level is not None: logger.setLevel(log_level) @@ -46,48 +41,38 @@ def get_logger(name: str, disable_log: bool = False, log_level: int = None): return MultiProcessAdapter(logger, {"disable_log": disable_log}) -def _get_monotonic_time() -> float: - """Gets the time after CUDA synchronization.""" - if torch.cuda.is_available() and torch.cuda.is_initialized(): - torch.cuda.synchronize() - return time.monotonic() - - class Profiler: - """Profiling object to measure the time taken to run a certain operation. + """Profiling object to measure the time taken to run a certain operation. The profiler is helpful + for checking any bottlenecks in the code. The code is modified from: - - https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers - - https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/profiler.py + - https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers. + - https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/profiler.py. """ - def __init__(self, local_rank: Optional[int] = None) -> None: - self._local_rank = local_rank + def __init__(self, state: State) -> None: + """Initializes an instance of the Profiler class. + + Args: + state (State): + The current process's information (e.g., device being used). + """ + self.state = state self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) - self.start_time = _get_monotonic_time() - - def set_local_rank(self, local_rank: int) -> None: - """Sets the current local rank.""" - self._local_rank = local_rank - - @property - def local_rank(self) -> int: - """Returns the current local rank.""" - return 0 if self._local_rank is None else self._local_rank def start(self, action_name: str) -> None: - """Start recording the initial time for an action.""" - if self.local_rank != 0: - pass + """Defines how to start recording an action.""" + if not self.state.is_main_process: + return if action_name in self.current_actions: raise ValueError(f"Attempted to start {action_name} which has already started.") self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: - """Stops recording the initial time for an action.""" - if self.local_rank != 0: - pass + """Defines how to record the duration once an action is complete.""" + if not self.state.is_main_process: + return end_time = _get_monotonic_time() if action_name not in self.current_actions: raise ValueError(f"Attempting to stop recording an action " f"({action_name}) which was never started.") @@ -97,30 +82,30 @@ def stop(self, action_name: str) -> None: @contextmanager def profile(self, action_name: str) -> Generator: - """A context manager for Profiler.""" + """Yields a context manager to encapsulate the scope of a profiled action.""" try: self.start(action_name) yield action_name finally: self.stop(action_name) - def _make_report( - self, - ) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]: - total_duration = _get_monotonic_time() - self.start_time - report = [ - ( - str(a), - float(np.mean(d)), - float(np.std(d)), - len(d), - float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration, - ) - for a, d in self.recorded_durations.items() - ] - report.sort(key=lambda x: x[5], reverse=True) - total_calls = sum(x[3] for x in report) + @torch.no_grad() + def _make_report(self) -> Tuple[_TABLE_DATA, float, float]: + total_duration = 0.0 + for a, d in self.recorded_durations.items(): + d_tensor = torch.tensor(d, dtype=torch.float64, requires_grad=False) + total_duration += torch.sum(d_tensor).item() + + report = [] + for a, d in self.recorded_durations.items(): + d_tensor = torch.tensor(d, dtype=torch.float64, requires_grad=False) + len_d = len(d) + sum_d = torch.sum(d_tensor).item() + percentage_d = 100.0 * sum_d / total_duration + report.append((a, sum_d / len_d, len_d, sum_d, percentage_d)) + + report.sort(key=lambda x: x[4], reverse=True) + total_calls = sum(x[2] for x in report) return report, total_calls, total_duration def summary(self) -> str: @@ -131,45 +116,22 @@ def summary(self) -> str: if len(self.recorded_durations) > 0: max_key = max(len(k) for k in self.recorded_durations.keys()) - def log_row(action, mean, std, num_calls, total, per): - row = f"{sep}| {action:<{max_key}s}\t| " - row += f"{mean:<15}\t| {std:<15}\t|" + def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str: + row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row - header_string = log_row( - "Action", - "Mean Duration (s)", - "Std Duration (s)", - "Num Calls", - "Total Time (s)", - "Percentage %", - ) + header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") output_string_len = len(header_string.expandtabs()) - sep_lines = f'{sep}{"-" * output_string_len}' + sep_lines = f"{sep}{'-' * output_string_len}" output_string += sep_lines + header_string + sep_lines - report, total_calls, total_duration = self._make_report() - output_string += log_row( - "Total", - "-----", - "-----", - f"{total_calls:}", - f"{total_duration:.5}", - "100 %", - ) + report_extended, total_calls, total_duration = self._make_report() + output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") output_string += sep_lines - for ( - action, - mean_duration, - std_duration, - num_calls, - total_duration, - duration_per, - ) in report: + for action, mean_duration, num_calls, total_duration, duration_per in report_extended: output_string += log_row( action, f"{mean_duration:.5}", - f"{std_duration:.5}", f"{num_calls}", f"{total_duration:.5}", f"{duration_per:.5}", @@ -180,31 +142,39 @@ def log_row(action, mean, std, num_calls, total, per): class PassThroughProfiler(Profiler): - """A pass through Profiler objective.""" + """A pass through Profiler objective that does not record timing for the profiler.""" def start(self, action_name: str) -> None: - pass + """Defines how to start recording an action.""" + return def stop(self, action_name: str) -> None: - pass + """Defines how to record the duration once an action is complete.""" + return def summary(self) -> str: + """Returns a formatted summary for the Profiler.""" return "" -def sync_ddp_time(_time: float, device: torch.device) -> float: - """Synchronizes the time.""" - time_tensor = torch.tensor(_time, dtype=torch.float64, device=device) - dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) - return time_tensor.item() +# Timing utilities copied from +# https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/pytorch_utils.py. +def _get_monotonic_time() -> float: + """Gets the monotonic time after the CUDA synchronization if necessary.""" + if torch.cuda.is_available() and torch.cuda.is_initialized(): + torch.cuda.synchronize() + return time.monotonic() +@torch.no_grad() def get_time(state: State) -> float: """Gets the current time after synchronizing with other devices.""" if not state.use_distributed: - if torch.cuda.is_available(): + if torch.cuda.is_available() and torch.cuda.is_initialized(): torch.cuda.synchronize() return time.time() torch.cuda.synchronize() - t = time.time() - return sync_ddp_time(t, state.device) + current_time = time.time() + time_tensor = torch.tensor(current_time, dtype=torch.float64, device=state.device, requires_grad=False) + dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) + return time_tensor.item() diff --git a/kronfluence/utils/save.py b/kronfluence/utils/save.py index 589d892..f280d76 100644 --- a/kronfluence/utils/save.py +++ b/kronfluence/utils/save.py @@ -22,13 +22,13 @@ def load_file(path: Path) -> Dict[str, torch.Tensor]: def save_json(obj: Any, path: Path) -> None: - """Saves the object to a json file.""" + """Saves the object to a JSON file.""" with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=4) def load_json(path: Path) -> Dict[str, Any]: - """Loads an object from the json file.""" + """Loads an object from the JSON file.""" with open(path, "rb") as f: obj = json.load(f) return obj diff --git a/kronfluence/utils/state.py b/kronfluence/utils/state.py index a432607..7d485e3 100644 --- a/kronfluence/utils/state.py +++ b/kronfluence/utils/state.py @@ -7,16 +7,17 @@ import torch.distributed as dist from accelerate.state import SharedDict from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel class State: """A singleton class to manage the process environment state, such as device and process count. This class is inspired by Accelerate's `PartialState`: - https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py + https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py. The direct use of `PartialState` from Accelerate can be problematic, since the analysis - (influence computation) environment may be different from training environment. + (influence computation) environment may be different from the training environment. """ _shared_state: Dict[str, Any] = SharedDict() @@ -37,8 +38,8 @@ def __init__(self, cpu: bool = False) -> None: if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu and torch.cuda.is_available(): if not dist.is_initialized(): dist.init_process_group(backend="nccl") - self.num_processes = torch.distributed.get_world_size() - self.process_index = torch.distributed.get_rank() + self.num_processes = dist.get_world_size() + self.process_index = dist.get_rank() self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) self.device = torch.device("cuda", self.local_process_index) self.n_gpus = torch.cuda.device_count() @@ -91,25 +92,21 @@ def wait_for_everyone(self) -> None: """Will stop the execution of the current process until every other process has reached that point (so this does nothing when the script is only run in one process).""" if self.use_distributed: - torch.distributed.barrier() + dist.barrier() @property def default_device(self) -> torch.device: """Finds the default device currently available.""" - if torch.backends.mps.is_available() and torch.backends.mps.is_built(): - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def release_memory() -> None: - """Releases the memory.""" + """Releases the memory by calling `gc.collect()` and `torch.cuda.empty_cache()`.""" + gc.collect() if torch.cuda.is_available(): - gc.collect() torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() @contextlib.contextmanager @@ -118,7 +115,10 @@ def no_sync(model: nn.Module, state: State) -> Callable: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L852. """ context = contextlib.nullcontext - if state.use_distributed: + + # `no_sync()` for FSDP instance can result in higher memory usage, detailed in: + # https://pytorch.org/docs/stable/fsdp.html. + if state.use_distributed and not isinstance(model, FullyShardedDataParallel): context = getattr(model, "no_sync", context) with context(): diff --git a/requirements.txt b/requirements.txt index 72112e0..353244f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ torchvision>=0.16.0 accelerate>=0.27.2 einops>=0.7.0 einconv>=0.1.0 -opt_einsum>=3.3.0 \ No newline at end of file +opt_einsum>=3.3.0 +safetensors>=0.4.2 +tqdm>=4.66.2 \ No newline at end of file diff --git a/tests/factors/test_covariances.py b/tests/factors/test_covariances.py index 5ae8463..32366da 100644 --- a/tests/factors/test_covariances.py +++ b/tests/factors/test_covariances.py @@ -100,6 +100,7 @@ def test_covariance_matrices_batch_size_equivalence( train_size: int, seed: int, ) -> None: + # Covariance matrices should be identical regardless of the batch size used. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -162,6 +163,7 @@ def test_covariance_matrices_partition_equivalence( train_size: int, seed: int, ) -> None: + # Covariance matrices should be identical regardless of the partition used. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -221,6 +223,8 @@ def test_covariance_matrices_attention_mask( train_size: int, seed: int, ) -> None: + # Make sure the attention mask is correctly implemented by comparing with the results + # without any padding (and batch size of 1). model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -298,6 +302,7 @@ def test_covariance_matrices_automatic_batch_size( train_size: int, seed: int, ) -> None: + # Make sure the automatic batch size search feature is working. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -360,6 +365,7 @@ def test_covariance_matrices_max_examples( train_size: int, seed: int, ) -> None: + # Make sure the max covariance data selection is working. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, diff --git a/tests/factors/test_eigens.py b/tests/factors/test_eigens.py index aec2338..0cedba3 100644 --- a/tests/factors/test_eigens.py +++ b/tests/factors/test_eigens.py @@ -152,6 +152,7 @@ def test_lambda_matrices_batch_size_equivalence( train_size: int, seed: int, ) -> None: + # Lambda matrices should be identical regardless of the batch size used. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -216,6 +217,7 @@ def test_lambda_matrices_partition_equivalence( train_size: int, seed: int, ) -> None: + # Covariance matrices should be identical regardless of the partition used. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -275,13 +277,14 @@ def test_lambda_matrices_partition_equivalence( "gpt", ], ) -@pytest.mark.parametrize("train_size", [50]) +@pytest.mark.parametrize("train_size", [63]) @pytest.mark.parametrize("seed", [3]) def test_lambda_matrices_iterative_aggregate( test_name: str, train_size: int, seed: int, ) -> None: + # Make sure aggregated lambda computation is working and the results are identical. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -318,7 +321,7 @@ def test_lambda_matrices_iterative_aggregate( factors_name=factors_name + "_iterative", dataset=train_dataset, factor_args=factor_args, - per_device_batch_size=8, + per_device_batch_size=4, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) @@ -345,6 +348,7 @@ def test_lambda_matrices_max_examples( train_size: int, seed: int, ) -> None: + # Make sure the max Lambda data selection is working. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, diff --git a/tests/test_dataset_utils.py b/tests/test_dataset_utils.py new file mode 100644 index 0000000..66857b6 --- /dev/null +++ b/tests/test_dataset_utils.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest + +from kronfluence.utils.dataset import ( + DistributedEvalSampler, + DistributedSamplerWithStack, + make_indices_partition, +) +from tests.utils import prepare_test + + +@pytest.mark.parametrize("dataset_size", [3, 105, 1027]) +@pytest.mark.parametrize("num_replicas", [4, 105, 1027]) +def test_eval_distributed_sampler( + dataset_size: int, + num_replicas: int, +): + _, train_dataset, _, _, _ = prepare_test( + test_name="mlp", + train_size=dataset_size, + seed=0, + ) + + indices = [] + for rank in range(num_replicas): + sampler = DistributedEvalSampler(train_dataset, num_replicas=num_replicas, rank=rank) + indices.append(np.array(list(iter(sampler)))) + + assert len(np.hstack(indices)) == dataset_size + # Make sure that there aren't any duplicates. + assert len(np.unique(np.hstack(indices))) == dataset_size + + +@pytest.mark.parametrize("dataset_size", [3, 105, 1027]) +@pytest.mark.parametrize("num_replicas", [4, 105, 1027]) +def test_eval_distributed_sampler_with_stack( + dataset_size: int, + num_replicas: int, +): + dataset_size = 1002 + _, train_dataset, _, _, _ = prepare_test( + test_name="mlp", + train_size=dataset_size, + seed=0, + ) + + num_replicas = 4 + indices = [] + for rank in range(num_replicas): + sampler = DistributedSamplerWithStack(train_dataset, num_replicas=num_replicas, rank=rank) + indices.append(np.array(list(iter(sampler)))) + + for i, sample_indices in enumerate(indices): + if i != len(indices) - 1: + assert np.all(np.sort(sample_indices) == sample_indices) + assert len(np.unique(np.hstack(indices[:-1]))) == len(np.hstack(indices[:-1])) + + +@pytest.mark.parametrize("total_data_examples", [520, 1000, 8129]) +@pytest.mark.parametrize("partition_size", [2, 270, 520]) +def test_make_indices_partition(total_data_examples: int, partition_size: int): + indices = make_indices_partition(total_data_examples=total_data_examples, partition_size=partition_size) + assert len(indices) == partition_size + reconstructions = [] + for start_index, end_index in indices: + reconstructions.extend(list(range(start_index, end_index))) + assert len(reconstructions) == total_data_examples + assert len(set(reconstructions)) == len(reconstructions) diff --git a/tests/test_module_utils.py b/tests/test_module_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/testable_tasks/language_modeling.py b/tests/testable_tasks/language_modeling.py index 74a8f85..17b5147 100644 --- a/tests/testable_tasks/language_modeling.py +++ b/tests/testable_tasks/language_modeling.py @@ -125,7 +125,7 @@ def compute_measurement( ) -> torch.Tensor: return self.compute_train_loss(batch, model) - def influence_modules(self) -> List[str]: + def tracked_modules(self) -> List[str]: total_modules = [] for i in range(4):