From ee330c37265217ea72676f640d87579eacfebf26 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 13 Mar 2024 02:04:03 -0400 Subject: [PATCH] Use view instead of reshape --- kronfluence/factor/config.py | 5 +++-- kronfluence/module/tracked_module.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index e3d315c..070f2e6 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -319,7 +319,8 @@ def precondition_gradient( damping = 0.1 * torch.mean(lambda_matrix) rotated_gradient.div_(lambda_matrix + damping) - return num_lambda_processed * torch.einsum( + return (num_lambda_processed * + torch.einsum( "ij,bjl,lk->bik", (gradient_eigenvectors, rotated_gradient, activation_eigenvectors.t()), - ) + )) diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 90b9c6e..cddf9a0 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -428,25 +428,25 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: self._storage[NUM_LAMBDA_PROCESSED] = torch.zeros( size=(1,), dtype=torch.int64, - device=per_sample_gradient.device, + # device=per_sample_gradient.device, requires_grad=False, ) if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: if self.factor_args.lambda_iterative_aggregate: # This batch-wise iterative update can be useful when the GPU memory is limited. - rotated_gradient = torch.matmul( + per_sample_gradient = torch.matmul( per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME], ) for i in range(batch_size): sqrt_lambda = torch.matmul( self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - rotated_gradient[i, :, :], + per_sample_gradient[i, :, :], ) self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_()) else: - sqrt_lambda = torch.matmul( + per_sample_gradient = torch.matmul( self._storage[GRADIENT_EIGENVECTORS_NAME].t(), torch.matmul(per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME]) ) @@ -458,8 +458,8 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: # self._storage[ACTIVATION_EIGENVECTORS_NAME], # ), # ) - del per_sample_gradient - self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_().sum(dim=0)) + # del per_sample_gradient + self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) else: self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) @@ -662,7 +662,7 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: output_tensor=stacked_matrix, input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), ) - self._storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).reshape( + self._storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).view( num_processes * size[0], size[1], size[2] ) @@ -677,7 +677,7 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: output_tensor=stacked_preconditioned_gradient, input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose(0, 1).reshape( + self._storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose(0, 1).view( num_processes * size[0], size[1], size[2] )