From e3a5cc66580dd94372db787d4612d249f9b9ed0c Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 20 Jun 2024 12:46:08 -0400 Subject: [PATCH] Separate out num covariance matrices --- kronfluence/factor/covariance.py | 8 +--- kronfluence/factor/eigen.py | 9 ++-- kronfluence/module/linear.py | 7 +++- kronfluence/module/tracked_module.py | 61 ++++++++++++---------------- kronfluence/module/utils.py | 11 ----- kronfluence/utils/constants.py | 7 ++-- 6 files changed, 43 insertions(+), 60 deletions(-) diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index d9b9490..0ce97d4 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -8,20 +8,17 @@ from torch import autocast, nn from torch.cuda.amp import GradScaler from torch.utils import data -from torch.utils.checkpoint import checkpoint_sequential from tqdm import tqdm from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( load_factors, - remove_attention_mask, - remove_gradient_scale, set_attention_mask, set_gradient_scale, set_mode, synchronize_covariance_matrices, - update_factor_args, finalize_covariance_matrices, + update_factor_args, ) from kronfluence.task import Task from kronfluence.utils.constants import ( @@ -187,7 +184,6 @@ def fit_covariance_matrices_with_loader( pbar.update(1) with torch.no_grad(): - finalize_covariance_matrices(model=model) if factor_args.compile_mode is not None: model = original_model diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index fa44666..f78c1e2 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -34,7 +34,8 @@ GRADIENT_EIGENVALUES_NAME, GRADIENT_EIGENVECTORS_NAME, LAMBDA_FACTOR_NAMES, - NUM_COVARIANCE_PROCESSED, + NUM_ACTIVATION_COVARIANCE_PROCESSED, + NUM_GRADIENT_COVARIANCE_PROCESSED, PARTITION_TYPE, ) from kronfluence.utils.logger import TQDM_BAR_FORMAT @@ -126,14 +127,16 @@ def perform_eigendecomposition( disable=not state.is_main_process, ) as pbar: for module_name in tracked_module_names: - for covariance_name, eigenvectors_name, eigenvalues_name in [ + for covariance_name, num_processed_name, eigenvectors_name, eigenvalues_name in [ ( ACTIVATION_COVARIANCE_MATRIX_NAME, + NUM_ACTIVATION_COVARIANCE_PROCESSED, ACTIVATION_EIGENVECTORS_NAME, ACTIVATION_EIGENVALUES_NAME, ), ( GRADIENT_COVARIANCE_MATRIX_NAME, + NUM_GRADIENT_COVARIANCE_PROCESSED, GRADIENT_EIGENVECTORS_NAME, GRADIENT_EIGENVALUES_NAME, ), @@ -145,7 +148,7 @@ def perform_eigendecomposition( ) # Normalize covariance matrices. covariance_matrix.div_( - covariance_factors[NUM_COVARIANCE_PROCESSED][module_name].to(device=state.device) + covariance_factors[num_processed_name][module_name].to(device=state.device) ) # In cases where covariance matrices are not exactly symmetric due to numerical issues. covariance_matrix = covariance_matrix + covariance_matrix.t() diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index c1d74a2..2dfad40 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -57,7 +57,12 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor The flattened output gradient tensor. The flattened gradient is a 2-dimensional matrix with dimension `gradient_num x gradient_dim`. """ - return rearrange(tensor=output_gradient, pattern="b ... d_out -> (b ...) d_out") + flattened_gradient = rearrange(tensor=output_gradient, pattern="b ... d_out -> (b ...) d_out") + if self._attention_mask is not None and flattened_gradient.size(0) == self._attention_mask.numel(): + count = self._attention_mask.sum() + else: + count = flattened_gradient.size(0) + return flattened_gradient, count def _compute_per_sample_gradient( self, input_activation: torch.Tensor, output_gradient: torch.Tensor diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 911d779..cb3b75d 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -6,7 +6,6 @@ from accelerate.utils.dataclasses import BaseEnum from opt_einsum import contract from torch import nn -from torch.utils.hooks import RemovableHandle from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.factor.config import FactorConfig @@ -20,7 +19,8 @@ GRADIENT_EIGENVECTORS_NAME, LAMBDA_FACTOR_NAMES, LAMBDA_MATRIX_NAME, - NUM_COVARIANCE_PROCESSED, + NUM_ACTIVATION_COVARIANCE_PROCESSED, + NUM_GRADIENT_COVARIANCE_PROCESSED, NUM_LAMBDA_PROCESSED, PAIRWISE_SCORE_MATRIX_NAME, PRECONDITIONED_GRADIENT_NAME, @@ -46,15 +46,6 @@ def do_nothing(_: Any) -> None: """Does not perform any operations.""" pass - -def scalar_tensor() -> torch.Tensor: - return torch.zeros( - 1, - requires_grad=False, - dtype=torch.int64, - ) - - class TrackedModule(nn.Module): """A wrapper class for PyTorch modules to compute preconditioning factors and influence scores.""" @@ -102,8 +93,6 @@ def __init__( # Operations that will be performed before and after a forward pass. self._pre_forward = do_nothing self._post_forward = do_nothing - self._num_forward_passes = scalar_tensor() - self._num_backward_passes = scalar_tensor() if factor_args is None: factor_args = FactorArguments() @@ -178,8 +167,6 @@ def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None: """Sets the module mode of all `TrackedModule` instances within a model.""" self.remove_attention_mask() self.remove_gradient_scale() - self._num_forward_passes = scalar_tensor() - self._num_backward_passes = scalar_tensor() if not keep_factors: self._release_covariance_matrices() @@ -243,7 +230,7 @@ def _get_flattened_activation( The input tensor to the module. Returns: - Tuple[torch.Tensor, torch.Tensor]: + Tuple[torch.Tensor, Union[torch.Tensor, int]]: The flattened activation tensor and the number of stacked activations. The flattened activation is a 2-dimensional matrix with dimension `activation_num x activation_dim`. """ @@ -270,22 +257,21 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) - # Adds the current batch's activation covariance to the stored activation covariance matrix. self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) - if self._storage[NUM_COVARIANCE_PROCESSED] is None: + if self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] is None: device = None if isinstance(count, torch.Tensor): # When using attention masks, `count` can be a tensor. device = count.device - self._storage[NUM_COVARIANCE_PROCESSED] = torch.zeros( + self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros( size=(1,), dtype=torch.int64, device=device, requires_grad=False, ) # Keeps track of total number of elements used to aggregate covariance matrices. - self._storage[NUM_COVARIANCE_PROCESSED].add_(count) - self._num_forward_passes.add_(1) + self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count) - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor: + def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: """Returns the flattened gradient tensor. Args: @@ -294,9 +280,9 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor PyTorch's backward hook. Returns: - torch.Tensor: - The flattened output gradient tensor. The flattened gradient is a 2-dimensional matrix - with dimension `gradient_num x gradient_dim`. + Tuple[torch.Tensor, Union[torch.Tensor, int]]: + The flattened output gradient tensor and the number of stacked gradients. The flattened + gradient is a 2-dimensional matrix with dimension `gradient_num x gradient_dim`. """ raise NotImplementedError("Subclasses must implement the `_get_flattened_gradient` method.") @@ -309,7 +295,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N PyTorch's backward hook. """ output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) - flattened_gradient = self._get_flattened_gradient(output_gradient) + flattened_gradient, count = self._get_flattened_gradient(output_gradient) if self._gradient_scale != 1.0: # Avoiding in-place operation here. flattened_gradient = self._gradient_scale * flattened_gradient @@ -325,7 +311,20 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N ) # Adds the current batch's pseudo-gradient covariance to the stored pseudo-gradient covariance matrix. self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) - self._num_backward_passes.add_(1) + + if self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None: + device = None + if isinstance(count, torch.Tensor): + # When using attention masks, `count` can be a tensor. + device = count.device + self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + device=device, + requires_grad=False, + ) + # Keeps track of total number of elements used to aggregate covariance matrices. + self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count) def _covariance_pre_forward(self, inputs: torch.Tensor) -> None: """Computes and updates activation covariance matrix in the forward pass.""" @@ -351,16 +350,6 @@ def _covariance_matrices_available(self) -> bool: return False return True - def finalize_covariance_matrices(self) -> None: - """Rescales the activation covariance matrix if the number of forward and backward passes do not match. This - could happen when using gradient checkpointing or torch.compile.""" - if self._num_forward_passes == self._num_backward_passes: - return - assert self._num_forward_passes % self._num_backward_passes == 0 - mismatch_ratio = self._num_forward_passes // self._num_backward_passes - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].div_(mismatch_ratio) - self._storage[NUM_COVARIANCE_PROCESSED].div_(mismatch_ratio) - @torch.no_grad() def synchronize_covariance_matrices(self) -> None: """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" diff --git a/kronfluence/module/utils.py b/kronfluence/module/utils.py index 4946195..1cd71b3 100644 --- a/kronfluence/module/utils.py +++ b/kronfluence/module/utils.py @@ -145,17 +145,6 @@ def get_tracked_module_names(model: nn.Module) -> List[str]: return tracked_modules -def finalize_covariance_matrices(model: nn.Module) -> None: - """Finalizes covariance matrices of all `TrackedModule` instances within a model.""" - tracked_module_count = 0 - for module in model.modules(): - if isinstance(module, TrackedModule): - module.finalize_covariance_matrices() - tracked_module_count += 1 - if tracked_module_count == 0: - raise TrackedModuleNotFoundError("Tracked modules not found when trying to finalize covariance matrices.") - - def synchronize_covariance_matrices(model: nn.Module) -> None: """Synchronizes covariance matrices of all `TrackedModule` instances within a model.""" tracked_module_count = 0 diff --git a/kronfluence/utils/constants.py b/kronfluence/utils/constants.py index 74847dc..202d72d 100644 --- a/kronfluence/utils/constants.py +++ b/kronfluence/utils/constants.py @@ -13,16 +13,17 @@ # Pseudo-gradient covariance matrix. GRADIENT_COVARIANCE_MATRIX_NAME = "gradient_covariance" # Number of elements used to aggregate activation and gradient covariance. -NUM_COVARIANCE_PROCESSED = "num_covariance_processed" +NUM_ACTIVATION_COVARIANCE_PROCESSED = "num_activation_covariance_processed" +NUM_GRADIENT_COVARIANCE_PROCESSED = "num_gradient_covariance_processed" # A list of factors to keep track of when computing covariance matrices. COVARIANCE_FACTOR_NAMES = [ ACTIVATION_COVARIANCE_MATRIX_NAME, GRADIENT_COVARIANCE_MATRIX_NAME, - NUM_COVARIANCE_PROCESSED, + NUM_ACTIVATION_COVARIANCE_PROCESSED, + NUM_GRADIENT_COVARIANCE_PROCESSED, ] - # Eigenvectors for the activation covariance matrix. ACTIVATION_EIGENVECTORS_NAME = "activation_eigenvectors" # Eigenvalues for the activation covariance matrix.