Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent a2efa47 commit 1919097
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 25 deletions.
24 changes: 12 additions & 12 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,12 @@ def requires_lambda_matrices_for_precondition(self) -> bool:
return False

def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None:
storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
).contiguous()
storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
).contiguous()
storage[ACTIVATION_EIGENVECTORS_NAME] = (
storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous()
)
storage[GRADIENT_EIGENVECTORS_NAME] = (
storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous()
)
activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device)
gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device)
lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0)
Expand Down Expand Up @@ -316,12 +316,12 @@ def requires_lambda_matrices_for_precondition(self) -> bool:
return True

def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None:
storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
).contiguous()
storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to(
dtype=score_args.precondition_dtype
).contiguous()
storage[ACTIVATION_EIGENVECTORS_NAME] = (
storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous()
)
storage[GRADIENT_EIGENVECTORS_NAME] = (
storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous()
)
storage[ACTIVATION_EIGENVALUES_NAME] = None
storage[GRADIENT_EIGENVALUES_NAME] = None
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device)
Expand Down
7 changes: 2 additions & 5 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def compute_per_sample_gradient(
def compute_pairwise_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
if isinstance(preconditioned_gradient, list):
input_activation = self._flatten_input_activation(input_activation=input_activation)
left_mat, right_mat = preconditioned_gradient
if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
Expand All @@ -97,11 +97,8 @@ def compute_pairwise_score(
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous()
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
input_activation = self._flatten_input_activation(input_activation=input_activation)
return torch.einsum("qio,bti,bto->qbt", preconditioned_gradient, output_gradient, input_activation)
gradient = self.compute_per_sample_gradient(input_activation=input_activation, output_gradient=output_gradient)
# return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation)
return torch.matmul(preconditioned_gradient.view(preconditioned_gradient.size(0), -1), gradient.view(gradient.size(0), -1).T)
return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation)

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
Expand Down
3 changes: 2 additions & 1 deletion tests/gpu_tests/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from kronfluence.utils.common.factor_arguments import pytest_factor_arguments
from kronfluence.utils.common.score_arguments import pytest_score_arguments
from kronfluence.utils.constants import (
ALL_MODULE_NAME,
COVARIANCE_FACTOR_NAMES,
LAMBDA_FACTOR_NAMES, ALL_MODULE_NAME,
LAMBDA_FACTOR_NAMES,
)
from kronfluence.utils.model import apply_fsdp
from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset
Expand Down
3 changes: 2 additions & 1 deletion tests/modules/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import time

import opt_einsum
import pytest
import torch
from accelerate.utils import set_seed
from opt_einsum import DynamicProgramming
import time


def test_query_gradient_svd(
Expand Down
8 changes: 2 additions & 6 deletions tests/scores/test_pairwise_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,7 @@ def test_query_accumulation_steps(

@pytest.mark.parametrize(
"test_name",
[
"mlp", "conv"
],
["mlp", "conv"],
)
@pytest.mark.parametrize("query_size", [50])
@pytest.mark.parametrize("train_size", [32])
Expand Down Expand Up @@ -739,9 +737,7 @@ def test_query_gradient_aggregation(

@pytest.mark.parametrize(
"test_name",
[
"mlp", "conv"
],
["mlp", "conv"],
)
@pytest.mark.parametrize("query_size", [64])
@pytest.mark.parametrize("train_size", [32])
Expand Down

0 comments on commit 1919097

Please sign in to comment.