Skip to content

Commit

Permalink
Separate out num covariance matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 20, 2024
1 parent ba1404c commit e3a5cc6
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 60 deletions.
8 changes: 2 additions & 6 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
Expand All @@ -8,20 +8,17 @@
from torch import autocast, nn
from torch.cuda.amp import GradScaler
from torch.utils import data
from torch.utils.checkpoint import checkpoint_sequential
from tqdm import tqdm

from kronfluence.arguments import FactorArguments
from kronfluence.module.tracked_module import ModuleMode
from kronfluence.module.utils import (
load_factors,
remove_attention_mask,
remove_gradient_scale,
set_attention_mask,
set_gradient_scale,
set_mode,
synchronize_covariance_matrices,
update_factor_args, finalize_covariance_matrices,
update_factor_args,
)
from kronfluence.task import Task
from kronfluence.utils.constants import (
Expand Down Expand Up @@ -187,7 +184,6 @@ def fit_covariance_matrices_with_loader(
pbar.update(1)

with torch.no_grad():
finalize_covariance_matrices(model=model)
if factor_args.compile_mode is not None:
model = original_model

Expand Down
9 changes: 6 additions & 3 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
GRADIENT_EIGENVALUES_NAME,
GRADIENT_EIGENVECTORS_NAME,
LAMBDA_FACTOR_NAMES,
NUM_COVARIANCE_PROCESSED,
NUM_ACTIVATION_COVARIANCE_PROCESSED,
NUM_GRADIENT_COVARIANCE_PROCESSED,
PARTITION_TYPE,
)
from kronfluence.utils.logger import TQDM_BAR_FORMAT
Expand Down Expand Up @@ -126,14 +127,16 @@ def perform_eigendecomposition(
disable=not state.is_main_process,
) as pbar:
for module_name in tracked_module_names:
for covariance_name, eigenvectors_name, eigenvalues_name in [
for covariance_name, num_processed_name, eigenvectors_name, eigenvalues_name in [
(
ACTIVATION_COVARIANCE_MATRIX_NAME,
NUM_ACTIVATION_COVARIANCE_PROCESSED,
ACTIVATION_EIGENVECTORS_NAME,
ACTIVATION_EIGENVALUES_NAME,
),
(
GRADIENT_COVARIANCE_MATRIX_NAME,
NUM_GRADIENT_COVARIANCE_PROCESSED,
GRADIENT_EIGENVECTORS_NAME,
GRADIENT_EIGENVALUES_NAME,
),
Expand All @@ -145,7 +148,7 @@ def perform_eigendecomposition(
)
# Normalize covariance matrices.
covariance_matrix.div_(
covariance_factors[NUM_COVARIANCE_PROCESSED][module_name].to(device=state.device)
covariance_factors[num_processed_name][module_name].to(device=state.device)
)
# In cases where covariance matrices are not exactly symmetric due to numerical issues.
covariance_matrix = covariance_matrix + covariance_matrix.t()
Expand Down
7 changes: 6 additions & 1 deletion kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor
The flattened output gradient tensor. The flattened gradient is a 2-dimensional matrix
with dimension `gradient_num x gradient_dim`.
"""
return rearrange(tensor=output_gradient, pattern="b ... d_out -> (b ...) d_out")
flattened_gradient = rearrange(tensor=output_gradient, pattern="b ... d_out -> (b ...) d_out")
if self._attention_mask is not None and flattened_gradient.size(0) == self._attention_mask.numel():
count = self._attention_mask.sum()
else:
count = flattened_gradient.size(0)
return flattened_gradient, count

def _compute_per_sample_gradient(
self, input_activation: torch.Tensor, output_gradient: torch.Tensor
Expand Down
61 changes: 25 additions & 36 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from accelerate.utils.dataclasses import BaseEnum
from opt_einsum import contract
from torch import nn
from torch.utils.hooks import RemovableHandle

from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.factor.config import FactorConfig
Expand All @@ -20,7 +19,8 @@
GRADIENT_EIGENVECTORS_NAME,
LAMBDA_FACTOR_NAMES,
LAMBDA_MATRIX_NAME,
NUM_COVARIANCE_PROCESSED,
NUM_ACTIVATION_COVARIANCE_PROCESSED,
NUM_GRADIENT_COVARIANCE_PROCESSED,
NUM_LAMBDA_PROCESSED,
PAIRWISE_SCORE_MATRIX_NAME,
PRECONDITIONED_GRADIENT_NAME,
Expand All @@ -46,15 +46,6 @@ def do_nothing(_: Any) -> None:
"""Does not perform any operations."""
pass


def scalar_tensor() -> torch.Tensor:
return torch.zeros(
1,
requires_grad=False,
dtype=torch.int64,
)


class TrackedModule(nn.Module):
"""A wrapper class for PyTorch modules to compute preconditioning factors and influence scores."""

Expand Down Expand Up @@ -102,8 +93,6 @@ def __init__(
# Operations that will be performed before and after a forward pass.
self._pre_forward = do_nothing
self._post_forward = do_nothing
self._num_forward_passes = scalar_tensor()
self._num_backward_passes = scalar_tensor()

if factor_args is None:
factor_args = FactorArguments()
Expand Down Expand Up @@ -178,8 +167,6 @@ def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None:
"""Sets the module mode of all `TrackedModule` instances within a model."""
self.remove_attention_mask()
self.remove_gradient_scale()
self._num_forward_passes = scalar_tensor()
self._num_backward_passes = scalar_tensor()

if not keep_factors:
self._release_covariance_matrices()
Expand Down Expand Up @@ -243,7 +230,7 @@ def _get_flattened_activation(
The input tensor to the module.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
Tuple[torch.Tensor, Union[torch.Tensor, int]]:
The flattened activation tensor and the number of stacked activations. The flattened
activation is a 2-dimensional matrix with dimension `activation_num x activation_dim`.
"""
Expand All @@ -270,22 +257,21 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -
# Adds the current batch's activation covariance to the stored activation covariance matrix.
self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation)

if self._storage[NUM_COVARIANCE_PROCESSED] is None:
if self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] is None:
device = None
if isinstance(count, torch.Tensor):
# When using attention masks, `count` can be a tensor.
device = count.device
self._storage[NUM_COVARIANCE_PROCESSED] = torch.zeros(
self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros(
size=(1,),
dtype=torch.int64,
device=device,
requires_grad=False,
)
# Keeps track of total number of elements used to aggregate covariance matrices.
self._storage[NUM_COVARIANCE_PROCESSED].add_(count)
self._num_forward_passes.add_(1)
self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count)

def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor:
def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]:
"""Returns the flattened gradient tensor.
Args:
Expand All @@ -294,9 +280,9 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor
PyTorch's backward hook.
Returns:
torch.Tensor:
The flattened output gradient tensor. The flattened gradient is a 2-dimensional matrix
with dimension `gradient_num x gradient_dim`.
Tuple[torch.Tensor, Union[torch.Tensor, int]]:
The flattened output gradient tensor and the number of stacked gradients. The flattened
gradient is a 2-dimensional matrix with dimension `gradient_num x gradient_dim`.
"""
raise NotImplementedError("Subclasses must implement the `_get_flattened_gradient` method.")

Expand All @@ -309,7 +295,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
PyTorch's backward hook.
"""
output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype)
flattened_gradient = self._get_flattened_gradient(output_gradient)
flattened_gradient, count = self._get_flattened_gradient(output_gradient)
if self._gradient_scale != 1.0:
# Avoiding in-place operation here.
flattened_gradient = self._gradient_scale * flattened_gradient
Expand All @@ -325,7 +311,20 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
)
# Adds the current batch's pseudo-gradient covariance to the stored pseudo-gradient covariance matrix.
self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient)
self._num_backward_passes.add_(1)

if self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None:
device = None
if isinstance(count, torch.Tensor):
# When using attention masks, `count` can be a tensor.
device = count.device
self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros(
size=(1,),
dtype=torch.int64,
device=device,
requires_grad=False,
)
# Keeps track of total number of elements used to aggregate covariance matrices.
self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count)

def _covariance_pre_forward(self, inputs: torch.Tensor) -> None:
"""Computes and updates activation covariance matrix in the forward pass."""
Expand All @@ -351,16 +350,6 @@ def _covariance_matrices_available(self) -> bool:
return False
return True

def finalize_covariance_matrices(self) -> None:
"""Rescales the activation covariance matrix if the number of forward and backward passes do not match. This
could happen when using gradient checkpointing or torch.compile."""
if self._num_forward_passes == self._num_backward_passes:
return
assert self._num_forward_passes % self._num_backward_passes == 0
mismatch_ratio = self._num_forward_passes // self._num_backward_passes
self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].div_(mismatch_ratio)
self._storage[NUM_COVARIANCE_PROCESSED].div_(mismatch_ratio)

@torch.no_grad()
def synchronize_covariance_matrices(self) -> None:
"""Aggregates covariance matrices across multiple devices or nodes in a distributed setting."""
Expand Down
11 changes: 0 additions & 11 deletions kronfluence/module/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,6 @@ def get_tracked_module_names(model: nn.Module) -> List[str]:
return tracked_modules


def finalize_covariance_matrices(model: nn.Module) -> None:
"""Finalizes covariance matrices of all `TrackedModule` instances within a model."""
tracked_module_count = 0
for module in model.modules():
if isinstance(module, TrackedModule):
module.finalize_covariance_matrices()
tracked_module_count += 1
if tracked_module_count == 0:
raise TrackedModuleNotFoundError("Tracked modules not found when trying to finalize covariance matrices.")


def synchronize_covariance_matrices(model: nn.Module) -> None:
"""Synchronizes covariance matrices of all `TrackedModule` instances within a model."""
tracked_module_count = 0
Expand Down
7 changes: 4 additions & 3 deletions kronfluence/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# Pseudo-gradient covariance matrix.
GRADIENT_COVARIANCE_MATRIX_NAME = "gradient_covariance"
# Number of elements used to aggregate activation and gradient covariance.
NUM_COVARIANCE_PROCESSED = "num_covariance_processed"
NUM_ACTIVATION_COVARIANCE_PROCESSED = "num_activation_covariance_processed"
NUM_GRADIENT_COVARIANCE_PROCESSED = "num_gradient_covariance_processed"

# A list of factors to keep track of when computing covariance matrices.
COVARIANCE_FACTOR_NAMES = [
ACTIVATION_COVARIANCE_MATRIX_NAME,
GRADIENT_COVARIANCE_MATRIX_NAME,
NUM_COVARIANCE_PROCESSED,
NUM_ACTIVATION_COVARIANCE_PROCESSED,
NUM_GRADIENT_COVARIANCE_PROCESSED,
]


# Eigenvectors for the activation covariance matrix.
ACTIVATION_EIGENVECTORS_NAME = "activation_eigenvectors"
# Eigenvalues for the activation covariance matrix.
Expand Down

0 comments on commit e3a5cc6

Please sign in to comment.