Skip to content

Commit

Permalink
Use einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent c757145 commit 77c0e3a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
25 changes: 13 additions & 12 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,23 @@ def compute_pairwise_score(
input_activation.shape,
optimize=DynamicProgramming(search_outer=True, minimize="size"),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation)

if self.einsum_expression is None:
self.einsum_expression = contract_expression(
"qio,bti,bto->qb",
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous()
return torch.einsum("qio,bti,bto->qb", preconditioned_gradient, output_gradient, input_activation)

# if self.einsum_expression is None:
# self.einsum_expression = contract_expression(
# "qio,bti,bto->qb",
# preconditioned_gradient.shape,
# output_gradient.shape,
# input_activation.shape,
# optimize=DynamicProgramming(search_outer=True, minimize="flops"),
# )
# return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation)
return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation).contiguous()
35 changes: 17 additions & 18 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,26 +95,25 @@ def compute_pairwise_score(
input_activation.shape,
optimize=DynamicProgramming(search_outer=True, minimize="size"),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation)

if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
expr = "qio,bti,bto->qbt"
minimize = "size"
else:
expr = "qio,b...i,b...o->qb"
minimize = "flops"
self.einsum_expression = contract_expression(
expr,
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(search_outer=True, minimize=minimize),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous()

if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
return torch.einsum("qio,bti,bto->qbt", preconditioned_gradient, output_gradient, input_activation)
return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation)
# else:
# expr = "qio,b...i,b...o->qb"
# minimize = "flops"
# self.einsum_expression = contract_expression(
# expr,
# preconditioned_gradient.shape,
# output_gradient.shape,
# input_activation.shape,
# optimize=DynamicProgramming(search_outer=True, minimize=minimize),
# )
# return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation)
return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation).contiguous()

0 comments on commit 77c0e3a

Please sign in to comment.