From 312f43e33bd231be808ddbf3fd20325f0a75bf58 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 1 Jul 2024 17:08:14 -0400 Subject: [PATCH] Clean up factor arguments --- .github/workflows/python-test.yml | 1 + kronfluence/arguments.py | 80 ++- kronfluence/computer/computer.py | 28 +- kronfluence/computer/factor_computer.py | 61 +-- kronfluence/factor/covariance.py | 8 +- kronfluence/factor/eigen.py | 14 +- kronfluence/module/conv2d.py | 36 +- kronfluence/module/linear.py | 16 +- kronfluence/module/tracked_module.py | 49 +- kronfluence/utils/common/factor_arguments.py | 30 +- tests/factors/test_covariances.py | 51 +- tests/factors/test_eigens.py | 464 +----------------- tests/factors/test_lambdas.py | 488 +++++++++++++++++++ tests/test_analyzer.py | 23 +- tests/testable_tasks/language_modeling.py | 10 +- tests/testable_tasks/multiple_choice.py | 10 +- tests/testable_tasks/text_classification.py | 8 +- 17 files changed, 724 insertions(+), 653 deletions(-) create mode 100644 tests/factors/test_lambdas.py diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 038d882..ebadd47 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -29,6 +29,7 @@ jobs: pytest -vx tests/test_testable_tasks.py pytest -vx tests/factors/test_covariances.py pytest -vx tests/factors/test_eigens.py + pytest -vx tests/factors/test_lambdas.py pytest -vx tests/modules/test_modules.py pytest -vx tests/modules/test_per_sample_gradients.py pytest -vx tests/modules/test_svd.py diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index 6c68c36..faa2d62 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -32,110 +32,106 @@ class FactorArguments(Arguments): # General configuration. # strategy: str = field( default="ekfac", - metadata={"help": "Strategy for computing preconditioning factors."}, + metadata={ + "help": "Specifies the algorithm for computing influence factors. Default is 'ekfac' " + "(Eigenvalue-corrected Kronecker-factored Approximate Curvature)." + }, ) use_empirical_fisher: bool = field( default=False, metadata={ - "help": "Whether to use empirical Fisher (using labels from batch) instead of " + "help": "Determines whether to approximate empirical Fisher (using true labels) or " "true Fisher (using sampled labels)." }, ) - distributed_sync_steps: int = field( + distributed_sync_interval: int = field( default=1_000, - metadata={ - "help": "Specifies the total iteration step to synchronize the process when using distributed setting." - }, + metadata={"help": "Number of iterations between synchronization steps in distributed computing settings."}, ) amp_dtype: Optional[torch.dtype] = field( default=None, - metadata={"help": "Dtype for automatic mixed precision (AMP). Disables AMP if None."}, + metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."}, ) - shared_parameters_exist: bool = field( + has_shared_parameters: bool = field( default=False, - metadata={"help": "Specifies whether the shared parameters exist in the forward pass."}, + metadata={"help": "Indicates whether shared parameters are present in the model's forward pass."}, ) # Configuration for fitting covariance matrices. # covariance_max_examples: Optional[int] = field( default=100_000, metadata={ - "help": "Maximum number of examples for fitting covariance matrices. " - "Uses all data examples for the given dataset if None." + "help": "Maximum number of examples to use when fitting covariance matrices. " + "Uses entire dataset if `None`." }, ) - covariance_data_partition_size: int = field( + covariance_data_partitions: int = field( default=1, - metadata={ - "help": "Number of data partitions for computing covariance matrices. " - "For example, when `covariance_data_partition_size = 2`, the dataset is split " - "into 2 chunks and covariance matrices are separately computed for each chunk." - }, + metadata={"help": "Number of partitions to divide the dataset into for covariance matrix computation."}, ) - covariance_module_partition_size: int = field( + covariance_module_partitions: int = field( default=1, metadata={ - "help": "Number of module partitions for computing covariance matrices. " - "For example, when `covariance_module_partition_size = 2`, the modules (layers) are split " - "into 2 chunks and covariance matrices are separately computed for each chunk." + "help": "Number of partitions to divide the model's modules (layers) into for " + "covariance matrix computation." }, ) activation_covariance_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing activation covariance matrices."}, + metadata={"help": "Data type for activation covariance computations."}, ) gradient_covariance_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing pseudo-gradient covariance matrices."}, + metadata={"help": "Data type for pseudo-gradient covariance computations."}, ) # Configuration for performing eigendecomposition. # eigendecomposition_dtype: torch.dtype = field( default=torch.float64, - metadata={"help": "Dtype for performing Eigendecomposition. Recommended to use `torch.float64`."}, + metadata={ + "help": "Data type for eigendecomposition computations. Double precision (`torch.float64`) is " + "recommended for numerical stability." + }, ) # Configuration for fitting Lambda matrices. # lambda_max_examples: Optional[int] = field( default=100_000, metadata={ - "help": "Maximum number of examples for fitting Lambda matrices. " - "Uses all data examples for the given dataset if None." + "help": "Maximum number of examples to use when fitting Lambda matrices. Uses entire dataset if `None`." }, ) - lambda_data_partition_size: int = field( + lambda_data_partitions: int = field( default=1, - metadata={ - "help": "Number of data partitions for computing Lambda matrices. " - "For example, when `lambda_data_partition_size = 2`, the dataset is split " - "into 2 chunks and Lambda matrices are separately computed for each chunk." - }, + metadata={"help": "Number of partitions to divide the dataset into for Lambda matrix computation."}, ) - lambda_module_partition_size: int = field( + lambda_module_partitions: int = field( default=1, metadata={ - "help": "Number of module partitions for computing Lambda matrices. " - "For example, when `lambda_module_partition_size = 2`, the modules (layers) are split " - "into 2 chunks and Lambda matrices are separately computed for each chunk." + "help": "Number of partitions to divide the model's modules (layers) into for Lambda matrix computation." }, ) - lambda_iterative_aggregate: bool = field( + use_iterative_lambda_aggregation: bool = field( default=False, metadata={ - "help": "Whether to aggregate squared sum of projected per-sample-gradient with for-loop iterations." + "help": "If True, aggregates the squared sum of projected per-sample gradients " + "iteratively to reduce GPU memory usage." }, ) - cached_activation_cpu_offload: bool = field( + offload_activations_to_cpu: bool = field( default=False, - metadata={"help": "Whether to offload cached activation to CPU for computing the per-sample-gradient."}, + metadata={ + "help": "If True, offloads cached activations to CPU memory when computing " + "per-sample gradients, reducing GPU memory usage." + }, ) per_sample_gradient_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing per-sample-gradients."}, + metadata={"help": "Data type for per-sample-gradient computations."}, ) lambda_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Dtype for computing Lambda (corrected eigenvalues) matrices."}, + metadata={"help": "Data type for Lambda matrix computations."}, ) diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index ded0e3e..9abd0d3 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -249,13 +249,13 @@ def _configure_dataloader(self, dataloader_kwargs: DataLoaderKwargs) -> Dict[str def _get_data_partition( self, total_data_examples: int, - data_partition_size: int, + data_partitions: int, target_data_partitions: Optional[Union[int, List[int]]], ) -> Tuple[List[Tuple[int, int]], List[int]]: """Partitions the dataset into several chunks.""" - if total_data_examples < data_partition_size: + if total_data_examples < data_partitions: error_msg = ( - f"Data partition size ({data_partition_size}) cannot be greater than the " + f"Data partition size ({data_partitions}) cannot be greater than the " f"total data points ({total_data_examples}). Please reduce the data partition " f"size in the argument." ) @@ -263,20 +263,20 @@ def _get_data_partition( raise ValueError(error_msg) indices_partitions = make_indices_partition( - total_data_examples=total_data_examples, partition_size=data_partition_size + total_data_examples=total_data_examples, partition_size=data_partitions ) if target_data_partitions is None: - target_data_partitions = list(range(data_partition_size)) + target_data_partitions = list(range(data_partitions)) if isinstance(target_data_partitions, int): target_data_partitions = [target_data_partitions] for data_partition in target_data_partitions: - if data_partition < 0 or data_partition > data_partition_size: + if data_partition < 0 or data_partition > data_partitions: error_msg = ( f"Invalid data partition {data_partition} encountered. " - f"The module partition needs to be in between [0, {data_partition_size})." + f"The module partition needs to be in between [0, {data_partitions})." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -285,15 +285,15 @@ def _get_data_partition( def _get_module_partition( self, - module_partition_size: int, + module_partitions: int, target_module_partitions: Optional[Union[int, List[int]]], ) -> Tuple[List[List[str]], List[int]]: """Partitions the modules into several chunks.""" tracked_module_names = get_tracked_module_names(self.model) - if len(tracked_module_names) < module_partition_size: + if len(tracked_module_names) < module_partitions: error_msg = ( - f"Module partition size ({module_partition_size}) cannot be greater than the " + f"Module partition size ({module_partitions}) cannot be greater than the " f"total tracked modules ({len(tracked_module_names)}). Please reduce the module partition " f"size in the argument." ) @@ -302,20 +302,20 @@ def _get_module_partition( modules_partition_list = make_modules_partition( total_module_names=tracked_module_names, - partition_size=module_partition_size, + partition_size=module_partitions, ) if target_module_partitions is None: - target_module_partitions = list(range(module_partition_size)) + target_module_partitions = list(range(module_partitions)) if isinstance(target_module_partitions, int): target_module_partitions = [target_module_partitions] for module_partition in target_module_partitions: - if module_partition < 0 or module_partition > module_partition_size: + if module_partition < 0 or module_partition > module_partitions: error_msg = ( f"Invalid module partition {module_partition} encountered. " - f"The module partition needs to be in between [0, {module_partition_size})." + f"The module partition needs to be in between [0, {module_partitions})." ) self.logger.error(error_msg) raise ValueError(error_msg) diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py index 205a5ba..af67eba 100644 --- a/kronfluence/computer/factor_computer.py +++ b/kronfluence/computer/factor_computer.py @@ -60,9 +60,9 @@ def _configure_and_save_factor_args( def _aggregate_factors( self, factors_name: str, - data_partition_size: int, - module_partition_size: int, - exists_fnc: Callable, + data_partitions: int, + module_partitions: int, + exist_fnc: Callable, load_fnc: Callable, save_fnc: Callable, ) -> Optional[FACTOR_TYPE]: @@ -73,9 +73,9 @@ def _aggregate_factors( self.logger.error(error_msg) raise FileNotFoundError(error_msg) - all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] + all_required_partitions = [(i, j) for i in range(data_partitions) for j in range(module_partitions)] all_partition_exists = all( - exists_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions + exist_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions ) if not all_partition_exists: self.logger.warning("Factors are not aggregated as factors for some partitions are not yet computed.") @@ -83,8 +83,8 @@ def _aggregate_factors( start_time = time.time() aggregated_factors: FACTOR_TYPE = {} - for data_partition in range(data_partition_size): - for module_partition in range(module_partition_size): + for data_partition in range(data_partitions): + for module_partition in range(module_partitions): loaded_factors = load_fnc( output_dir=factors_output_dir, partition=(data_partition, module_partition), @@ -95,9 +95,10 @@ def _aggregate_factors( for module_name in factors: if module_name not in aggregated_factors[factor_name]: - aggregated_factors[factor_name][module_name] = factors[module_name] - else: - aggregated_factors[factor_name][module_name].add_(factors[module_name]) + aggregated_factors[factor_name][module_name] = torch.zeros( + size=factors[module_name].shape, dtype=factors[module_name].dtype, requires_grad=False + ) + aggregated_factors[factor_name][module_name].add_(factors[module_name]) del loaded_factors save_fnc( output_dir=factors_output_dir, @@ -227,9 +228,7 @@ def fit_covariance_matrices( total_data_examples = min([factor_args.covariance_max_examples, len(dataset)]) self.logger.info(f"Total data examples to fit covariance matrices: {total_data_examples}.") - no_partition = ( - factor_args.covariance_data_partition_size == 1 and factor_args.covariance_module_partition_size == 1 - ) + no_partition = factor_args.covariance_data_partitions == 1 and factor_args.covariance_module_partitions == 1 partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -241,17 +240,20 @@ def fit_covariance_matrices( data_partition_indices, target_data_partitions = self._get_data_partition( total_data_examples=total_data_examples, - data_partition_size=factor_args.covariance_data_partition_size, + data_partitions=factor_args.covariance_data_partitions, target_data_partitions=target_data_partitions, ) - max_partition_examples = total_data_examples // factor_args.covariance_data_partition_size + max_partition_examples = total_data_examples // factor_args.covariance_data_partitions module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=factor_args.covariance_module_partition_size, + module_partitions=factor_args.covariance_module_partitions, target_module_partitions=target_module_partitions, ) if max_partition_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples. Try reducing the number of processes." + error_msg = ( + "The number of processes are larger than the total data examples. " + "Try reducing the number of processes." + ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -372,9 +374,9 @@ def aggregate_covariance_matrices( with self.profiler.profile("Aggregate Covariance"): self._aggregate_factors( factors_name=factors_name, - data_partition_size=factor_args.covariance_data_partition_size, - module_partition_size=factor_args.covariance_module_partition_size, - exists_fnc=covariance_matrices_exist, + data_partitions=factor_args.covariance_data_partitions, + module_partitions=factor_args.covariance_module_partitions, + exist_fnc=covariance_matrices_exist, load_fnc=load_covariance_matrices, save_fnc=save_covariance_matrices, ) @@ -580,7 +582,7 @@ def fit_lambda_matrices( total_data_examples = min([factor_args.lambda_max_examples, len(dataset)]) self.logger.info(f"Total data examples to fit Lambda matrices: {total_data_examples}.") - no_partition = factor_args.lambda_data_partition_size == 1 and factor_args.lambda_module_partition_size == 1 + no_partition = factor_args.lambda_data_partitions == 1 and factor_args.lambda_module_partitions == 1 partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -592,17 +594,20 @@ def fit_lambda_matrices( data_partition_indices, target_data_partitions = self._get_data_partition( total_data_examples=total_data_examples, - data_partition_size=factor_args.lambda_data_partition_size, + data_partitions=factor_args.lambda_data_partitions, target_data_partitions=target_data_partitions, ) - max_partition_examples = total_data_examples // factor_args.lambda_data_partition_size + max_partition_examples = total_data_examples // factor_args.lambda_data_partitions module_partition_names, target_module_partitions = self._get_module_partition( - module_partition_size=factor_args.lambda_module_partition_size, + module_partitions=factor_args.lambda_module_partitions, target_module_partitions=target_module_partitions, ) if max_partition_examples < self.state.num_processes: - error_msg = "The number of processes are more than the data examples. Try reducing the number of processes." + error_msg = ( + "The number of processes are larger than the total data examples. " + "Try reducing the number of processes." + ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -725,9 +730,9 @@ def aggregate_lambda_matrices( with self.profiler.profile("Aggregate Lambda"): self._aggregate_factors( factors_name=factors_name, - data_partition_size=factor_args.lambda_data_partition_size, - module_partition_size=factor_args.lambda_module_partition_size, - exists_fnc=lambda_matrices_exist, + data_partitions=factor_args.lambda_data_partitions, + module_partitions=factor_args.lambda_module_partitions, + exist_fnc=lambda_matrices_exist, load_fnc=load_lambda_matrices, save_fnc=save_lambda_matrices, ) diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 86ad2ca..3eb5009 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -162,8 +162,8 @@ def fit_covariance_matrices_with_loader( if attention_mask is not None: set_attention_mask(model=model, attention_mask=attention_mask) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=factor_args.amp_dtype): loss = task.compute_train_loss( batch=batch, @@ -174,7 +174,7 @@ def fit_covariance_matrices_with_loader( if ( state.use_distributed - and total_steps % factor_args.distributed_sync_steps == 0 + and total_steps % factor_args.distributed_sync_interval == 0 and index not in [len(loader) - 1, len(loader) - 2] ): # Periodically synchronizes all processes to avoid timeout at the final synchronization. @@ -193,13 +193,13 @@ def fit_covariance_matrices_with_loader( saved_factors: FACTOR_TYPE = {} for factor_name in COVARIANCE_FACTOR_NAMES: - saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name, clone=True) - state.wait_for_everyone() + saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name, clone=False) # Clean up the memory. model.zero_grad(set_to_none=True) if enable_amp: set_gradient_scale(model=model, gradient_scale=1.0) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) + state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index b9db412..d28a8d9 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -115,9 +115,7 @@ def perform_eigendecomposition( Returns: FACTOR_TYPE: The Eigendecomposition results. These results are organized in nested dictionaries, where the first key - is the name of the factor (e.g.,activation eigenvector), and the second key is the module name. - disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + is the name of the factor (e.g., activation eigenvector), and the second key is the module name. """ eigen_factors: FACTOR_TYPE = {} for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: @@ -306,8 +304,8 @@ def fit_lambda_matrices_with_loader( for index, batch in enumerate(loader): batch = send_to_device(tensor=batch, device=state.device) - model.zero_grad(set_to_none=True) with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=factor_args.amp_dtype): loss = task.compute_train_loss( batch=batch, @@ -316,14 +314,14 @@ def fit_lambda_matrices_with_loader( ) scaler.scale(loss).backward() - if factor_args.shared_parameters_exist: + if factor_args.has_shared_parameters: # If shared parameter exists, Lambda matrices are computed and updated only after all # per-sample-gradients are aggregated. finalize_lambda_matrices(model=model) if ( state.use_distributed - and total_steps % factor_args.distributed_sync_steps == 0 + and total_steps % factor_args.distributed_sync_interval == 0 and index not in [len(loader) - 1, len(loader) - 2] ): # Periodically synchronizes all processes to avoid timeout at the final synchronization. @@ -343,13 +341,13 @@ def fit_lambda_matrices_with_loader( saved_factors: FACTOR_TYPE = {} if state.is_main_process: for factor_name in LAMBDA_FACTOR_NAMES: - saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name, clone=True) - state.wait_for_everyone() + saved_factors[factor_name] = load_factors(model=model, factor_name=factor_name, clone=False) # Clean up the memory. model.zero_grad(set_to_none=True) if enable_amp: set_gradient_scale(model=model, gradient_scale=1.0) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) + state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 5e0df30..6e08fda 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -68,13 +68,39 @@ class TrackedConv2d(TrackedModule, module_type=nn.Conv2d): """A tracking wrapper for `nn.Conv2D` modules.""" @property - def weight(self) -> torch.Tensor: - """Returns the weight matrix.""" + def in_channels(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.in_channels + + @property + def out_channels(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.out_channels + + @property + def kernel_size(self) -> Tuple[int, int]: # pylint: disable=missing-function-docstring + return self.original_module.kernel_size + + @property + def padding(self) -> Tuple[int, int]: # pylint: disable=missing-function-docstring + return self.original_module.padding + + @property + def dilation(self) -> Tuple[int, int]: # pylint: disable=missing-function-docstring + return self.original_module.dilation + + @property + def groups(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.groups + + @property + def padding_mode(self) -> str: # pylint: disable=missing-function-docstring + return self.original_module.padding_mode + + @property + def weight(self) -> torch.Tensor: # pylint: disable=missing-function-docstring return self.original_module.weight @property - def bias(self) -> torch.Tensor: - """Returns the bias.""" + def bias(self) -> Optional[torch.Tensor]: # pylint: disable=missing-function-docstring return self.original_module.bias def _get_flattened_activation( diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 6644f9f..54b5e67 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from einops import rearrange @@ -12,13 +12,19 @@ class TrackedLinear(TrackedModule, module_type=nn.Linear): """A tracking wrapper for `nn.Linear` modules.""" @property - def weight(self) -> torch.Tensor: - """Returns the weight matrix.""" + def in_features(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.in_features + + @property + def out_features(self) -> int: # pylint: disable=missing-function-docstring + return self.original_module.out_features + + @property + def weight(self) -> torch.Tensor: # pylint: disable=missing-function-docstring return self.original_module.weight @property - def bias(self) -> torch.Tensor: - """Returns the bias.""" + def bias(self) -> Optional[torch.Tensor]: # pylint: disable=missing-function-docstring return self.original_module.bias def _get_flattened_activation( diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index e6d41dd..b74c53b 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -268,6 +268,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient) if self._gradient_scale != 1.0: + # Avoids in-place operation here. flattened_gradient = flattened_gradient * self._gradient_scale if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None: @@ -280,9 +281,9 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N ) self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) - # This is not necessary as `NUM_GRADIENT_COVARIANCE_PROCESSED` should be identical to - # `NUM_ACTIVATION_COVARIANCE_PROCESSED` in most cases. However, they can be different when using - # gradient checkpointing or torch compile. + # This is not necessary in most cases as `NUM_GRADIENT_COVARIANCE_PROCESSED` should be typically identical to + # `NUM_ACTIVATION_COVARIANCE_PROCESSED`. However, they can be different when using gradient checkpointing + # or torch compile (`torch.compile`). if self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None: self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( size=(1,), @@ -298,14 +299,14 @@ def _register_covariance_hooks(self) -> None: @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - # Computes and updates activation covariance matrix in the forward pass. + # Computes and updates activation covariance in the forward pass. self._update_activation_covariance_matrix(inputs[0].detach().clone()) # Registers backward hook to obtain gradient with respect to the output. outputs.register_hook(backward_hook) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: - # Computes and updates pseudo-gradient covariance matrix in the backward pass. + # Computes and updates pseudo-gradient covariance in the backward pass. self._update_gradient_covariance_matrix(output_gradient.detach()) self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) @@ -357,7 +358,7 @@ def _compute_per_sample_gradient( self, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: """Returns the flattened per-sample-gradient tensor. For a brief introduction to - per-sample-gradients, see https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. + per-sample-gradient, see https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. Args: input_activation (torch.Tensor): @@ -412,7 +413,7 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: ) if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: - if self.factor_args.lambda_iterative_aggregate: + if self.factor_args.use_iterative_lambda_aggregation: # This batch-wise iterative update can be useful when the GPU memory is limited. per_sample_gradient = torch.matmul( per_sample_gradient, @@ -443,16 +444,16 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: self._storage[NUM_LAMBDA_PROCESSED].add_(batch_size) def _register_lambda_hooks(self) -> None: - """Installs forward and backward hooks for computation of the Lambda matrices.""" + """Installs forward and backward hooks for computation of Lambda matrices.""" @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module cached_activation = inputs[0].detach().clone().to(dtype=self.factor_args.per_sample_gradient_dtype) - if self.factor_args.cached_activation_cpu_offload: + if self.factor_args.offload_activations_to_cpu: cached_activation = cached_activation.cpu() - if self.factor_args.shared_parameters_exist: + if self.factor_args.has_shared_parameters: if self._cached_activations is None: self._cached_activations = [] self._cached_activations.append(cached_activation) @@ -460,14 +461,14 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. self._cached_activations = cached_activation # Registers backward hook to obtain gradient with respect to the output. - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) + outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self._cached_activations is None: raise RuntimeError( f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." + "Set `has_shared_parameters=True` to avoid this error." ) per_sample_gradient = self._compute_per_sample_gradient( input_activation=self._cached_activations.to(device=output_gradient.device), @@ -620,21 +621,21 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. if self.score_args.cached_activation_cpu_offload: cached_activation = cached_activation.cpu() - if self.factor_args.shared_parameters_exist: + if self.factor_args.has_shared_parameters: if self._cached_activations is None: self._cached_activations = [] self._cached_activations.append(cached_activation) else: self._cached_activations = cached_activation - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) + outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self._cached_activations is None: raise RuntimeError( f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." + "Set `has_shared_parameters=True` to avoid this error." ) per_sample_gradient = self._compute_per_sample_gradient( input_activation=self._cached_activations.to(device=output_gradient.device), @@ -824,21 +825,21 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. if self.score_args.cached_activation_cpu_offload: cached_activation = cached_activation.cpu() - if self.factor_args.shared_parameters_exist: + if self.factor_args.has_shared_parameters: if self._cached_activations is None: self._cached_activations = [] self._cached_activations.append(cached_activation) else: self._cached_activations = cached_activation - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) + outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self._cached_activations is None: raise RuntimeError( f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid the error." + "Set `has_shared_parameters=True` to avoid the error." ) per_sample_gradient = self._compute_per_sample_gradient( input_activation=self._cached_activations.to(device=output_gradient.device), @@ -915,21 +916,21 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. if self.score_args.cached_activation_cpu_offload: cached_activation = cached_activation.cpu() - if self.factor_args.shared_parameters_exist: + if self.factor_args.has_shared_parameters: if self._cached_activations is None: self._cached_activations = [] self._cached_activations.append(cached_activation) else: self._cached_activations = cached_activation - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) + outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self._cached_activations is None: raise RuntimeError( f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." + "Set `has_shared_parameters=True` to avoid this error." ) per_sample_gradient = self._compute_per_sample_gradient( input_activation=self._cached_activations.to(device=output_gradient.device), @@ -997,21 +998,21 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. if self.score_args.cached_activation_cpu_offload: cached_activation = cached_activation.cpu() - if self.factor_args.shared_parameters_exist: + if self.factor_args.has_shared_parameters: if self._cached_activations is None: self._cached_activations = [] self._cached_activations.append(cached_activation) else: self._cached_activations = cached_activation - outputs.register_hook(shared_backward_hook if self.factor_args.shared_parameters_exist else backward_hook) + outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: if self._cached_activations is None: raise RuntimeError( f"The module {self.name} is used several times during a forward pass. " - "Set `shared_parameters_exist=True` to avoid this error." + "Set `has_shared_parameters=True` to avoid this error." ) per_sample_gradient = self._compute_per_sample_gradient( input_activation=self._cached_activations.to(device=output_gradient.device), diff --git a/kronfluence/utils/common/factor_arguments.py b/kronfluence/utils/common/factor_arguments.py index 122baea..a40d3e3 100644 --- a/kronfluence/utils/common/factor_arguments.py +++ b/kronfluence/utils/common/factor_arguments.py @@ -9,14 +9,15 @@ def default_factor_arguments(strategy: str = "ekfac") -> FactorArguments: return factor_args -def test_factor_arguments(strategy: str = "ekfac") -> FactorArguments: +def pytest_factor_arguments(strategy: str = "ekfac") -> FactorArguments: """Factor arguments used for unit tests.""" factor_args = FactorArguments(strategy=strategy) + # Makes the computations deterministic. factor_args.use_empirical_fisher = True factor_args.activation_covariance_dtype = torch.float64 factor_args.gradient_covariance_dtype = torch.float64 factor_args.per_sample_gradient_dtype = torch.float64 - factor_args.lambda_dtype = torch.float32 + factor_args.lambda_dtype = torch.float64 return factor_args @@ -34,7 +35,7 @@ def smart_low_precision_factor_arguments( def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: - """Factor arguments with low precision.""" + """Factor arguments with low precision for all computations.""" factor_args = FactorArguments(strategy=strategy) factor_args.amp_dtype = dtype factor_args.activation_covariance_dtype = dtype @@ -45,27 +46,18 @@ def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dty def reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: - """Factor arguments with low precision + iterative lambda update.""" + """Factor arguments with low precision and iterative lambda aggregations.""" factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) - factor_args.lambda_iterative_aggregate = True + factor_args.use_iterative_lambda_aggregation = True return factor_args def extreme_reduce_memory_factor_arguments( strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 ) -> FactorArguments: - """Factor arguments for models that is difficult to fit in a single GPU.""" - factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) - factor_args.lambda_iterative_aggregate = True - factor_args.cached_activation_cpu_offload = True - factor_args.covariance_module_partition_size = 4 - factor_args.lambda_module_partition_size = 4 - return factor_args - - -def large_dataset_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: - """Factor arguments for large models and datasets.""" - factor_args = smart_low_precision_factor_arguments(strategy=strategy, dtype=dtype) - factor_args.covariance_data_partition_size = 4 - factor_args.lambda_data_partition_size = 4 + """Factor arguments for models that is difficult to fit with a single GPU.""" + factor_args = reduce_memory_factor_arguments(strategy=strategy, dtype=dtype) + factor_args.offload_activations_to_cpu = True + factor_args.covariance_module_partitions = 4 + factor_args.lambda_module_partitions = 4 return factor_args diff --git a/tests/factors/test_covariances.py b/tests/factors/test_covariances.py index 1c6e97f..78b804b 100644 --- a/tests/factors/test_covariances.py +++ b/tests/factors/test_covariances.py @@ -5,7 +5,7 @@ from kronfluence.utils.common.factor_arguments import ( default_factor_arguments, - test_factor_arguments, + pytest_factor_arguments, ) from kronfluence.utils.constants import ( ACTIVATION_COVARIANCE_MATRIX_NAME, @@ -76,6 +76,7 @@ def test_fit_covariance_matrices( ) assert set(covariance_factors.keys()) == set(COVARIANCE_FACTOR_NAMES) assert len(covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME]) > 0 + assert len(covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME]) > 0 for module_name in covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME]: assert covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME][module_name].dtype == activation_covariance_dtype assert covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME][module_name].dtype == gradient_covariance_dtype @@ -109,7 +110,7 @@ def test_covariance_matrices_batch_size_equivalence( task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, @@ -146,14 +147,14 @@ def test_covariance_matrices_batch_size_equivalence( "conv_bn", ], ) -@pytest.mark.parametrize("data_partition_size", [2, 4]) -@pytest.mark.parametrize("module_partition_size", [2, 3]) +@pytest.mark.parametrize("data_partitions", [2, 4]) +@pytest.mark.parametrize("module_partitions", [2, 3]) @pytest.mark.parametrize("train_size", [62]) @pytest.mark.parametrize("seed", [2]) def test_covariance_matrices_partition_equivalence( test_name: str, - data_partition_size: int, - module_partition_size: int, + data_partitions: int, + module_partitions: int, train_size: int, seed: int, ) -> None: @@ -170,7 +171,7 @@ def test_covariance_matrices_partition_equivalence( task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, @@ -181,10 +182,10 @@ def test_covariance_matrices_partition_equivalence( ) covariance_factors = analyzer.load_covariance_matrices(factors_name=DEFAULT_FACTORS_NAME) - factor_args.covariance_data_partition_size = data_partition_size - factor_args.covariance_module_partition_size = module_partition_size + factor_args.covariance_data_partitions = data_partitions + factor_args.covariance_module_partitions = module_partitions analyzer.fit_covariance_matrices( - factors_name=custom_factors_name(f"{data_partition_size}_{module_partition_size}"), + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), dataset=train_dataset, factor_args=factor_args, per_device_batch_size=7, @@ -192,7 +193,7 @@ def test_covariance_matrices_partition_equivalence( dataloader_kwargs=kwargs, ) partitioned_covariance_factors = analyzer.load_covariance_matrices( - factors_name=custom_factors_name(f"{data_partition_size}_{module_partition_size}"), + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), ) for name in COVARIANCE_FACTOR_NAMES: @@ -238,7 +239,7 @@ def test_covariance_matrices_attention_mask( ) kwargs = DataLoaderKwargs(collate_fn=data_collator) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, @@ -305,7 +306,7 @@ def test_covariance_matrices_automatic_batch_size( task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, @@ -338,12 +339,16 @@ def test_covariance_matrices_automatic_batch_size( @pytest.mark.parametrize("test_name", ["mlp"]) -@pytest.mark.parametrize("data_partition_size", [1, 4]) +@pytest.mark.parametrize("max_examples", [4, 26]) +@pytest.mark.parametrize("data_partitions", [1, 4]) +@pytest.mark.parametrize("module_partitions", [1, 2]) @pytest.mark.parametrize("train_size", [80]) @pytest.mark.parametrize("seed", [5]) def test_covariance_matrices_max_examples( test_name: str, - data_partition_size: int, + max_examples: int, + data_partitions: int, + module_partitions: int, train_size: int, seed: int, ) -> None: @@ -359,10 +364,10 @@ def test_covariance_matrices_max_examples( task=task, ) - MAX_EXAMPLES = 26 - factor_args = test_factor_arguments() - factor_args.covariance_max_examples = MAX_EXAMPLES - factor_args.covariance_data_partition_size = data_partition_size + factor_args = pytest_factor_arguments() + factor_args.covariance_max_examples = max_examples + factor_args.covariance_data_partitions = data_partitions + factor_args.covariance_module_partitions = module_partitions analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, @@ -375,10 +380,10 @@ def test_covariance_matrices_max_examples( covariance_factors = analyzer.load_covariance_matrices(factors_name=DEFAULT_FACTORS_NAME) for num_examples in covariance_factors[NUM_ACTIVATION_COVARIANCE_PROCESSED].values(): - assert num_examples == MAX_EXAMPLES + assert num_examples == max_examples for num_examples in covariance_factors[NUM_GRADIENT_COVARIANCE_PROCESSED].values(): - assert num_examples == MAX_EXAMPLES + assert num_examples == max_examples @pytest.mark.parametrize( @@ -407,7 +412,7 @@ def test_covariance_matrices_amp( task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, @@ -460,7 +465,7 @@ def test_covariance_matrices_gradient_checkpoint( task=task, ) - factor_args = test_factor_arguments() + factor_args = pytest_factor_arguments() analyzer.fit_covariance_matrices( factors_name=DEFAULT_FACTORS_NAME, dataset=train_dataset, diff --git a/tests/factors/test_eigens.py b/tests/factors/test_eigens.py index 0773e12..ee74da2 100644 --- a/tests/factors/test_eigens.py +++ b/tests/factors/test_eigens.py @@ -4,25 +4,13 @@ import torch from kronfluence.arguments import FactorArguments -from kronfluence.utils.common.factor_arguments import test_factor_arguments from kronfluence.utils.constants import ( ACTIVATION_EIGENVECTORS_NAME, EIGENDECOMPOSITION_FACTOR_NAMES, GRADIENT_EIGENVECTORS_NAME, - LAMBDA_FACTOR_NAMES, - LAMBDA_MATRIX_NAME, - NUM_LAMBDA_PROCESSED, ) from kronfluence.utils.dataset import DataLoaderKwargs -from tests.utils import ( - ATOL, - DEFAULT_FACTORS_NAME, - RTOL, - check_tensor_dict_equivalence, - custom_factors_name, - prepare_model_and_analyzer, - prepare_test, -) +from tests.utils import DEFAULT_FACTORS_NAME, prepare_model_and_analyzer, prepare_test @pytest.mark.parametrize( @@ -34,7 +22,7 @@ ], ) @pytest.mark.parametrize("eigendecomposition_dtype", [torch.float32, torch.float64]) -@pytest.mark.parametrize("train_size", [16]) +@pytest.mark.parametrize("train_size", [1, 30]) @pytest.mark.parametrize("seed", [0]) def test_perform_eigendecomposition( test_name: str, @@ -72,453 +60,7 @@ def test_perform_eigendecomposition( eigen_factors = analyzer.load_eigendecomposition(factors_name=DEFAULT_FACTORS_NAME) assert set(eigen_factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) assert len(eigen_factors[ACTIVATION_EIGENVECTORS_NAME]) > 0 + assert len(eigen_factors[GRADIENT_EIGENVECTORS_NAME]) > 0 for module_name in eigen_factors[ACTIVATION_EIGENVECTORS_NAME]: assert eigen_factors[ACTIVATION_EIGENVECTORS_NAME][module_name] is not None assert eigen_factors[GRADIENT_EIGENVECTORS_NAME][module_name] is not None - - -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "repeated_mlp", - "conv", - "bert", - "gpt", - "gpt_checkpoint", - ], -) -@pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("lambda_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("train_size", [16]) -@pytest.mark.parametrize("seed", [1]) -def test_fit_lambda_matrices( - test_name: str, - per_sample_gradient_dtype: torch.dtype, - lambda_dtype: torch.dtype, - train_size: int, - seed: int, -) -> None: - # Makes sure that the Lambda computations are working properly. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = FactorArguments( - lambda_dtype=lambda_dtype, - per_sample_gradient_dtype=per_sample_gradient_dtype, - ) - if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True - - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - per_device_batch_size=train_size // 4, - factor_args=factor_args, - dataloader_kwargs=kwargs, - overwrite_output_dir=True, - ) - - lambda_factors = analyzer.load_lambda_matrices(factors_name=DEFAULT_FACTORS_NAME) - assert set(lambda_factors.keys()) == set(LAMBDA_FACTOR_NAMES) - assert len(lambda_factors[LAMBDA_MATRIX_NAME]) > 0 - for module_name in lambda_factors[LAMBDA_MATRIX_NAME]: - assert lambda_factors[LAMBDA_MATRIX_NAME][module_name].dtype == lambda_dtype - - -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "conv", - "roberta", - ], -) -@pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) -@pytest.mark.parametrize("train_size", [50]) -@pytest.mark.parametrize("seed", [1]) -def test_lambda_matrices_batch_size_equivalence( - test_name: str, - strategy: str, - train_size: int, - seed: int, -) -> None: - # Lambda matrices should be identical regardless of what batch size used. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model = model.to(dtype=torch.float64) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = test_factor_arguments(strategy=strategy) - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - per_device_batch_size=1, - factor_args=factor_args, - dataloader_kwargs=kwargs, - overwrite_output_dir=True, - ) - bs1_lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - analyzer.fit_all_factors( - factors_name=custom_factors_name("bs8"), - dataset=train_dataset, - per_device_batch_size=8, - factor_args=factor_args, - dataloader_kwargs=kwargs, - overwrite_output_dir=True, - ) - bs8_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name("bs8"), - ) - - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence(bs1_lambda_factors[name], bs8_lambda_factors[name], atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize("test_name", ["mlp"]) -@pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) -@pytest.mark.parametrize("data_partition_size", [2, 4]) -@pytest.mark.parametrize("module_partition_size", [2, 3]) -@pytest.mark.parametrize("train_size", [81]) -@pytest.mark.parametrize("seed", [2]) -def test_lambda_matrices_partition_equivalence( - test_name: str, - strategy: str, - data_partition_size: int, - module_partition_size: int, - train_size: int, - seed: int, -) -> None: - # Lambda matrices should be identical regardless of what the partition used. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = test_factor_arguments(strategy=strategy) - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=8, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - factor_args.lambda_data_partition_size = data_partition_size - factor_args.lambda_module_partition_size = module_partition_size - analyzer.fit_all_factors( - factors_name=custom_factors_name(f"{data_partition_size}_{module_partition_size}"), - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=6, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - partitioned_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name(f"{data_partition_size}_{module_partition_size}"), - ) - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - lambda_factors[name], partitioned_lambda_factors[name], atol=ATOL, rtol=RTOL - ) - - -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "conv_bn", - "bert", - ], -) -@pytest.mark.parametrize("train_size", [63]) -@pytest.mark.parametrize("seed", [4]) -def test_lambda_matrices_iterative_aggregate( - test_name: str, - train_size: int, - seed: int, -) -> None: - # Makes sure iterative lambda computation is working properly. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model = model.to(dtype=torch.float64) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = test_factor_arguments() - factor_args.lambda_iterative_aggregate = False - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=8, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - factor_args.lambda_iterative_aggregate = True - analyzer.fit_all_factors( - factors_name=custom_factors_name("iterative"), - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=4, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - iterative_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name("iterative"), - ) - - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence(lambda_factors[name], iterative_lambda_factors[name], atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize( - "test_name", - ["mlp", "gpt"], -) -@pytest.mark.parametrize("data_partition_size", [1, 4]) -@pytest.mark.parametrize("train_size", [82]) -@pytest.mark.parametrize("seed", [3]) -def test_lambda_matrices_max_examples( - test_name: str, - data_partition_size: int, - train_size: int, - seed: int, -) -> None: - # Makes sure the max Lambda data selection is working properly. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - MAX_EXAMPLES = 33 - factor_args = FactorArguments( - use_empirical_fisher=True, lambda_max_examples=MAX_EXAMPLES, lambda_data_partition_size=data_partition_size - ) - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=8, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - for num_examples in lambda_factors[NUM_LAMBDA_PROCESSED].values(): - assert num_examples == MAX_EXAMPLES - - -@pytest.mark.parametrize( - "test_name", - [ - "mlp", - "conv_bn", - ], -) -@pytest.mark.parametrize("train_size", [100]) -@pytest.mark.parametrize("seed", [8]) -def test_lambda_matrices_amp( - test_name: str, - train_size: int, - seed: int, -) -> None: - # Lambda matrices should be similar when AMP is enabled. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = test_factor_arguments() - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - factor_args=factor_args, - per_device_batch_size=8, - overwrite_output_dir=True, - dataloader_kwargs=kwargs, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - factor_args.amp_dtype = torch.float16 - analyzer.fit_all_factors( - factors_name=custom_factors_name("amp"), - dataset=train_dataset, - per_device_batch_size=8, - overwrite_output_dir=True, - factor_args=factor_args, - dataloader_kwargs=kwargs, - ) - amp_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name("amp"), - ) - - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence(lambda_factors[name], amp_lambda_factors[name], atol=1e-01, rtol=1e-02) - - -@pytest.mark.parametrize("train_size", [105]) -@pytest.mark.parametrize("seed", [12]) -def test_lambda_matrices_gradient_checkpoint( - train_size: int, - seed: int, -) -> None: - # Lambda matrices should be the same even when gradient checkpointing is used. - model, train_dataset, _, data_collator, task = prepare_test( - test_name="mlp", - train_size=train_size, - seed=seed, - ) - model = model.to(dtype=torch.float64) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - - factor_args = test_factor_arguments() - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - per_device_batch_size=5, - overwrite_output_dir=True, - factor_args=factor_args, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - model, _, _, _, task = prepare_test( - test_name="mlp_checkpoint", - train_size=train_size, - seed=seed, - ) - model = model.to(dtype=torch.float64) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - analyzer.fit_all_factors( - factors_name=custom_factors_name("cp"), - dataset=train_dataset, - per_device_batch_size=6, - overwrite_output_dir=True, - factor_args=factor_args, - ) - checkpoint_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name("cp"), - ) - - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL - ) - - -@pytest.mark.parametrize( - "test_name", - ["mlp", "conv", "gpt"], -) -@pytest.mark.parametrize("train_size", [105]) -@pytest.mark.parametrize("seed", [12]) -def test_lambda_matrices_shared_parameters( - test_name: str, - train_size: int, - seed: int, -) -> None: - # When there are no shared parameters, results with and without `shared_parameters_exist` should - # produce the same results. - model, train_dataset, _, data_collator, task = prepare_test( - test_name=test_name, - train_size=train_size, - seed=seed, - ) - model = model.to(dtype=torch.float64) - model, analyzer = prepare_model_and_analyzer( - model=model, - task=task, - ) - kwargs = DataLoaderKwargs(collate_fn=data_collator) - - factor_args = test_factor_arguments() - analyzer.fit_all_factors( - factors_name=DEFAULT_FACTORS_NAME, - dataset=train_dataset, - per_device_batch_size=5, - overwrite_output_dir=True, - factor_args=factor_args, - dataloader_kwargs=kwargs, - ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=DEFAULT_FACTORS_NAME, - ) - - factor_args.shared_parameters_exist = True - analyzer.fit_all_factors( - factors_name=custom_factors_name("shared"), - dataset=train_dataset, - per_device_batch_size=6, - overwrite_output_dir=True, - factor_args=factor_args, - dataloader_kwargs=kwargs, - ) - checkpoint_lambda_factors = analyzer.load_lambda_matrices( - factors_name=custom_factors_name("shared"), - ) - - for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL - ) diff --git a/tests/factors/test_lambdas.py b/tests/factors/test_lambdas.py new file mode 100644 index 0000000..58c4c49 --- /dev/null +++ b/tests/factors/test_lambdas.py @@ -0,0 +1,488 @@ +# pylint: skip-file + +import pytest +import torch + +from kronfluence.arguments import FactorArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.constants import ( + LAMBDA_FACTOR_NAMES, + LAMBDA_MATRIX_NAME, + NUM_LAMBDA_PROCESSED, +) +from kronfluence.utils.dataset import DataLoaderKwargs +from tests.utils import ( + ATOL, + DEFAULT_FACTORS_NAME, + RTOL, + check_tensor_dict_equivalence, + custom_factors_name, + prepare_model_and_analyzer, + prepare_test, +) + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "repeated_mlp", + "conv", + "bert", + "gpt", + "gpt_checkpoint", + ], +) +@pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("lambda_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("train_size", [16]) +@pytest.mark.parametrize("seed", [1]) +def test_fit_lambda_matrices( + test_name: str, + per_sample_gradient_dtype: torch.dtype, + lambda_dtype: torch.dtype, + train_size: int, + seed: int, +) -> None: + # Makes sure that the Lambda computations are working properly. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = FactorArguments( + lambda_dtype=lambda_dtype, + per_sample_gradient_dtype=per_sample_gradient_dtype, + ) + if test_name == "repeated_mlp": + factor_args.has_shared_parameters = True + + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + per_device_batch_size=train_size // 4, + factor_args=factor_args, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + + lambda_factors = analyzer.load_lambda_matrices(factors_name=DEFAULT_FACTORS_NAME) + assert set(lambda_factors.keys()) == set(LAMBDA_FACTOR_NAMES) + assert len(lambda_factors[LAMBDA_MATRIX_NAME]) > 0 + for module_name in lambda_factors[LAMBDA_MATRIX_NAME]: + assert lambda_factors[LAMBDA_MATRIX_NAME][module_name].dtype == lambda_dtype + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "conv", + "roberta", + ], +) +@pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) +@pytest.mark.parametrize("train_size", [50]) +@pytest.mark.parametrize("seed", [1]) +def test_lambda_matrices_batch_size_equivalence( + test_name: str, + strategy: str, + train_size: int, + seed: int, +) -> None: + # Lambda matrices should be identical regardless of what batch size used. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = pytest_factor_arguments(strategy=strategy) + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + per_device_batch_size=1, + factor_args=factor_args, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + bs1_lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + analyzer.fit_all_factors( + factors_name=custom_factors_name("bs8"), + dataset=train_dataset, + per_device_batch_size=8, + factor_args=factor_args, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + bs8_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("bs8"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence(bs1_lambda_factors[name], bs8_lambda_factors[name], atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize("test_name", ["mlp"]) +@pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) +@pytest.mark.parametrize("data_partitions", [2, 4]) +@pytest.mark.parametrize("module_partitions", [2, 3]) +@pytest.mark.parametrize("train_size", [81]) +@pytest.mark.parametrize("seed", [2]) +def test_lambda_matrices_partition_equivalence( + test_name: str, + strategy: str, + data_partitions: int, + module_partitions: int, + train_size: int, + seed: int, +) -> None: + # Lambda matrices should be identical regardless of what the partition used. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = pytest_factor_arguments(strategy=strategy) + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + factor_args=factor_args, + per_device_batch_size=8, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + factor_args.lambda_data_partitions = data_partitions + factor_args.lambda_module_partitions = module_partitions + analyzer.fit_all_factors( + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), + dataset=train_dataset, + factor_args=factor_args, + per_device_batch_size=6, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + partitioned_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name(f"{data_partitions}_{module_partitions}"), + ) + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence( + lambda_factors[name], partitioned_lambda_factors[name], atol=ATOL, rtol=RTOL + ) + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "conv_bn", + "bert", + ], +) +@pytest.mark.parametrize("train_size", [63, 121]) +@pytest.mark.parametrize("seed", [4]) +def test_lambda_matrices_iterative_lambda_aggregation( + test_name: str, + train_size: int, + seed: int, +) -> None: + # Makes sure iterative lambda computation is working properly. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = pytest_factor_arguments() + factor_args.use_iterative_lambda_aggregation = False + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + factor_args=factor_args, + per_device_batch_size=8, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + factor_args.use_iterative_lambda_aggregation = True + analyzer.fit_all_factors( + factors_name=custom_factors_name("iterative"), + dataset=train_dataset, + factor_args=factor_args, + per_device_batch_size=8, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + iterative_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("iterative"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence(lambda_factors[name], iterative_lambda_factors[name], atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "test_name", + ["conv_bn", "gpt"], +) +@pytest.mark.parametrize("max_examples", [4, 31]) +@pytest.mark.parametrize("data_partitions", [1, 3]) +@pytest.mark.parametrize("module_partitions", [1, 2]) +@pytest.mark.parametrize("train_size", [82]) +@pytest.mark.parametrize("seed", [3]) +def test_lambda_matrices_max_examples( + test_name: str, + max_examples: int, + data_partitions: int, + module_partitions: int, + train_size: int, + seed: int, +) -> None: + # Makes sure the max Lambda data selection is working properly. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = FactorArguments( + lambda_max_examples=max_examples, + lambda_data_partitions=data_partitions, + lambda_module_partitions=module_partitions, + ) + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + factor_args=factor_args, + per_device_batch_size=8, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + for num_examples in lambda_factors[NUM_LAMBDA_PROCESSED].values(): + assert num_examples == max_examples + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "conv_bn", + ], +) +@pytest.mark.parametrize("module_partitions", [1, 2]) +@pytest.mark.parametrize("train_size", [100]) +@pytest.mark.parametrize("seed", [8]) +def test_lambda_matrices_amp( + test_name: str, + module_partitions: int, + train_size: int, + seed: int, +) -> None: + # Lambda matrices should be similar when AMP is enabled. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = pytest_factor_arguments() + factor_args.lambda_module_partitions = module_partitions + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + factor_args=factor_args, + per_device_batch_size=8, + overwrite_output_dir=True, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + factor_args.amp_dtype = torch.float16 + analyzer.fit_all_factors( + factors_name=custom_factors_name("amp"), + dataset=train_dataset, + per_device_batch_size=8, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + amp_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("amp"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence(lambda_factors[name], amp_lambda_factors[name], atol=1e-01, rtol=1e-02) + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "gpt", + ], +) +@pytest.mark.parametrize("train_size", [105]) +@pytest.mark.parametrize("seed", [12]) +def test_lambda_matrices_gradient_checkpoint( + test_name: str, + train_size: int, + seed: int, +) -> None: + # Lambda matrices should be the same even when gradient checkpointing is used. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + + factor_args = pytest_factor_arguments() + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + per_device_batch_size=5, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + model, _, _, _, task = prepare_test( + test_name=test_name + "_checkpoint", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + analyzer.fit_all_factors( + factors_name=custom_factors_name("cp"), + dataset=train_dataset, + per_device_batch_size=6, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + checkpoint_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("cp"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence( + lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL + ) + + +@pytest.mark.parametrize( + "test_name", + ["mlp", "conv", "gpt"], +) +@pytest.mark.parametrize("train_size", [105]) +@pytest.mark.parametrize("seed", [12]) +def test_lambda_matrices_shared_parameters( + test_name: str, + train_size: int, + seed: int, +) -> None: + # When there are no shared parameters, results with and without `has_shared_parameters` should + # produce the same results. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + + factor_args = pytest_factor_arguments() + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + per_device_batch_size=5, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + factor_args.has_shared_parameters = True + analyzer.fit_all_factors( + factors_name=custom_factors_name("shared"), + dataset=train_dataset, + per_device_batch_size=6, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + checkpoint_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("shared"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence( + lambda_factors[name], checkpoint_lambda_factors[name], atol=ATOL, rtol=RTOL + ) diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index a66879b..9f92c3b 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -13,10 +13,11 @@ "test_name", [ "mlp", + "mlp_checkpoint", "repeated_mlp", - "conv", "conv_bn", "bert", + "roberta", "gpt", ], ) @@ -48,8 +49,8 @@ def test_analyzer( analysis_name=f"pytest_{test_name}", model=model, task=task, - disable_model_save=True, disable_tqdm=True, + disable_model_save=True, cpu=True, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) @@ -89,7 +90,7 @@ def test_analyzer( scores_name="self", factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", train_dataset=train_dataset, - per_device_train_batch_size=8, + per_device_train_batch_size=6, dataloader_kwargs=kwargs, score_args=score_args, overwrite_output_dir=True, @@ -101,23 +102,23 @@ def test_default_factor_arguments() -> None: assert factor_args.strategy == "ekfac" assert factor_args.use_empirical_fisher is False - assert factor_args.distributed_sync_steps == 1000 + assert factor_args.distributed_sync_interval == 1000 assert factor_args.amp_dtype is None - assert factor_args.shared_parameters_exist is False + assert factor_args.has_shared_parameters is False assert factor_args.covariance_max_examples == 100_000 - assert factor_args.covariance_data_partition_size == 1 - assert factor_args.covariance_module_partition_size == 1 + assert factor_args.covariance_data_partitions == 1 + assert factor_args.covariance_module_partitions == 1 assert factor_args.activation_covariance_dtype == torch.float32 assert factor_args.gradient_covariance_dtype == torch.float32 assert factor_args.eigendecomposition_dtype == torch.float64 assert factor_args.lambda_max_examples == 100_000 - assert factor_args.lambda_data_partition_size == 1 - assert factor_args.lambda_module_partition_size == 1 - assert factor_args.lambda_iterative_aggregate is False - assert factor_args.cached_activation_cpu_offload is False + assert factor_args.lambda_data_partitions == 1 + assert factor_args.lambda_module_partitions == 1 + assert factor_args.use_iterative_lambda_aggregation is False + assert factor_args.offload_activations_to_cpu is False assert factor_args.per_sample_gradient_dtype == torch.float32 assert factor_args.lambda_dtype == torch.float32 diff --git a/tests/testable_tasks/language_modeling.py b/tests/testable_tasks/language_modeling.py index bec805a..e3fb016 100644 --- a/tests/testable_tasks/language_modeling.py +++ b/tests/testable_tasks/language_modeling.py @@ -8,11 +8,17 @@ from datasets import load_dataset from torch import nn from torch.utils import data -from torch.utils.checkpoint import checkpoint_sequential -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Conv1D +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + Conv1D, + logging, +) from kronfluence.task import Task +logging.set_verbosity_error() BATCH_TYPE = Dict[str, torch.Tensor] diff --git a/tests/testable_tasks/multiple_choice.py b/tests/testable_tasks/multiple_choice.py index 8d60f50..52f40de 100644 --- a/tests/testable_tasks/multiple_choice.py +++ b/tests/testable_tasks/multiple_choice.py @@ -8,10 +8,11 @@ from datasets import load_dataset from torch import nn from torch.utils import data -from transformers import AutoConfig, AutoModelForMultipleChoice, AutoTokenizer +from transformers import AutoConfig, AutoModelForMultipleChoice, AutoTokenizer, logging from kronfluence.task import Task +logging.set_verbosity_error() BATCH_TYPE = Dict[str, torch.Tensor] @@ -53,11 +54,9 @@ def preprocess_function(examples: Any): ] labels = examples[label_column_name] - # Flatten out. first_sentences = list(chain(*first_sentences)) second_sentences = list(chain(*second_sentences)) - # Tokenize. tokenized_examples = tokenizer( first_sentences, second_sentences, @@ -65,7 +64,6 @@ def preprocess_function(examples: Any): padding=padding, truncation=True, ) - # Un-flatten. tokenized_inputs = {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} tokenized_inputs["labels"] = labels return tokenized_inputs @@ -98,12 +96,12 @@ def compute_train_loss( if not sample: return F.cross_entropy(logits, batch["labels"], reduction="sum") with torch.no_grad(): - probs = torch.nn.functional.softmax(logits, dim=-1) + probs = torch.nn.functional.softmax(logits.detach(), dim=-1) sampled_labels = torch.multinomial( probs, num_samples=1, ).flatten() - return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum") + return F.cross_entropy(logits, sampled_labels, reduction="sum") def compute_measurement( self, diff --git a/tests/testable_tasks/text_classification.py b/tests/testable_tasks/text_classification.py index ea964ba..6e9c005 100644 --- a/tests/testable_tasks/text_classification.py +++ b/tests/testable_tasks/text_classification.py @@ -7,10 +7,16 @@ from datasets import load_dataset from torch import nn from torch.utils import data -from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + logging, +) from kronfluence.task import Task +logging.set_verbosity_error() BATCH_TYPE = Dict[str, torch.Tensor]