From e6dd644f493809c0a8827c9ea916dea0a6a5962d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 16 Dec 2024 18:19:00 -0600 Subject: [PATCH] fixup --- src/enzyme_ad/jax/Implementations/HLODerivatives.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index da8ac71ed..9040add94 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -893,7 +893,7 @@ def FftLength : GlobalExprgetResult(0).getType().cast(); - auto lengths = op.getFftLength(); + auto lengths = op.getFftLengthAttr().getValues(); auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue(); double value = N; @@ -919,15 +919,15 @@ def FftMultiplier : GlobalExpr(op.getLoc(), SplatElementsAttr::get( RT, FloatAttr::get(resTy.getElementType(), 0))); auto end_constant = builder.create(op.getLoc(), SplatElementsAttr::get( - RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1))); + RT, FloatAttr::get(resTy.getElementType(), lengths[lengths.size()-1]-1))); auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); Value start[] = { - builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0))) + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(0))) }; Value end[] = { - builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1))) + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(lengths.size()-1))) }; ret_constant = builder.create(op.getLoc(), resTy, ret_constant, zero_constant, start); ret_constant = builder.create(op.getLoc(), resTy, ret_constant, end_constant, end);