Skip to content

Commit

Permalink
Math raising
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 14, 2025
1 parent 9bf7f2f commit 662ab31
Showing 1 changed file with 316 additions and 0 deletions.
316 changes: 316 additions & 0 deletions src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,135 @@
#include "src/enzyme_ad/jax/Passes/PassDetails.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "mlir/Conversion/LLVMCommon/VectorPattern.h"

using namespace mlir;
using namespace mlir::enzyme;

template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
class VectorConvertFromLLVMPattern : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);

auto operands = op->getOperands();
auto llvmNDVectorTy = operands[0].getType();
if (isa<LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType,
LLVM::LLVMScalableVectorType>(llvmNDVectorTy)) {
return failure();
}
Operation *newOp = rewriter.create(
op->getLoc(), rewriter.getStringAttr(TargetOp::getOperationName()),
operands, op->getResultTypes(), attrConvert.getAttrs());

rewriter.replaceOp(op, newOp->getResult(0));
return success();
}
};

arith::IntegerOverflowFlags
convertArithOverflowFlagsFromLLVM(LLVM::IntegerOverflowFlags llvmFlags) {
arith::IntegerOverflowFlags arithFlags{};
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
flags[] = {
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(llvmFlags, llvmFlag))
arithFlags = arithFlags | arithFlag;
}
return arithFlags;
}

template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowFromLLVM {
public:
AttrConvertOverflowFromLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
if (auto arithAttr = dyn_cast_if_present<LLVM::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName))) {
if (arithAttr.getValue() != LLVM::IntegerOverflowFlags::none) {
StringRef targetAttrName = TargetOp::getOverflowFlagsAttrName();
convertedAttr.set(targetAttrName, arith::IntegerOverflowFlagsAttr::get(
srcOp->getContext(),
convertArithOverflowFlagsFromLLVM(
arithAttr.getValue())));
}
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
NamedAttrList convertedAttr;
};

arith::FastMathFlags
convertArithFastMathFlagsFromLLVM(LLVM::FastmathFlags llvmFMF) {
arith::FastMathFlags arithFMF{};
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(llvmFMF, llvmFlag))
arithFMF = arithFMF | arithFlag;
}
return arithFMF;
}

arith::FastMathFlagsAttr
convertArithFastMathAttrFromLLVM(LLVM::FastmathFlagsAttr fmfAttr) {
auto arithFMF = fmfAttr.getValue();
return arith::FastMathFlagsAttr::get(
fmfAttr.getContext(), convertArithFastMathFlagsFromLLVM(arithFMF));
}

// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
template <typename SourceOp, typename TargetOp>
class AttrConvertFastMathFromLLVM {
public:
AttrConvertFastMathFromLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
StringRef arithFMFAttrName = SourceOp::getFastmathAttrName();
// Remove the source fastmath attribute.
auto arithFMFAttr = dyn_cast_if_present<LLVM::FastmathFlagsAttr>(
convertedAttr.erase(arithFMFAttrName));
if (arithFMFAttr &&
arithFMFAttr.getValue() != mlir::LLVM::FastmathFlags::none) {
StringRef targetAttrName = TargetOp::getFastMathAttrName();
convertedAttr.set(targetAttrName,
convertArithFastMathAttrFromLLVM(arithFMFAttr));
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
NamedAttrList convertedAttr;
};

namespace {
template <typename TargetOp>
class CallToOpRaising : public OpRewritePattern<LLVM::CallOp> {
Expand Down Expand Up @@ -58,6 +184,151 @@ static void populateOpPatterns(MLIRContext *context,
patterns.add<CallToOpRaising<TargetOp>>(context, f16Func);
}

namespace {

// From
// https://github.com/llvm/llvm-project/blob/7d8b4eb0ead277f41ff69525ed807f9f6e227f37/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp#L31
// except we invert source and target
template <typename SourceOp, typename TargetOp>
using ConvertFastMath = AttrConvertFastMathFromLLVM<SourceOp, TargetOp>;

template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
using InvVectorConvertFromLLVMPattern =
VectorConvertFromLLVMPattern<TargetOp, SourceOp, AttrConvertPassThrough>;

template <typename SourceOp, typename TargetOp>
using ConvertFMFMathFromLLVMPattern =
VectorConvertFromLLVMPattern<TargetOp, SourceOp, ConvertFastMath>;

using AbsFOpLowering =
ConvertFMFMathFromLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
using CeilOpLowering =
ConvertFMFMathFromLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
using CopySignOpLowering =
ConvertFMFMathFromLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = ConvertFMFMathFromLLVMPattern<math::CosOp, LLVM::CosOp>;
using CtPopFOpLowering =
VectorConvertFromLLVMPattern<LLVM::CtPopOp, math::CtPopOp>;
using Exp2OpLowering =
ConvertFMFMathFromLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = ConvertFMFMathFromLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
ConvertFMFMathFromLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
using FmaOpLowering = ConvertFMFMathFromLLVMPattern<math::FmaOp, LLVM::FMAOp>;
using Log10OpLowering =
ConvertFMFMathFromLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering =
ConvertFMFMathFromLLVMPattern<math::Log2Op, LLVM::Log2Op>;
using LogOpLowering = ConvertFMFMathFromLLVMPattern<math::LogOp, LLVM::LogOp>;
using PowFOpLowering = ConvertFMFMathFromLLVMPattern<math::PowFOp, LLVM::PowOp>;
using FPowIOpLowering =
ConvertFMFMathFromLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
using RoundEvenOpLowering =
ConvertFMFMathFromLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
using RoundOpLowering =
ConvertFMFMathFromLLVMPattern<math::RoundOp, LLVM::RoundOp>;
using SinOpLowering = ConvertFMFMathFromLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering =
ConvertFMFMathFromLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using FTruncOpLowering =
ConvertFMFMathFromLLVMPattern<math::TruncOp, LLVM::FTruncOp>;

using AddFOpLowering =
InvVectorConvertFromLLVMPattern<arith::AddFOp, LLVM::FAddOp,
AttrConvertFastMathFromLLVM>;
using AddIOpLowering =
InvVectorConvertFromLLVMPattern<arith::AddIOp, LLVM::AddOp,
AttrConvertOverflowFromLLVM>;
using AndIOpLowering =
InvVectorConvertFromLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
InvVectorConvertFromLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
InvVectorConvertFromLLVMPattern<arith::DivFOp, LLVM::FDivOp,
AttrConvertFastMathFromLLVM>;
using DivSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
using ExtFOpLowering =
InvVectorConvertFromLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
using ExtSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
using ExtUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
using FPToSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
using FPToUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
using MaximumFOpLowering =
InvVectorConvertFromLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
AttrConvertFastMathFromLLVM>;
using MaxNumFOpLowering =
InvVectorConvertFromLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
AttrConvertFastMathFromLLVM>;
using MaxSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
InvVectorConvertFromLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
AttrConvertFastMathFromLLVM>;
using MinNumFOpLowering =
InvVectorConvertFromLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
AttrConvertFastMathFromLLVM>;
using MinSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
InvVectorConvertFromLLVMPattern<arith::MulFOp, LLVM::FMulOp,
AttrConvertFastMathFromLLVM>;
using MulIOpLowering =
InvVectorConvertFromLLVMPattern<arith::MulIOp, LLVM::MulOp,
AttrConvertOverflowFromLLVM>;
using NegFOpLowering =
InvVectorConvertFromLLVMPattern<arith::NegFOp, LLVM::FNegOp,
AttrConvertFastMathFromLLVM>;
using OrIOpLowering = InvVectorConvertFromLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
InvVectorConvertFromLLVMPattern<arith::RemFOp, LLVM::FRemOp,
AttrConvertFastMathFromLLVM>;
using RemSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
using SelectOpLowering =
InvVectorConvertFromLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
using ShLIOpLowering =
InvVectorConvertFromLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
AttrConvertOverflowFromLLVM>;
using ShRSIOpLowering =
InvVectorConvertFromLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
using ShRUIOpLowering =
InvVectorConvertFromLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
using SIToFPOpLowering =
InvVectorConvertFromLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
InvVectorConvertFromLLVMPattern<arith::SubFOp, LLVM::FSubOp,
AttrConvertFastMathFromLLVM>;
using SubIOpLowering =
InvVectorConvertFromLLVMPattern<arith::SubIOp, LLVM::SubOp,
AttrConvertOverflowFromLLVM>;
// using TruncFOpLowering =
// ConstrainedVectorConvertFromLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
// false>;
// using ConstrainedTruncFOpLowering = ConstrainedVectorConvertFromLLVMPattern<
// arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
// arith::AttrConverterConstrainedFPFromLLVM>;
using TruncIOpLowering =
VectorConvertFromLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using UIToFPOpLowering =
VectorConvertFromLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
using XOrIOpLowering = VectorConvertFromLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
} // namespace

void mlir::enzyme::populateLibDeviceFuncsToOpsPatterns(
MLIRContext *context, RewritePatternSet &patterns) {
// XXX: Keep in sync with
Expand Down Expand Up @@ -130,6 +401,50 @@ void mlir::enzyme::populateLibDeviceFuncsToOpsPatterns(
"__nv_tanh");
}

void populateLLVMToMathPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
auto *converter = context;
// From
// https://github.com/llvm/llvm-project/blob/7d8b4eb0ead277f41ff69525ed807f9f6e227f37/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp#L306
// patterns.add<FTruncOpLowering>(converter);
patterns.add<AbsFOpLowering,
// AbsIOpLowering,
CeilOpLowering, CopySignOpLowering, CosOpLowering,
// CountLeadingZerosOpLowering,
// CountTrailingZerosOpLowering,
// CtPopFOpLowering,
Exp2OpLowering,
// ExpM1OpLowering,
ExpOpLowering, FPowIOpLowering, FloorOpLowering, FmaOpLowering,
Log10OpLowering, Log2OpLowering, LogOpLowering, PowFOpLowering,
RoundEvenOpLowering, RoundOpLowering,
// RsqrtOpLowering,
SinOpLowering, SqrtOpLowering, FTruncOpLowering>(converter);

patterns
.add<AddFOpLowering, AddIOpLowering, AndIOpLowering,
// AddUIExtendedOpLowering,
BitcastOpLowering,
// ConstantOpLowering,
// CmpFOpLowering,
// CmpIOpLowering,
DivFOpLowering, DivSIOpLowering, DivUIOpLowering, ExtFOpLowering,
ExtSIOpLowering, ExtUIOpLowering, FPToSIOpLowering, FPToUIOpLowering,
// IndexCastOpSILowering,
// IndexCastOpUILowering,
MaximumFOpLowering, MaxNumFOpLowering, MaxSIOpLowering,
MaxUIOpLowering, MinimumFOpLowering, MinNumFOpLowering,
MinSIOpLowering, MinUIOpLowering, MulFOpLowering, MulIOpLowering,
// MulSIExtendedOpLowering,
// MulUIExtendedOpLowering,
NegFOpLowering, OrIOpLowering, RemFOpLowering, RemSIOpLowering,
RemUIOpLowering, SelectOpLowering, ShLIOpLowering, ShRSIOpLowering,
ShRUIOpLowering, SIToFPOpLowering, SubFOpLowering, SubIOpLowering,
// TruncFOpLowering,
// ConstrainedTruncFOpLowering,
TruncIOpLowering, UIToFPOpLowering, XOrIOpLowering>(converter);
}

namespace {
class LibDeviceFuncsRaisingPass
: public LibDeviceFuncsRaisingPassBase<LibDeviceFuncsRaisingPass> {
Expand All @@ -138,6 +453,7 @@ class LibDeviceFuncsRaisingPass

void runOnOperation() override {
RewritePatternSet patterns(getOperation()->getContext());
populateLLVMToMathPatterns(getOperation()->getContext(), patterns);
populateLibDeviceFuncsToOpsPatterns(getOperation()->getContext(), patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
emitError(getOperation()->getLoc()) << "failed to raise __nv functions";
Expand Down

0 comments on commit 662ab31

Please sign in to comment.