Skip to content

Commit

Permalink
Change to torch einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent 9744252 commit c757145
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
summed_gradient = contract("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0)
summed_gradient = torch.einsum("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0)
return summed_gradient

def compute_per_sample_gradient(
Expand All @@ -169,7 +169,7 @@ def compute_per_sample_gradient(
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
per_sample_gradient = contract("bci,bco->bio", output_gradient, input_activation)
per_sample_gradient = torch.einsum("bci,bco->bio", output_gradient, input_activation)
if self.per_sample_gradient_process_fnc is not None:
per_sample_gradient = self.per_sample_gradient_process_fnc(
module_name=self.name, gradient=per_sample_gradient
Expand Down
6 changes: 3 additions & 3 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten

def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(dim=0)
summed_gradient = torch.einsum("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(dim=0)
return summed_gradient

def compute_per_sample_gradient(
self, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
per_sample_gradient = contract("b...i,b...o->bio", output_gradient, input_activation)
per_sample_gradient = torch.einsum("b...i,b...o->bio", output_gradient, input_activation)
if self.per_sample_gradient_process_fnc is not None:
per_sample_gradient = self.per_sample_gradient_process_fnc(
module_name=self.name, gradient=per_sample_gradient
Expand Down Expand Up @@ -103,7 +103,7 @@ def compute_pairwise_score(
minimize = "size"
else:
expr = "qio,b...i,b...o->qb"
minimize = "size"
minimize = "flops"
self.einsum_expression = contract_expression(
expr,
preconditioned_gradient.shape,
Expand Down

0 comments on commit c757145

Please sign in to comment.