Skip to content

Commit

Permalink
Move tracker to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 20, 2024
1 parent fed831d commit ba1404c
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def do_nothing(_: Any) -> None:
pass


def scalar_tensor() -> torch.Tensor:
return torch.zeros(
1,
requires_grad=False,
dtype=torch.int64,
)


class TrackedModule(nn.Module):
"""A wrapper class for PyTorch modules to compute preconditioning factors and influence scores."""

Expand Down Expand Up @@ -94,8 +102,8 @@ def __init__(
# Operations that will be performed before and after a forward pass.
self._pre_forward = do_nothing
self._post_forward = do_nothing
self._num_forward_passes = 0
self._num_backward_passes = 0
self._num_forward_passes = scalar_tensor()
self._num_backward_passes = scalar_tensor()

if factor_args is None:
factor_args = FactorArguments()
Expand Down Expand Up @@ -170,8 +178,8 @@ def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None:
"""Sets the module mode of all `TrackedModule` instances within a model."""
self.remove_attention_mask()
self.remove_gradient_scale()
self._num_forward_passes = 0
self._num_backward_passes = 0
self._num_forward_passes = scalar_tensor()
self._num_backward_passes = scalar_tensor()

if not keep_factors:
self._release_covariance_matrices()
Expand Down Expand Up @@ -275,7 +283,7 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -
)
# Keeps track of total number of elements used to aggregate covariance matrices.
self._storage[NUM_COVARIANCE_PROCESSED].add_(count)
self._num_forward_passes += 1
self._num_forward_passes.add_(1)

def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor:
"""Returns the flattened gradient tensor.
Expand Down Expand Up @@ -317,7 +325,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
)
# Adds the current batch's pseudo-gradient covariance to the stored pseudo-gradient covariance matrix.
self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient)
self._num_backward_passes += 1
self._num_backward_passes.add_(1)

def _covariance_pre_forward(self, inputs: torch.Tensor) -> None:
"""Computes and updates activation covariance matrix in the forward pass."""
Expand Down

0 comments on commit ba1404c

Please sign in to comment.