From 495fde356c97a6095c34d42a503d82de70cfe63f Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Thu, 9 Jan 2025 23:25:54 +0100 Subject: [PATCH] tablegen: Add StaticSelect to select based on static condition (#2206) * tablegen: Add StaticIf to select based on static condition * rename to StaticSelect and implement SelectIfActive and SelectIfComplex with it * define for llvm * put vector mode for LLVM back * basic use analysis * StaticSelect use analysis * fixup --------- Co-authored-by: William S. Moses --- enzyme/Enzyme/InstructionDerivatives.td | 8 +- enzyme/Enzyme/MLIR/Implementations/Common.td | 25 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 446 +++++++++++-------- 3 files changed, 273 insertions(+), 206 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index c778e63aa25..1c38bd9dfca 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -13,6 +13,11 @@ class ConstantFP : Operation { string value = val; } +class StaticSelect : Operation { + string condition = condition_; +} + +def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">; class Attribute { string name = name_; @@ -62,9 +67,6 @@ class Inst : Operation { def TypeOf : Operation { } def VectorSize : Operation { -} -def SelectIfActive : Operation { - } // Define ops to rewrite. diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 29f977c95a9..170f4e26bf1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -86,10 +86,20 @@ def DiffeRet : DiffeRetIndex<[-1]>; def Shadow : Operation { } -class GlobalExpr : Operation{ +class GlobalExpr : Operation { string value = val; } +// Class for a dag operator that generates either a or b +// It can then be used with a two or three arguments. +// The two arguments version is (StaticSelect a, b) +// The three arguments version accepts a name as a first argument +// which is then available in the condition as a `Value` under the +// variable `imVal`. +class StaticSelect : Operation { + string condition = condition_; +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; @@ -99,13 +109,14 @@ class Inst : Operation { - -} - -def SelectIfComplex : Operation { +def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">; -} +def SelectIfComplex : StaticSelect<[{ + auto ty = imVal.getType(); + ty.isa() || + ty.isa() && + ty.cast().getElementType().isa(); +}]>; class ConstantFP : Operation { string value = val; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 3f85a07548e..50efdaeae61 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -436,53 +436,100 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << ".Min"; #endif return false; - } else if (opName == "SelectIfActive" || - Def->isSubClassOf("SelectIfActive")) { - if (resultRoot->getNumArgs() != 3) + } else if (Def->isSubClassOf("StaticSelect")) { + auto numArgs = resultRoot->getNumArgs(); + + if (numArgs != 2 && numArgs != 3) PrintFatalError(pattern->getLoc(), - "only three op SelectIfActive supported"); + "only two/three op StaticSelect supported"); os << "({\n"; - os << curIndent << INDENT << "// Computing SelectIfActive\n"; + os << curIndent << INDENT << "// Computing " << opName << "\n"; if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value imVal = nullptr;\n"; + os << curIndent << INDENT << "mlir::Value imVal = "; else - os << curIndent << INDENT << "llvm::Value *imVal = nullptr;\n"; + os << curIndent << INDENT << "llvm::Value *imVal = "; - os << curIndent << INDENT << "if (!gutils->isConstantValue("; + int index = numArgs == 3; - if (isa(resultRoot->getArg(0)) && resultRoot->getArgName(0)) { - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - assert(!isVec); - // This assumes that activity of inner extractions are the same as - // outer. assert(!ext.size()); - os << ord; - } else - assert("Requires name for arg"); + // First one is a name, set imVal to it + if (numArgs == 3) { + if (isa(resultRoot->getArg(0)) && + resultRoot->getArgName(0)) { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord, isVec, ext] = + nameToOrdinal.lookup(name, pattern, resultRoot); + assert(!isVec); + os << ord << ";\n"; + } else + assert("Requires name for arg"); + } else { + os << "nullptr;\n"; + } + + os << curIndent << INDENT << "bool condition = "; + + auto condition = dyn_cast(Def->getValueInit("condition")); + if (!condition) + PrintFatalError(pattern->getLoc(), + Twine("string 'condition' not defined in ") + + resultTree->getAsString()); + auto conditionStr = condition->getValue(); + + if (conditionStr.contains("imVal") && numArgs == 2) + PrintFatalError(pattern->getLoc(), "need a name as first argument"); - os << ")) {\n"; + bool complexExpr = conditionStr.contains(';'); + if (complexExpr) + os << "({\n"; + os << conditionStr; + if (complexExpr) + os << "\n" << curIndent << INDENT << "})"; + + os << ";\n"; - for (size_t i = 1; i < 3; i++) { + os << curIndent << INDENT << "bool vectorized = false;\n"; + + os << curIndent << INDENT << "if (condition) {\n"; + + bool any_vector = false; + bool all_vector = true; + for (size_t i = index; i < numArgs; ++i) { os << curIndent << INDENT << INDENT << "imVal = "; + bool vector; if (isa(resultRoot->getArg(i)) && resultRoot->getArgName(i)) { auto name = resultRoot->getArgName(i)->getAsUnquotedString(); auto [ord, isVec, ext] = nameToOrdinal.lookup(name, pattern, resultRoot); - vector = isVec; assert(!ext.size()); + vector = isVec; os << ord; - } else - vector = handle(curIndent + INDENT + INDENT, - argPattern + "_sia_" + Twine(i), os, pattern, - resultRoot->getArg(i), builder, nameToOrdinal, lookup, - retidx, origName, newFromOriginal, intrinsic); + } else { + vector = + handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, + pattern, resultRoot->getArg(i), builder, nameToOrdinal, + lookup, retidx, origName, newFromOriginal, intrinsic); + } os << ";\n"; + if (vector) { + any_vector = true; + os << curIndent << INDENT << INDENT << "vectorized = true;\n"; + } else { + all_vector = false; + } + + if (i == numArgs - 1) { + os << curIndent << INDENT << "}\n"; + } else { + os << curIndent << INDENT << "} else {\n"; + } + } - if (!vector && intrinsic != MLIRDerivatives) { + if (any_vector && !all_vector) { + os << curIndent << INDENT << "if (!vectorized) {\n"; + if (intrinsic != MLIRDerivatives) { os << curIndent << INDENT << INDENT << "llvm::Value* vec_imVal = gutils->getWidth() == 1 ? imVal : " "UndefValue::get(gutils->getShadowType(imVal" @@ -496,81 +543,19 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, << ".CreateInsertValue(vec_imVal, imVal, " "std::vector({(unsigned)i}));\n"; os << curIndent << INDENT << INDENT << "imVal = vec_imVal;\n"; + } else { + os << curIndent << INDENT << "if (gutils->width != 1)\n" + << curIndent << INDENT << INDENT + << "imVal = builder.create(imVal.getLoc(), " + "imVal, SmallVector({gutils->width}));\n"; } - if (i == 1) - os << curIndent << INDENT << "} else {\n"; - else - os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "}\n"; } os << curIndent << INDENT << "imVal;\n"; - os << curIndent << "})"; - return true; - } else if (opName == "SelectIfComplex" || - Def->isSubClassOf("SelectIfComplex")) { - if (resultRoot->getNumArgs() != 3) - PrintFatalError(pattern->getLoc(), - "only three op SelectIfComplex supported"); - - os << "({\n"; - os << curIndent << INDENT << "// Computing SelectIfComplex\n"; - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value imVal = "; - else - os << curIndent << INDENT << "llvm::Value *imVal = "; - - if (isa(resultRoot->getArg(0)) && resultRoot->getArgName(0)) { - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - os << ord << ";\n"; - } else { - handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, - resultRoot->getArg(0), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal, intrinsic); - os << ";\n"; - } + os << curIndent << INDENT << "})"; - os << curIndent << INDENT - << "if (isa(imVal.getType()) || " - "(isa(imVal.getType()) && " - "isa(cast(imVal.getType()).getElementType(" - ")))) {\n"; - - os << curIndent << INDENT << INDENT << "imVal = "; - if (isa(resultRoot->getArg(1)) && resultRoot->getArgName(1)) { - auto name = resultRoot->getArgName(1)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - assert(!ext.size()); - os << ord << ";\n"; - } else { - handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, - resultRoot->getArg(1), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal, intrinsic); - os << ";\n"; - } - - os << curIndent << INDENT << "} else {\n"; - - os << curIndent << INDENT << INDENT << "imVal = "; - if (isa(resultRoot->getArg(2)) && resultRoot->getArgName(2)) { - auto name = resultRoot->getArgName(2)->getAsUnquotedString(); - auto [ord, isVec, ext] = - nameToOrdinal.lookup(name, pattern, resultRoot); - assert(!ext.size()); - os << ord << ";\n"; - } else { - handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, - resultRoot->getArg(2), builder, nameToOrdinal, lookup, retidx, - origName, newFromOriginal, intrinsic); - os << ";\n"; - } - - os << curIndent << INDENT << "}\n"; - os << curIndent << INDENT << "imVal;"; - os << curIndent << INDENT << "})\n"; - return true; + return any_vector; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { auto value = dyn_cast(Def->getValueInit("value")); if (!value) @@ -1192,6 +1177,131 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("unknown operation")); } +std::string ReplaceAll(std::string str, const std::string &from, + const std::string &to) { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += + to.length(); // Handles case where 'to' is a substring of 'from' + } + return str; +} + +void handleUse( + const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, + std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, + const DagInit *tree, + StringMap> &varNameToCondition); + +void handleUseArgument( + StringRef name, const Init *arg, bool usesPrimal, bool usesShadow, + const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, + std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, + const DagInit *tree, + StringMap> &varNameToCondition) { + + auto arg2 = dyn_cast(arg); + + if (arg2) { + // Recursive use of shadow is unhandled + assert(!usesShadow); + + std::string foundPrimalUse2 = ""; + std::string foundShadowUse2 = ""; + + bool foundDiffRet2 = false; + // We set precondition to be false (aka "") if we do not need the + // primal, since we are now only recurring to set variables + // correctly. + if (name.size() || usesPrimal) + handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, + name.size() ? foundShadowUse2 : foundShadowUse, + name.size() ? foundDiffRet2 : foundDiffRet, + usesPrimal ? precondition : "", tree, varNameToCondition); + + if (name.size()) { + if (foundPrimalUse2.size() && + !(startsWith(foundPrimalUse, foundPrimalUse2) || + endsWith(foundPrimalUse, foundPrimalUse2))) { + if (foundPrimalUse.size() == 0) + foundPrimalUse = foundPrimalUse2; + else + foundPrimalUse += " || " + foundPrimalUse2; + } + if (foundShadowUse2.size() && + !(startsWith(foundShadowUse, foundShadowUse2) || + endsWith(foundShadowUse, foundShadowUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundShadowUse2; + else + foundShadowUse += " || " + foundShadowUse2; + } + foundDiffRet |= foundDiffRet2; + + varNameToCondition[name] = + std::make_tuple(foundPrimalUse2, foundShadowUse2, foundDiffRet2); + } + } else { + assert(name.size()); + + if (name.size()) { + auto found = varNameToCondition.find(name); + if (found == varNameToCondition.end()) { + llvm::errs() << "tree scope: " << *tree << "\n"; + llvm::errs() << "root scope: " << *root << "\n"; + llvm::errs() << "could not find var name: " << name << "\n"; + } + assert(found != varNameToCondition.end()); + } + + if (precondition.size()) { + auto [foundPrimalUse2, foundShadowUse2, foundDiffRet2] = + varNameToCondition[name]; + if (precondition != "true") { + if (foundPrimalUse2.size()) { + foundPrimalUse2 = + "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; + } + if (foundShadowUse2.size()) { + foundShadowUse2 = + "((" + foundShadowUse2 + ")&&(" + precondition + ")"; + } + } + if (usesPrimal) { + if (foundPrimalUse2.size() && + !(startsWith(foundPrimalUse, foundPrimalUse2) || + endsWith(foundPrimalUse, foundPrimalUse2))) { + if (foundPrimalUse.size() == 0) + foundPrimalUse = foundPrimalUse2; + else + foundPrimalUse += " || " + foundPrimalUse2; + } + if (foundShadowUse2.size() && + !(startsWith(foundShadowUse, foundShadowUse2) || + endsWith(foundShadowUse, foundShadowUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundShadowUse2; + else + foundShadowUse += " || " + foundShadowUse2; + } + foundDiffRet |= foundDiffRet2; + } + if (usesShadow) { + if (foundPrimalUse2.size() && + !(startsWith(foundShadowUse, foundPrimalUse2) || + endsWith(foundShadowUse, foundPrimalUse2))) { + if (foundShadowUse.size() == 0) + foundShadowUse = foundPrimalUse2; + else + foundShadowUse += " || " + foundPrimalUse2; + } + assert(!foundDiffRet2); + assert(foundShadowUse2 == ""); + } + } + } +} void handleUse( const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, @@ -1215,113 +1325,57 @@ void handleUse( bool usesShadow = Def->getValueAsBit("usesShadow"); bool usesCustom = Def->getValueAsBit("usesCustom"); - // We don't handle any custom primal/shadow - (void)usesCustom; - assert(!usesCustom); + if (Def->isSubClassOf("StaticSelect")) { + auto numArgs = resultTree->getNumArgs(); - for (auto argEn : llvm::enumerate(resultTree->getArgs())) { - auto name = resultTree->getArgNameStr(argEn.index()); + assert(numArgs == 2 || numArgs == 3); + auto condition = dyn_cast(Def->getValueInit("condition")); + assert(condition); + std::string conditionStr = condition->getValue().str(); - auto arg2 = dyn_cast(argEn.value()); + assert(!(StringRef(conditionStr).contains("imVal") && numArgs == 2)); - if (arg2) { - // Recursive use of shadow is unhandled - assert(!usesShadow); + // First one is a name, set imVal to it + if (numArgs == 3) { + if (isa(resultTree->getArg(0)) && resultTree->getArgName(0)) { + auto name = resultTree->getArgName(0)->getAsUnquotedString(); + conditionStr = ReplaceAll(conditionStr, "imVal", name); + } else + assert("Requires name for arg"); + } - std::string foundPrimalUse2 = ""; - std::string foundShadowUse2 = ""; + bool complexExpr = StringRef(conditionStr).contains(';'); + if (complexExpr) { + conditionStr = "({ " + conditionStr + " })"; + } - bool foundDiffRet2 = false; - // We set precondition to be false (aka "") if we do not need the - // primal, since we are now only recurring to set variables - // correctly. - if (name.size() || usesPrimal) - handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, - name.size() ? foundShadowUse2 : foundShadowUse, - name.size() ? foundDiffRet2 : foundDiffRet, - usesPrimal ? precondition : "", tree, varNameToCondition); + for (size_t i = numArgs == 3; i < numArgs; ++i) { + std::string conditionStr2 = + (i == numArgs - 1) ? ("!(" + conditionStr + ")") : conditionStr; + std::string precondition2; + if (precondition == "true") + precondition2 = conditionStr2; + else + precondition2 = "((" + precondition + ")&&(" + conditionStr2 + ")"; - if (name.size()) { - if (foundPrimalUse2.size() && - !(startsWith(foundPrimalUse, foundPrimalUse2) || - endsWith(foundPrimalUse, foundPrimalUse2))) { - if (foundPrimalUse.size() == 0) - foundPrimalUse = foundPrimalUse2; - else - foundPrimalUse += " || " + foundPrimalUse2; - } - if (foundShadowUse2.size() && - !(startsWith(foundShadowUse, foundShadowUse2) || - endsWith(foundShadowUse, foundShadowUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundShadowUse2; - else - foundShadowUse += " || " + foundShadowUse2; - } - foundDiffRet |= foundDiffRet2; + auto name = resultTree->getArgNameStr(i); + auto arg = resultTree->getArg(i); + handleUseArgument(name, arg, true, false, root, resultTree, + foundPrimalUse, foundShadowUse, foundDiffRet, + precondition2, tree, varNameToCondition); + } - varNameToCondition[name] = - std::make_tuple(foundPrimalUse2, foundShadowUse2, foundDiffRet2); - } - } else { - assert(name.size()); - - if (name.size()) { - auto found = varNameToCondition.find(name); - if (found == varNameToCondition.end()) { - llvm::errs() << "tree scope: " << *tree << "\n"; - llvm::errs() << "root scope: " << *root << "\n"; - llvm::errs() << "could not find var name: " << name << "\n"; - } - assert(found != varNameToCondition.end()); - } + return; + } - if (precondition.size()) { - auto [foundPrimalUse2, foundShadowUse2, foundDiffRet2] = - varNameToCondition[name]; - if (precondition != "true") { - if (foundPrimalUse2.size()) { - foundPrimalUse2 = - "((" + foundPrimalUse2 + ")&&(" + precondition + ")"; - } - if (foundShadowUse2.size()) { - foundShadowUse2 = - "((" + foundShadowUse2 + ")&&(" + precondition + ")"; - } - } - if (usesPrimal) { - if (foundPrimalUse2.size() && - !(startsWith(foundPrimalUse, foundPrimalUse2) || - endsWith(foundPrimalUse, foundPrimalUse2))) { - if (foundPrimalUse.size() == 0) - foundPrimalUse = foundPrimalUse2; - else - foundPrimalUse += " || " + foundPrimalUse2; - } - if (foundShadowUse2.size() && - !(startsWith(foundShadowUse, foundShadowUse2) || - endsWith(foundShadowUse, foundShadowUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundShadowUse2; - else - foundShadowUse += " || " + foundShadowUse2; - } - foundDiffRet |= foundDiffRet2; - } - if (usesShadow) { - if (foundPrimalUse2.size() && - !(startsWith(foundShadowUse, foundPrimalUse2) || - endsWith(foundShadowUse, foundPrimalUse2))) { - if (foundShadowUse.size() == 0) - foundShadowUse = foundPrimalUse2; - else - foundShadowUse += " || " + foundPrimalUse2; - } - assert(!foundDiffRet2); - assert(foundShadowUse2 == ""); - } - } - } + (void)usesCustom; + assert(!usesCustom); + + for (auto argEn : llvm::enumerate(resultTree->getArgs())) { + auto name = resultTree->getArgNameStr(argEn.index()); + handleUseArgument(name, argEn.value(), usesPrimal, usesShadow, root, + resultTree, foundPrimalUse, foundShadowUse, foundDiffRet, + precondition, tree, varNameToCondition); } }