Skip to content

Commit

Permalink
Lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 29, 2024
1 parent 8e5eae3 commit d1aecdc
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
2 changes: 0 additions & 2 deletions kronfluence/computer/score_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.computer.computer import Computer
from kronfluence.module.tracked_module import ModuleMode
from kronfluence.module.utils import set_mode
from kronfluence.score.pairwise import (
compute_pairwise_scores_with_loaders,
load_pairwise_scores,
Expand Down
2 changes: 2 additions & 0 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ class TrackedConv2d(TrackedModule, module_type=nn.Conv2d):

@property
def weight(self) -> torch.Tensor:
"""Returns the weight matrix."""
return self.original_module.weight

@property
def bias(self) -> torch.Tensor:
"""Returns the bias."""
return self.original_module.bias

def _get_flattened_activation(
Expand Down
2 changes: 2 additions & 0 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ class TrackedLinear(TrackedModule, module_type=nn.Linear):

@property
def weight(self) -> torch.Tensor:
"""Returns the weight matrix."""
return self.original_module.weight

@property
def bias(self) -> torch.Tensor:
"""Returns the bias."""
return self.original_module.bias

def _get_flattened_activation(
Expand Down
17 changes: 8 additions & 9 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,15 +571,14 @@ def _compute_low_rank_preconditioned_gradient(
torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=self.score_args.score_dtype).contiguous().clone(),
V_k.to(dtype=self.score_args.score_dtype),
]
else:
U, S, V = torch.svd_lowrank(
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype),
q=rank,
)
return [
torch.matmul(U, torch.diag_embed(S)).to(dtype=self.score_args.score_dtype).contiguous().clone(),
V.transpose(1, 2).contiguous().to(dtype=self.score_args.score_dtype),
]
U, S, V = torch.svd_lowrank(
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype),
q=rank,
)
return [
torch.matmul(U, torch.diag_embed(S)).to(dtype=self.score_args.score_dtype).contiguous().clone(),
V.transpose(1, 2).contiguous().to(dtype=self.score_args.score_dtype),
]

def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> None:
"""Computes the preconditioned per-sample-gradient.
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def custom_scores_name(name: str) -> str:
def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]:
model = prepare_model(model=model, task=task)
analyzer = Analyzer(
analysis_name=f"pytest",
analysis_name="pytest",
model=model,
task=task,
disable_model_save=True,
Expand Down

0 comments on commit d1aecdc

Please sign in to comment.