Skip to content

Commit

Permalink
Inplace EKFAC preconditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 13, 2024
1 parent 61852da commit 935ec78
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,21 +306,27 @@ def precondition_gradient(
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device)
num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device)

rotated_gradient = torch.einsum(
"ij,bjl,lk->bik",
(
gradient_eigenvectors.t(),
gradient,
activation_eigenvectors,
),
gradient = torch.matmul(
gradient_eigenvectors.t(),
torch.matmul(gradient, activation_eigenvectors)
)

# rotated_gradient = torch.einsum(
# "ij,bjl,lk->bik",
# (
# gradient_eigenvectors.t(),
# gradient,
# activation_eigenvectors,
# ),
# )

if damping is None:
damping = 0.1 * torch.mean(lambda_matrix)

rotated_gradient.div_(lambda_matrix + damping)
return (num_lambda_processed *
torch.einsum(
"ij,bjl,lk->bik",
(gradient_eigenvectors, rotated_gradient, activation_eigenvectors.t()),
))
gradient.div_(lambda_matrix + damping)
gradient = torch.matmul(
gradient_eigenvectors,
torch.matmul(gradient, activation_eigenvectors.t())
)
gradient.mul_(num_lambda_processed)
return gradient

0 comments on commit 935ec78

Please sign in to comment.