From 77c0e3a7ba2e405bb6e280760e1acae33e0572d3 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 9 Jul 2024 03:02:47 -0400 Subject: [PATCH] Use einsum --- kronfluence/module/conv2d.py | 25 +++++++++++++------------ kronfluence/module/linear.py | 35 +++++++++++++++++------------------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 4c6f632..7f15a0f 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -193,17 +193,18 @@ def compute_pairwise_score( input_activation.shape, optimize=DynamicProgramming(search_outer=True, minimize="size"), ) - return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) - - if self.einsum_expression is None: - self.einsum_expression = contract_expression( - "qio,bti,bto->qb", - preconditioned_gradient.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming(search_outer=True, minimize="flops"), - ) - return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous() + return torch.einsum("qio,bti,bto->qb", preconditioned_gradient, output_gradient, input_activation) + + # if self.einsum_expression is None: + # self.einsum_expression = contract_expression( + # "qio,bti,bto->qb", + # preconditioned_gradient.shape, + # output_gradient.shape, + # input_activation.shape, + # optimize=DynamicProgramming(search_outer=True, minimize="flops"), + # ) + # return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) def compute_self_measurement_score( self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor @@ -211,4 +212,4 @@ def compute_self_measurement_score( input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation) + return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation).contiguous() diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 8f6d2c1..171873f 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -95,26 +95,25 @@ def compute_pairwise_score( input_activation.shape, optimize=DynamicProgramming(search_outer=True, minimize="size"), ) - return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) - - if self.einsum_expression is None: - if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: - expr = "qio,bti,bto->qbt" - minimize = "size" - else: - expr = "qio,b...i,b...o->qb" - minimize = "flops" - self.einsum_expression = contract_expression( - expr, - preconditioned_gradient.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming(search_outer=True, minimize=minimize), - ) - return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous() + + if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: + return torch.einsum("qio,bti,bto->qbt", preconditioned_gradient, output_gradient, input_activation) + return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation) + # else: + # expr = "qio,b...i,b...o->qb" + # minimize = "flops" + # self.einsum_expression = contract_expression( + # expr, + # preconditioned_gradient.shape, + # output_gradient.shape, + # input_activation.shape, + # optimize=DynamicProgramming(search_outer=True, minimize=minimize), + # ) + # return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) def compute_self_measurement_score( self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) - return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation) + return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation).contiguous()