Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: missing neg in rsqrt derivative #233

Merged
merged 1 commit into from
Jan 11, 2025
Merged

Conversation

avik-pal
Copy link
Collaborator

No description provided.

@Pangoraw
Copy link
Collaborator

The corresponding test needs to be updated:

// REVERSE: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// REVERSE-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<2xf32>
// REVERSE-NEXT: %cst_0 = arith.constant dense<0.000000e+00> : tensor<2xf32>
// REVERSE-NEXT: %0 = arith.addf %arg1, %cst_0 : tensor<2xf32>
// REVERSE-NEXT: %1 = stablehlo.sqrt %arg0 : tensor<2xf32>
// REVERSE-NEXT: %2 = stablehlo.multiply %arg0, %1 : tensor<2xf32>
// REVERSE-NEXT: %3 = stablehlo.multiply %cst, %2 : tensor<2xf32>
// REVERSE-NEXT: %4 = stablehlo.divide %0, %3 : tensor<2xf32>
// REVERSE-NEXT: %5 = arith.addf %4, %cst_0 : tensor<2xf32>
// REVERSE-NEXT: return %5 : tensor<2xf32>
// REVERSE-NEXT: }

@@ -1107,7 +1107,7 @@ def : HLOInactiveOp<"RngBitGeneratorOp">;
def : HLODerivative<"RsqrtOp", (Op $x),
[
// (Select (FCmpUEQ $x, (ConstantFP<"0"> $x)), (ConstantFP<"0"> $x), (FDiv (DiffeRet), (FMul (ConstantFP<"2"> $x), (Call<(SameFunc), [ReadNone,NoUnwind]> $x))))
(Div (DiffeRet), (Mul (HLOConstantFP<"2"> $x), (Mul $x, (Sqrt $x))))
(Neg (Div (DiffeRet), (Mul (HLOConstantFP<"2"> $x), (Mul $x, (Sqrt $x)))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor but can you put the neg on the mul and not on the div? This will enable better constant propagation (since often times differrt will be 1 or 0 or a phi thereof so then it simplifies nicely)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also actually why not just make the constant 2 into -2?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also only changes one AD mode, presumably you’ll need to change the 2 to be a -2 in the line above as well

@avik-pal avik-pal force-pushed the ap/fix_rsqrt_derivative branch from 547f9b8 to c592ee0 Compare January 11, 2025 18:24
@wsmoses wsmoses merged commit 66cac35 into main Jan 11, 2025
6 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants