Skip to content

Commit

Permalink
Use view instead of reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 13, 2024
1 parent 0b0c1a1 commit ee330c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
5 changes: 3 additions & 2 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def precondition_gradient(
damping = 0.1 * torch.mean(lambda_matrix)

rotated_gradient.div_(lambda_matrix + damping)
return num_lambda_processed * torch.einsum(
return (num_lambda_processed *
torch.einsum(
"ij,bjl,lk->bik",
(gradient_eigenvectors, rotated_gradient, activation_eigenvectors.t()),
)
))
16 changes: 8 additions & 8 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,25 +428,25 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None:
self._storage[NUM_LAMBDA_PROCESSED] = torch.zeros(
size=(1,),
dtype=torch.int64,
device=per_sample_gradient.device,
# device=per_sample_gradient.device,
requires_grad=False,
)

if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda:
if self.factor_args.lambda_iterative_aggregate:
# This batch-wise iterative update can be useful when the GPU memory is limited.
rotated_gradient = torch.matmul(
per_sample_gradient = torch.matmul(
per_sample_gradient,
self._storage[ACTIVATION_EIGENVECTORS_NAME],
)
for i in range(batch_size):
sqrt_lambda = torch.matmul(
self._storage[GRADIENT_EIGENVECTORS_NAME].t(),
rotated_gradient[i, :, :],
per_sample_gradient[i, :, :],
)
self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_())
else:
sqrt_lambda = torch.matmul(
per_sample_gradient = torch.matmul(
self._storage[GRADIENT_EIGENVECTORS_NAME].t(),
torch.matmul(per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME])
)
Expand All @@ -458,8 +458,8 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None:
# self._storage[ACTIVATION_EIGENVECTORS_NAME],
# ),
# )
del per_sample_gradient
self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_().sum(dim=0))
# del per_sample_gradient
self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0))
else:
self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0))

Expand Down Expand Up @@ -662,7 +662,7 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None:
output_tensor=stacked_matrix,
input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(),
)
self._storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).reshape(
self._storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).view(
num_processes * size[0], size[1], size[2]
)

Expand All @@ -677,7 +677,7 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None:
output_tensor=stacked_preconditioned_gradient,
input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME].contiguous(),
)
self._storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose(0, 1).reshape(
self._storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose(0, 1).view(
num_processes * size[0], size[1], size[2]
)

Expand Down

0 comments on commit ee330c3

Please sign in to comment.