diff --git a/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp b/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp index 850f529b..a776d7ea 100644 --- a/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp +++ b/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp @@ -119,7 +119,7 @@ void AddMLIRVulkanRunnerPasses(PassManager& passManager) passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); passManager.addPass(accera::transforms::vulkan::createEmitVulkanWrapperPass()); - passManager.addPass(createLowerToCFGPass()); + passManager.addPass(createConvertSCFToCFPass()); passManager.addPass(LLVM::createLegalizeForExportPass()); LowerToLLVMOptions llvmOptions(passManager.getContext()); llvmOptions.useBarePtrCallConv = false; diff --git a/accera/acc-opt/test/commandline.mlir b/accera/acc-opt/test/commandline.mlir index d2b4d2f2..b4076ad7 100644 --- a/accera/acc-opt/test/commandline.mlir +++ b/accera/acc-opt/test/commandline.mlir @@ -8,6 +8,7 @@ // CHECK-NEXT: affine // CHECK-NEXT: arith // CHECK-NEXT: builtin +// CHECK-NEXT: cf // CHECK-NEXT: gpu // CHECK-NEXT: llvm // CHECK-NEXT: math diff --git a/accera/acc-opt/test/vectorization.mlir b/accera/acc-opt/test/vectorization.mlir index 31993ce0..f696aa4a 100644 --- a/accera/acc-opt/test/vectorization.mlir +++ b/accera/acc-opt/test/vectorization.mlir @@ -15,7 +15,7 @@ module @test_accera_vectorization attributes {accv.target_device_features = "-av // mlir::affine::AffineLoadOp non-sequential // mlir::affine::AffineStoreOp sequential // mlir::affine::AffineStoreOp non-sequential - // mlir::SelectOp + // mlir::arith::SelectOp // mlir::arith::ShLIOp // mlir::arith::FPToSIOp // mlir::arith::ExtSIOp diff --git a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp index 637dc45e..41f9bca2 100644 --- a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp +++ b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp @@ -415,7 +415,7 @@ namespace cpp_printer return success(); } - LogicalResult StdDialectCppPrinter::printSelectOp(SelectOp selectOp) + LogicalResult StdDialectCppPrinter::printSelectOp(arith::SelectOp selectOp) { if (selectOp.getNumOperands() != 3) { @@ -728,7 +728,7 @@ namespace cpp_printer if (auto returnOp = dyn_cast(op)) return printReturnOp(returnOp); - if (auto selectOp = dyn_cast(op)) + if (auto selectOp = dyn_cast(op)) return printSelectOp(selectOp); if (auto getGlobal = dyn_cast(op)) diff --git a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h index 3dcccf26..b99c0a94 100644 --- a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h +++ b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h @@ -83,7 +83,7 @@ namespace cpp_printer LogicalResult printReturnOp(ReturnOp returnOp); /// print SelectOp as ternary operator - LogicalResult printSelectOp(SelectOp selectOp); + LogicalResult printSelectOp(arith::SelectOp selectOp); /// print GetGlobalOp as a call to the global variable LogicalResult printGetGlobalOp(memref::GetGlobalOp getGlobalOp); diff --git a/accera/accc/accc.py b/accera/accc/accc.py index a104fee0..08dd2f77 100644 --- a/accera/accc/accc.py +++ b/accera/accc/accc.py @@ -23,6 +23,7 @@ class SystemTarget(Enum): HOST = "host" + AVX2 = "avx2" AVX512 = "avx512" RPI4 = "pi4" RPI3 = "pi3" @@ -116,6 +117,7 @@ def bstr(val): "-O3", "--march=arm", "-mcpu=arm1136jf-s", "--mtriple=armv6-linux-gnueabihf" ], SystemTarget.AVX512.value: ["-O3", "--march=x86-64", "-mcpu=skylake-avx512"], + SystemTarget.AVX2.value: ["-O3", "--march=x86-64", "-mcpu=skylake"], SystemTarget.ARM_CORTEX_M4.value: [ "-Oz", "-mcpu=cortex-m4", "--mtriple=thumbv7em-arm-none-eabi", ], diff --git a/accera/ir/include/argo/ArgoOps.td b/accera/ir/include/argo/ArgoOps.td index c9d45a18..8352bc54 100644 --- a/accera/ir/include/argo/ArgoOps.td +++ b/accera/ir/include/argo/ArgoOps.td @@ -43,6 +43,9 @@ def Argo_YieldOp : Argo_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, argo.yield %f0, %f1 : f32, f32 ``` }]; + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; } def Argo_EntryPointOp : Argo_Op<"entry_point", [IsolatedFromAbove, FunctionOpInterface, @@ -90,6 +93,9 @@ def Argo_EntryPointOp : Argo_Op<"entry_point", [IsolatedFromAbove, FunctionOpInt let skipDefaultBuilders = 1; + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "StringRef":$entryName, "FunctionType":$type, "StringRef":$kernelName, diff --git a/accera/ir/include/argo/ArgoStructuredOps.td b/accera/ir/include/argo/ArgoStructuredOps.td index 9bed523f..eddf220d 100644 --- a/accera/ir/include/argo/ArgoStructuredOps.td +++ b/accera/ir/include/argo/ArgoStructuredOps.td @@ -246,6 +246,9 @@ def OpaqueOp : ArgoStructuredBase_Op<"opaque", let regions = (region AnyRegion:$region); + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder< (ins "ValueRange":$args, "int64_t":$argsIn, "int64_t":$argsOut, diff --git a/accera/ir/src/IRUtil.cpp b/accera/ir/src/IRUtil.cpp index acdeeead..cd9aa47f 100644 --- a/accera/ir/src/IRUtil.cpp +++ b/accera/ir/src/IRUtil.cpp @@ -1000,14 +1000,31 @@ namespace util mlir::Operation* GetDefiningOpOrForLoop(mlir::Value val) { - if (mlir::isForInductionVar(val)) // AffineForOp + if (auto affineForOp = mlir::getForInductionVarOwner(val)) // AffineForOp { - return mlir::getForInductionVarOwner(val); + return affineForOp; } else if (auto scfForOp = mlir::scf::getForInductionVarOwner(val)) // SCFForOp { return scfForOp; } + else if (auto ivArg = val.dyn_cast()) + { + auto block = ivArg.getOwner(); + if (!block) + { + return nullptr; + } + auto parentOp = block->getParentOp(); + + // only handle AffineParallelOp and scf::ParallelOp, other block args such as function args should not return their associated ops + if (mlir::isa(parentOp) || + mlir::isa(parentOp)) + { + return parentOp; + } + return nullptr; + } else // Arbitrary other op { return val.getDefiningOp(); diff --git a/accera/ir/src/argo/ArgoOps.cpp b/accera/ir/src/argo/ArgoOps.cpp index 84c03bb7..94fb5ce0 100644 --- a/accera/ir/src/argo/ArgoOps.cpp +++ b/accera/ir/src/argo/ArgoOps.cpp @@ -128,17 +128,17 @@ static LogicalResult verify(CopyOp op) // YieldOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter& p, argo::YieldOp op) +void YieldOp::print(OpAsmPrinter& p) { - p << op.getOperationName(); - if (op.getNumOperands() > 0) - p << ' ' << op.getOperands(); - p.printOptionalAttrDict(op->getAttrs()); - if (op.getNumOperands() > 0) - p << " : " << op.getOperandTypes(); + p << getOperationName(); + if (getNumOperands() > 0) + p << ' ' << getOperands(); + p.printOptionalAttrDict((*this)->getAttrs()); + if (getNumOperands() > 0) + p << " : " << getOperandTypes(); } -static ParseResult parseYieldOp(OpAsmParser& parser, OperationState& result) +ParseResult YieldOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector opInfo; SmallVector types; @@ -243,33 +243,33 @@ void OpaqueOp::build( bodyBuild(odsBuilder, odsState.location, bodyBlock->getArguments()); } -static void print(OpAsmPrinter& p, OpaqueOp op) +void OpaqueOp::print(OpAsmPrinter& p) { - auto attrNames = op.argoTraitAttrNames(); + auto attrNames = argoTraitAttrNames(); llvm::StringSet<> argoTraitAttrsSet; argoTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; - for (auto attr : op->getAttrs()) + for (auto attr : (*this)->getAttrs()) if (argoTraitAttrsSet.count(attr.getName().strref()) > 0) attrs.push_back(attr); - auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); - p << op.getOperationName() << " " << dictAttr; - p.printOptionalAttrDict(op->getAttrs(), attrNames); - p << " (" << op.getOperands() << ")"; - if (!op.region().empty()) + auto dictAttr = DictionaryAttr::get(getContext(), attrs); + p << getOperationName() << " " << dictAttr; + p.printOptionalAttrDict((*this)->getAttrs(), attrNames); + p << " (" << getOperands() << ")"; + if (!region().empty()) { - p.printRegion(op.region()); + p.printRegion(region()); } - auto inputTypes = op.getOperandTypes(); + auto inputTypes = getOperandTypes(); if (!inputTypes.empty()) { p << " : " << inputTypes; } } -static ParseResult parseOpaqueOp(OpAsmParser& parser, OperationState& result) +ParseResult OpaqueOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector operandsInfo, regionOperandsInfo; DictionaryAttr dictAttr; @@ -340,8 +340,8 @@ void EntryPointOp::build(OpBuilder& builder, OperationState& result, StringRef e /// Parse an Argo entry_point op /// ::= `argo.entry_point` symbol-ref-id `(` argument-list `)` /// (`->` function-result-list)? function-attributes? -static ParseResult parseEntryPointOp(OpAsmParser& parser, - OperationState& result) +ParseResult EntryPointOp::parse(OpAsmParser& parser, + OperationState& result) { SmallVector entryArgs; SmallVector argTypes; @@ -387,17 +387,17 @@ static ParseResult parseEntryPointOp(OpAsmParser& parser, return success(); } -static void printEntryPointOp(OpAsmPrinter& p, EntryPointOp op) +void EntryPointOp::print(OpAsmPrinter& p) { p << EntryPointOp::getOperationName() << ' '; - p.printSymbolName(op.getName()); + p.printSymbolName(getName()); - FunctionType type = op.getType(); - function_interface_impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), - /*isVariadic=*/false, - type.getResults()); + FunctionType type = getType(); + function_interface_impl::printFunctionSignature(p, getOperation(), type.getInputs(), + /*isVariadic=*/false, + type.getResults()); - function_interface_impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), type.getNumResults()); + function_interface_impl::printFunctionAttributes(p, getOperation(), type.getNumInputs(), type.getNumResults()); } static LogicalResult verify(EntryPointOp op) diff --git a/accera/ir/test/nest_dialect_test/IRTestVerification.cpp b/accera/ir/test/nest_dialect_test/IRTestVerification.cpp index 3ea0ea3c..00719dd8 100644 --- a/accera/ir/test/nest_dialect_test/IRTestVerification.cpp +++ b/accera/ir/test/nest_dialect_test/IRTestVerification.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -376,7 +377,7 @@ bool VerifyLowerToLLVM(mlir::OwningOpRef& module, mlir::FuncOp& funcPm.addPass(mlir::arith::createArithmeticExpandOpsPass()); // --arith-expand pm.addPass(mlir::createLowerAffinePass()); // --lower-affine - pm.addPass(mlir::createLowerToCFGPass()); // --convert-scf-to-std + pm.addPass(mlir::createConvertSCFToCFPass()); // --convert-scf-to-cf pm.addPass(mlir::createMemRefToLLVMPass()); // --convert-memref-to-llvm pm.addPass(mlir::createLowerToLLVMPass()); // --convert-std-to-llvm="use-bare-ptr-memref-call-conv" pm.addPass(mlir::createConvertVectorToLLVMPass()); // --convert-vector-to-llvm @@ -437,7 +438,7 @@ bool VerifyTranslateToLLVMIR(mlir::OwningOpRef& module, mlir::Fu funcPm.addPass(mlir::arith::createArithmeticExpandOpsPass()); // --arith-expand pm.addPass(mlir::createLowerAffinePass()); // --lower-affine - pm.addPass(mlir::createLowerToCFGPass()); // --convert-scf-to-std + pm.addPass(mlir::createConvertSCFToCFPass()); // --convert-scf-to-cf pm.addPass(mlir::createMemRefToLLVMPass()); // --convert-memref-to-llvm pm.addPass(mlir::createLowerToLLVMPass()); // --convert-std-to-llvm="use-bare-ptr-memref-call-conv" pm.addPass(mlir::createConvertVectorToLLVMPass()); // --convert-vector-to-llvm diff --git a/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp b/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp index c66ebc3a..8525c209 100644 --- a/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp +++ b/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp @@ -178,10 +178,10 @@ TEST_CASE("Int8Test1") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -189,14 +189,14 @@ TEST_CASE("Int8Test1") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); return truncVal; }; @@ -279,10 +279,10 @@ TEST_CASE("Int8Test1b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -290,14 +290,14 @@ TEST_CASE("Int8Test1b") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); return truncVal; }; @@ -378,10 +378,10 @@ TEST_CASE("Int8Test1c") auto bOdd = builder.create(loc, b, std::vector{ 1 }); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -389,14 +389,14 @@ TEST_CASE("Int8Test1c") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); return truncVal; }; @@ -497,10 +497,10 @@ TEST_CASE("Int8Test2") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -508,14 +508,14 @@ TEST_CASE("Int8Test2") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, truncVecType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, truncVecType, minVal); return truncVal; }; @@ -529,10 +529,10 @@ TEST_CASE("Int8Test2") auto r2Even = builder.create(loc, halfResultVecType, r2, r2, evenMask); auto r2Odd = builder.create(loc, halfResultVecType, r2, r2, oddMask); - auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); - // auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - // auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); + auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); + // auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + // auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); auto finalSum = builder.create(loc, r1Sum, r2Sum); builder.create(loc, finalSum, C); }); @@ -616,10 +616,10 @@ TEST_CASE("Int8Test2b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -627,14 +627,14 @@ TEST_CASE("Int8Test2b") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, truncVecType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, truncVecType, minVal); return truncVal; }; @@ -648,10 +648,10 @@ TEST_CASE("Int8Test2b") auto r2Even = builder.create(loc, halfResultVecType, r2, r2, evenMask); auto r2Odd = builder.create(loc, halfResultVecType, r2, r2, oddMask); - auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); - // auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - // auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); + auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); + // auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + // auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); auto finalSum = builder.create(loc, r1Sum, r2Sum); builder.create(loc, finalSum, C); }); @@ -759,10 +759,10 @@ TEST_CASE("Int8Test3") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -770,14 +770,14 @@ TEST_CASE("Int8Test3") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, i); auto cPlus = builder.create(loc, c, truncVal); @@ -901,29 +901,29 @@ TEST_CASE("Int8Test3b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 16 bits - auto aEvenExt = builder.create(loc, aEven, medVecType); - auto bEvenExt = builder.create(loc, bEven, medVecType); - auto aOddExt = builder.create(loc, aOdd, medVecType); - auto bOddExt = builder.create(loc, bOdd, medVecType); + auto aEvenExt = builder.create(loc, medVecType, aEven); + auto bEvenExt = builder.create(loc, medVecType, bEven); + auto aOddExt = builder.create(loc, medVecType, aOdd); + auto bOddExt = builder.create(loc, medVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); // extend to 32 bits - auto mulEvenExt = builder.create(loc, mulEven, bigVecType); - auto mulOddExt = builder.create(loc, mulOdd, bigVecType); + auto mulEvenExt = builder.create(loc, bigVecType, mulEven); + auto mulOddExt = builder.create(loc, bigVecType, mulOdd); auto sum = builder.create(loc, mulEvenExt, mulOddExt); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, i); auto cPlus = builder.create(loc, c, truncVal); @@ -1090,10 +1090,10 @@ TEST_CASE("Int8Test3c") auto bOdd = builder.create(loc, halfVecType, b, b, oddBMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -1101,16 +1101,16 @@ TEST_CASE("Int8Test3c") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, midVecType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, midVecType, minVal); - auto embiggenVal = builder.create(loc, truncVal, bigVecType); + auto embiggenVal = builder.create(loc, bigVecType, truncVal); auto c = CC.Get(builder, i, j); auto cPlus = builder.create(loc, c, embiggenVal); @@ -1241,10 +1241,10 @@ TEST_CASE("Int8Test4") auto bOdd = BB.Get(builder, iOdd); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, i32Type); - auto bEvenExt = builder.create(loc, bEven, i32Type); - auto aOddExt = builder.create(loc, aOdd, i32Type); - auto bOddExt = builder.create(loc, bOdd, i32Type); + auto aEvenExt = builder.create(loc, i32Type, aEven); + auto bEvenExt = builder.create(loc, i32Type, bEven); + auto aOddExt = builder.create(loc, i32Type, aOdd); + auto bOddExt = builder.create(loc, i32Type, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -1255,10 +1255,10 @@ TEST_CASE("Int8Test4") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, i); auto cPlus = builder.create(loc, c, truncVal); @@ -1391,8 +1391,8 @@ TEST_CASE("Int8Test5") auto b = BB.Get(builder, j, i); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto prod = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1406,13 +1406,13 @@ TEST_CASE("Int8Test5") // Make sum be saturated auto minI16Val = builder.create(loc, -32768, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto maxI16Val = builder.create(loc, 32767, 32); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, j); auto cPlus = builder.create(loc, c, truncVal); CC.Set(builder, cPlus, j); @@ -1555,8 +1555,8 @@ TEST_CASE("Int8Test5b") auto b = BB.Get(builder, j, i); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto mul = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1571,10 +1571,10 @@ TEST_CASE("Int8Test5b") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, j); auto cPlus = builder.create(loc, c, truncVal); @@ -1724,8 +1724,8 @@ TEST_CASE("Int8Test6") auto b = BB.Get(builder, jInner, iInner); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto mul = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1740,10 +1740,10 @@ TEST_CASE("Int8Test6") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, jInner); auto cPlus = builder.create(loc, c, truncVal); @@ -1982,8 +1982,8 @@ TEST_CASE("Int8Test8") auto b = BB.Get(builder, jInner, kInner); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto mul = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1998,19 +1998,19 @@ TEST_CASE("Int8Test8") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); if (cccElemType == cElemType) { - auto truncVal = builder.create(loc, minVal, i16Type); - auto expandVal = builder.create(loc, truncVal, cccElemType); + auto truncVal = builder.create(loc, i16Type, minVal); + auto expandVal = builder.create(loc, cccElemType, truncVal); CCC.Set(builder, expandVal, jInner, kInner1Count); } else { - auto truncVal = builder.create(loc, minVal, cccElemType); + auto truncVal = builder.create(loc, cccElemType, minVal); auto c = CCC.Get(builder, jInner, kInner1Count); auto cPlus = builder.create(loc, c, truncVal); CCC.Set(builder, cPlus, jInner, kInner1Count); @@ -2031,7 +2031,7 @@ TEST_CASE("Int8Test8") auto c = CCC.Get(builder, jInner, kInner); if (cccElemType != ccElemType) { - c = builder.create(loc, c, ccElemType); + c = builder.create(loc, ccElemType, c); } auto c2 = CC.Get(builder, iInner, jInner); @@ -2063,7 +2063,7 @@ TEST_CASE("Int8Test8") } else { - auto ccVal = builder.create(loc, CC.Get(builder, iInner, jInner), i32Type); + auto ccVal = builder.create(loc, i32Type, CC.Get(builder, iInner, jInner)); C.Set(builder, ccVal, i, j); } } @@ -2148,10 +2148,10 @@ TEST_CASE("Int16Test1") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -2235,10 +2235,10 @@ TEST_CASE("Int16Test1b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -2328,10 +2328,10 @@ TEST_CASE("Int16Test2") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); diff --git a/accera/ir/test/nest_dialect_test/NestIRTests.cpp b/accera/ir/test/nest_dialect_test/NestIRTests.cpp index 16d7bfa4..a6a696fb 100644 --- a/accera/ir/test/nest_dialect_test/NestIRTests.cpp +++ b/accera/ir/test/nest_dialect_test/NestIRTests.cpp @@ -202,9 +202,9 @@ TEST_CASE("UnrankedMemRefTest") mlir::Value val6 = Alloca(builder, memrefType6, Value{ builder.create(builder.getUnknownLoc(), 10) }); // auto cast_3_4 = builder.create(loc, val3, val4.getType()); // cast <10xi32> -> <2x5xi32> -- illegal - [[maybe_unused]] auto cast_4_5 = builder.create(loc, val3, val5.getType()); // cast <10xi32> -> - [[maybe_unused]] auto cast_4_6 = builder.create(loc, val4, val6.getType()); // cast <2x5xi32> -> <2x?xi32> - [[maybe_unused]] auto cast_4_7 = builder.create(loc, val4, memrefType7); // cast <2x5xi32> -> <*xi32> + [[maybe_unused]] auto cast_4_5 = builder.create(loc, val5.getType(), val3); // cast <10xi32> -> + [[maybe_unused]] auto cast_4_6 = builder.create(loc, val6.getType(), val4); // cast <2x5xi32> -> <2x?xi32> + [[maybe_unused]] auto cast_4_7 = builder.create(loc, memrefType7, val4); // cast <2x5xi32> -> <*xi32> }); SECTION("Parsing") diff --git a/accera/mlirHelpers/CMakeLists.txt b/accera/mlirHelpers/CMakeLists.txt index c3f844a9..005c105d 100644 --- a/accera/mlirHelpers/CMakeLists.txt +++ b/accera/mlirHelpers/CMakeLists.txt @@ -79,9 +79,11 @@ target_link_libraries( ${conversion_libs} ir MLIRStandardToLLVM - MLIRSCFToStandard + MLIRSCFToControlFlow + MLIRControlFlowToLLVM MLIRAffineToStandard MLIRAffineTransforms + MLIRAffineUtils MLIRExecutionEngine MLIRLinalgToLLVM MLIRLinalgTransforms diff --git a/accera/mlirHelpers/src/ConvertToLLVM.cpp b/accera/mlirHelpers/src/ConvertToLLVM.cpp index 24c29e38..2ba9b218 100644 --- a/accera/mlirHelpers/src/ConvertToLLVM.cpp +++ b/accera/mlirHelpers/src/ConvertToLLVM.cpp @@ -9,9 +9,10 @@ #include #include -#include -#include +#include #include +#include +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include #include #include @@ -49,8 +50,8 @@ namespace ir // affine -> loops funcOpPM.addPass(mlir::createLowerAffinePass()); - // loops -> std - pm.addPass(mlir::createLowerToCFGPass()); + // loops -> cf + pm.addPass(mlir::createConvertSCFToCFPass()); // add custom LLVM passes addLLVMPassesFn(pm); @@ -58,6 +59,9 @@ namespace ir // linalg -> llvm pm.addPass(mlir::createConvertLinalgToLLVMPass()); + // cf -> llvm + pm.addPass(mlir::cf::createConvertControlFlowToLLVMPass()); + // another canonicalization pass pm.addPass(mlir::createCanonicalizerPass()); diff --git a/accera/python/accera/Package.py b/accera/python/accera/Package.py index 461eadf0..eb401b35 100644 --- a/accera/python/accera/Package.py +++ b/accera/python/accera/Package.py @@ -48,9 +48,7 @@ def _(arg: lang.Array): @singledispatch def _resolve_array_shape(source, arr: lang.Array): is_infinite_value = ( - arr.shape[-1].get_value() == inf - if isinstance(arr.shape[-1], DelayedParameter) - else arr.shape[-1] == inf + arr.shape[-1].get_value() == inf if isinstance(arr.shape[-1], DelayedParameter) else arr.shape[-1] == inf ) if is_infinite_value: # TODO: support shape inference for lang.Function, Callable if needed @@ -63,9 +61,7 @@ def _(source, arr: lang.Array): from .lang.IntrospectionUtilities import get_array_access_indices is_infinite_value = ( - arr.shape[-1].get_value() == inf - if isinstance(arr.shape[-1], DelayedParameter) - else arr.shape[-1] == inf + arr.shape[-1].get_value() == inf if isinstance(arr.shape[-1], DelayedParameter) else arr.shape[-1] == inf ) if is_infinite_value: # introspect array access index to determine dimensions of the array @@ -73,9 +69,7 @@ def _(source, arr: lang.Array): # TODO: support multiple logic fns if needed assert len(logic_fns) == 1, "Only one logic function is supported" access_indices = get_array_access_indices(arr, logic_fns[0]) - assert len(access_indices) == len( - arr.shape - ), "Access indices and shape must have the same dimensions" + assert len(access_indices) == len(arr.shape), "Access indices and shape must have the same dimensions" idx = source.get_indices().index(access_indices[-1]) # initialize the array with the new shape @@ -100,9 +94,7 @@ def _emit_module(module_to_emit, target, mode, output_dir, name): working_dir = os.path.join(output_dir, "_tmp") proj = accc.AcceraProject(output_dir=working_dir, library_name=name) - proj.module_file_sets = [ - accc.ModuleFileSet(name=name, common_module_dir=working_dir) - ] + proj.module_file_sets = [accc.ModuleFileSet(name=name, common_module_dir=working_dir)] module_to_emit.Save(proj.module_file_sets[0].generated_mlir_filepath) proj.generate_and_emit( @@ -117,9 +109,7 @@ def _emit_module(module_to_emit, target, mode, output_dir, name): # Complete the HAT file with information we have stored at this layer hat_file = hat.HATFile.Deserialize(header_path) - hat_file.dependencies.link_target = os.path.basename( - proj.module_file_sets[0].object_filepath - ) + hat_file.dependencies.link_target = os.path.basename(proj.module_file_sets[0].object_filepath) hat_file.Serialize(header_path) # copy HAT package files into output directory @@ -128,6 +118,7 @@ def _emit_module(module_to_emit, target, mode, output_dir, name): class SetActiveModule: + def __init__(self, module): self.module = module @@ -150,28 +141,20 @@ class Format(Flag): MLIR = auto() MLIR_VERBOSE = auto() SOURCE = auto() - DEFAULT = auto() # HAT_DYNAMIC on HOST target, HAT_STATIC otherwise - HAT_DYNAMIC = ( - HAT_PACKAGE | DYNAMIC_LIBRARY - ) #: HAT package format, dynamically linked. - HAT_STATIC = ( - HAT_PACKAGE | STATIC_LIBRARY - ) #: HAT package format, statically linked + DEFAULT = auto() # HAT_DYNAMIC on HOST target, HAT_STATIC otherwise + HAT_DYNAMIC = (HAT_PACKAGE | DYNAMIC_LIBRARY) #: HAT package format, dynamically linked. + HAT_STATIC = (HAT_PACKAGE | STATIC_LIBRARY) #: HAT package format, statically linked HAT_SOURCE = HAT_PACKAGE | SOURCE - MLIR_DYNAMIC = ( - HAT_DYNAMIC | MLIR - ) #: MLIR (debugging) package format, dynamically linked. - MLIR_STATIC = ( - HAT_STATIC | MLIR - ) #: MLIR (debugging) package format, statically linked. + MLIR_DYNAMIC = (HAT_DYNAMIC | MLIR) #: MLIR (debugging) package format, dynamically linked. + MLIR_STATIC = (HAT_STATIC | MLIR) #: MLIR (debugging) package format, statically linked. MLIR_SOURCE = HAT_SOURCE | MLIR class Mode(Enum): - RELEASE = "Release" #: Release (maximally optimized). - DEBUG = "Debug" #: Debug mode (automatically tests logical equivalence). + RELEASE = "Release" #: Release (maximally optimized). + DEBUG = "Debug" #: Debug mode (automatically tests logical equivalence). class _Options(Flag): - NONE = auto() # (enable auto unroll | low precision fp ops) + NONE = auto() # (enable auto unroll | low precision fp ops) DISABLE_AUTO_UNROLL = auto() HIGH_PRECISION_FLOATING_POINT_OPS = auto() @@ -185,21 +168,15 @@ def __init__(self): self._description = {} self._dynamic_dependencies = set() - def _create_gpu_utility_module( - self, compiler_options, target, mode, output_dir, name="AcceraGPUUtilities" - ): + def _create_gpu_utility_module(self, compiler_options, target, mode, output_dir, name="AcceraGPUUtilities"): gpu_utility_module = _lang_python._Module(name=name, options=compiler_options) with SetActiveModule(gpu_utility_module): gpu_init_fn = _lang_python._DeclareFunction("AcceraGPUInitialize") gpu_deinit_fn = _lang_python._DeclareFunction("AcceraGPUDeInitialize") - gpu_init_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI( - True - ).addTag("rc_gpu_init") - gpu_deinit_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI( - True - ).addTag("rc_gpu_deinit") + gpu_init_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI(True).addTag("rc_gpu_init") + gpu_deinit_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI(True).addTag("rc_gpu_deinit") # No common initialization / de-initialization at this layer, however lowering passes may add steps def empty_func(args): @@ -211,8 +188,7 @@ def empty_func(args): return _emit_module(gpu_utility_module, target, mode, output_dir, name) def _create_mapping_of_heuristic_parameters_with_possible_values( - self, - source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable] + self, source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable] ): parameter_dict = {} heuristic_parameters = source._get_heuristic_parameters() @@ -226,9 +202,7 @@ def _create_mapping_of_heuristic_parameters_with_possible_values( def add( self, - source: Union[ - "accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable - ], + source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable], args: List[Union["accera.Dimension", "accera.Array"]] = None, base_name: str = "", parameters: Union[dict, List[dict]] = {}, @@ -255,8 +229,11 @@ def add( # Note: this does not prevent TEMP arrays from being passed as an argument to a function, but they cannot be the # api-defining arguments for the function for idx, arg in enumerate(args): - if isinstance(arg, (lang.Array, _lang_python._lang.Scalar, _lang_python._lang.Dimension)) and arg.role == _lang_python.Role.TEMP: - raise ValueError(f"Error in package.add() for function {base_name}: args includes TEMP array at positions {idx}") + if isinstance(arg, (lang.Array, _lang_python._lang.Scalar, + _lang_python._lang.Dimension)) and arg.role == _lang_python.Role.TEMP: + raise ValueError( + f"Error in package.add() for function {base_name}: args includes TEMP array at positions {idx}" + ) heuristic_parameters_dict = {} if isinstance(source, lang.Plan): @@ -267,7 +244,7 @@ def add( product_parameter_grid = [] # Create a list of delayed parameter for each possible value separately using `get_parameters_from_grid` - if heuristic_parameters_dict: + if heuristic_parameters_dict: product_parameter_grid = get_parameters_from_grid(heuristic_parameters_dict) # TODO: Add functions for product parameter grid in next PR instead of adding fns separately for @@ -275,17 +252,15 @@ def add( if parameters and not isinstance(parameters, dict): return [self._add_function(source, args, base_name, p, function_opts, auxiliary) for p in parameters] elif product_parameter_grid and not isinstance(product_parameter_grid, dict): - return [self._add_function(source, args, base_name, p, function_opts, auxiliary) for p in product_parameter_grid] + return [ + self._add_function(source, args, base_name, p, function_opts, auxiliary) for p in product_parameter_grid + ] else: - return self._add_function( - source, args, base_name, parameters, function_opts, auxiliary - ) + return self._add_function(source, args, base_name, parameters, function_opts, auxiliary) def _add_function( self, - source: Union[ - "accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable - ], + source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable], args: List[Union["accera.Dimension", "accera.Array"]] = None, base_name: str = "", parameters: dict = {}, @@ -315,14 +290,14 @@ def _add_function( else: if isinstance(value, tuple) or isinstance(value, list): if all(isinstance(v, LoopIndex) for v in value): - param_value_dict[delayed_param._name] = str( - [x._name for x in value] - ) + param_value_dict[delayed_param._name] = str([x._name for x in value]) else: raise ValueError("Invalid value of parameters") else: param_value_dict[delayed_param._name] = str(value) - auxiliary_metadata["accera"] = {"parameters": param_value_dict} + auxiliary_metadata["accera"] = { + "parameters": param_value_dict + } def validate_target(target: Target): # can't use set because targets are mutable (therefore unhashable) @@ -347,20 +322,14 @@ def get_function_name(target: Target): base_name or token_hex(4), target, auxiliary_metadata["accera"], - ] - + [ - (a.role, a.element_type, a.shape, a.layout) - if isinstance(a, lang.Array) else (a.name, a.type) - if isinstance(a, _lang_python._lang.Dimension) else None - for a in args - ], + ] + [(a.role, a.element_type, a.shape, a.layout) if isinstance(a, lang.Array) else + (a.name, a.type) if isinstance(a, _lang_python._lang.Dimension) else None + for a in args], ) ) ).encode("utf-8") - ) - .digest() - .hex()[:16] - ) # truncate + ).digest().hex()[:16] + ) # truncate # Function names must begin with an _ or alphabetical character return f"{base_name}_{suffix}" if base_name else f"_{suffix}" @@ -381,13 +350,15 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): arg_size_refs = [] for arg in args: if isinstance(arg, lang.Array): - arr_dim_mappings = [args.index(dim) if isinstance(dim, _lang_python._lang.Dimension) else SENTINEL_VALUE for dim in arg.shape] + arr_dim_mappings = [ + args.index(dim) if isinstance(dim, _lang_python._lang.Dimension) else SENTINEL_VALUE + for dim in arg.shape + ] arg_size_refs.append(arr_dim_mappings) else: arg_size_refs.append([SENTINEL_VALUE]) return arg_size_refs - # Resolve any undefined argument shapes based on the source usage pattern for arr in args: if isinstance(arr, lang.Array): @@ -400,12 +371,13 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): if isinstance(source, lang.Plan): self._dynamic_dependencies.update(source._dynamic_dependencies) - source = source._create_function( - args, **function_opts - ) + source = source._create_function(args, **function_opts) # fall-through - arg_names = [arg.name if isinstance(arg, lang.Array) or isinstance(arg, _lang_python._lang.Dimension) else "" for arg in args] + arg_names = [ + arg.name if isinstance(arg, lang.Array) or isinstance(arg, _lang_python._lang.Dimension) else "" + for arg in args + ] arg_sizes = [arg._size_str if isinstance(arg, lang.Array) else "" for arg in args] if isinstance(source, lang.Function): @@ -414,7 +386,7 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): # due to the fall-through, we only need to validate here validate_target(source.target) - native_array_dim_args = [arg._get_native_array() if isinstance(arg, lang.Array) else arg for arg in args ] + native_array_dim_args = [arg._get_native_array() if isinstance(arg, lang.Array) else arg for arg in args] source.name = get_function_name(source.target) source.base_name = base_name @@ -426,7 +398,7 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): source.arg_sizes = arg_sizes source.requested_args = args self._fns[source.name] = source - return source # for composability + return source # for composability elif isinstance(source, Callable): @@ -454,7 +426,7 @@ def wrapper_fn(args): ) self._fns[name] = wrapped_func - return wrapped_func # for composability + return wrapped_func # for composability else: raise ValueError("Invalid type for source") @@ -467,9 +439,7 @@ def _add_functions_to_module(self, module, fail_on_error=False): wrapped_func._emit() except Exception as e: to_pop.append(name) - logging.error( - f"Compiler error when trying to build function {name}" - ) + logging.error(f"Compiler error when trying to build function {name}") logging.error(e) if fail_on_error: raise @@ -492,11 +462,8 @@ def _add_debug_utilities(self, tolerance): # only add if there are actually arguments to debug return add_debugging_functions( self, - { - name: fn_and_args - for name, fn_and_args in fns_to_add.items() - if fn_and_args[1] - }, + {name: fn_and_args + for name, fn_and_args in fns_to_add.items() if fn_and_args[1]}, atol=tolerance, ) @@ -511,10 +478,10 @@ def _generate_target_options(self, platform: Platform, mode: Mode = Mode.RELEASE host_target_device = _lang_python._GetTargetDeviceFromName("host") if platform in [ - Package.Platform.HOST, - Package.Platform.LINUX, - Package.Platform.MACOS, - Package.Platform.WINDOWS, + Package.Platform.HOST, + Package.Platform.LINUX, + Package.Platform.MACOS, + Package.Platform.WINDOWS, ]: target_device = _lang_python._GetTargetDeviceFromName(platform.value) else: @@ -536,15 +503,19 @@ def _generate_target_options(self, platform: Platform, mode: Mode = Mode.RELEASE elif target.architecture == Target.Architecture.X86_64: target_device.architecture = "x86_64" - if "AVX512" in target.extensions: - target_device.device_name = "avx512" - target_device.cpu = "skylake-avx512" - # TODO: make this functionality less hidden - avx512_feat_str = ",".join( - [f"+{feature.lower()}" for feature in target.extensions] - ) + if not target_device.is_macOS(): # macOS does not support these instructions + if "AVX512" in target.extensions: + target_device.device_name = "avx512" + target_device.cpu = "skylake-avx512" + # TODO: make this functionality less hidden + avx512_feat_str = ",".join([f"+{feature.lower()}" for feature in target.extensions]) - target_device.features = avx512_feat_str + target_device.features = avx512_feat_str + + elif "AVX2" in target.extensions: + target_device.device_name = "avx2" + target_device.cpu = "skylake" + target_device.features = "+avx2" elif target.architecture == Target.Architecture.X86: target_device.architecture = "x86" @@ -554,21 +525,14 @@ def _generate_target_options(self, platform: Platform, mode: Mode = Mode.RELEASE compiler_options = _lang_python.CompilerOptions() compiler_options.target_device = target_device compiler_options.debug = mode == Package.Mode.DEBUG - compiler_options.gpu_only = ( - target.category == Target.Category.GPU and target.runtime != Runtime.VULKAN - ) + compiler_options.gpu_only = (target.category == Target.Category.GPU and target.runtime != Runtime.VULKAN) BuildConfig.obj_extension = ".obj" if target_device.is_windows() else ".o" - libs = list( - filter( - None, - [ - get_library_reference(dep, platform) - for dep in self._dynamic_dependencies - ], - ) - ) + libs = list(filter( + None, + [get_library_reference(dep, platform) for dep in self._dynamic_dependencies], + )) return target, target_device, compiler_options, libs def _make_accc_options(self, options: _Options): @@ -627,9 +591,9 @@ def build( format_is_default = bool( format & Package.Format.DEFAULT - ) # store it as a boolean because we're going to turn off the actual flag + ) # store it as a boolean because we're going to turn off the actual flag if format_is_default: - format &= ~Package.Format.DEFAULT # Turn off "DEFAULT" + format &= ~Package.Format.DEFAULT # Turn off "DEFAULT" if target.runtime in [Target.Runtime.CUDA, Target.Runtime.ROCM]: format |= Package.Format.HAT_SOURCE @@ -640,9 +604,7 @@ def build( dynamic_link = bool(format & Package.Format.DYNAMIC_LIBRARY) if cross_compile and dynamic_link: - raise ValueError( - "Package.Format.DYNAMIC_LIBRARY is not supported when cross-compiling" - ) + raise ValueError("Package.Format.DYNAMIC_LIBRARY is not supported when cross-compiling") output_dir = output_dir or os.getcwd() working_dir = os.path.join(output_dir, "_tmp") @@ -662,44 +624,22 @@ def build( # Emit the package module if format & Package.Format.SOURCE: - output_type = ( - accc.ModuleOutputType.CUDA - if compiler_options.gpu_only - else accc.ModuleOutputType.CPP - ) + output_type = (accc.ModuleOutputType.CUDA if compiler_options.gpu_only else accc.ModuleOutputType.CPP) else: output_type = accc.ModuleOutputType.OBJECT # Emit the supporting modules supporting_hats = [] - if ( - not compiler_options.gpu_only - and output_type == accc.ModuleOutputType.OBJECT - ): + if (not compiler_options.gpu_only and output_type == accc.ModuleOutputType.OBJECT): supporting_hats.append( - Package._emit_default_module( - compiler_options, target, mode, output_dir, f"{name}_Globals" - ) + Package._emit_default_module(compiler_options, target, mode, output_dir, f"{name}_Globals") ) - if any( - fn.target.category == Target.Category.GPU - and fn.target.runtime == Target.Runtime.VULKAN - for fn in self._fns.values() - ): - supporting_hats.append( - self._create_gpu_utility_module( - compiler_options, target, mode, output_dir - ) - ) + if any(fn.target.category == Target.Category.GPU and fn.target.runtime == Target.Runtime.VULKAN + for fn in self._fns.values()): + supporting_hats.append(self._create_gpu_utility_module(compiler_options, target, mode, output_dir)) - proj = accc.AcceraProject( - output_dir=working_dir, library_name=name, output_type=output_type - ) - proj.module_file_sets = [ - accc.ModuleFileSet( - name=name, common_module_dir=working_dir, output_type=output_type - ) - ] + proj = accc.AcceraProject(output_dir=working_dir, library_name=name, output_type=output_type) + proj.module_file_sets = [accc.ModuleFileSet(name=name, common_module_dir=working_dir, output_type=output_type)] package_module.Save(proj.module_file_sets[0].generated_mlir_filepath) # Enable dumping of IR passes based on build format @@ -737,40 +677,28 @@ def build( # Complete the HAT file with information we have stored at this layer hat_file: hat.HATFile = hat.HATFile.Deserialize(header_path) - if format & ( - Package.Format.DYNAMIC_LIBRARY | Package.Format.STATIC_LIBRARY - ): - hat_file.dependencies.link_target = os.path.basename( - proj.module_file_sets[0].object_filepath - ) + if format & (Package.Format.DYNAMIC_LIBRARY | Package.Format.STATIC_LIBRARY): + hat_file.dependencies.link_target = os.path.basename(proj.module_file_sets[0].object_filepath) supporting_hats = map(hat.HATFile.Deserialize, supporting_hats) supporting_objs = [] supporting_decls = [] for support in supporting_hats: path = os.path - dependency_path = path.abspath( - path.join(output_dir, support.dependencies.link_target) - ) + dependency_path = path.abspath(path.join(output_dir, support.dependencies.link_target)) # Collect the supporting modules as dependencies - supporting_objs.append( - hat.LibraryReference(target_file=dependency_path) - ) + supporting_objs.append(hat.LibraryReference(target_file=dependency_path)) # Collecting the supporting code decls supporting_decls.append(support.declaration.code) # Merge the function maps - hat_file._function_table.function_map.update( - support._function_table.function_map - ) + hat_file._function_table.function_map.update(support._function_table.function_map) decl_code = hat_file.declaration.code hat_file.dependencies.dynamic = dynamic_dependencies + supporting_objs - hat_file.declaration.code = decl_code._new( - "\n".join(map(str, ["", decl_code] + supporting_decls)) - ) + hat_file.declaration.code = decl_code._new("\n".join(map(str, ["", decl_code] + supporting_decls))) for fn_name in self._fns: fn: lang.Function = self._fns[fn_name] @@ -779,28 +707,19 @@ def build( hat_func = hat_file.function_map.get(fn_name) if hat_func is None: - raise ValueError( - f"Couldn't find header-declared function {fn_name} in emitted HAT file" - ) + raise ValueError(f"Couldn't find header-declared function {fn_name} in emitted HAT file") hat_func.auxiliary = fn.auxiliary - if ( - fn.target.category == Target.Category.GPU - and fn.target.runtime != Target.Runtime.VULKAN - ): + if (fn.target.category == Target.Category.GPU and fn.target.runtime != Target.Runtime.VULKAN): # TODO: Remove this when the header is emitted as part of the compilation gpu_source = proj.module_file_sets[0].translated_source_filepath gpu_device_func = fn_name + "__gpu__" with open(gpu_source) as gpu_source_f: - s = re.search( - gpu_device_func + _R_GPU_LAUNCH, gpu_source_f.read() - ) + s = re.search(gpu_device_func + _R_GPU_LAUNCH, gpu_source_f.read()) if not s: raise RuntimeError("Couldn't parse emitted source code") - launch_parameters = list( - map(int, [s[n] for n in range(1, 7)]) - ) + launch_parameters = list(map(int, [s[n] for n in range(1, 7)])) dynamic_shared_mem_bytes = int(s[7]) gpu_source = os.path.split(gpu_source)[1] @@ -847,11 +766,11 @@ def build( if not cross_compile and (format & Package.Format.STATIC_LIBRARY): lib_hat_path = f"{path_root}_lib{extension}" hat.create_static_package(header_path, lib_hat_path) - + lib_hat_file = hat_file.Deserialize(lib_hat_path) lib_hat_file.dependencies.auxiliary["static"] = lib_hat_file.dependencies.link_target lib_hat_file.Serialize() - + shutil.move(lib_hat_path, header_path) if dynamic_link: @@ -863,7 +782,7 @@ def build( dyn_hat_file.Serialize() shutil.move(dyn_hat_path, header_path) - + # TODO: plumb cross-compilation of static libs return proj.module_file_sets @@ -892,9 +811,7 @@ def add_description( self._description["auxiliary"].update(other) # remove any keys marked None - keys_to_remove = [ - k for k, v in self._description["auxiliary"].items() if v is None - ] + keys_to_remove = [k for k, v in self._description["auxiliary"].items() if v is None] for k in keys_to_remove: del self._description["auxiliary"][k] diff --git a/accera/python/accera/Targets.py b/accera/python/accera/Targets.py index 2df3b721..ce54551e 100644 --- a/accera/python/accera/Targets.py +++ b/accera/python/accera/Targets.py @@ -6,13 +6,15 @@ import copy import cpuinfo import re + from typing import List, Union from dataclasses import dataclass, field, fields from enum import Enum, auto from ._lang_python import ScalarType, _GetKnownDeviceNames from ._lang_python._lang import ( - BLOCK_X, BLOCK_Y, BLOCK_Z, THREAD_X, THREAD_Y, THREAD_Z, WARP_X, WARP_Y, _MemorySpace, MMAShape, _ExecutionRuntime as Runtime + BLOCK_X, BLOCK_Y, BLOCK_Z, THREAD_X, THREAD_Y, THREAD_Z, WARP_X, WARP_Y, _MemorySpace, MMAShape, _ExecutionRuntime + as Runtime ) @@ -768,14 +770,14 @@ def supports( def mma_shape_to_tuple(self, mma_shape: MMAShape): return { - MMAShape.M64xN64xK1_B4 : (64, 64, 1), - MMAShape.M64xN64xK1_B2 : (64, 64, 1), - MMAShape.M32xN32xK2_B1 : (32, 32, 2), - MMAShape.M16xN16xK4_B1 : (16, 16, 4), - MMAShape.M64xN64xK4_B4 : (64, 64, 4), - MMAShape.M64xN64xK4_B2 : (64, 64, 4), - MMAShape.M32xN32xK8_B1 : (32, 32, 8), - MMAShape.M16xN16xK16_B1 : (16, 16, 16), + MMAShape.M64xN64xK1_B4: (64, 64, 1), + MMAShape.M64xN64xK1_B2: (64, 64, 1), + MMAShape.M32xN32xK2_B1: (32, 32, 2), + MMAShape.M16xN16xK4_B1: (16, 16, 4), + MMAShape.M64xN64xK4_B4: (64, 64, 4), + MMAShape.M64xN64xK4_B2: (64, 64, 4), + MMAShape.M32xN32xK8_B1: (32, 32, 8), + MMAShape.M16xN16xK16_B1: (16, 16, 16), MMAShape.M64xN64xK2_B4: (64, 64, 2), MMAShape.M64xN64xK2_B2: (64, 64, 2), MMAShape.M32xN32xK4_B1: (32, 32, 4), @@ -791,6 +793,7 @@ def compute_tensor_splits(self, mma_shape: MMAShape, num_total_passes: int = 1): return tuple(mutable_tensor_splits) +# yapf: disable MI100_TENSORCORE_INFO = TensorCoreInformation([ TensorCoreInformationEntry(shape=MMAShape.M16xN16xK4_B1, inType=ScalarType.float32, outType=ScalarType.float32), # maps to the 16x16x4 mfma instruction TensorCoreInformationEntry(shape=MMAShape.M32xN32xK2_B1, inType=ScalarType.float32, outType=ScalarType.float32), # maps to the 32x32x2 mfma instruction @@ -916,6 +919,14 @@ def __post_init__(self): device_name = \ self.family.lower() if self.family else (self.name.lower() if self.name else self._device_name) + + # FIXUP: determine the device name if it matches with certain known extensions + # revisit using self.family for the device_name? + if "AVX512" in self.extensions: + device_name = "avx512" + elif "AVX2" in self.extensions: + device_name = "avx2" + if device_name in _GetKnownDeviceNames(): self._device_name = device_name @@ -1045,7 +1056,7 @@ def __init__( num_threads: int = None, num_cores: int = None, vector_bytes: int = 0, - vector_registers: int = None, + vector_registers: int = 0, frequency_GHz: float = None, tensor_core_info: TensorCoreInformation = None, turbo_frequency_GHz: float = None, @@ -1060,15 +1071,25 @@ def __init__( known_name = self._try_get_known_name(known_name) if known_name == "HOST": + if not extensions: + # infer extensions from cpu information + cpu_info = cpuinfo.get_cpu_info() + extensions = [ + ext for ext in + ["MMX", "SSE", "SSE2", "SSE3", "SSSE3", "SSE4", "SSE4.1", "SSE4.2", "AVX", "AVX2", "FMA3"] + if "flags" in cpu_info and ext.lower() in cpu_info["flags"] + ] + + if "AVX2" in extensions: + vector_bytes = 32 # There are 32-bytes per full SIMD register + vector_registers = 16 # There are 16 YMM registers super().__init__( category=category or Target.Category.CPU, architecture=Target.Architecture["HOST"], - vector_bytes=32, # There are 32-bytes per full SIMD register - vector_registers=16, # There are 16 YMM registers - extensions=[ - "MMX", "SSE", "SSE2", "SSE3", "SSSE3", "SSE4", "SSE4.1", "SSE4.2", "AVX", "AVX2", "FMA3" - ] + vector_bytes=vector_bytes, + vector_registers=vector_registers, + extensions=extensions ) else: diff --git a/accera/python/accera/test/dsl_tests.py b/accera/python/accera/test/dsl_tests.py index 72eda320..e4ff0126 100644 --- a/accera/python/accera/test/dsl_tests.py +++ b/accera/python/accera/test/dsl_tests.py @@ -27,10 +27,9 @@ from accera import ScalarType, Array, Function, Nest, Target, Package, algorithms, cast, AllocateFlags, Role from accera.test import verifiers -from accera.test.test_utils import expectedFailure, FailedReason +from accera.test.test_utils import expectedFailure, FailedReason, avx2_cpu, get_avx_platform from accera._lang_python._lang import Dimension, EnterProfileRegion, ExitProfileRegion, PrintProfileResults - INTERNAL_FUNCTION_OPTS = { "no_inline_into": True, "public": False @@ -38,6 +37,7 @@ TEST_MODE = Package.Mode.DEBUG if DEV_MODE else Package.Mode.RELEASE TEST_FORMAT = Package.Format.MLIR_DYNAMIC if DEV_MODE else Package.Format.HAT_DYNAMIC +TEST_FORMAT_XCOMPILE = Package.Format.MLIR_STATIC TEST_PACKAGE_DIR = "test_acccgen" # Groups of types commonly used for tests @@ -75,7 +75,13 @@ def _verify_nest(self, nest, args: Tuple[Array], package_name, correctness_check # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir, _quiet=quiet) + package.build( + package_name, + format=TEST_FORMAT, + mode=_get_test_mode(correctness_check_values), + output_dir=output_dir, + _quiet=quiet + ) if correctness_check_values: v.check_correctness( function.name, @@ -373,6 +379,7 @@ def _(): ) def test_dynamic_temp_array(self) -> None: + def make_test_fn(package, A, B, C, N): T = Array(role=Role.TEMP, element_type=A.element_type, shape=A.shape) @@ -538,7 +545,9 @@ def _(): "pre": (A_test, B_test, C_test), "post": (A_test, B_expected, C_expected), } - self._verify_nest(plan, (A, B, C), "test_array_vectorize_cast", correctness_check_values=correctness_check_values) + self._verify_nest( + plan, (A, B, C), "test_array_vectorize_cast", correctness_check_values=correctness_check_values + ) def test_interleaved_vectorize_cast(self) -> None: shape = (64, 32, 8, 2) @@ -643,7 +652,7 @@ def test_reinterpret_cast(self) -> None: def reinterpret_arr_as_int16(array: Array): # Assumes array is f32 - num_elements = reduce(lambda x, y: x*y, array.shape, 1) + num_elements = reduce(lambda x, y: x * y, array.shape, 1) arr_mb = array._get_memory_buffer() self.assertEqual(arr_mb.shape, [num_elements * 4]) self.assertEqual(arr_mb.element_type, ScalarType.uint8) @@ -659,7 +668,7 @@ def reinterpret_arr_as_int16(array: Array): def reinterpret_arr_as_int32(array: Array): # Assumes array is f32 - num_elements = reduce(lambda x, y: x*y, array.shape, 1) + num_elements = reduce(lambda x, y: x * y, array.shape, 1) arr_mb = array._get_memory_buffer() self.assertEqual(arr_mb.shape, [num_elements * 4]) self.assertEqual(arr_mb.element_type, ScalarType.uint8) @@ -675,7 +684,7 @@ def reinterpret_arr_as_int32(array: Array): def simple_reinterpret_arr_as_int32(array: Array): # Assumes array is f32 - num_elements = reduce(lambda x, y: x*y, array.shape, 1) + num_elements = reduce(lambda x, y: x * y, array.shape, 1) arr_as_int32 = array._reinterpret_cast(ScalarType.int32) self.assertEqual(arr_as_int32.shape, array.shape) self.assertEqual(arr_as_int32.element_type, ScalarType.int32) @@ -718,18 +727,14 @@ def main(array): output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir): package.build( - package_name, - format=TEST_FORMAT, - mode=Package.Mode.RELEASE, - output_dir=output_dir, - _quiet=False + package_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, output_dir=output_dir, _quiet=False ) def test_heap_alloc_reinterpret_cast(self) -> None: package = Package() - temp = Array(role=Role.TEMP, element_type=ScalarType.int32, shape=(32,), flags=AllocateFlags.HEAP) - output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(32,)) + temp = Array(role=Role.TEMP, element_type=ScalarType.int32, shape=(32, ), flags=AllocateFlags.HEAP) + output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(32, )) nest = Nest(shape=output.shape) i, = nest.get_indices() @@ -739,18 +744,14 @@ def _(): temp_mb = temp._get_memory_buffer() cast_temp = temp_mb._reinterpret_cast(ScalarType.float32) output[i] = cast_temp[i] - + package.add(nest, args=(output, )) package_name = "test_heap_alloc_reinterpret_cast" output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir): package.build( - package_name, - format=TEST_FORMAT, - mode=Package.Mode.RELEASE, - output_dir=output_dir, - _quiet=False + package_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, output_dir=output_dir, _quiet=False ) def test_reinterpret_cast_partially_dynamic_shape(self) -> None: @@ -769,18 +770,12 @@ def test_reinterpret_cast_partially_dynamic_shape(self) -> None: def _(): float_A = A._reinterpret_cast(ScalarType.float32) B[indices] = float_A[indices] - + package.add(nest, args=(M, N, A, B), base_name=test_name) output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name with verifiers.VerifyPackage(self, test_name, output_dir): - package.build( - test_name, - format=TEST_FORMAT, - mode=Package.Mode.RELEASE, - output_dir=output_dir, - _quiet=False - ) + package.build(test_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, output_dir=output_dir, _quiet=False) def test_subarray(self) -> None: package = Package() @@ -885,7 +880,9 @@ def _verify_helper(self, package, test_name, function_name=None, correctness_che output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name with verifiers.VerifyPackage(self, test_name, output_dir) as v: shutil.rmtree(output_dir, ignore_errors=True) - package.build(test_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + test_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if function_name and correctness_check_values: v.check_correctness( function_name, @@ -1523,11 +1520,13 @@ def test_output_array_range_node2(self) -> None: output_start = Scalar(type=ScalarType.float32, role=Role.TEMP) nest1 = Nest((1, )) + @nest1.iteration_logic def _(): outputDim.set(cast(floor((limit - start) / delta), ScalarType.int64)) nest2 = Nest((1, )) + @nest2.iteration_logic def _(): output_start.set(start) @@ -1549,7 +1548,7 @@ def _(): package = Package() # BUGBUG: dim args ordered first due to issue with Debug mode package.add(nest1, args=(start, limit, delta, outputDim), base_name=f"range_get_size") - ini_start_fn = package.add(nest2, args=(start,), base_name=f"ini_start") + ini_start_fn = package.add(nest2, args=(start, ), base_name=f"ini_start") get_result_fn = package.add(nest3, args=(inputDim, output, delta), base_name=f"get_result") nest4 = Nest((1, )) @@ -1692,7 +1691,16 @@ def _create_nest(self, shape: Tuple[int], type=ScalarType.float32) -> Tuple: return Nest(shape=(M, N, S)), A, B, C - def _build_nest(self, nest, args: Tuple[Array], package_name, correctness_check_values=None, quiet=True) -> None: + def _build_nest( + self, + nest, + args: Tuple[Array], + package_name, + correctness_check_values=None, + quiet=True, + file_check_fn=None, + platform=Package.Platform.HOST + ) -> None: # helper function to build a nest so that we can focus on the logic function # create a HAT package and add the nest to it package = Package() @@ -1701,13 +1709,22 @@ def _build_nest(self, nest, args: Tuple[Array], package_name, correctness_check_ # build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir, _quiet=quiet) + package.build( + package_name, + format=TEST_FORMAT if correctness_check_values else TEST_FORMAT_XCOMPILE, + mode=_get_test_mode(correctness_check_values), + output_dir=output_dir, + platform=platform, + _quiet=quiet + ) if correctness_check_values: v.check_correctness( function.name, before=correctness_check_values["pre"], after=correctness_check_values["post"], ) + if file_check_fn: + file_check_fn(v) def test_signed_types(self) -> None: for t in [ScalarType.int16, ScalarType.int32, ScalarType.int64] + FLOAT_TYPES: @@ -1939,7 +1956,6 @@ def _(): self._build_nest(nest, [A, B], "test_round_intrinsic", correctness_check_values=correctness_check_values) - @expectedFailure(FailedReason.INVALID, "x86 round intrinsic not supported on MacOS", sys.platform == "darwin") def test_round_intrinsic_vectorized(self) -> None: from accera import round as accround @@ -1962,7 +1978,7 @@ def _(): j: 8 }) sched.reorder(i, j, ii, jj) - plan = sched.create_plan() + plan = sched.create_plan(Target("Intel 6700")) plan.vectorize(ii) A_test = np.random.uniform(low=-1000.0, high=1000.0, size=A.shape).astype(np.float32) @@ -1978,10 +1994,24 @@ def _(): correctness_check_values = { "pre": [A_test, B_test], "post": [A_test, B_ref] - } + } if avx2_cpu() else None + + package_name = "test_round_intrinsic_vectorized" + + def file_check_fn(v): + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check_label("llvm.func @{{[a-z_]*}}" + package_name) + checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.cvt.ps2dq.256"(%{{[0-9]+}}) : (vector<8xf32>) -> vector<8xi32>', 4 + ) + checker.run() self._build_nest( - plan, [A, B], "test_round_intrinsic_vectorized", correctness_check_values=correctness_check_values + plan, [A, B], + package_name, + correctness_check_values=correctness_check_values, + file_check_fn=file_check_fn, + platform=get_avx_platform() ) # TODO : fix this test - it appears to abort on just the linux buddy build machine @@ -2018,7 +2048,6 @@ def _(): # self._build_nest(nest, [A, B], "test_remainderf_intrinsic_rounding", correctness_check_values=correctness_check_values) - @expectedFailure(FailedReason.INVALID, "x86 max min intrinsics not supported on MacOS", sys.platform == "darwin") def test_vectorized_max_min(self) -> None: from accera import max, min @@ -2052,7 +2081,7 @@ def _(): j: 8 }) sched.reorder(i, j, ii, jj) - plan = sched.create_plan() + plan = sched.create_plan(Target("Intel 6700")) plan.vectorize(ii) function = package.add(plan, args=(A, B, C_max, C_min), base_name=fn_name) @@ -2064,18 +2093,20 @@ def _(): C_max_ref = np.maximum(A_test, B_test) C_min_ref = np.minimum(A_test, B_test) - correctness_check_values[fn_name] = { - "pre": [A_test, B_test, C_max_test, C_min_test], - "post": [A_test, B_test, C_max_ref, C_min_ref] - } + if avx2_cpu(): + correctness_check_values[fn_name] = { + "pre": [A_test, B_test, C_max_test, C_min_test], + "post": [A_test, B_test, C_max_ref, C_min_ref] + } # build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: package.build( package_name, - format=TEST_FORMAT | Package.Format.MLIR_VERBOSE, + format=TEST_FORMAT if avx2_cpu() else TEST_FORMAT_XCOMPILE, mode=Package.Mode.RELEASE, + platform=get_avx_platform(), output_dir=output_dir ) for fn_name in func_names: @@ -2086,7 +2117,22 @@ def _(): after=correctness_check_values[fn_name]["post"], ) - @expectedFailure(FailedReason.INVALID, "x86 max min intrinsics not supported on MacOS", sys.platform == "darwin") + max_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + max_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + max_checker.run() + + min_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + min_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.min.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + min_checker.run() + def test_vectorized_single_max_min_block(self) -> None: # In this test we're trying to find the single max and single min value of a 2-D array. # To vectorize this, we'll want to compute several maxs and mins in paralle and then reduce them @@ -2134,7 +2180,7 @@ def _(): io_A_min_cache[inner_i, inner_j] = min(io_A_min_cache[inner_i, inner_j], A[i, j]) inner_sched = inner_nest.create_schedule() - inner_plan = inner_sched.create_plan() + inner_plan = inner_sched.create_plan(Target("Intel 6700")) inner_plan.vectorize(inner_i) inner_fn = package.add( inner_plan, @@ -2158,7 +2204,7 @@ def _(): outer_j: N_tile }) outer_sched.reorder(outer_i, outer_j, outer_ii, outer_iii, outer_jj) - outer_plan = outer_sched.create_plan() + outer_plan = outer_sched.create_plan(Target("Intel 6700")) outer_plan._erase_loops([outer_iii, outer_jj]) outer_fn = package.add( outer_plan, @@ -2178,7 +2224,10 @@ def _(): arr[indices] = outer_arr[indices] return package.add( - zero_nest, args=(outer_arr, arr), base_name=base_name, function_opts=INTERNAL_FUNCTION_OPTS + zero_nest.create_plan(Target("Intel 6700")), + args=(outer_arr, arr), + base_name=base_name, + function_opts=INTERNAL_FUNCTION_OPTS ) zero_max_cache_fn = _make_init_fn(package, A, io_A_max_cache, "max_cache_zeroing") @@ -2201,7 +2250,10 @@ def _(): outer_arr[0] = min(outer_arr[0], cache[indices]) return package.add( - reduce_nest, args=(cache, outer_arr), base_name=base_name, function_opts=INTERNAL_FUNCTION_OPTS + reduce_nest.create_plan(Target("Intel 6700")), + args=(cache, outer_arr), + base_name=base_name, + function_opts=INTERNAL_FUNCTION_OPTS ) reduce_max_cache_fn = _make_cache_reduce_fn(package, io_A_max_cache, A_max, "max_cache_reduce", True) @@ -2219,7 +2271,9 @@ def _(): reduce_max_cache_fn(A_max_cache, A_max) reduce_min_cache_fn(A_min_cache, A_min) - function = package.add(top_nest, args=(A, A_max, A_min), base_name=fn_name) + function = package.add( + top_nest.create_plan(Target("Intel 6700")), args=(A, A_max, A_min), base_name=fn_name + ) A_test = np.random.random(A.shape).astype(np.float32) A_max_test = np.random.random(A_max.shape).astype(np.float32) @@ -2228,18 +2282,20 @@ def _(): A_max_ref = np.max(A_test).reshape((1, )) A_min_ref = np.min(A_test).reshape((1, )) - correctness_check_values[fn_name] = { - "pre": [A_test, A_max_test, A_min_test], - "post": [A_test, A_max_ref, A_min_ref] - } + if avx2_cpu(): + correctness_check_values[fn_name] = { + "pre": [A_test, A_max_test, A_min_test], + "post": [A_test, A_max_ref, A_min_ref] + } # build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: package.build( package_name, - format=TEST_FORMAT | Package.Format.MLIR_VERBOSE, + format=TEST_FORMAT if avx2_cpu() else TEST_FORMAT_XCOMPILE, mode=Package.Mode.RELEASE, + platform=get_avx_platform(), output_dir=output_dir ) for fn_name in func_names: @@ -2250,6 +2306,22 @@ def _(): after=correctness_check_values[fn_name]["post"], ) + max_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + max_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + max_checker.run() + + min_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + min_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.min.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + min_checker.run() + def test_intrinsics_float(self) -> None: from accera import ( abs, @@ -2361,7 +2433,9 @@ def _verify_schedule(self, schedule, args: Tuple[Array], package_name, correctne # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -4148,7 +4222,9 @@ def _verify_plan(self, plan, args: Tuple[Array], package_name, correctness_check # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -4351,13 +4427,21 @@ def _(): class DSLTest_07PlansVectorizationParallelization(unittest.TestCase): - def _verify_plan(self, plan, args: Tuple[int], package_name, correctness_check_values=None) -> None: + def _verify_plan(self, plan, args: Tuple[int], package_name, correctness_check_values=None, check_parallelization=False) -> None: package = Package() function = package.add(plan, args, base_name="vectorization_parallelization_test") output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) + + if check_parallelization: + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check_label("omp.parallel") + checker.run() + if correctness_check_values: v.check_correctness( function.name, @@ -4622,6 +4706,7 @@ def _(): [A, B, C], f"test_parallelize_i_{policy}", correctness_check_values, + check_parallelization=True ) # parallelizing middle index @@ -4632,6 +4717,7 @@ def _(): [A, B, C], f"test_parallelize_ii_{policy}", correctness_check_values, + check_parallelization=True ) try: @@ -4643,6 +4729,7 @@ def _(): [A, B, C], f"test_parallelize_i_ii_j_{policy}", correctness_check_values, + check_parallelization=True ) # partial collapsed inner indices @@ -4653,6 +4740,7 @@ def _(): [A, B, C], f"test_parallelize_ii_j_{policy}", correctness_check_values, + check_parallelization=True ) except: # BUGBUG: partial collapsed + dynamic is broken in mlir-translate since LLVM 14 @@ -4679,7 +4767,9 @@ def _verify_package(self, plan, args, package_name, correctness_check_values) -> output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -5053,7 +5143,9 @@ def _(): # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -6464,6 +6556,7 @@ def test_autoplan(self) -> None: class DSLTest_12Profiling(unittest.TestCase): + def _verify_func( self, package, function, package_name, correctness_check_values, quiet=True, mode=TEST_MODE ) -> None: @@ -6471,7 +6564,9 @@ def _verify_func( # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=mode, output_dir=output_dir, _quiet=quiet, profile=True) + package.build( + package_name, format=TEST_FORMAT, mode=mode, output_dir=output_dir, _quiet=quiet, profile=True + ) if correctness_check_values: v.check_correctness( function.name, @@ -6540,7 +6635,7 @@ def _tile_logic(): EnterProfileRegion("pack_b_fn") pack_b_fn(B, B_temp, j, k) ExitProfileRegion("pack_b_fn") - + EnterProfileRegion("matmul_fn") matmul_fn(A, B, C, B_temp, i, j, k) ExitProfileRegion("matmul_fn") diff --git a/accera/python/accera/test/smoke_tests.py b/accera/python/accera/test/smoke_tests.py index 582ec90f..c2d5d067 100644 --- a/accera/python/accera/test/smoke_tests.py +++ b/accera/python/accera/test/smoke_tests.py @@ -8,11 +8,9 @@ import logging import os import pathlib -import platform import shutil import sys import unittest -from itertools import product from typing import Callable, List import numpy as np @@ -54,13 +52,13 @@ from accera._lang_python._lang import Dimension, _MemorySpace, _If, as_index from accera._lang_python._lang._gpu import Barrier from accera.samples import MatrixMultiplication -from accera.Targets import KNOWN_DEVICES +from accera.test.test_utils import expectedFailure, FailedReason from accera.test import verifiers -from accera.test.test_utils import FailedReason, expectedFailure +from accera.test.test_utils import FailedReason, expectedFailure, avx2_cpu, avx512_cpu, get_avx_platform from accera import ( - AUTO, AllocateFlags, Array, Constants, Nest, Package, Role, Scalar, ScalarType, Target, cast, MMAShape, create_dimensions, - create_parameters, fuse + AUTO, AllocateFlags, Array, Constants, Nest, Package, Role, Scalar, ScalarType, Target, cast, MMAShape, + create_dimensions, create_parameters, fuse ) from accera import min as accmin from accera import abs as accabs @@ -73,10 +71,84 @@ # TODO: Remove all @expectedFailure decorators as implementation converges with spec +def make_asm_sequence_filechecker(v: verifiers.VerifyPackage, file_name: str, interleave_loc_tags: bool = True): + + class FileCheckerWrapper: + + def __init__(self, file_checker): + self.file_checker = file_checker + + # Forward all other calls to the file_checker + def __getattr__(self, item): + return getattr(self.file_checker, item) + + def scan_to_next_loc(self): + self.file_checker.check(".loc") + + def check_label(self, check_str: str, check_loc: bool = interleave_loc_tags): + self.file_checker.check_label(check_str) + if check_loc: + self.file_checker.check_next(".loc") + + def check_next(self, check_str: str, check_loc: bool = interleave_loc_tags): + self.file_checker.check_next(check_str) + if check_loc: + self.file_checker.check_next(".loc") + + def check(self, check_str: str, check_loc: bool = interleave_loc_tags): + self.file_checker.check(check_str) + if check_loc: + self.file_checker.check_next(".loc") + + def get_simple_register_regex(self): + # Matches %rbp, %rsi, %r8, %r14, etc. + # Doesn't match xmm, ymm, zmm registers + return r"%([a-z]{3}|r[0-9]+)" + + def get_vector_register_regex(self, required_prefix: str = None): + # Matches xmm, ymm, zmm registers + suffix = r"mm[0-9]+" + if required_prefix is not None: + return r"%" + required_prefix + suffix + return r"%(x|y|z)" + suffix + + def get_register_regex(self): + # Matches any register name prefixed with % + simple_register_regex = self.get_simple_register_regex() + vector_register_regex = self.get_vector_register_regex() + return r"(%s|%s)" % (simple_register_regex, vector_register_regex) + + def get_memory_position_regex(self): + optional_offset = r"(-?[0-9]+)?" + # Examples: + # - any register name in isolation: %rbp + register_regex = self.get_register_regex() + + # - any register name in parentheses without an offset: (%rbp) + # - any register offset by a constant: -64(%rbp), 32(%rsi) + no_offset_register_regex = r"\(%s\)" % register_regex + possibly_offset_register = optional_offset + no_offset_register_regex + + # - or offset by a scaled value in another register: (%rbp,%rsi,8) + # - or offset by a constant and offset by a scaled value in another register: -1028(%rbp,%rsi,8) + scaled_offset_regex = r"\(%s,%s,[0-9]+\)" % (register_regex, register_regex) + possibly_constant_offset_scaled_offset_regex = optional_offset + scaled_offset_regex + + # Matches any of the above + return r"(%s|%s|%s)" % ( + register_regex, possibly_offset_register, possibly_constant_offset_scaled_offset_regex + ) + + return FileCheckerWrapper(v.file_checker(file_name)) + + class SmokeTest(unittest.TestCase): PACKAGE_FORMAT = Package.Format.MLIR_DYNAMIC if DEV_MODE else Package.Format.HAT_DYNAMIC PACKAGE_MODE = Package.Mode.RELEASE + def get_xcompile_package_format(self, matcher_fn: Callable[[], bool] = avx2_cpu) -> Package.Format: + return self.PACKAGE_FORMAT if matcher_fn() else Package.Format.MLIR_STATIC + def test_full_fusion_trivial(self) -> None: A = Array(role=Role.INPUT, shape=(16, 16)) B = Array(role=Role.INPUT, shape=(16, 16)) @@ -626,7 +698,7 @@ def _(): v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) def _test_fast_exp_mlas(self, func_level_precision: bool): - from accera import fast_exp, fast_exp_mlas + from accera import fast_exp_mlas M = 64 N = 64 @@ -665,33 +737,31 @@ def _(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: package.build( package_name, - format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + format=self.get_xcompile_package_format(), mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR, + platform=get_avx_platform(), _opts=pkg_opt, _quiet=False ) + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.check( + '%{{[0-9]+}} = "llvm.intr.fmuladd"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.run() - v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) + if avx2_cpu(): + v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_mlas_w_func_level_precision(self): self._test_fast_exp_mlas(True) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_mlas_w_pkg_level_precision(self): self._test_fast_exp_mlas(False) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_mlas_with_3_vectors(self): from accera import fast_exp_mlas @@ -731,20 +801,30 @@ def _(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: package.build( package_name, - format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + format=self.get_xcompile_package_format(), mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR, + platform=get_avx_platform(), _opts=Package._Options.HIGH_PRECISION_FLOATING_POINT_OPS, _quiet=False ) - v.check_correctness(function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref)) + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 3 + ) + checker.check_count( + '%{{[0-9]+}} = "llvm.intr.fmuladd"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 3 + ) + checker.run() + if avx2_cpu(): + v.check_correctness( + function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref) + ) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_sum(self): from accera import fast_exp_mlas @@ -1051,14 +1131,27 @@ def call_fn_masked_vec_fast_exp_sum(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: package.build( package_name, - format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + format=self.get_xcompile_package_format(), mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR, + platform=get_avx_platform(), _opts=Package._Options.HIGH_PRECISION_FLOATING_POINT_OPS, _quiet=False ) - v.check_correctness(function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref)) + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.check( + '%{{[0-9]+}} = "llvm.intr.fmuladd"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.run() + + if avx2_cpu(): + v.check_correctness( + function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref) + ) def test_emittime_cache_mlas_matmul(self) -> None: from accera.samples.OfflineCacheMatrixMultiplication import \ @@ -1085,7 +1178,8 @@ def test_emittime_cache_mlas_matmul(self) -> None: package.build(package_name, output_dir=output_dir, mode=self.PACKAGE_MODE, format=self.PACKAGE_FORMAT) # check build and correctness - v.check_correctness(function.name, before=(A_test, B_test, C_test), after=(A_test, B_test, C_ref)) + if avx2_cpu(): + v.check_correctness(function.name, before=(A_test, B_test, C_test), after=(A_test, B_test, C_ref)) def test_runtime_init_cache_mlas_matmul(self) -> None: from accera.samples.OfflineCacheMatrixMultiplication import \ @@ -2443,7 +2537,10 @@ def _(): # TODO : Bug - this could occur during buffer packing at an app initialization-time and/or runtime, # but the underlying issue is probably related to the next bug with static split sizes - @expectedFailure(FailedReason.BUG, "Multiple _split_dimension's of a dynamically sized dimension with a dynamic size is not working.") + @expectedFailure( + FailedReason.BUG, + "Multiple _split_dimension's of a dynamically sized dimension with a dynamic size is not working." + ) def test_multiple_dynamic_split_dim_dynamic_size(self) -> None: test_name = "test_multiple_dynamic_split_dim_dynamic_size" @@ -2490,7 +2587,10 @@ def _(): ) # TODO : Bug - this could occur during buffer packing in an app compile-time, initialization-time or runtime - @expectedFailure(FailedReason.BUG, "Multiple _split_dimension's of a dynamically sized buffer into statically sized dimensions is not working.") + @expectedFailure( + FailedReason.BUG, + "Multiple _split_dimension's of a dynamically sized buffer into statically sized dimensions is not working." + ) def test_multiple_dynamic_split_dim_static_size(self) -> None: test_name = "test_multiple_dynamic_split_dim_static_size" @@ -2508,7 +2608,7 @@ def test_multiple_dynamic_split_dim_static_size(self) -> None: Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(MN, )) Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N)) - nest = Nest(shape=(M, N_0, N_1, N_2)) # (M, 4, 5, 6) + nest = Nest(shape=(M, N_0, N_1, N_2)) # (M, 4, 5, 6) i, j_0, j_1, j_2 = nest.get_indices() @nest.iteration_logic @@ -2591,7 +2691,7 @@ def _(): after=(test_MN, test_M, test_N, test_input, test_output_ref) ) - # This test uses all static sizes to make sure the fix for dynamic size (test_dynamic_split_dim_static_size) + # This test uses all static sizes to make sure the fix for dynamic size (test_dynamic_split_dim_static_size) # won't regress the static size case. def test_dynamic_split_dim_all_static(self) -> None: test_name = "test_dynamic_split_dim_all_static" @@ -4612,10 +4712,6 @@ def _(): after=correctness_check_values["post"], ) - @expectedFailure( - FailedReason.INVALID, "generated x86_64 lib not readable by MacOS arm64 build tools", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_int16_matmul_vpmaddwd_16_element_avx512(self): test_name = "test_int16_matmul_vpmaddwd_16_element_avx512" M = 240 @@ -4652,7 +4748,7 @@ def _(): schedule.reorder(i, j, k, ii, jj, kk, kkk, iii, jjj, jjjj, kkkk) # The Intel 8351N is a known Xeon Platinum with AVX-512 support - target = KNOWN_DEVICES[Target.Category.CPU]["Intel 8351N"] + target = Target("Intel 8351N") plan = schedule.create_plan(target) plan.cache(A, index=ii, element_type=ScalarType.int16, vectorize=False) plan.cache(B, index=jjjj, trigger_index=jj, layout=Array.Layout.LAST_MAJOR, vectorize=False) @@ -4660,14 +4756,101 @@ def _(): plan.vectorize(jjjj) package = Package() - function = package.add(plan, args=(A, B, C), base_name=test_name) + package.add(plan, args=(A, B, C), base_name=test_name) output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name # build the HAT package with verifiers.VerifyPackage(self, test_name, output_dir) as v: - package.build(test_name, format=Package.Format.DEFAULT, mode=Package.Mode.RELEASE, output_dir=output_dir) - # Don't check correctness as we've set a target that we may not be running the tests on + package.build( + test_name, + format=self.get_xcompile_package_format(avx512_cpu), + mode=Package.Mode.RELEASE, + output_dir=output_dir, + platform=get_avx_platform() + ) + checker = v.file_checker(f"{test_name}_llvm.mlir") + checker.check( + '%{{[0-9]+}} = "accintr.x86.avx512.pmaddw.d.512"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<32xi16>, vector<32xi16>) -> vector<16xi32>' + ) + checker.run() + + asm_checker = make_asm_sequence_filechecker(v, f"{test_name}.s", interleave_loc_tags=True) + + memory_pos_regex = asm_checker.get_memory_position_regex() + zmm_register_regex = asm_checker.get_vector_register_regex("z") + + # This is computing a 6x32 region using vpmaddwd and vpaddd instructions on zmm registers + # one vpmaddwd followed by vpaddd computes a single 1x16 output for 2 elements of k + store_reg = [[f"store_reg_{row_offset}_{col_offset}" for col_offset in range(0, 32, 16)] + for row_offset in range(0, 6, 1)] + + # 6 different broadcast values from A + broadcast_reg = [f"broadcast_reg_{row_offset}" for row_offset in range(0, 6, 1)] + + # Load 4 different vectors of B data, one for each of (k,j) = (0,0), (0,16), (2,0), (2,16) + load_reg = [[f"load_reg_{k_offset}_{col_offset}" for col_offset in range(0, 32, 16)] + for k_offset in range(0, 4, 2)] + + asm_checker.check_label( + "{{.*}}LBB0_14:", check_loc=False + ) # check_loc=False because there may be comments on the following lines before the next instruction + # Don't use python f-strings because FileCheck uses {{}} as delimiters that would collide with {} delimiters in the f-str. This could be worked around, but %-strs make it simpler + + # Broadcast A[0,0:2] + asm_checker.check( + "vpbroadcastd {{%s}}, [[%s:%s]]" % (memory_pos_regex, broadcast_reg[0], zmm_register_regex) + ) # vpbroadcastd -1028(%rbp,%rsi,8), %zmm5 + + # Load 4 different vectors of B data, one for each of (k,j) = (0,0), (0,16), (2,0), (2,16) + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[0][0], zmm_register_regex), check_loc=False + ) # vmovdqu64 -192(%rbx), %zmm13 + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[0][1], zmm_register_regex), check_loc=False + ) # vmovdqu64 -128(%rbx), %zmm14 + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[1][0], zmm_register_regex), check_loc=False + ) # vmovdqu64 -64(%rbx), %zmm15 + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[1][1], zmm_register_regex), check_loc=False + ) # vmovdqu64 (%rbx), %zmm8 + asm_checker.check_next( + ".loc", check_loc=False + ) # We could have set check_loc=True on the previous line, but we write it this way for better clarity + + def match_vpmaddwd_vpaddd_block(broadcast_reg, load0, load1, store0, store1): + asm_checker.check_next( + "vpmaddwd [[%s]], [[%s]], [[temp_reg0:%s]]" % (load0, broadcast_reg, zmm_register_regex) + ) + asm_checker.check_next( + "vpaddd [[temp_reg0]], {{%s}}, [[%s:%s]]" % (zmm_register_regex, store0, zmm_register_regex) + ) # We don't really care what the original source register was, so don't capture it + asm_checker.check_next( + "vpmaddwd [[%s]], [[%s]], [[temp_reg1:%s]]" % (load1, broadcast_reg, zmm_register_regex) + ) + asm_checker.check_next( + "vpaddd [[temp_reg1]], {{%s}}, [[%s:%s]]" % (zmm_register_regex, store1, zmm_register_regex), + check_loc=False + ) + + for k_idx in range(2): + for row_idx in range(6): + if not (row_idx == 0 and k_idx == 0): + # First broadcast occurs before the loads, so we need to check it separately + asm_checker.check( + "vpbroadcastd {{%s}}, [[%s:%s]]" % + (memory_pos_regex, broadcast_reg[row_idx], zmm_register_regex) + ) # vpbroadcastd -1028(%rbp,%rsi,8), %zmm5 + match_vpmaddwd_vpaddd_block( + broadcast_reg[row_idx], load_reg[k_idx][0], load_reg[k_idx][1], store_reg[row_idx][0], + store_reg[row_idx][1] + ) + + # Jump back to the top of the compute loop + asm_checker.check_label("jne {{.*}}LBB0_14", check_loc=False) + + asm_checker.run() def test_int16_matmul_vpmaddwd_16_element_host(self): test_name = "test_int16_matmul_vpmaddwd_16_element_host" diff --git a/accera/python/accera/test/test_utils.py b/accera/python/accera/test/test_utils.py index 8a3b274b..4e22bea8 100644 --- a/accera/python/accera/test/test_utils.py +++ b/accera/python/accera/test/test_utils.py @@ -5,11 +5,13 @@ from enum import Enum from typing import Callable +import cpuinfo import unittest +import sys import numpy as np -from accera import ScalarType +from accera import Package, ScalarType class FailedReason(Enum): @@ -24,6 +26,7 @@ def expectedFailure(reason: FailedReason, msg: str, condition: bool = True) -> C "Extends the unittest.expectedFailure decorator to print failure details and takes an optional condition" def _decorator(func): + @unittest.expectedFailure def _wrapper(x): print(f"\n{reason.value}: {msg}") @@ -38,8 +41,23 @@ def _wrapper(x): return _decorator +class SkipReason(Enum): + BUG = "Bug" + TARGET_MISMATCH = "Test doesn't execute for target" + + +def skipTest(reason: SkipReason, msg: str, condition: bool = True) -> Callable: + "Extends the unittest.expectedFailure decorator to print failure details and takes an optional condition" + + def _decorator(func): + + return unittest.skipIf(condition, f"\n{reason.value}: {msg}")(func) + + return _decorator + + def get_type_str(datatype: ScalarType): - return datatype.name + return datatype.name def accera_to_np_type(datatype: ScalarType): @@ -62,3 +80,17 @@ def accera_to_np_type(datatype: ScalarType): return np.int8 return None + + +def avx2_cpu(): + cpu_info = cpuinfo.get_cpu_info() + return "flags" in cpu_info and "avx2" in cpu_info["flags"] + + +def avx512_cpu(): + cpu_info = cpuinfo.get_cpu_info() + return "flags" in cpu_info and "avx512" in cpu_info["flags"] + + +def get_avx_platform(): + return Package.Platform.HOST if sys.platform != "darwin" else Package.Platform.WINDOWS \ No newline at end of file diff --git a/accera/python/accera/test/unit_tests.py b/accera/python/accera/test/unit_tests.py index 14f3dc7b..122b32dc 100644 --- a/accera/python/accera/test/unit_tests.py +++ b/accera/python/accera/test/unit_tests.py @@ -172,7 +172,7 @@ def test_module(self) -> None: self.assertIsNotNone(module) with ModuleScope(module): module.Print() - module.Verify() + # module.Verify() # BUGBUG: ModuleOp::verify is private as of LLVM 15 header_filename = f"test_module_{time.time()}.hat" module.WriteHeader(header_filename) diff --git a/accera/python/lib/src/PackagingTypes.cpp b/accera/python/lib/src/PackagingTypes.cpp index 9f98cee8..29ae4b1d 100644 --- a/accera/python/lib/src/PackagingTypes.cpp +++ b/accera/python/lib/src/PackagingTypes.cpp @@ -116,7 +116,6 @@ ARM: fp16, neon, vfp3, d16, vfp4, hwdiv-arm, hwdiv "_flags"_a = value::AllocateFlags::None) .def("Print", &value::MLIRContext::print, "Prints the module") .def("Save", &value::MLIRContext::save, "filename"_a) - .def("Verify", &value::MLIRContext::verify) .def("WriteHeader", &value::MLIRContext::writeHeader, "filename"_a = std::nullopt) .def("SetMetadata", &value::MLIRContext::setMetadata) .def("GetFullMetadata", &value::MLIRContext::getFullMetadata) diff --git a/accera/python/llvm/setup.cfg b/accera/python/llvm/setup.cfg index af4bd8bf..6ec56db4 100644 --- a/accera/python/llvm/setup.cfg +++ b/accera/python/llvm/setup.cfg @@ -6,7 +6,10 @@ name = accera-llvm # Accera micro versions are at least 2 digits: # LLVM 15.0.7 -> 15.0.700 # LLVM 15.0.7-1 -> be 15.0.701 -version = 14.0.602 +# Note that the micro versions must start with a non-zero digit, else setup tools will +# normalize the version by removing the leading zeros +# Note: keep version in sync with Accera/setup.cfg +version = 15.0.101 author = Microsoft Research AI Compilers Team author_email = mlodev@microsoft.com summary = Accera LLVM Binaries diff --git a/accera/transforms/CMakeLists.txt b/accera/transforms/CMakeLists.txt index eca1aab3..e9036114 100644 --- a/accera/transforms/CMakeLists.txt +++ b/accera/transforms/CMakeLists.txt @@ -157,9 +157,11 @@ target_link_libraries( MLIRROCDLIR MLIRROCDLToLLVMIRTranslation MLIRStandardToLLVM - MLIRSCFToStandard + MLIRSCFToControlFlow + MLIRControlFlowToLLVM MLIRAffineToStandard MLIRAffineTransforms + MLIRAffineUtils MLIRLinalgToLLVM MLIRLinalgTransforms MLIRTargetLLVMIRExport diff --git a/accera/transforms/include/util/RangeValueUtilities.h b/accera/transforms/include/util/RangeValueUtilities.h index cc06fd1a..1c5fd8d7 100644 --- a/accera/transforms/include/util/RangeValueUtilities.h +++ b/accera/transforms/include/util/RangeValueUtilities.h @@ -79,6 +79,9 @@ class RangeValueAnalysis RangeValue resolveRangeValue(mlir::AffineApplyOp op); RangeValue resolveRangeValue(mlir::scf::ForOp op); RangeValue resolveRangeValue(mlir::Operation* op); + + void addSCFParallelOp(mlir::scf::ParallelOp op); + void addAffineParallelOp(mlir::AffineParallelOp op); }; } // namespace accera::ir::util diff --git a/accera/transforms/include/value/ValueToLLVMLoweringPass.h b/accera/transforms/include/value/ValueToLLVMLoweringPass.h index 62a73dc4..f6f8417f 100644 --- a/accera/transforms/include/value/ValueToLLVMLoweringPass.h +++ b/accera/transforms/include/value/ValueToLLVMLoweringPass.h @@ -34,9 +34,9 @@ class RewritePatternSet; namespace accera::transforms::value { -void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); +void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo); void populateGlobalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); -void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); +void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo); void populateValueToLLVMMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); void populateReshapeOpToLLVMMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); @@ -49,5 +49,6 @@ std::unique_ptr> createValueToLLVMPass(bool unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, + accera::value::TargetDevice deviceInfo = {}, const IntraPassSnapshotOptions& options = {}); } // namespace accera::transforms::value diff --git a/accera/transforms/src/AcceraPasses.cpp b/accera/transforms/src/AcceraPasses.cpp index e075383d..55785f5c 100644 --- a/accera/transforms/src/AcceraPasses.cpp +++ b/accera/transforms/src/AcceraPasses.cpp @@ -133,6 +133,20 @@ void simplifyAndLowerAffine(PassManagerAdaptor& pmAdaptor) pmAdaptor.addPass(createLowerAffinePass()); } +bool addGPUPasses(PassManagerAdaptor& pmAdaptor, const accera::value::ExecutionRuntime execRuntime, const AcceraPassPipelineOptions& options) +{ + auto gpuPass = createAcceraToGPUPass(execRuntime); + if (gpuPass) + { + pmAdaptor.addPass(createGPUSimplificationPass()); + pmAdaptor.addPass(value::createBarrierOptPass(options.writeBarrierGraph.getValue(), options.barrierGraphFilename.getValue())); + pmAdaptor.addPass(std::move(gpuPass)); + return true; + } + + return false; +} + void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOptions& options) { ir::InitializeAccera(); @@ -173,7 +187,6 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti funcOpPM.addPass(createLoopInvariantCodeMotionPass()); funcOpPM.addPass(createCSEPass()); - pmAdaptor.addPass(createConvertSCFToOpenMPPass()); pmAdaptor.addPass(value::createValueToStdPass(options.enableProfile)); pmAdaptor.addPass(value::createRangeValueOptimizePass()); pmAdaptor.addPass(createCanonicalizerPass()); @@ -183,24 +196,28 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti { // The spirv lowering doesn't generate affine dialect ops, and the SPIRV dialect doesn't play nicely with them, so lower the affine ops before running the GPU lowering simplifyAndLowerAffine(pmAdaptor); + pmAdaptor.addPass(createGpuKernelOutliningPass()); + addGPUPasses(pmAdaptor, execRuntime, options); } - - pmAdaptor.addPass(createGpuKernelOutliningPass()); - auto gpuPass = createAcceraToGPUPass(execRuntime); - if (gpuPass) + else { - pmAdaptor.addPass(createGPUSimplificationPass()); - pmAdaptor.addPass(value::createBarrierOptPass(options.writeBarrierGraph.getValue(), options.barrierGraphFilename.getValue())); - pmAdaptor.addPass(std::move(gpuPass)); - } + pmAdaptor.addPass(createGpuKernelOutliningPass()); + const auto isGPU = addGPUPasses(pmAdaptor, execRuntime, options); - if (execRuntime != accera::value::ExecutionRuntime::VULKAN) - { // lowering to runtimes other than SPIRV generates affine dialect ops so optimize and lower those now simplifyAndLowerAffine(pmAdaptor); - if (execRuntime == accera::value::ExecutionRuntime::ROCM) + + if (isGPU) + { + if (execRuntime == accera::value::ExecutionRuntime::ROCM) + { + pmAdaptor.addPass(createGPUToROCDLPass()); + } + } + else { - pmAdaptor.addPass(createGPUToROCDLPass()); + // Convert to OMP when in non-GPU scenarios + pmAdaptor.addPass(createConvertSCFToOpenMPPass()); } } @@ -233,7 +250,7 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti funcOpPM.addPass(createConvertVectorToSCFPass( VectorTransferToSCFOptions{} /*.setLowerPermutationMaps(true) .setLowerTensors(true).setUnroll(true) */)); - pmAdaptor.addPass(createLowerToCFGPass()); + pmAdaptor.addPass(createConvertSCFToCFPass()); if (execRuntime != accera::value::ExecutionRuntime::VULKAN) { @@ -255,6 +272,7 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti /* indexBitwidth = */ kDeriveIndexBitwidthFromDataLayout, /* useAlignedAlloc = */ true, /* dataLayout = */ llvm::DataLayout(accera::value::GetTargetDevice(options.target).dataLayout), + /* deviceInfo = */ accera::value::GetTargetDevice(options.target), { options.dumpIntraPassIR.getValue(), options.basename + "ValueToLLVM_Subpasses" })); pmAdaptor.addPass(createCanonicalizerPass()); pmAdaptor.addPass(LLVM::createLegalizeForExportPass()); diff --git a/accera/transforms/src/affine/AffineLoopNormalize.cpp b/accera/transforms/src/affine/AffineLoopNormalize.cpp index d29a9539..2b905d82 100644 --- a/accera/transforms/src/affine/AffineLoopNormalize.cpp +++ b/accera/transforms/src/affine/AffineLoopNormalize.cpp @@ -81,7 +81,7 @@ struct AcceraAffineLoopNormalizePass : public accera::transforms::AcceraAffineLo // See \mlir\lib\Dialect\Affine\Transforms\AffineLoopNormalize.cpp op->walk([](mlir::AffineForOp affineForOp) { workaroundModifyAffineForOp(affineForOp); - mlir::normalizeAffineFor(affineForOp); + (void)mlir::normalizeAffineFor(affineForOp); }); } }; diff --git a/accera/transforms/src/affine/AffineSimplifications.cpp b/accera/transforms/src/affine/AffineSimplifications.cpp index 64d249bd..9da3dac4 100644 --- a/accera/transforms/src/affine/AffineSimplifications.cpp +++ b/accera/transforms/src/affine/AffineSimplifications.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include diff --git a/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp b/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp index dd125b54..5cf3ac80 100644 --- a/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp +++ b/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp @@ -705,14 +705,12 @@ std::tuple CreateCacheLoopnestHelper( { // Determine how much of the nest can be vectorized and set the vectorization info on those loops budget /= innermostLoopRange.NumIterations(); - int numVectorizedLoops = 1; for (size_t loopCounter = 1; loopCounter < loopnestInfo.fullySplitRanges.size(); ++loopCounter) { size_t loopIdx = loopnestInfo.fullySplitRanges.size() - loopCounter - 1; // Vectorize loops from the innermost to the outermost as long as we still have vector registers to work with auto loopRange = loopnestInfo.fullySplitRanges[loopIdx]; auto loopUnrollFactor = std::min(budget, loopRange.NumIterations()); InPlaceUnrollInfo inPlaceUnrollInfo{ loopUnrollFactor }; - numVectorizedLoops++; SetInPlaceUnrollInfo(cacheNestSchedule, cacheNestScheduleOrder[loopIdx], inPlaceUnrollInfo); budget /= loopUnrollFactor; if (budget <= 1) // if there is only 1 in-place op unroll left in the budget then we're done vectorizing diff --git a/accera/transforms/src/gpu/AcceraToGPUPass.cpp b/accera/transforms/src/gpu/AcceraToGPUPass.cpp index 4a9d36f5..3f91bb09 100644 --- a/accera/transforms/src/gpu/AcceraToGPUPass.cpp +++ b/accera/transforms/src/gpu/AcceraToGPUPass.cpp @@ -457,7 +457,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern return failure(); } - newOp = rewriter.create(loc, newOp, rewriter.getIndexType()); + newOp = rewriter.create(loc, rewriter.getIndexType(), newOp); rewriter.replaceOp(op, { newOp }); return success(); diff --git a/accera/transforms/src/util/MathUtilities.cpp b/accera/transforms/src/util/MathUtilities.cpp index 56f4f898..e4fb2942 100644 --- a/accera/transforms/src/util/MathUtilities.cpp +++ b/accera/transforms/src/util/MathUtilities.cpp @@ -6,7 +6,7 @@ #include "util/MathUtilities.h" #include -#include +#include namespace accera::transforms { @@ -24,13 +24,13 @@ mlir::Value SaturateValue(mlir::PatternRewriter& rewriter, mlir::Value value, in if (auto vectorType = resultType.dyn_cast()) { - minConst = rewriter.create(loc, minConst, vectorType); - maxConst = rewriter.create(loc, maxConst, vectorType); + minConst = rewriter.create(loc, minConst, vectorType); + maxConst = rewriter.create(loc, maxConst, vectorType); } auto maxCmp = rewriter.create(loc, mlir::arith::CmpIPredicate::sgt, value, minConst); - auto temp = rewriter.create(loc, maxCmp, value, minConst); + auto temp = rewriter.create(loc, maxCmp, value, minConst); auto minCmp = rewriter.create(loc, mlir::arith::CmpIPredicate::slt, temp, maxConst); - auto result = rewriter.create(loc, minCmp, temp, maxConst); + auto result = rewriter.create(loc, minCmp, temp, maxConst); return result; } diff --git a/accera/transforms/src/util/RangeValueUtilities.cpp b/accera/transforms/src/util/RangeValueUtilities.cpp index d8b25b89..de565cff 100644 --- a/accera/transforms/src/util/RangeValueUtilities.cpp +++ b/accera/transforms/src/util/RangeValueUtilities.cpp @@ -68,6 +68,25 @@ RangeValue resolveGridDimRange(Operation* op, gpu::Dimension dimId) return RangeValue(); } +mlir::APInt toAPInt(int64_t val) +{ + return mlir::APInt(RangeValue::maxBitWidth, val, true); +} + +RangeValue resolveConstantForLoopRange(mlir::APInt lb, mlir::APInt ub, mlir::APInt step) +{ + auto range = ub - lb; + auto remainder = range.srem(step); + auto largestInductionVarValue = (remainder.sgt(0)) ? (ub - remainder) : (ub - step); + + return RangeValue(lb, largestInductionVarValue); +} + +RangeValue resolveConstantForLoopRange(int64_t lb, int64_t ub, int64_t step) +{ + return resolveConstantForLoopRange(toAPInt(lb), toAPInt(ub), toAPInt(step)); +} + } // namespace namespace accera::ir::util @@ -172,8 +191,86 @@ RangeValue RangeValueAnalysis::getRange(Value value) const return it->second; } +void RangeValueAnalysis::addSCFParallelOp(mlir::scf::ParallelOp op) +{ + mlir::scf::ParallelOpAdaptor adaptor{ op }; + auto ivs = op.getInductionVars(); + auto lbs = adaptor.getLowerBound(); + auto ubs = adaptor.getUpperBound(); + auto steps = adaptor.getStep(); + assert(ivs.size() == lbs.size()); + assert(ivs.size() == ubs.size()); + assert(ivs.size() == steps.size()); + unsigned numDims = ivs.size(); + + for (unsigned dimIdx = 0; dimIdx < numDims; ++dimIdx) + { + auto lb = lbs[dimIdx]; + auto ub = ubs[dimIdx]; + auto step = steps[dimIdx]; + auto lbConst = lb.getDefiningOp(); + auto ubConst = ub.getDefiningOp(); + auto stepConst = step.getDefiningOp(); + if (lbConst && ubConst && stepConst) + { + auto range = resolveConstantForLoopRange(lbConst.value(), ubConst.value(), stepConst.value()); + _rangeMap.insert({ ivs[dimIdx], range }); + } + else + { + _rangeMap.insert({ ivs[dimIdx], RangeValue() }); + } + } +} + +void RangeValueAnalysis::addAffineParallelOp(mlir::AffineParallelOp op) +{ + auto constantRangesOpt = op.getConstantRanges(); + if (!constantRangesOpt.hasValue()) + { + return; + } + + unsigned numDims = op.getNumDims(); + auto ivs = op.getIVs(); + auto constantRanges = constantRangesOpt.getValue(); + auto lbValueMap = op.getLowerBoundsValueMap(); + auto steps = op.getSteps(); + + assert(numDims == ivs.size()); + assert(numDims == constantRanges.size()); + assert(numDims == lbValueMap.getNumResults()); + assert(numDims == steps.size()); + + for (unsigned dimIdx = 0; dimIdx < numDims; ++dimIdx) + { + auto expr = lbValueMap.getResult(dimIdx); + if (expr.isa()) + { + auto lb = expr.cast().getValue(); + auto range = resolveConstantForLoopRange(lb, lb + constantRanges[dimIdx], steps[dimIdx]); + _rangeMap.insert({ ivs[dimIdx], range }); + } + else + { + _rangeMap.insert({ ivs[dimIdx], RangeValue() }); + } + } +} + RangeValue RangeValueAnalysis::addOperation(mlir::Operation* op) { + // Special case AffineParallelOp and scf::ParallelOp as they can have multiple results, + // however we care more about their possibly-multiple IVs which may have static ranges + if (isa(op) || isa(op)) + { + mlir::TypeSwitch(op) + .Case([&](scf::ParallelOp op) { addSCFParallelOp(op); }) + .Case([&](AffineParallelOp op) { addAffineParallelOp(op); }); + // Possibly no single range to return for the op itself + return RangeValue(); + } + if (op->getNumResults() > 1) { // Only operations with 0 or 1 results can have their ranges tracked successfully currently @@ -415,14 +512,11 @@ RangeValue RangeValueAnalysis::resolveRangeValue(AffineForOp op) auto ub = op.getConstantUpperBound(); auto step = op.getStep(); - auto range = ub - lb; - auto remainder = range % step; - auto largestInductionVarValue = (remainder > 0) ? (ub - remainder) : (ub - step); - - return RangeValue(lb, largestInductionVarValue); + return resolveConstantForLoopRange(lb, ub, step); } return RangeValue(); } + RangeValue RangeValueAnalysis::resolveRangeValue(scf::ForOp op) { assert(op.getNumInductionVars() == 1); @@ -442,14 +536,11 @@ RangeValue RangeValueAnalysis::resolveRangeValue(scf::ForOp op) auto ub = upperBound.range.getUpper(); auto step = stepSize.range.getLower(); - auto range = ub - lb; - auto remainder = range.srem(step); - auto largestInductionVarValue = (remainder.sgt(0)) ? (ub - remainder) : (ub - step); - - return RangeValue(lb, largestInductionVarValue); + return resolveConstantForLoopRange(lb, ub, step); } return RangeValue(); } + RangeValue RangeValueAnalysis::resolveRangeValue(mlir::Operation* op) { return mlir::TypeSwitch(op) diff --git a/accera/transforms/src/value/ValueSimplifyPass.cpp b/accera/transforms/src/value/ValueSimplifyPass.cpp index b67ea6a3..b7722df7 100644 --- a/accera/transforms/src/value/ValueSimplifyPass.cpp +++ b/accera/transforms/src/value/ValueSimplifyPass.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include @@ -125,7 +124,7 @@ struct ValueSliceSimplifyPattern : public OpRewritePattern auto indexShape = index.getType().cast().getShape(); if (indexShape.size() == 0 || indexShape.size() == 1) { - resolvedOffsets[dim] = rewriter.create(loc, rewriter.create(loc, index), indexType); + resolvedOffsets[dim] = rewriter.create(loc, indexType, rewriter.create(loc, index)); } else { @@ -344,7 +343,7 @@ LogicalResult CopyOpLowering::matchAndRewrite( if (outputMemRef.getElementType().isInteger(64)) // this should really be target dependent... { (void)rewriter.create(loc, - rewriter.create(loc, input, rewriter.getIntegerType(64)), + rewriter.create(loc, rewriter.getIntegerType(64), input), output, std::vector(outputMemRef.getRank(), zero)); } diff --git a/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp b/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp index b7db858c..70dff286 100644 --- a/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp +++ b/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -28,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -66,6 +66,15 @@ using namespace accera::transforms::value; namespace { +// This is ported from Linux code time.h +// #define CLOCK_REALTIME 0 // Identifier for system-wide realtime clock. +// #define CLOCK_MONOTONIC 1 // Monotonic system-wide clock. +enum class ClockID +{ + ACCERA_CLOCK_REALTIME = 0, + ACCERA_CLOCK_MONOTONIC = 1, +}; + // TODO: Refactor this class and find a better place for this helper class class LLVMTypeConverterDynMem : public mlir::LLVMTypeConverter { @@ -344,6 +353,13 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern { using ValueLLVMOpConversionPattern::ValueLLVMOpConversionPattern; + accera::value::TargetDevice deviceInfo; + + GetTimeOpLowering(LLVMTypeConverter& converter, mlir::MLIRContext* context, accera::value::TargetDevice deviceInfo) : + ValueLLVMOpConversionPattern(converter, context), + deviceInfo(deviceInfo) + {} + LogicalResult matchAndRewrite( GetTimeOp op, OpAdaptor adaptor, @@ -371,10 +387,11 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern static FlatSymbolRefAttr getOrInsertClockGetTime(PatternRewriter& rewriter, ModuleOp module, - LLVM::LLVMDialect* llvmDialect) + LLVM::LLVMDialect* llvmDialect, + size_t numBits) { auto* context = module.getContext(); - auto llvmFnType = getGetTimeFunctionType(context); + auto llvmFnType = getGetTimeFunctionType(context, numBits); return getOrInsertLibraryFunction(rewriter, "clock_gettime", llvmFnType, module, llvmDialect); } @@ -394,13 +411,13 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern return LLVM::LLVMFunctionType::get(boolTy, { argTy }, /*isVarArg=*/false); } - static Type getGetTimeFunctionType(mlir::MLIRContext* context) + static Type getGetTimeFunctionType(mlir::MLIRContext* context, size_t numBits) { // Create a function type for clock_gettime, the signature is: // int clock_gettime(clockid_t clockid, struct timespec *tp); - auto returnTy = getIntType(context); - auto llvmClockIdTy = getClockIdType(context); - auto llvmTimespecTy = getTimeSpecType(context); + auto returnTy = getIntType(context, numBits); + auto llvmClockIdTy = getClockIdType(context, numBits); + auto llvmTimespecTy = getTimeSpecType(context, numBits); auto llvmTimespecPtrTy = LLVM::LLVMPointerType::get(llvmTimespecTy); return LLVM::LLVMFunctionType::get(returnTy, { llvmClockIdTy, llvmTimespecPtrTy }, /*isVarArg=*/false); } @@ -410,29 +427,27 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern return IntegerType::get(context, 64); } - static Type getClockIdType(mlir::MLIRContext* context) + static Type getClockIdType(mlir::MLIRContext* context, size_t numBits) { - return getIntType(context); + return getIntType(context, numBits); } - static Type getTimeSpecType(mlir::MLIRContext* context) + static Type getTimeSpecType(mlir::MLIRContext* context, size_t numBits) { // struct timespec { // time_t tv_sec; /* seconds */ // long tv_nsec; /* nanoseconds */ // }; - auto llvmIntTy = getIntType(context); + auto llvmIntTy = getIntType(context, numBits); auto llvmTimespecTy = LLVM::LLVMStructType::getLiteral(context, { llvmIntTy, llvmIntTy }, /* isPacked */ true); return llvmTimespecTy; } - static Type getIntType(mlir::MLIRContext* context) + static Type getIntType(mlir::MLIRContext* context, size_t numBits) { auto llvmI32Ty = IntegerType::get(context, 32); auto llvmI64Ty = IntegerType::get(context, 64); - const int hostBitSize = 64; // TODO:: FIXME :: This assumes that the host is always 64bit - // Should query the target hardware - auto llvmIntTy = hostBitSize == 32 ? llvmI32Ty : llvmI64Ty; + auto llvmIntTy = numBits == 32 ? llvmI32Ty : llvmI64Ty; return llvmIntTy; } }; @@ -757,7 +772,7 @@ struct RoundOpLowering : public ValueLLVMOpConversionPattern // Create arithmetic dialect cast ops with the expectation that other arithmetic dialect ops are getting lowered as part of this pass auto signlessOutputType = util::ToSignlessMLIRType(rewriter, op.getType()); - mlir::Value roundedSIVal = rewriter.create(op.getLoc(), roundedFPVal, signlessOutputType); + mlir::Value roundedSIVal = rewriter.create(op.getLoc(), signlessOutputType, roundedFPVal); rewriter.replaceOpWithNewOp(op, op.getType(), roundedSIVal); } return success(); @@ -766,8 +781,9 @@ struct RoundOpLowering : public ValueLLVMOpConversionPattern struct ValueToLLVMLoweringPass : public ConvertValueToLLVMBase { - ValueToLLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, const IntraPassSnapshotOptions& snapshotteroptions = {}) : - _intrapassSnapshotter(snapshotteroptions) + ValueToLLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, accera::value::TargetDevice deviceInfo = {}, const IntraPassSnapshotOptions& snapshotteroptions = {}) : + _intrapassSnapshotter(snapshotteroptions), + deviceInfo(deviceInfo) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; @@ -781,6 +797,7 @@ struct ValueToLLVMLoweringPass : public ConvertValueToLLVMBase(timerRegionTypeAttr.getInt()) == TimerRegionType::enterRegion; } - // TODO: get check `TargetDeviceInfo` for the OS instead -#ifdef WIN32 - if (isEnterRegionTimer) { - auto boolTy = IntegerType::get(context, 8); - auto argTy = getPerformanceCounterType(context); - LLVMTypeConverter llvmTypeConverter(context); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - - auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); - Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); - - Value perfFreq = rewriter.create(loc, perfFreqPtr); - Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); - - auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); - Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); - - [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); - [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); - - Value perfCount = rewriter.create(loc, perfCountPtr); - Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); - - Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); - return result; - } - else + if (this->deviceInfo.IsWindows()) { - auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); - auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); + if (isEnterRegionTimer) { + auto boolTy = IntegerType::get(context, 8); + auto argTy = getPerformanceCounterType(context); + LLVMTypeConverter llvmTypeConverter(context); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + + auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); + Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); - auto boolTy = IntegerType::get(context, 8); - auto argTy = getPerformanceCounterType(context); - LLVMTypeConverter llvmTypeConverter(context); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - - Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); + Value perfFreq = rewriter.create(loc, perfFreqPtr); + Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); - Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); + auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); + Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); - [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); - [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); + [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); + [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); - Value perfCount = rewriter.create(loc, perfCountPtr); - Value perfFreq = rewriter.create(loc, perfFreqPtr); + Value perfCount = rewriter.create(loc, perfCountPtr); + Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); - Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); - Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); - Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); - return result; - } -#else - if (isEnterRegionTimer) - { - auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect); - - auto llvmTimespecTy = getTimeSpecType(context); - auto clockIdTy = getClockIdType(context); - auto intTy = getIntType(context); - - // Get a symbol reference to the gettime function, inserting it if necessary. - LLVMTypeConverter llvmTypeConverter(context); - Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(CLOCK_REALTIME)); - - Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); - Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); - Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); - - std::vector args{ clockId, timespecPtr }; - auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context) }, clockGetTimeFn, args); - [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); - - Value secondsIntVal = rewriter.create(loc, secondsPtr); - Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); - Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); - Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); - Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); - Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); - Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); - return totalSecondsDoubleVal; + Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); + return result; + } + else + { + auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); + auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); + + auto boolTy = IntegerType::get(context, 8); + auto argTy = getPerformanceCounterType(context); + LLVMTypeConverter llvmTypeConverter(context); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); + + Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); + [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); + [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); + + Value perfCount = rewriter.create(loc, perfCountPtr); + Value perfFreq = rewriter.create(loc, perfFreqPtr); + + Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); + Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); + Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); + return result; + } } else { - auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect); - - auto llvmTimespecTy = getTimeSpecType(context); - auto clockIdTy = getClockIdType(context); - auto intTy = getIntType(context); - - // Get a symbol reference to the gettime function, inserting it if necessary. - LLVMTypeConverter llvmTypeConverter(context); - Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(CLOCK_REALTIME)); - - Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); - - std::vector args{ clockId, timespecPtr }; - auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context) }, clockGetTimeFn, args); - [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); - - Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); - Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); - - Value secondsIntVal = rewriter.create(loc, secondsPtr); - Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); - Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); - Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); - Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); - Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); - Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); - return totalSecondsDoubleVal; + if (isEnterRegionTimer) + { + auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect, this->deviceInfo.numBits); + + auto llvmTimespecTy = getTimeSpecType(context, this->deviceInfo.numBits); + auto clockIdTy = getClockIdType(context, this->deviceInfo.numBits); + auto intTy = getIntType(context, this->deviceInfo.numBits); + + // Get a symbol reference to the gettime function, inserting it if necessary. + LLVMTypeConverter llvmTypeConverter(context); + Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); + + Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(int64_t(ClockID::ACCERA_CLOCK_REALTIME))); + + Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); + Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); + Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); + + std::vector args{ clockId, timespecPtr }; + auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context, this->deviceInfo.numBits) }, clockGetTimeFn, args); + [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); + + Value secondsIntVal = rewriter.create(loc, secondsPtr); + Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); + Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); + Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); + Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); + Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); + Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); + return totalSecondsDoubleVal; + } + else + { + auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect, this->deviceInfo.numBits); + + auto llvmTimespecTy = getTimeSpecType(context, this->deviceInfo.numBits); + auto clockIdTy = getClockIdType(context, this->deviceInfo.numBits); + auto intTy = getIntType(context, this->deviceInfo.numBits); + + // Get a symbol reference to the gettime function, inserting it if necessary. + LLVMTypeConverter llvmTypeConverter(context); + Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); + Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(int64_t(ClockID::ACCERA_CLOCK_REALTIME))); + + Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); + + std::vector args{ clockId, timespecPtr }; + auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context, this->deviceInfo.numBits) }, clockGetTimeFn, args); + [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); + + Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); + Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); + + Value secondsIntVal = rewriter.create(loc, secondsPtr); + Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); + Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); + Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); + Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); + Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); + Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); + return totalSecondsDoubleVal; + } } -#endif } LogicalResult GetTimeOpLowering::matchAndRewrite( @@ -1856,7 +1872,7 @@ void ValueToLLVMLoweringPass::runOnModule() intermediateTarget.addLegalDialect(); RewritePatternSet patterns(&getContext()); - populateValueToLLVMNonMemPatterns(llvmTypeConverter, patterns); + populateValueToLLVMNonMemPatterns(llvmTypeConverter, patterns, this->deviceInfo); populateLinalgToLLVMConversionPatterns(llvmTypeConverter, patterns); @@ -1891,6 +1907,7 @@ void ValueToLLVMLoweringPass::runOnModule() populateStdToLLVMConversionPatterns(llvmTypeConverter, patterns); arith::populateArithmeticToLLVMConversionPatterns(llvmTypeConverter, patterns); arith::populateArithmeticExpandOpsPatterns(patterns); + cf::populateControlFlowToLLVMConversionPatterns(llvmTypeConverter, patterns); // Subset of LowerVectorToLLVMPass patterns vector::populateVectorToVectorCanonicalizationPatterns(patterns); @@ -1947,7 +1964,7 @@ void populateGlobalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConver patterns.insert(typeConverter, context); } -void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns) +void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo) { mlir::MLIRContext* context = patterns.getContext(); @@ -1957,19 +1974,20 @@ void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConvert BitcastOpLowering, CallOpLowering, PrintFOpLowering, - GetTimeOpLowering, RangeOpLowering, VpmaddwdOpLowering, VmaxpsOpLowering, VminpsOpLowering, RoundOpLowering, MemrefAllocOpLowering>(typeConverter, context); + + patterns.insert(typeConverter, context, deviceInfo); } -void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns) +void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo) { populateGlobalValueToLLVMNonMemPatterns(typeConverter, patterns); - populateLocalValueToLLVMNonMemPatterns(typeConverter, patterns); + populateLocalValueToLLVMNonMemPatterns(typeConverter, patterns, deviceInfo); } void populateValueToLLVMMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns) @@ -2009,9 +2027,10 @@ std::unique_ptr> createValueToLLVMPass(bool unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, + accera::value::TargetDevice deviceInfo /* = {} */, const IntraPassSnapshotOptions& options /* = {} */) { - return std::make_unique(useBasePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, dataLayout, options); + return std::make_unique(useBasePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, dataLayout, deviceInfo, options); } std::unique_ptr> createValueToLLVMPass() diff --git a/accera/transforms/src/value/ValueToStandardLoweringPass.cpp b/accera/transforms/src/value/ValueToStandardLoweringPass.cpp index 690df492..813ee62c 100644 --- a/accera/transforms/src/value/ValueToStandardLoweringPass.cpp +++ b/accera/transforms/src/value/ValueToStandardLoweringPass.cpp @@ -580,7 +580,7 @@ struct CastOpLowering : public OpRewritePattern #define CAST_FROM_TO_WITH_OP_IF(testFromType, testToType, castOp, conditional) \ if (fromType && toType && fromElementType.isa() && toElementType.isa() && conditional) \ { \ - mlir::Value castValue = rewriter.create(op.getLoc(), signlessFromValue, signlessToType); \ + mlir::Value castValue = rewriter.create(op.getLoc(), signlessToType, signlessFromValue); \ if (toType.isIntOrIndex()) \ { \ rewriter.replaceOpWithNewOp(op, toType, castValue); \ @@ -652,14 +652,14 @@ struct CastOpLowering : public OpRewritePattern auto i64IntermediateType = accera::ir::util::CloneTypeWithNewElementType(op.source().getType(), rewriter.getI64Type()); if (fromElementType.isa() && toElementType.isa()) { - auto int64Value = rewriter.create(loc, op.source(), i64IntermediateType); // index->int64 - rewriter.replaceOpWithNewOp(op, int64Value, toElementType); // int64->fp + auto int64Value = rewriter.create(loc, i64IntermediateType, op.source()); // index->int64 + rewriter.replaceOpWithNewOp(op, toElementType, int64Value); // int64->fp return success(); } if (fromElementType.isa() && toElementType.isa()) { - auto int64Value = rewriter.create(loc, op.source(), i64IntermediateType); // fp->int64 - rewriter.replaceOpWithNewOp(op, int64Value, toElementType); // int64->index + auto int64Value = rewriter.create(loc, i64IntermediateType, op.source()); // fp->int64 + rewriter.replaceOpWithNewOp(op, toElementType, int64Value); // int64->index return success(); } @@ -1238,7 +1238,7 @@ LogicalResult CopyOpLowering::matchAndRewrite( if (outputMemRef.getElementType().isInteger(64)) // this should really be target dependent... { (void)rewriter.create(loc, - rewriter.create(loc, input, rewriter.getIntegerType(64)), + rewriter.create(loc, rewriter.getIntegerType(64), input), output, std::vector(outputMemRef.getRank(), zero)); } @@ -1358,7 +1358,7 @@ LogicalResult LoadOpLowering::matchAndRewrite( else { resolvedIndices.push_back( - rewriter.create(loc, rewriter.create(loc, index), indexType)); + rewriter.create(loc, indexType, rewriter.create(loc, index))); } } @@ -1388,8 +1388,8 @@ LogicalResult StoreOpLowering::matchAndRewrite( { resolvedIndices.push_back( rewriter.create(loc, - rewriter.create(loc, index), - indexType)); + indexType, + rewriter.create(loc, index))); } } @@ -1524,8 +1524,8 @@ LogicalResult OffsetOpLowering::matchAndRewrite( { resolvedOffsets.push_back( rewriter.create(loc, - rewriter.create(loc, index), - indexType)); + indexType, + rewriter.create(loc, index))); } else { @@ -1611,7 +1611,7 @@ LogicalResult SliceOpLowering::matchAndRewrite( auto indexShape = index.getType().cast().getShape(); if (indexShape.size() == 0 || indexShape.size() == 1) { - index = rewriter.create(loc, rewriter.create(loc, index), indexType); + index = rewriter.create(loc, indexType, rewriter.create(loc, index)); } else { @@ -1743,8 +1743,8 @@ LogicalResult ReorderOpLowering::matchAndRewrite( auto resultType = op.getType(); // cast to a value with type `memref` (via `memref<* x elem_type>`) - mlir::Value ptr = rewriter.create(loc, source, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace())); - auto result = rewriter.create(loc, ptr, resultType); + mlir::Value ptr = rewriter.create(loc, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace()), source); + auto result = rewriter.create(loc, resultType, ptr); rewriter.replaceOp(op, { result }); return success(); @@ -1762,8 +1762,8 @@ LogicalResult ReshapeOpLowering::matchAndRewrite( auto resultType = op.getType(); // cast to a value with type `memref` (via `memref<* x elem_type>`) - mlir::Value ptr = rewriter.create(loc, source, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace())); - auto result = rewriter.create(loc, ptr, resultType); + mlir::Value ptr = rewriter.create(loc, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace()), source); + auto result = rewriter.create(loc, resultType, ptr); rewriter.replaceOp(op, { result }); return success(); @@ -1932,11 +1932,11 @@ LogicalResult ReduceOpLowering::matchAndRewrite( vectorSize = vectorType.getShape()[0]; } auto stepValue = isParallelReduction ? vectorSize : 1; - auto oldInputValue = op.getInputValueVar(); auto oldInductionValue = op.getInductionValue(); auto oldTerminator = op.getBody()->getTerminator(); auto oldYieldValue = oldTerminator->getOperand(0); // TODO: add "get result value" helper to ReduceOp + bool isFloatType = initialValueType.isa(); // Check for trivial reductions of the form bin_op(arg1, arg2) if (isHorizontalReduction) @@ -1959,40 +1959,36 @@ LogicalResult ReduceOpLowering::matchAndRewrite( { using accera::ir::value::BinaryOpPredicate; auto pred = binOp.predicate(); - std::string opName = ""; + mlir::vector::CombiningKind kind; switch (pred) { case BinaryOpPredicate::ADD: - opName = "add"; + kind = mlir::vector::CombiningKind::ADD; break; case BinaryOpPredicate::MUL: - opName = "mul"; + kind = mlir::vector::CombiningKind::MUL; break; case BinaryOpPredicate::SUB: - opName = "sub"; - break; + [[fallthrough]]; default: - break; + llvm_unreachable("Unsupported binary op predicate for vector reduction"); } - if (!opName.empty()) + // We can use the init value for floating-point add and mul + if (isFloatType && (pred == BinaryOpPredicate::ADD || pred == BinaryOpPredicate::MUL)) { - // We can use the init value for floating-point add and mul - if (initialValueType.isa() && (pred == BinaryOpPredicate::ADD || pred == BinaryOpPredicate::MUL)) - { - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr(opName), op.input(), op.initArg()); - rewriter.replaceOp(op, { result }); - } - else - { - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr(opName), op.input(), llvm::None); - rewriter.replaceOp(op, { result }); - } - return success(); + auto result = rewriter.create(loc, kind, op.input(), op.initArg()); + rewriter.replaceOp(op, { result }); } + else + { + auto result = rewriter.create(loc, kind, op.input()); + rewriter.replaceOp(op, { result }); + } + return success(); } } - else if (auto selectOp = dyn_cast(yieldValueOp)) + else if (auto selectOp = dyn_cast(yieldValueOp)) { // Look for sequences like: // @@ -2022,13 +2018,13 @@ LogicalResult ReduceOpLowering::matchAndRewrite( [[fallthrough]]; case ValueCmpOpPredicate::LE: // min - result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("min"), op.input(), llvm::None); + result = rewriter.create(loc, isFloatType ? mlir::vector::CombiningKind::MINF : mlir::vector::CombiningKind::MINSI, op.input()); break; case ValueCmpOpPredicate::GT: [[fallthrough]]; case ValueCmpOpPredicate::GE: // max - result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("max"), op.input(), llvm::None); + result = rewriter.create(loc, isFloatType ? mlir::vector::CombiningKind::MAXF : mlir::vector::CombiningKind::MAXSI, op.input()); break; } @@ -2149,8 +2145,8 @@ LogicalResult ReferenceGlobalOpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, - getGlobalOpValue, - op.getType()); + op.getType(), + getGlobalOpValue); return success(); } @@ -2538,11 +2534,11 @@ LogicalResult ReduceMaxOpLowering::matchAndRewrite( return op.emitError("Can only reduce a rank-1 memref"); } + auto elementType = memRefType.getElementType(); mlir::Value memrefToCast = input; mlir::Value loadedVector = nullptr; if (!memRefType.getLayout().isIdentity()) { - auto elementType = memRefType.getElementType(); auto vectorType = mlir::VectorType::get(memRefType.getShape(), elementType); auto zero = rewriter.create(loc, 0); loadedVector = rewriter.create(loc, vectorType, memrefToCast, mlir::ValueRange{ zero }); @@ -2552,7 +2548,8 @@ LogicalResult ReduceMaxOpLowering::matchAndRewrite( auto castMemRefVector = rewriter.create(loc, memrefToCast); loadedVector = rewriter.create(loc, castMemRefVector, llvm::None); } - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("max"), loadedVector, llvm::None); + auto kind = elementType.isa() ? mlir::vector::CombiningKind::MAXF : mlir::vector::CombiningKind::MAXSI; + auto result = rewriter.create(loc, kind, loadedVector); rewriter.replaceOp(op, { result }); return success(); } @@ -2633,7 +2630,7 @@ LogicalResult ExitProfileRegionOpLowering::matchAndRewrite(ExitProfileRegionOp o rewriter.create(loc, totalTime, totalTimeRef); mlir::Value prevCount = rewriter.create(loc, countRef); - auto one = rewriter.create(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); mlir::Value newCount = rewriter.create(loc, vir::BinaryOpPredicate::ADD, prevCount, one); rewriter.create(loc, newCount, countRef); @@ -2701,7 +2698,7 @@ LogicalResult ReduceSumOpLowering::matchAndRewrite( auto castMemRefVector = rewriter.create(loc, memrefToCast); loadedVector = rewriter.create(loc, castMemRefVector, llvm::None); } - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("add"), loadedVector, llvm::None); + auto result = rewriter.create(loc, mlir::vector::CombiningKind::ADD, loadedVector); rewriter.replaceOp(op, { result }); return success(); } @@ -2724,7 +2721,7 @@ LogicalResult PrintOpLowering::matchAndRewrite( auto printElement = [&](mlir::Value el) { if (elementType.isF32()) { - el = rewriter.create(loc, el, rewriter.getF64Type()); + el = rewriter.create(loc, rewriter.getF64Type(), el); } rewriter.create(loc, formatStr, ValueRange{ el }, toStderr); }; diff --git a/accera/transforms/src/vectorization/VectorizationUtil.cpp b/accera/transforms/src/vectorization/VectorizationUtil.cpp index 94b4322b..b8306937 100644 --- a/accera/transforms/src/vectorization/VectorizationUtil.cpp +++ b/accera/transforms/src/vectorization/VectorizationUtil.cpp @@ -140,7 +140,7 @@ bool CanVectorizeOp(mlir::Operation* op, .Case([](mlir::memref::StoreOp) { return true; }) .Case([](mlir::AffineLoadOp) { return true; }) .Case([](mlir::AffineStoreOp) { return true; }) - .Case([](mlir::SelectOp) { return true; }) + .Case([](mlir::arith::SelectOp) { return true; }) .Case([](mlir::arith::ShLIOp) { return true; }) .Case([](mlir::arith::FPToSIOp) { return true; }) .Case([](mlir::arith::ExtSIOp) { return true; }) @@ -810,7 +810,7 @@ std::optional VectorizeAffineApplyOp(mlir::PatternRewriter& rewrit } std::optional VectorizeSelectOp(mlir::PatternRewriter& rewriter, - mlir::SelectOp op, + mlir::arith::SelectOp op, const VectorizedOpMap& vectorizedOps, std::vector& laneMappings, mlir::Value inductionVar, @@ -830,7 +830,7 @@ std::optional VectorizeSelectOp(mlir::PatternRewriter& rewrite } auto loc = op.getLoc(); - auto result = rewriter.create(loc, cond->GetVectorResult(), trueVal->GetVectorResult(), falseVal->GetVectorResult()); + auto result = rewriter.create(loc, cond->GetVectorResult(), trueVal->GetVectorResult(), falseVal->GetVectorResult()); return result; } @@ -1193,7 +1193,7 @@ std::optional VectorizeOp(mlir::PatternRewriter& rewriter, .Case([&](mlir::AffineApplyOp affineApplyOp) { return VectorizeAffineApplyOp(rewriter, affineApplyOp, vectorizedOps, laneMappings, inductionVar, step, vectorSize); }) - .Case([&](mlir::SelectOp selectOp) { + .Case([&](mlir::arith::SelectOp selectOp) { return VectorizeSelectOp(rewriter, selectOp, vectorizedOps, laneMappings, inductionVar, step, vectorSize); }) .Case([&](mlir::arith::ShLIOp shiftLeftOp) { @@ -2185,7 +2185,7 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, } else { - reducedVal = rewriter.create(binOp.getLoc(), storeElementType, mlir::vector::stringifyEnum(reductionKind), vectorValToReduce, mlir::ValueRange{} /* optional accumulate values */); + reducedVal = rewriter.create(binOp.getLoc(), reductionKind, vectorValToReduce); } auto scalarValThatWasReduced = lhsLoadIsLoopSequential ? lhsVal : rhsVal; diff --git a/accera/value/src/Cache.cpp b/accera/value/src/Cache.cpp index 49ee22fd..ccf381e7 100644 --- a/accera/value/src/Cache.cpp +++ b/accera/value/src/Cache.cpp @@ -763,8 +763,8 @@ namespace value auto localScopeGlobalRef = builder.create(GetLocation(), globalScopeGlobalRef.getGlobal()); // Re-view the packed buffer memref to match the function argument // TODO : update the function arguments to match the packed shape - mlir::Value shapelessMemref = builder.create(loc, localScopeGlobalRef, mlir::UnrankedMemRefType::get(GetElementType(), GetInputType().getMemorySpace())); - auto reshapedMemref = builder.create(loc, shapelessMemref, GetInputType()); + mlir::Value shapelessMemref = builder.create(loc, mlir::UnrankedMemRefType::get(GetElementType(), GetInputType().getMemorySpace()), localScopeGlobalRef); + auto reshapedMemref = builder.create(loc, GetInputType(), shapelessMemref); constantInjectedArgs.insert(constantInjectedArgs.begin() + targetArgIdx, reshapedMemref); auto launchFuncOp = builder.create(GetLocation(), scheduleFuncOp, constantInjectedArgs); diff --git a/accera/value/src/MLIREmitterContext.cpp b/accera/value/src/MLIREmitterContext.cpp index b14b3a9d..04044aaf 100644 --- a/accera/value/src/MLIREmitterContext.cpp +++ b/accera/value/src/MLIREmitterContext.cpp @@ -281,7 +281,7 @@ mlir::FunctionType ToMLIRType(mlir::OpBuilder& builder, const FunctionDeclaratio if (type.isa()) { auto loc = builder.getUnknownLoc(); - return builder.create(loc, v, mlir::IndexType::get(v.getContext())); + return builder.create(loc, mlir::IndexType::get(v.getContext()), v); } // Index types fall through @@ -355,7 +355,7 @@ std::vector ToMLIRValue(mlir::OpBuilder& builder, std::vector(loc, mlirValue, indexType); + return builder.create(loc, indexType, mlirValue); } } @@ -825,7 +825,9 @@ void MLIRContext::print() const void MLIRContext::verify() const { - (void)_impl->_mlirModule.verify(); + // BUGBUG: ModuleOp::verify() is private as of LLVM 15 + // (void)_impl->_mlirModule.verify(); + throw utilities::LogicException(utilities::LogicExceptionErrors::notImplemented, "verify() is not implemented."); } mlir::OwningOpRef MLIRContext::cloneModule() const @@ -918,8 +920,8 @@ Scalar CreateGPUIndexOp(mlir::OpBuilder& builder, accera::ir::value::Processor i auto loc = builder.getUnknownLoc(); return Wrap( builder.create(loc, - accera::ir::util::GetGPUIndex(idxType, builder, loc), - builder.getI64Type())); + builder.getI64Type(), + accera::ir::util::GetGPUIndex(idxType, builder, loc))); } template @@ -2097,7 +2099,7 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) mlir::MemRefType::Builder outputTypeBuilder(inputMemrefType); outputTypeBuilder.setElementType(outputMlirElemType); mlir::MemRefType outputMemRefType = outputTypeBuilder; - auto returnVal = builder.create(loc, inputMlir, outputMemRefType); + auto returnVal = builder.create(loc, outputMemRefType, inputMlir); return Wrap(returnVal, input.GetLayout()); } @@ -2254,7 +2256,7 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) auto outputMemRefType = mlir::MemRefType::get({ numElementsInOutput }, outputMlirElemType, composedMap); // Fetch a new memref type after normalizing the old memref to have an identity map layout. outputMemRefType = normalizeMemRefType(outputMemRefType, builder, composedMap.getNumSymbols() /* ?? No idea if this is correct */); - auto returnVal = builder.create(loc, inputMlir, outputMemRefType); + auto returnVal = builder.create(loc, outputMemRefType, inputMlir); return Wrap(returnVal, MemoryLayout{ numElementsInOutput }); } @@ -2262,7 +2264,7 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) { // Case 3 // auto outputMemRefType = mlir::UnrankedMemRefType::get(outputMlirElemType, inputUnrankedMemrefType.getMemorySpace()); - // auto returnVal = builder.create(loc, inputMlir, outputMemRefType); + // auto returnVal = builder.create(loc, outputMemRefType, inputMlir); // TODO: This is going to assert because returnVal is of type UnrankedMemRefType and MemoryLayout doesn't support unranked // return Wrap(returnVal); @@ -2786,8 +2788,8 @@ void MLIRContext::PrintRawMemoryImpl(ViewAdapter value) auto identityLayout = mlir::MemRefLayoutAttrInterface{}; // cast to a value with type `memref` (via `memref<* x elem_type>`) - mlir::Value ptr = builder.create(loc, mem, mlir::UnrankedMemRefType::get(elemTy, memType.getMemorySpace())); - mlir::Value mlirValue = builder.create(loc, ptr, mlir::MemRefType::get({ size }, elemTy, identityLayout, memType.getMemorySpace())); + mlir::Value ptr = builder.create(loc, mlir::UnrankedMemRefType::get(elemTy, memType.getMemorySpace()), mem); + mlir::Value mlirValue = builder.create(loc, mlir::MemRefType::get({ size }, elemTy, identityLayout, memType.getMemorySpace()), ptr); [[maybe_unused]] auto op = builder.create(loc, mlirValue, /*toStderr=*/false); } @@ -3141,8 +3143,8 @@ void FillResource(ViewAdapter resourceView, Scalar fillValue) auto loc = b.getUnknownLoc(); mlir::Value memrefCasted = b.create( loc, - res, - castType); + castType, + res); auto DeclareFn = [&](const std::string& name, mlir::FunctionType fnTy) -> ir::value::ValueFuncOp { auto mod = res.getParentRegion()->getParentOfType(); @@ -3208,8 +3210,8 @@ void PrintMemref(ViewAdapter memView) } mlir::Value memrefCasted = b.create( loc, - mem, - mlir::UnrankedMemRefType::get(elemTy, shapedType.getMemorySpace())); + mlir::UnrankedMemRefType::get(elemTy, shapedType.getMemorySpace()), + mem); auto DeclareFn = [&](const std::string& name, mlir::FunctionType fnTy) -> ir::value::ValueFuncOp { auto mod = mem.getParentRegion()->getParentOfType(); diff --git a/accera/value/src/ScalarOperations.cpp b/accera/value/src/ScalarOperations.cpp index 1b8929e5..034a226e 100644 --- a/accera/value/src/ScalarOperations.cpp +++ b/accera/value/src/ScalarOperations.cpp @@ -251,7 +251,7 @@ namespace value Scalar Select(Scalar cmp, Scalar a, Scalar b) { std::tie(a, b) = Scalar::MakeTypeCompatible(a, b); - return ScalarOpBuilder(cmp, a, b); + return ScalarOpBuilder(cmp, a, b); } Scalar Sin(Scalar s) diff --git a/accera/value/src/TargetDevice.cpp b/accera/value/src/TargetDevice.cpp index 1a46ad67..c4767318 100644 --- a/accera/value/src/TargetDevice.cpp +++ b/accera/value/src/TargetDevice.cpp @@ -92,6 +92,12 @@ namespace value targetDevice.triple = c_windowsTriple; targetDevice.dataLayout = c_windowsDataLayout; } }, + { "avx2", [](TargetDevice& targetDevice) { + targetDevice.architecture = "x86_64"; + targetDevice.cpu = "skylake"; + targetDevice.numBits = 64; + targetDevice.features = "+avx2"; + } }, { "avx512", [](TargetDevice& targetDevice) { targetDevice.architecture = "x86_64"; targetDevice.cpu = "skylake-avx512"; diff --git a/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch b/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch index b3a31422..d913eadd 100644 --- a/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch +++ b/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch @@ -1,7 +1,7 @@ -From 08caef4cc7d6e8bc79185d0775da02552eddea9d Mon Sep 17 00:00:00 2001 +From c0e0b6c647aaa8a9c8f8167ef54f4846f25f827b Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Tue, 17 May 2022 15:16:57 +0800 -Subject: [PATCH] Merged PR 2213: [mlir] Plumb OpenMP dialect attributes +Subject: [PATCH 1/6] Merged PR 2213: [mlir] Plumb OpenMP dialect attributes through affine and scf lowering * Updated AffineToSCF and SCFToOpenMP to support OMP dialect attributes for num_threads, schedule_val, proc_bind, and collapse @@ -43,7 +43,7 @@ index 05d7637d52d7..8295b01f8fcd 100644 + #endif // MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td -index 505e9cb22a0a..d124be11cb94 100644 +index ddeb698fb2a2..6a74eeb217bd 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -124,7 +124,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments, @@ -54,13 +54,13 @@ index 505e9cb22a0a..d124be11cb94 100644 + OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "Value":$num_threads, "ClauseProcBindKindAttr":$proc_bind)> ]; - let parser = [{ return parseParallelOp(parser, result); }]; - let printer = [{ return printParallelOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp -index 8ff1134f4b7b..888a6ed20fdc 100644 +index 7c91af4c49f0..0992fc0c1f3a 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp -@@ -385,6 +385,11 @@ public: +@@ -177,6 +177,11 @@ public: SmallVector upperBoundTuple; SmallVector lowerBoundTuple; SmallVector identityVals; @@ -72,7 +72,7 @@ index 8ff1134f4b7b..888a6ed20fdc 100644 // Emit IR computing the lower and upper bound by expanding the map // expression. lowerBoundTuple.reserve(op.getNumDims()); -@@ -418,6 +423,7 @@ public: +@@ -210,6 +215,7 @@ public: rewriter.eraseBlock(parOp.getBody()); rewriter.inlineRegionBefore(op.region(), parOp.getRegion(), parOp.getRegion().end()); @@ -80,7 +80,7 @@ index 8ff1134f4b7b..888a6ed20fdc 100644 rewriter.replaceOp(op, parOp.getResults()); return success(); } -@@ -467,6 +473,7 @@ public: +@@ -259,6 +265,7 @@ public: reduceOp.getReductionOperator().front().getArgument(1)); rewriter.create(loc, reductionResult); } @@ -89,10 +89,10 @@ index 8ff1134f4b7b..888a6ed20fdc 100644 return success(); } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp -index e472af257776..526909ca4224 100644 +index a9e7759aa75e..30e0e3e9ad16 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp -@@ -363,8 +363,12 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -364,8 +364,12 @@ struct ParallelOpLowering : public OpRewritePattern { loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); @@ -107,7 +107,7 @@ index e472af257776..526909ca4224 100644 for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || init.getType().isa()) && -@@ -389,7 +393,19 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -390,7 +394,19 @@ struct ParallelOpLowering : public OpRewritePattern { } // Create the parallel wrapper. @@ -128,7 +128,7 @@ index e472af257776..526909ca4224 100644 { OpBuilder::InsertionGuard guard(rewriter); rewriter.createBlock(&ompParallel.region()); -@@ -405,9 +421,20 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -406,9 +422,20 @@ struct ParallelOpLowering : public OpRewritePattern { } // Replace the loop. @@ -150,7 +150,7 @@ index e472af257776..526909ca4224 100644 rewriter.create(loc); rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), -@@ -420,15 +447,21 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -421,15 +448,21 @@ struct ParallelOpLowering : public OpRewritePattern { } // Load loop results. @@ -179,7 +179,7 @@ index e472af257776..526909ca4224 100644 return success(); } }; -@@ -437,7 +470,7 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -438,7 +471,7 @@ struct ParallelOpLowering : public OpRewritePattern { static LogicalResult applyPatterns(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addIllegalOp(); @@ -189,7 +189,7 @@ index e472af257776..526909ca4224 100644 RewritePatternSet patterns(module.getContext()); patterns.add(module.getContext()); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp -index 46a2bf3019e0..d9104bf41a93 100644 +index 4ff38e2b455a..a4b6fd78e7f9 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -73,6 +73,16 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, @@ -210,5 +210,5 @@ index 46a2bf3019e0..d9104bf41a93 100644 // Parser and printer for Operand and type list //===----------------------------------------------------------------------===// -- -2.32.1 (Apple Git-133) +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch b/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch index 81d187e7..70ca7374 100644 --- a/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch +++ b/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch @@ -1,7 +1,7 @@ -From 6172c5a07d0e4ca745aeff45298d2995aa685c42 Mon Sep 17 00:00:00 2001 +From b2cecf54139212bbea1f29337c51e47fca81be5c Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Mon, 14 Feb 2022 16:18:47 +0800 -Subject: [PATCH] From 97c4232342b7e8b802c4557368bc016264679930 Mon Sep 17 +Subject: [PATCH 2/6] From 97c4232342b7e8b802c4557368bc016264679930 Mon Sep 17 00:00:00 2001 From: Chuck Jacobs Date: Wed, 6 Oct 2021 16:40:38 +0000 Subject: Merged PR 2237: Improved codegen of vpmaddwd instruction @@ -12,7 +12,7 @@ This PR adds another codegen path that lowers certain "multiply-like" operations 1 file changed, 574 insertions(+) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp -index 6f1fe8195595..f7bff9795548 100644 +index 4c622568f8d0..7ee8b9be2154 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -39,6 +39,7 @@ @@ -23,7 +23,7 @@ index 6f1fe8195595..f7bff9795548 100644 #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/WinEHFuncInfo.h" #include "llvm/IR/CallingConv.h" -@@ -49340,6 +49341,350 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, +@@ -49524,6 +49525,350 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, return DAG.getNode(Opc, DL, VT, LHS, RHS); } @@ -374,7 +374,7 @@ index 6f1fe8195595..f7bff9795548 100644 // Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes // from one vector with signed bytes from another vector, adds together // adjacent pairs of 16-bit products, and saturates the result before -@@ -52432,6 +52777,233 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, +@@ -52546,6 +52891,233 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, PMADDBuilder); } @@ -608,7 +608,7 @@ index 6f1fe8195595..f7bff9795548 100644 // ADD(VPMADDWD(X,Y),VPMADDWD(Z,W)) -> VPMADDWD(SHUFFLE(X,Z), SHUFFLE(Y,W)) // If upper element in each pair of both VPMADDWD are zero then we can merge // the operand elements and use the implicit add of VPMADDWD. -@@ -52554,6 +53126,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, +@@ -52668,6 +53240,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, return MAdd; if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, DL, VT, Subtarget)) return MAdd; @@ -618,5 +618,5 @@ index 6f1fe8195595..f7bff9795548 100644 return MAdd; -- -2.30.0.windows.1 +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0003-Fix-bad-merge.patch b/external/llvm/0003-Fix-bad-merge.patch index e21a29db..87e80c4a 100644 --- a/external/llvm/0003-Fix-bad-merge.patch +++ b/external/llvm/0003-Fix-bad-merge.patch @@ -1,7 +1,7 @@ -From 2dbd177abcf79fe183c989824863aafe5b38e1cd Mon Sep 17 00:00:00 2001 +From 5cf0dde409f5f937e22f9a6c81db8368494a63cb Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Mon, 14 Feb 2022 16:32:01 +0800 -Subject: [PATCH] From 7f9a254015c977405957fb5b2b6e2a1895f0ca69 Mon Sep 17 +Subject: [PATCH 3/6] From 7f9a254015c977405957fb5b2b6e2a1895f0ca69 Mon Sep 17 00:00:00 2001 From: Kern Handa Date: Wed, 6 Oct 2021 11:09:31 -0700 Subject: Fix bad merge @@ -32,5 +32,5 @@ index 79a062fd9735..377080b25c36 100644 } -- -2.30.0.windows.1 +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch b/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch index e6fe7654..b82ee1e9 100644 --- a/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch +++ b/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch @@ -1,36 +1,34 @@ -From bc61595ab6d6f5eb6fbf3564cf00f8839250d2d7 Mon Sep 17 00:00:00 2001 +From c865623e33d24660ecc474529192914d0f87b48f Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Mon, 23 May 2022 12:39:55 +0800 -Subject: [PATCH] Lower memref.copy to memcpy when layouts canonicalize to +Subject: [PATCH 4/6] Lower memref.copy to memcpy when layouts canonicalize to identity layouts. A memref.cast won't work because it gets folded into memref.copy during op canonicalization. --- - mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 9 +++++++-- - 1 file changed, 7 insertions(+), 2 deletions(-) + mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 7 +++++++ + 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp -index 288c252b81bb..bec7513f7986 100644 +index 56413c415590..4402485545ad 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp -@@ -933,10 +933,15 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { - auto srcType = op.source().getType().cast(); - auto targetType = op.target().getType().cast(); - -+ // Memref casts get folded away during CopyOp::fold, so we have to replace -+ // the operand with its canonicalized identity form, if they are equivalent -+ auto cannedSrcType = canonicalizeStridedLayout(srcType.cast()); -+ auto cannedTargetType = canonicalizeStridedLayout(targetType.cast()); +@@ -944,8 +944,15 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { + // We can use memcpy for memrefs if they have an identity layout or are + // contiguous with an arbitrary offset. Ignore empty memrefs, which is a + // special case handled by memrefCopy. + - if (srcType.hasRank() && -- srcType.cast().getLayout().isIdentity() && -+ (srcType.cast().getLayout().isIdentity() || cannedSrcType.getLayout().isIdentity()) && - targetType.hasRank() && -- targetType.cast().getLayout().isIdentity()) -+ (targetType.cast().getLayout().isIdentity() || cannedTargetType.getLayout().isIdentity())) - return lowerToMemCopyIntrinsic(op, adaptor, rewriter); - - return lowerToMemCopyFunctionCall(op, adaptor, rewriter); ++ // Memref casts get folded away during CopyOp::fold, so we have to replace ++ // the operand with its canonicalized identity form, if they are ++ // equivalent ++ auto cannedType = canonicalizeStridedLayout(memrefType); ++ + return memrefType && + (memrefType.getLayout().isIdentity() || ++ cannedType.getLayout().isIdentity() || + (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && + isStaticShapeAndContiguousRowMajor(memrefType))); + }; -- -2.32.1 (Apple Git-133) +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch b/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch index 2067d5d3..04d43dce 100644 --- a/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch +++ b/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch @@ -1,14 +1,16 @@ -From 39f0a4c97f5c89d7fa815118a3230091172bc795 Mon Sep 17 00:00:00 2001 -From: Charles Jacobs -Date: Mon, 15 Aug 2022 16:00:43 -0700 -Subject: [PATCH] Fix issue where passed-in op-printing flags were ignored +From bb702476975ebfc65d445f8da7d6064a81c09666 Mon Sep 17 00:00:00 2001 +From: Chuck Jacobs +Date: Wed, 24 Aug 2022 04:13:49 +0000 +Subject: [PATCH 5/6] Merged PR 2823: Fix op-printing bug in + LocationSnapshotPass +This PR fixes an issue where the printing flags passed in to the LocationSnapshotPass were being ignored. --- mlir/lib/Transforms/LocationSnapshot.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp -index f23a3eee1511..c3c323284b18 100644 +index a042d07335bb..808f2ad2a67c 100644 --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -133,7 +133,7 @@ struct LocationSnapshotPass @@ -21,5 +23,5 @@ index f23a3eee1511..c3c323284b18 100644 } -- -2.32.1 (Apple Git-133) +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch b/external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch new file mode 100644 index 00000000..207470e3 --- /dev/null +++ b/external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch @@ -0,0 +1,317 @@ +From f8fa2aac2352539b963c03336a69403e5b40fc29 Mon Sep 17 00:00:00 2001 +From: Chuck Jacobs +Date: Tue, 25 Oct 2022 02:38:08 +0000 +Subject: [PATCH 6/6] Merged PR 2919: More flexible code generation for + vpmaddwd instruction + +This change increases the number of patterns that will generate a `vpmaddwd` instruction: + +- It now handles the case where `opt` linearizes the sum-of-product instructions +- It handles some cases where one of the operands can be interpreted as 2 adjacent values broadcast to fill a vector register (this is the case with the "A" matrix in matrix-matrix multiply) +--- + llvm/lib/Target/X86/X86ISelLowering.cpp | 206 +++++++++++++++++++++--- + 1 file changed, 187 insertions(+), 19 deletions(-) + +diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp +index 7ee8b9be2154..d4ba64d64f87 100644 +--- a/llvm/lib/Target/X86/X86ISelLowering.cpp ++++ b/llvm/lib/Target/X86/X86ISelLowering.cpp +@@ -31,6 +31,7 @@ + #include "llvm/Analysis/ObjCARCUtil.h" + #include "llvm/Analysis/ProfileSummaryInfo.h" + #include "llvm/Analysis/VectorUtils.h" ++#include "llvm/CodeGen/ISDOpcodes.h" + #include "llvm/CodeGen/IntrinsicLowering.h" + #include "llvm/CodeGen/MachineFrameInfo.h" + #include "llvm/CodeGen/MachineFunction.h" +@@ -49527,7 +49528,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, + + // Helper function for PMADDUBSW / PMADDWD + // TODO: need to do the convert-load-to-broadcast-splat fixup in the function that takes the a0/a1 pair +-static SDValue getExtMulOperand(SDValue node, int index, EVT VT, const SDLoc &DL, ++static SDValue getPMADDOperand(SDValue node, int index, EVT VT, const SDLoc &DL, + SelectionDAG &DAG) { + unsigned NumElems = VT.getVectorNumElements(); + auto op = node.getOperand(index); +@@ -49634,14 +49635,38 @@ static SDValue getExtMulOperand(SDValue node, int index, EVT VT, const SDLoc &DL + // TODO: reduce to a single global load + } + ++ auto bv = DAG.getBuildVector(argVT, DL, elements); ++ auto replacement = DAG.getNode(opcode, DL, VT, bv); ++ return replacement; ++ } // end if (arg.getOpcode() == ISD::INSERT_VECTOR_ELT) ++ else if (arg.getOpcode() == ISD::VECTOR_SHUFFLE) { ++ // "broadcast-A" case: arg is vector_shuffle<...> X ++ SmallVector elements; ++ auto shuffleNode = dyn_cast(arg); ++ assert(shuffleNode); ++ auto mask = shuffleNode->getMask(); ++ auto val = shuffleNode->getOperand(0); ++ if (val.getOpcode() == ISD::EXTRACT_SUBVECTOR) { ++ val = val->getOperand(0); ++ } ++ // TODO: see if val is an extract_subvector op, and if so return its arg ++ ++ MVT elemMT = argVT.getSimpleVT().getVectorElementType(); ++ for (auto idx: mask) ++ { ++ SDValue elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, elemMT, val, ++ DAG.getIntPtrConstant(idx, DL)); ++ elements.push_back(elt); ++ } ++ + auto bv = DAG.getBuildVector(argVT, DL, elements); + auto replacement = DAG.getNode(opcode, DL, VT, bv); + return replacement; + } + return op; +- } ++ } // end if (opcode == ISD::ZERO_EXTEND || ...) + +- // look for (shuf<0,0,0,0...> (insert_elt (extend ...) 0) undef) pattern and translate it ++ // TODO: look for (shuf<0,0,0,0...> (insert_elt (extend ...) 0) undef) pattern and translate it + // into (extend (shuf<0,0,0,0...> (insert_elt ... 0) undef)) + // or better yet: (extend (splat ...)) + if (opcode != ISD::VECTOR_SHUFFLE) +@@ -49673,7 +49698,7 @@ static SDValue getExtMulOperand(SDValue node, int index, EVT VT, const SDLoc &DL + SDValue buildVec = DAG.getSplatBuildVector(splatVT, DL, val); + SDValue newExt = isSigned ? DAG.getSExtOrTrunc(buildVec, DL, resultVT) : DAG.getZExtOrTrunc(buildVec, DL, resultVT); + return newExt; +-} // end of getExtMulOperand ++} // end of getPMADDOperand + + static void replaceWithInterleavedShuffles(ArrayRef> interleavedOps, const std::vector>& masks, SDValue val, SelectionDAG &DAG, const SDLoc &DL) { + const unsigned dotProdSize = interleavedOps.size(); +@@ -52769,10 +52794,19 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + !isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + +- SDValue N00 = N0.getOperand(0); +- SDValue N01 = N0.getOperand(1); +- SDValue N10 = N1.getOperand(0); +- SDValue N11 = N1.getOperand(1); ++ // Nxx naming scheme: N, where phase == 0 means "even", and operand is the operand index of the MUL operation ++ // So, the "even" left-hand operand is N00, and the "odd" left-hand operand is N10 ++ ++ // "evens" ++ SDValue N00 = getPMADDOperand(N0, 0, VT, DL, DAG); ++ SDValue N01 = getPMADDOperand(N0, 1, VT, DL, DAG); ++ ++ // "odds" ++ SDValue N10 = getPMADDOperand(N1, 0, VT, DL, DAG); ++ SDValue N11 = getPMADDOperand(N1, 1, VT, DL, DAG); ++ ++ if (!N00 || !N01 || !N10 || !N11) ++ return SDValue(); + + // All inputs need to be sign extends. + // TODO: Support ZERO_EXTEND from known positive? +@@ -52801,6 +52835,8 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + ++ // TODO: verify N00, N01, N10, and N11 have the same # of operands ++ + // For each element, we need to ensure we have an odd element from one vector + // multiplied by the odd element of another vector and the even element from + // one of the same vectors being multiplied by the even element from the +@@ -52808,6 +52844,8 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + // is being performed: + // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] + SDValue In0, In1; ++ bool prevIdxN00WasZero = true; ++ bool prevIdxN01WasZero = true; + for (unsigned i = 0; i != N00.getNumOperands(); ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); +@@ -52834,9 +52872,21 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } +- // N0 indices be the even element. N1 indices must be the next odd element. +- if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || +- IdxN01 != 2 * i || IdxN11 != 2 * i + 1) ++ ++ // N0 indices must be sequential even elements. ++ // TODO: also allow even indices to be element 0 broadcasted to fill the array ++ // ... which means Idx can be 0 if prev Idx was zero ++ if (IdxN00 != 2 * i && !(IdxN00 == 0 && prevIdxN00WasZero) || ++ IdxN01 != 2 * i && !(IdxN01 == 0 && prevIdxN01WasZero)) ++ { ++ return SDValue(); ++ } ++ ++ prevIdxN00WasZero = IdxN00 == 0; ++ prevIdxN01WasZero = IdxN01 == 0; ++ ++ // N1 indices must be the next (odd) element after the corresponding N0 indcex ++ if (IdxN10 != IdxN00 + 1 || IdxN11 != IdxN01 + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); +@@ -52864,6 +52914,41 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + return SDValue(); + } + ++ // #### ++ if (prevIdxN00WasZero) // a splat, so just use the original build_vector ++ { ++ unsigned NumElems = VT.getVectorNumElements(); ++ SmallVector elements(2*NumElems); ++ EVT argVT = In0.getValueType(); ++ MVT elemMT = argVT.getSimpleVT().getVectorElementType(); ++ for(unsigned idx = 0; idx < 2*NumElems; ++idx) ++ { ++ SDValue elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, elemMT, In0, ++ DAG.getIntPtrConstant(idx%2, DL)); ++ elements[idx] = elt; ++ } ++ ++ // In0 = ++ In0 = DAG.getBuildVector(argVT, DL, elements); ++ } ++ ++ if (prevIdxN01WasZero) // a splat, so just use the original build_vector ++ { ++ unsigned NumElems = VT.getVectorNumElements(); ++ SmallVector elements(2*NumElems); ++ EVT argVT = In1.getValueType(); ++ MVT elemMT = argVT.getSimpleVT().getVectorElementType(); ++ for(unsigned idx = 0; idx < 2*NumElems; ++idx) ++ { ++ SDValue elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, elemMT, In1, ++ DAG.getIntPtrConstant(idx%2, DL)); ++ elements[idx] = elt; ++ } ++ ++ // In1 = ++ In1 = DAG.getBuildVector(argVT, DL, elements); ++ } ++ + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + EVT OpVT = Ops[0].getValueType(); +@@ -52967,14 +53052,14 @@ static SDValue matchPMADDWD_3(SelectionDAG &DAG, SDValue N0, SDValue N1, + return SDValue(); + + SDValue p000, p001, p010, p011, p100, p101, p110, p111; +- if (!(p000 = getExtMulOperand(ssatArg0.getOperand(0), 0, VT, DL, DAG))) return SDValue(); +- if (!(p001 = getExtMulOperand(ssatArg0.getOperand(0), 1, VT, DL, DAG))) return SDValue(); +- if (!(p010 = getExtMulOperand(ssatArg0.getOperand(1), 0, VT, DL, DAG))) return SDValue(); +- if (!(p011 = getExtMulOperand(ssatArg0.getOperand(1), 1, VT, DL, DAG))) return SDValue(); +- if (!(p100 = getExtMulOperand(ssatArg1.getOperand(0), 0, VT, DL, DAG))) return SDValue(); +- if (!(p101 = getExtMulOperand(ssatArg1.getOperand(0), 1, VT, DL, DAG))) return SDValue(); +- if (!(p110 = getExtMulOperand(ssatArg1.getOperand(1), 0, VT, DL, DAG))) return SDValue(); +- if (!(p111 = getExtMulOperand(ssatArg1.getOperand(1), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p000 = getPMADDOperand(ssatArg0.getOperand(0), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p001 = getPMADDOperand(ssatArg0.getOperand(0), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p010 = getPMADDOperand(ssatArg0.getOperand(1), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p011 = getPMADDOperand(ssatArg0.getOperand(1), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p100 = getPMADDOperand(ssatArg1.getOperand(0), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p101 = getPMADDOperand(ssatArg1.getOperand(0), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p110 = getPMADDOperand(ssatArg1.getOperand(1), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p111 = getPMADDOperand(ssatArg1.getOperand(1), 1, VT, DL, DAG))) return SDValue(); + + SDValue zextArgs[4] = {p000, p010, p100, p110}; + SDValue sextArgs[4] = {p001, p011, p101, p111}; +@@ -53162,6 +53247,87 @@ static SDValue combineAddOfPMADDWD(SelectionDAG &DAG, SDValue N0, SDValue N1, + return DAG.getNode(X86ISD::VPMADDWD, DL, VT, LHS, RHS); + } + ++// Attempt to turn this pattern: ++// ++// (add ++// (add X, ++// (mul (sext (build_vector)), (sext (build_vector))), ++// (mul (sext (build_vector)), (sext (build_vector)))) ++// ++// into: ++// ++// (add X, ++// (add (mul (sext (build_vector)), (sext (build_vector))), ++// (mul (sext (build_vector)), (sext (build_vector))))) ++// ++// or, (X + a) + b --> X + (a + b), where a and b are (mul (sext (build_vector))) ++// ++// So that the inner add can be turned into a PMADDWD ++static SDValue rebalancePotentialPMADDWD(SelectionDAG &DAG, SDValue N0, SDValue N1, ++ const SDLoc &DL, EVT VT, ++ const X86Subtarget &Subtarget) { ++ if (!Subtarget.hasSSE2()) ++ return SDValue(); ++ ++ if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 || ++ VT.getVectorNumElements() < 4 || ++ !isPowerOf2_32(VT.getVectorNumElements())) ++ return SDValue(); ++ ++ // normalize N0 and N1 so N0 == (X + a) and N1 == b ++ if (N0.getOpcode() != ISD::ADD) ++ std::swap(N0, N1); ++ if (N0.getOpcode() != ISD::ADD || N1.getOpcode() != ISD::MUL) ++ return SDValue(); ++ ++ // function to verify v is a valid argument for an add that gets converted to PMADDWD ++ auto isValidPMADDWDArg = [](SDValue v) { ++ // get "a" and "b" operands ++ SDValue v0 = v.getOperand(0); ++ SDValue v1 = v.getOperand(1); ++ ++ // All inputs need to be sign extends. ++ // TODO: Support ZERO_EXTEND from known positive? ++ if (v0.getOpcode() != ISD::SIGN_EXTEND || ++ v1.getOpcode() != ISD::SIGN_EXTEND ) ++ return false; ++ ++ // Peek through the extends. ++ v0 = v0.getOperand(0); ++ v1 = v1.getOperand(0); ++ ++ // Must be extending from vXi16. ++ EVT InVT = v0.getValueType(); ++ if (InVT.getVectorElementType() != MVT::i16 || v1.getValueType() != InVT) ++ return false; ++ ++ // All inputs should be build_vectors. ++ // TODO: also allow broadcast-A pattern ++ if ((v0.getOpcode() != ISD::BUILD_VECTOR && v0.getOpcode() != ISD::VECTOR_SHUFFLE) || ++ (v1.getOpcode() != ISD::BUILD_VECTOR && v1.getOpcode() != ISD::VECTOR_SHUFFLE)) ++ return false; ++ ++ return true; ++ }; ++ ++ // normalize N00 and N01 so N00 == X and N01 == b ++ SDValue N00 = N0.getOperand(0); ++ SDValue N01 = N0.getOperand(1); ++ if (isValidPMADDWDArg(N00)) ++ std::swap(N00, N01); ++ // now N00 = X, N01 = a, and N1 = b ++ ++ // verify everything is correct (that is, a and b ) ++ if (!isValidPMADDWDArg(N01) || !isValidPMADDWDArg(N1)) ++ return SDValue(); ++ ++ // TODO: just turn this directly into X + vpmaddwd ++ ++ // return new expression X + (a + b) == N00 + (N01 + N1) ++ auto pmaddTerm = DAG.getNode(ISD::ADD, DL, VT, N01, N1); ++ return DAG.getNode(ISD::ADD, DL, VT, N00, pmaddTerm); ++} ++ + /// CMOV of constants requires materializing constant operands in registers. + /// Try to fold those constants into an 'add' instruction to reduce instruction + /// count. We do this with CMOV rather the generic 'select' because there are +@@ -53244,6 +53410,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, + return MAdd; + if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT)) + return MAdd; ++ if (SDValue MAdd = rebalancePotentialPMADDWD(DAG, Op0, Op1, DL, VT, Subtarget)) ++ return MAdd; + + // Try to synthesize horizontal adds from adds of shuffles. + if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget)) +-- +2.37.1 (Apple Git-137.1) + diff --git a/external/llvm/portfile.cmake b/external/llvm/portfile.cmake index f5faba93..4da7e637 100644 --- a/external/llvm/portfile.cmake +++ b/external/llvm/portfile.cmake @@ -1,5 +1,6 @@ # Builds LLVM for features needed by Accera -set(LLVM_VERSION llvmorg-14.0.6) +set(LLVM_VERSION 24a37a396a9bd6b73b05b4eafce8b87e7a748cf9) +set(LLVM_FRIENDLY_VERSION 15.0.0-rc1) set(VCPKG_BUILD_TYPE release) if((DEFINED ENV{LLVM_BUILD_TYPE}) AND ("$ENV{LLVM_BUILD_TYPE}" STREQUAL "debug")) @@ -20,7 +21,7 @@ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO llvm/llvm-project REF ${LLVM_VERSION} - SHA512 d64f97754c24f32deb5f284ebbd486b3a467978b7463d622f50d5237fff91108616137b4394f1d1ce836efa59bf7bec675b6dee257a79b241c15be52d4697460 + SHA512 5fdee8487afac16033a6d2cea720dedb5f05e00f20d761307805f0a6e1fad22d6c3ce45d89112cbe70aec1a4dfe8ae72a90d45e14bc67539fbe3dc948a316d92 HEAD_REF main PATCHES 0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch @@ -28,9 +29,8 @@ vcpkg_from_github( 0003-Fix-bad-merge.patch 0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch 0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch - 0006-Merged-PR-2822-Fix-lowering-of-MemrefCastOp-to-the-L.patch - 0007-More-flexible-code-generation-for-vpmaddwd-instructi.patch - 0008-fix-vcpkg-install-paths.patch # cf. https://github.com/microsoft/vcpkg/blob/master/ports/llvm + 0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch + 0007-fix-vcpkg-install-paths.patch # cf. https://github.com/microsoft/vcpkg/blob/master/ports/llvm ) vcpkg_find_acquire_program(PYTHON3) @@ -64,7 +64,7 @@ vcpkg_configure_cmake( -DLLVM_INSTALL_UTILS=ON # FileCheck "-DLLVM_ENABLE_PROJECTS=mlir;lld" "-DLLVM_TARGETS_TO_BUILD=host;X86;ARM;NVPTX;AMDGPU" - -DPACKAGE_VERSION=${LLVM_VERSION} + -DPACKAGE_VERSION=${LLVM_FRIENDLY_VERSION} # Force TableGen to be built with optimization. This will significantly improve build time. # cf. https://github.com/microsoft/vcpkg/blob/master/ports/llvm -DLLVM_OPTIMIZED_TABLEGEN=ON diff --git a/external/llvm/vcpkg.json b/external/llvm/vcpkg.json index 2b9172f0..36e65b52 100644 --- a/external/llvm/vcpkg.json +++ b/external/llvm/vcpkg.json @@ -1,6 +1,6 @@ { "name": "accera-llvm", - "version-string": "14.0.6-2", + "version-string": "15.0.0", "description": "LLVM Compiler Infrastructure for Accera.", "homepage": "https://llvm.org", "supports": "!uwp" diff --git a/requirements.txt b/requirements.txt index b68bbbdb..01cf7244 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ conan<2.0.0 lit packaging pytest -hatlib==0.0.38 # keep setup.cfg in sync; prefer == (avoids picking up breaking changes) +hatlib==0.0.39 # keep setup.cfg in sync; prefer == (avoids picking up breaking changes) varname py-cpuinfo termcolor diff --git a/setup.cfg b/setup.cfg index 1197abc7..92c94d3a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,13 +35,13 @@ install_requires = termcolor py-cpuinfo varname - hatlib==0.0.38 + hatlib==0.0.39 numpy pyyaml tomlkit>=0.11.1, <0.11.5 accera-compilers accera-gpu - accera-llvm==14.0.602 + accera-llvm==15.0.101 # keep in sync with accera/python/llvm/setup.cfg package_dir = accera = accera/python/accera accera.hat = accera/hat/scripts