From 0319f5abc0341d96bf97254f755ff7844aa07368 Mon Sep 17 00:00:00 2001 From: Ritwik Das Date: Mon, 17 Apr 2023 23:01:55 -0700 Subject: [PATCH] Squashed commit of the following: commit 5ec0fc859f017654144b33cfad92bbae62391088 Author: Captain Jack Sparrow Date: Mon Apr 17 18:37:24 2023 +0000 Merged PR 3211: Upgrade hatlib dependency to 0.0.39 Upgrade hatlib dependency to 0.0.39 commit 38642006cbc8c4ff01c7345d018f9a8233454dbd Author: Mason Remy Date: Fri Apr 14 19:27:01 2023 +0000 Merged PR 3209: Support AffineParallelOp and scf::ParallelOp in RangeValue utils Support AffineParallelOp and scf::ParallelOp in RangeValue utils commit addb45a1a4ccb50657b822591735916be83498c5 Author: Captain Jack Sparrow Date: Wed Apr 12 17:25:02 2023 +0000 Merged PR 3207: Fix parallelization and enable file checker in tests Fix parallelization and enable file checker in tests commit 7e206532932ff603decfd46656173702ebdceff5 Author: Lisa Ong Date: Wed Apr 12 08:02:20 2023 +0000 Merged PR 3195: [LLVM 15] progressive upgrade (24a37a396a9b), disable macos builds The first of a series of progressive upgrades from LLVM 14.0.6 to LLVM 15.0.7 (and possibly beyond). Current LLVM version: https://intelligentdevices.visualstudio.com/ELL/_git/accera.llvm?version=GBaccera/llvmorg-15-24a37a396a9b&_a=history This is llvmorg-15.0.0-init, fast forwarded to about 100 "relevant" MLIR commits (actual number of commits is higher). Performance on AVX2 is verified for Windows (no regressions). **Breaking Change: macOS builds** With this upgrade we are also retiring the macOS pipelines due to lack of build resources for LLVM macos/intel Conan packages. This only affects internal developer scenarios. Public developers continue to rely on vcpkg builds. commit 2927234171f8e6c960f654909f8ec0a2c19e3c54 Author: Kern Handa Date: Fri Apr 7 17:20:42 2023 +0000 Merged PR 3172: Adds better support for compiling specifically for AVX2 targets * Plumb AVX2 flags to LLVM, with a block for macOS. We plan to remove official support for macOS/Intel starting from LLVM 15 due to limited build resources. * Initialize Target.HOST extensions using cpu_info * Added more AVX2 filecheck tests to catch LLVM lowering regressions before moving to LLVM 15 [MasonR] **Breaking Change**: Target.HOST no longer unconditionally enables the AVX2 extension if the underlying CPU does not support it, otherwise codegen may result in unsupported instructions. To compile for AVX2 if your host doesn't support AVX2, specify Target(""). For example, `plan = schedule.create_plan(Target("Intel 6700"))` commit 6822bcb1fd222fe5b7e7292a9f7d1f35bcf1fdce Author: Denny Sun Date: Thu Apr 6 21:47:01 2023 +0000 Merged PR 3203: Plumb target device info into llvm lowering llvm lowering now depends on some static complier macro to check target device info, which breaks cross compilation support. ``` // TODO: get check `TargetDeviceInfo` for the OS instead ``` ``` const int hostBitSize = 64; // TODO:: FIXME :: This assumes that the host is always 64bit // Should query the target hardware auto llvmIntTy = hostBitSize == 32 ? llvmI32Ty : llvmI64Ty; ``` --- .../acc-gpu-runner/src/ACCGPURunnerMain.cpp | 2 +- accera/acc-opt/test/commandline.mlir | 1 + accera/acc-opt/test/vectorization.mlir | 2 +- .../src/Target/Cpp/StdDialectCppPrinter.cpp | 4 +- .../src/Target/Cpp/StdDialectCppPrinter.h | 2 +- accera/accc/accc.py | 2 + accera/ir/include/argo/ArgoOps.td | 6 + accera/ir/include/argo/ArgoStructuredOps.td | 3 + accera/ir/src/IRUtil.cpp | 21 +- accera/ir/src/argo/ArgoOps.cpp | 56 ++-- .../nest_dialect_test/IRTestVerification.cpp | 5 +- .../nest_dialect_test/LowLevelIRTests.cpp | 252 +++++++------- .../ir/test/nest_dialect_test/NestIRTests.cpp | 6 +- accera/mlirHelpers/CMakeLists.txt | 4 +- accera/mlirHelpers/src/ConvertToLLVM.cpp | 12 +- accera/python/accera/Package.py | 283 ++++++---------- accera/python/accera/Targets.py | 51 ++- accera/python/accera/test/dsl_tests.py | 219 ++++++++---- accera/python/accera/test/smoke_tests.py | 267 ++++++++++++--- accera/python/accera/test/test_utils.py | 36 +- accera/python/accera/test/unit_tests.py | 2 +- accera/python/lib/src/PackagingTypes.cpp | 1 - accera/python/llvm/setup.cfg | 5 +- accera/transforms/CMakeLists.txt | 4 +- .../include/util/RangeValueUtilities.h | 3 + .../include/value/ValueToLLVMLoweringPass.h | 5 +- accera/transforms/src/AcceraPasses.cpp | 46 ++- .../src/affine/AffineLoopNormalize.cpp | 2 +- .../src/affine/AffineSimplifications.cpp | 1 + .../ExecutionPlanToAffineLoweringPass.cpp | 2 - accera/transforms/src/gpu/AcceraToGPUPass.cpp | 2 +- accera/transforms/src/util/MathUtilities.cpp | 10 +- .../src/util/RangeValueUtilities.cpp | 111 +++++- .../src/value/ValueSimplifyPass.cpp | 5 +- .../src/value/ValueToLLVMLoweringPass.cpp | 299 +++++++++-------- .../src/value/ValueToStandardLoweringPass.cpp | 89 +++-- .../src/vectorization/VectorizationUtil.cpp | 10 +- accera/value/src/Cache.cpp | 4 +- accera/value/src/MLIREmitterContext.cpp | 30 +- accera/value/src/ScalarOperations.cpp | 2 +- accera/value/src/TargetDevice.cpp | 6 + ...lir-Plumb-OpenMP-dialect-attributes-.patch | 34 +- ...mproved-codegen-of-vpmaddwd-instruct.patch | 14 +- external/llvm/0003-Fix-bad-merge.patch | 6 +- ...y-to-memcpy-when-layouts-canonicaliz.patch | 44 ++- ...passed-in-op-printing-flags-were-ign.patch | 14 +- ...ore-flexible-code-generation-for-vpm.patch | 317 ++++++++++++++++++ external/llvm/portfile.cmake | 12 +- external/llvm/vcpkg.json | 2 +- requirements.txt | 2 +- setup.cfg | 4 +- 51 files changed, 1531 insertions(+), 791 deletions(-) create mode 100644 external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch diff --git a/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp b/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp index 850f529b..a776d7ea 100644 --- a/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp +++ b/accera/acc-gpu-runner/src/ACCGPURunnerMain.cpp @@ -119,7 +119,7 @@ void AddMLIRVulkanRunnerPasses(PassManager& passManager) passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); passManager.addPass(accera::transforms::vulkan::createEmitVulkanWrapperPass()); - passManager.addPass(createLowerToCFGPass()); + passManager.addPass(createConvertSCFToCFPass()); passManager.addPass(LLVM::createLegalizeForExportPass()); LowerToLLVMOptions llvmOptions(passManager.getContext()); llvmOptions.useBarePtrCallConv = false; diff --git a/accera/acc-opt/test/commandline.mlir b/accera/acc-opt/test/commandline.mlir index d2b4d2f2..b4076ad7 100644 --- a/accera/acc-opt/test/commandline.mlir +++ b/accera/acc-opt/test/commandline.mlir @@ -8,6 +8,7 @@ // CHECK-NEXT: affine // CHECK-NEXT: arith // CHECK-NEXT: builtin +// CHECK-NEXT: cf // CHECK-NEXT: gpu // CHECK-NEXT: llvm // CHECK-NEXT: math diff --git a/accera/acc-opt/test/vectorization.mlir b/accera/acc-opt/test/vectorization.mlir index 31993ce0..f696aa4a 100644 --- a/accera/acc-opt/test/vectorization.mlir +++ b/accera/acc-opt/test/vectorization.mlir @@ -15,7 +15,7 @@ module @test_accera_vectorization attributes {accv.target_device_features = "-av // mlir::affine::AffineLoadOp non-sequential // mlir::affine::AffineStoreOp sequential // mlir::affine::AffineStoreOp non-sequential - // mlir::SelectOp + // mlir::arith::SelectOp // mlir::arith::ShLIOp // mlir::arith::FPToSIOp // mlir::arith::ExtSIOp diff --git a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp index 637dc45e..41f9bca2 100644 --- a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp +++ b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.cpp @@ -415,7 +415,7 @@ namespace cpp_printer return success(); } - LogicalResult StdDialectCppPrinter::printSelectOp(SelectOp selectOp) + LogicalResult StdDialectCppPrinter::printSelectOp(arith::SelectOp selectOp) { if (selectOp.getNumOperands() != 3) { @@ -728,7 +728,7 @@ namespace cpp_printer if (auto returnOp = dyn_cast(op)) return printReturnOp(returnOp); - if (auto selectOp = dyn_cast(op)) + if (auto selectOp = dyn_cast(op)) return printSelectOp(selectOp); if (auto getGlobal = dyn_cast(op)) diff --git a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h index 3dcccf26..b99c0a94 100644 --- a/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h +++ b/accera/acc-translate/src/Target/Cpp/StdDialectCppPrinter.h @@ -83,7 +83,7 @@ namespace cpp_printer LogicalResult printReturnOp(ReturnOp returnOp); /// print SelectOp as ternary operator - LogicalResult printSelectOp(SelectOp selectOp); + LogicalResult printSelectOp(arith::SelectOp selectOp); /// print GetGlobalOp as a call to the global variable LogicalResult printGetGlobalOp(memref::GetGlobalOp getGlobalOp); diff --git a/accera/accc/accc.py b/accera/accc/accc.py index a104fee0..08dd2f77 100644 --- a/accera/accc/accc.py +++ b/accera/accc/accc.py @@ -23,6 +23,7 @@ class SystemTarget(Enum): HOST = "host" + AVX2 = "avx2" AVX512 = "avx512" RPI4 = "pi4" RPI3 = "pi3" @@ -116,6 +117,7 @@ def bstr(val): "-O3", "--march=arm", "-mcpu=arm1136jf-s", "--mtriple=armv6-linux-gnueabihf" ], SystemTarget.AVX512.value: ["-O3", "--march=x86-64", "-mcpu=skylake-avx512"], + SystemTarget.AVX2.value: ["-O3", "--march=x86-64", "-mcpu=skylake"], SystemTarget.ARM_CORTEX_M4.value: [ "-Oz", "-mcpu=cortex-m4", "--mtriple=thumbv7em-arm-none-eabi", ], diff --git a/accera/ir/include/argo/ArgoOps.td b/accera/ir/include/argo/ArgoOps.td index c9d45a18..8352bc54 100644 --- a/accera/ir/include/argo/ArgoOps.td +++ b/accera/ir/include/argo/ArgoOps.td @@ -43,6 +43,9 @@ def Argo_YieldOp : Argo_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, argo.yield %f0, %f1 : f32, f32 ``` }]; + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; } def Argo_EntryPointOp : Argo_Op<"entry_point", [IsolatedFromAbove, FunctionOpInterface, @@ -90,6 +93,9 @@ def Argo_EntryPointOp : Argo_Op<"entry_point", [IsolatedFromAbove, FunctionOpInt let skipDefaultBuilders = 1; + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "StringRef":$entryName, "FunctionType":$type, "StringRef":$kernelName, diff --git a/accera/ir/include/argo/ArgoStructuredOps.td b/accera/ir/include/argo/ArgoStructuredOps.td index 9bed523f..eddf220d 100644 --- a/accera/ir/include/argo/ArgoStructuredOps.td +++ b/accera/ir/include/argo/ArgoStructuredOps.td @@ -246,6 +246,9 @@ def OpaqueOp : ArgoStructuredBase_Op<"opaque", let regions = (region AnyRegion:$region); + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder< (ins "ValueRange":$args, "int64_t":$argsIn, "int64_t":$argsOut, diff --git a/accera/ir/src/IRUtil.cpp b/accera/ir/src/IRUtil.cpp index acdeeead..cd9aa47f 100644 --- a/accera/ir/src/IRUtil.cpp +++ b/accera/ir/src/IRUtil.cpp @@ -1000,14 +1000,31 @@ namespace util mlir::Operation* GetDefiningOpOrForLoop(mlir::Value val) { - if (mlir::isForInductionVar(val)) // AffineForOp + if (auto affineForOp = mlir::getForInductionVarOwner(val)) // AffineForOp { - return mlir::getForInductionVarOwner(val); + return affineForOp; } else if (auto scfForOp = mlir::scf::getForInductionVarOwner(val)) // SCFForOp { return scfForOp; } + else if (auto ivArg = val.dyn_cast()) + { + auto block = ivArg.getOwner(); + if (!block) + { + return nullptr; + } + auto parentOp = block->getParentOp(); + + // only handle AffineParallelOp and scf::ParallelOp, other block args such as function args should not return their associated ops + if (mlir::isa(parentOp) || + mlir::isa(parentOp)) + { + return parentOp; + } + return nullptr; + } else // Arbitrary other op { return val.getDefiningOp(); diff --git a/accera/ir/src/argo/ArgoOps.cpp b/accera/ir/src/argo/ArgoOps.cpp index 84c03bb7..94fb5ce0 100644 --- a/accera/ir/src/argo/ArgoOps.cpp +++ b/accera/ir/src/argo/ArgoOps.cpp @@ -128,17 +128,17 @@ static LogicalResult verify(CopyOp op) // YieldOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter& p, argo::YieldOp op) +void YieldOp::print(OpAsmPrinter& p) { - p << op.getOperationName(); - if (op.getNumOperands() > 0) - p << ' ' << op.getOperands(); - p.printOptionalAttrDict(op->getAttrs()); - if (op.getNumOperands() > 0) - p << " : " << op.getOperandTypes(); + p << getOperationName(); + if (getNumOperands() > 0) + p << ' ' << getOperands(); + p.printOptionalAttrDict((*this)->getAttrs()); + if (getNumOperands() > 0) + p << " : " << getOperandTypes(); } -static ParseResult parseYieldOp(OpAsmParser& parser, OperationState& result) +ParseResult YieldOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector opInfo; SmallVector types; @@ -243,33 +243,33 @@ void OpaqueOp::build( bodyBuild(odsBuilder, odsState.location, bodyBlock->getArguments()); } -static void print(OpAsmPrinter& p, OpaqueOp op) +void OpaqueOp::print(OpAsmPrinter& p) { - auto attrNames = op.argoTraitAttrNames(); + auto attrNames = argoTraitAttrNames(); llvm::StringSet<> argoTraitAttrsSet; argoTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; - for (auto attr : op->getAttrs()) + for (auto attr : (*this)->getAttrs()) if (argoTraitAttrsSet.count(attr.getName().strref()) > 0) attrs.push_back(attr); - auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); - p << op.getOperationName() << " " << dictAttr; - p.printOptionalAttrDict(op->getAttrs(), attrNames); - p << " (" << op.getOperands() << ")"; - if (!op.region().empty()) + auto dictAttr = DictionaryAttr::get(getContext(), attrs); + p << getOperationName() << " " << dictAttr; + p.printOptionalAttrDict((*this)->getAttrs(), attrNames); + p << " (" << getOperands() << ")"; + if (!region().empty()) { - p.printRegion(op.region()); + p.printRegion(region()); } - auto inputTypes = op.getOperandTypes(); + auto inputTypes = getOperandTypes(); if (!inputTypes.empty()) { p << " : " << inputTypes; } } -static ParseResult parseOpaqueOp(OpAsmParser& parser, OperationState& result) +ParseResult OpaqueOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector operandsInfo, regionOperandsInfo; DictionaryAttr dictAttr; @@ -340,8 +340,8 @@ void EntryPointOp::build(OpBuilder& builder, OperationState& result, StringRef e /// Parse an Argo entry_point op /// ::= `argo.entry_point` symbol-ref-id `(` argument-list `)` /// (`->` function-result-list)? function-attributes? -static ParseResult parseEntryPointOp(OpAsmParser& parser, - OperationState& result) +ParseResult EntryPointOp::parse(OpAsmParser& parser, + OperationState& result) { SmallVector entryArgs; SmallVector argTypes; @@ -387,17 +387,17 @@ static ParseResult parseEntryPointOp(OpAsmParser& parser, return success(); } -static void printEntryPointOp(OpAsmPrinter& p, EntryPointOp op) +void EntryPointOp::print(OpAsmPrinter& p) { p << EntryPointOp::getOperationName() << ' '; - p.printSymbolName(op.getName()); + p.printSymbolName(getName()); - FunctionType type = op.getType(); - function_interface_impl::printFunctionSignature(p, op.getOperation(), type.getInputs(), - /*isVariadic=*/false, - type.getResults()); + FunctionType type = getType(); + function_interface_impl::printFunctionSignature(p, getOperation(), type.getInputs(), + /*isVariadic=*/false, + type.getResults()); - function_interface_impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), type.getNumResults()); + function_interface_impl::printFunctionAttributes(p, getOperation(), type.getNumInputs(), type.getNumResults()); } static LogicalResult verify(EntryPointOp op) diff --git a/accera/ir/test/nest_dialect_test/IRTestVerification.cpp b/accera/ir/test/nest_dialect_test/IRTestVerification.cpp index 3ea0ea3c..00719dd8 100644 --- a/accera/ir/test/nest_dialect_test/IRTestVerification.cpp +++ b/accera/ir/test/nest_dialect_test/IRTestVerification.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -376,7 +377,7 @@ bool VerifyLowerToLLVM(mlir::OwningOpRef& module, mlir::FuncOp& funcPm.addPass(mlir::arith::createArithmeticExpandOpsPass()); // --arith-expand pm.addPass(mlir::createLowerAffinePass()); // --lower-affine - pm.addPass(mlir::createLowerToCFGPass()); // --convert-scf-to-std + pm.addPass(mlir::createConvertSCFToCFPass()); // --convert-scf-to-cf pm.addPass(mlir::createMemRefToLLVMPass()); // --convert-memref-to-llvm pm.addPass(mlir::createLowerToLLVMPass()); // --convert-std-to-llvm="use-bare-ptr-memref-call-conv" pm.addPass(mlir::createConvertVectorToLLVMPass()); // --convert-vector-to-llvm @@ -437,7 +438,7 @@ bool VerifyTranslateToLLVMIR(mlir::OwningOpRef& module, mlir::Fu funcPm.addPass(mlir::arith::createArithmeticExpandOpsPass()); // --arith-expand pm.addPass(mlir::createLowerAffinePass()); // --lower-affine - pm.addPass(mlir::createLowerToCFGPass()); // --convert-scf-to-std + pm.addPass(mlir::createConvertSCFToCFPass()); // --convert-scf-to-cf pm.addPass(mlir::createMemRefToLLVMPass()); // --convert-memref-to-llvm pm.addPass(mlir::createLowerToLLVMPass()); // --convert-std-to-llvm="use-bare-ptr-memref-call-conv" pm.addPass(mlir::createConvertVectorToLLVMPass()); // --convert-vector-to-llvm diff --git a/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp b/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp index c66ebc3a..8525c209 100644 --- a/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp +++ b/accera/ir/test/nest_dialect_test/LowLevelIRTests.cpp @@ -178,10 +178,10 @@ TEST_CASE("Int8Test1") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -189,14 +189,14 @@ TEST_CASE("Int8Test1") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); return truncVal; }; @@ -279,10 +279,10 @@ TEST_CASE("Int8Test1b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -290,14 +290,14 @@ TEST_CASE("Int8Test1b") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); return truncVal; }; @@ -378,10 +378,10 @@ TEST_CASE("Int8Test1c") auto bOdd = builder.create(loc, b, std::vector{ 1 }); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -389,14 +389,14 @@ TEST_CASE("Int8Test1c") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); return truncVal; }; @@ -497,10 +497,10 @@ TEST_CASE("Int8Test2") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -508,14 +508,14 @@ TEST_CASE("Int8Test2") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, truncVecType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, truncVecType, minVal); return truncVal; }; @@ -529,10 +529,10 @@ TEST_CASE("Int8Test2") auto r2Even = builder.create(loc, halfResultVecType, r2, r2, evenMask); auto r2Odd = builder.create(loc, halfResultVecType, r2, r2, oddMask); - auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); - // auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - // auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); + auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); + // auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + // auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); auto finalSum = builder.create(loc, r1Sum, r2Sum); builder.create(loc, finalSum, C); }); @@ -616,10 +616,10 @@ TEST_CASE("Int8Test2b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -627,14 +627,14 @@ TEST_CASE("Int8Test2b") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, truncVecType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, truncVecType, minVal); return truncVal; }; @@ -648,10 +648,10 @@ TEST_CASE("Int8Test2b") auto r2Even = builder.create(loc, halfResultVecType, r2, r2, evenMask); auto r2Odd = builder.create(loc, halfResultVecType, r2, r2, oddMask); - auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); - // auto r1Sum = builder.create(loc, builder.create(loc, r1Even, resultVecType), builder.create(loc, r1Odd, resultVecType)); - // auto r2Sum = builder.create(loc, builder.create(loc, r2Even, resultVecType), builder.create(loc, r2Odd, resultVecType)); + auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); + // auto r1Sum = builder.create(loc, builder.create(loc, resultVecType, r1Even), builder.create(loc, resultVecType, r1Odd)); + // auto r2Sum = builder.create(loc, builder.create(loc, resultVecType, r2Even), builder.create(loc, resultVecType, r2Odd)); auto finalSum = builder.create(loc, r1Sum, r2Sum); builder.create(loc, finalSum, C); }); @@ -759,10 +759,10 @@ TEST_CASE("Int8Test3") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -770,14 +770,14 @@ TEST_CASE("Int8Test3") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, i); auto cPlus = builder.create(loc, c, truncVal); @@ -901,29 +901,29 @@ TEST_CASE("Int8Test3b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 16 bits - auto aEvenExt = builder.create(loc, aEven, medVecType); - auto bEvenExt = builder.create(loc, bEven, medVecType); - auto aOddExt = builder.create(loc, aOdd, medVecType); - auto bOddExt = builder.create(loc, bOdd, medVecType); + auto aEvenExt = builder.create(loc, medVecType, aEven); + auto bEvenExt = builder.create(loc, medVecType, bEven); + auto aOddExt = builder.create(loc, medVecType, aOdd); + auto bOddExt = builder.create(loc, medVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); // extend to 32 bits - auto mulEvenExt = builder.create(loc, mulEven, bigVecType); - auto mulOddExt = builder.create(loc, mulOdd, bigVecType); + auto mulEvenExt = builder.create(loc, bigVecType, mulEven); + auto mulOddExt = builder.create(loc, bigVecType, mulOdd); auto sum = builder.create(loc, mulEvenExt, mulOddExt); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, i); auto cPlus = builder.create(loc, c, truncVal); @@ -1090,10 +1090,10 @@ TEST_CASE("Int8Test3c") auto bOdd = builder.create(loc, halfVecType, b, b, oddBMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -1101,16 +1101,16 @@ TEST_CASE("Int8Test3c") auto sum = builder.create(loc, mulEven, mulOdd); // Make sum be saturated - auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); - auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); + auto minI16Val = builder.create(loc, builder.create(loc, -32768, 32), bigVecType); + auto maxI16Val = builder.create(loc, builder.create(loc, 32767, 32), bigVecType); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, midVecType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, midVecType, minVal); - auto embiggenVal = builder.create(loc, truncVal, bigVecType); + auto embiggenVal = builder.create(loc, bigVecType, truncVal); auto c = CC.Get(builder, i, j); auto cPlus = builder.create(loc, c, embiggenVal); @@ -1241,10 +1241,10 @@ TEST_CASE("Int8Test4") auto bOdd = BB.Get(builder, iOdd); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, i32Type); - auto bEvenExt = builder.create(loc, bEven, i32Type); - auto aOddExt = builder.create(loc, aOdd, i32Type); - auto bOddExt = builder.create(loc, bOdd, i32Type); + auto aEvenExt = builder.create(loc, i32Type, aEven); + auto bEvenExt = builder.create(loc, i32Type, bEven); + auto aOddExt = builder.create(loc, i32Type, aOdd); + auto bOddExt = builder.create(loc, i32Type, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -1255,10 +1255,10 @@ TEST_CASE("Int8Test4") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, i); auto cPlus = builder.create(loc, c, truncVal); @@ -1391,8 +1391,8 @@ TEST_CASE("Int8Test5") auto b = BB.Get(builder, j, i); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto prod = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1406,13 +1406,13 @@ TEST_CASE("Int8Test5") // Make sum be saturated auto minI16Val = builder.create(loc, -32768, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto maxI16Val = builder.create(loc, 32767, 32); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, j); auto cPlus = builder.create(loc, c, truncVal); CC.Set(builder, cPlus, j); @@ -1555,8 +1555,8 @@ TEST_CASE("Int8Test5b") auto b = BB.Get(builder, j, i); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto mul = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1571,10 +1571,10 @@ TEST_CASE("Int8Test5b") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, j); auto cPlus = builder.create(loc, c, truncVal); @@ -1724,8 +1724,8 @@ TEST_CASE("Int8Test6") auto b = BB.Get(builder, jInner, iInner); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto mul = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1740,10 +1740,10 @@ TEST_CASE("Int8Test6") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); - auto truncVal = builder.create(loc, minVal, cElemType); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto truncVal = builder.create(loc, cElemType, minVal); auto c = CC.Get(builder, jInner); auto cPlus = builder.create(loc, c, truncVal); @@ -1982,8 +1982,8 @@ TEST_CASE("Int8Test8") auto b = BB.Get(builder, jInner, kInner); // extend to 32 bits - auto aExt = builder.create(loc, a, i32Type); - auto bExt = builder.create(loc, b, i32Type); + auto aExt = builder.create(loc, i32Type, a); + auto bExt = builder.create(loc, i32Type, b); auto mul = builder.create(loc, aExt, bExt); auto accumVal = builder.create(loc, accum); @@ -1998,19 +1998,19 @@ TEST_CASE("Int8Test8") auto maxI16Val = builder.create(loc, 32767, 32); auto maxCmp = builder.create(loc, mlir::arith::CmpIPredicate::sgt, sum, minI16Val); - auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); + auto maxVal = builder.create(loc, maxCmp, sum, minI16Val); auto minCmp = builder.create(loc, mlir::arith::CmpIPredicate::slt, maxVal, maxI16Val); - auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); + auto minVal = builder.create(loc, minCmp, maxVal, maxI16Val); if (cccElemType == cElemType) { - auto truncVal = builder.create(loc, minVal, i16Type); - auto expandVal = builder.create(loc, truncVal, cccElemType); + auto truncVal = builder.create(loc, i16Type, minVal); + auto expandVal = builder.create(loc, cccElemType, truncVal); CCC.Set(builder, expandVal, jInner, kInner1Count); } else { - auto truncVal = builder.create(loc, minVal, cccElemType); + auto truncVal = builder.create(loc, cccElemType, minVal); auto c = CCC.Get(builder, jInner, kInner1Count); auto cPlus = builder.create(loc, c, truncVal); CCC.Set(builder, cPlus, jInner, kInner1Count); @@ -2031,7 +2031,7 @@ TEST_CASE("Int8Test8") auto c = CCC.Get(builder, jInner, kInner); if (cccElemType != ccElemType) { - c = builder.create(loc, c, ccElemType); + c = builder.create(loc, ccElemType, c); } auto c2 = CC.Get(builder, iInner, jInner); @@ -2063,7 +2063,7 @@ TEST_CASE("Int8Test8") } else { - auto ccVal = builder.create(loc, CC.Get(builder, iInner, jInner), i32Type); + auto ccVal = builder.create(loc, i32Type, CC.Get(builder, iInner, jInner)); C.Set(builder, ccVal, i, j); } } @@ -2148,10 +2148,10 @@ TEST_CASE("Int16Test1") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -2235,10 +2235,10 @@ TEST_CASE("Int16Test1b") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); @@ -2328,10 +2328,10 @@ TEST_CASE("Int16Test2") auto bOdd = builder.create(loc, halfVecType, b, b, oddMask); // extend to 32 bits - auto aEvenExt = builder.create(loc, aEven, bigVecType); - auto bEvenExt = builder.create(loc, bEven, bigVecType); - auto aOddExt = builder.create(loc, aOdd, bigVecType); - auto bOddExt = builder.create(loc, bOdd, bigVecType); + auto aEvenExt = builder.create(loc, bigVecType, aEven); + auto bEvenExt = builder.create(loc, bigVecType, bEven); + auto aOddExt = builder.create(loc, bigVecType, aOdd); + auto bOddExt = builder.create(loc, bigVecType, bOdd); auto mulEven = builder.create(loc, aEvenExt, bEvenExt); auto mulOdd = builder.create(loc, aOddExt, bOddExt); diff --git a/accera/ir/test/nest_dialect_test/NestIRTests.cpp b/accera/ir/test/nest_dialect_test/NestIRTests.cpp index 16d7bfa4..a6a696fb 100644 --- a/accera/ir/test/nest_dialect_test/NestIRTests.cpp +++ b/accera/ir/test/nest_dialect_test/NestIRTests.cpp @@ -202,9 +202,9 @@ TEST_CASE("UnrankedMemRefTest") mlir::Value val6 = Alloca(builder, memrefType6, Value{ builder.create(builder.getUnknownLoc(), 10) }); // auto cast_3_4 = builder.create(loc, val3, val4.getType()); // cast <10xi32> -> <2x5xi32> -- illegal - [[maybe_unused]] auto cast_4_5 = builder.create(loc, val3, val5.getType()); // cast <10xi32> -> - [[maybe_unused]] auto cast_4_6 = builder.create(loc, val4, val6.getType()); // cast <2x5xi32> -> <2x?xi32> - [[maybe_unused]] auto cast_4_7 = builder.create(loc, val4, memrefType7); // cast <2x5xi32> -> <*xi32> + [[maybe_unused]] auto cast_4_5 = builder.create(loc, val5.getType(), val3); // cast <10xi32> -> + [[maybe_unused]] auto cast_4_6 = builder.create(loc, val6.getType(), val4); // cast <2x5xi32> -> <2x?xi32> + [[maybe_unused]] auto cast_4_7 = builder.create(loc, memrefType7, val4); // cast <2x5xi32> -> <*xi32> }); SECTION("Parsing") diff --git a/accera/mlirHelpers/CMakeLists.txt b/accera/mlirHelpers/CMakeLists.txt index c3f844a9..005c105d 100644 --- a/accera/mlirHelpers/CMakeLists.txt +++ b/accera/mlirHelpers/CMakeLists.txt @@ -79,9 +79,11 @@ target_link_libraries( ${conversion_libs} ir MLIRStandardToLLVM - MLIRSCFToStandard + MLIRSCFToControlFlow + MLIRControlFlowToLLVM MLIRAffineToStandard MLIRAffineTransforms + MLIRAffineUtils MLIRExecutionEngine MLIRLinalgToLLVM MLIRLinalgTransforms diff --git a/accera/mlirHelpers/src/ConvertToLLVM.cpp b/accera/mlirHelpers/src/ConvertToLLVM.cpp index 24c29e38..2ba9b218 100644 --- a/accera/mlirHelpers/src/ConvertToLLVM.cpp +++ b/accera/mlirHelpers/src/ConvertToLLVM.cpp @@ -9,9 +9,10 @@ #include #include -#include -#include +#include #include +#include +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include #include #include @@ -49,8 +50,8 @@ namespace ir // affine -> loops funcOpPM.addPass(mlir::createLowerAffinePass()); - // loops -> std - pm.addPass(mlir::createLowerToCFGPass()); + // loops -> cf + pm.addPass(mlir::createConvertSCFToCFPass()); // add custom LLVM passes addLLVMPassesFn(pm); @@ -58,6 +59,9 @@ namespace ir // linalg -> llvm pm.addPass(mlir::createConvertLinalgToLLVMPass()); + // cf -> llvm + pm.addPass(mlir::cf::createConvertControlFlowToLLVMPass()); + // another canonicalization pass pm.addPass(mlir::createCanonicalizerPass()); diff --git a/accera/python/accera/Package.py b/accera/python/accera/Package.py index 461eadf0..eb401b35 100644 --- a/accera/python/accera/Package.py +++ b/accera/python/accera/Package.py @@ -48,9 +48,7 @@ def _(arg: lang.Array): @singledispatch def _resolve_array_shape(source, arr: lang.Array): is_infinite_value = ( - arr.shape[-1].get_value() == inf - if isinstance(arr.shape[-1], DelayedParameter) - else arr.shape[-1] == inf + arr.shape[-1].get_value() == inf if isinstance(arr.shape[-1], DelayedParameter) else arr.shape[-1] == inf ) if is_infinite_value: # TODO: support shape inference for lang.Function, Callable if needed @@ -63,9 +61,7 @@ def _(source, arr: lang.Array): from .lang.IntrospectionUtilities import get_array_access_indices is_infinite_value = ( - arr.shape[-1].get_value() == inf - if isinstance(arr.shape[-1], DelayedParameter) - else arr.shape[-1] == inf + arr.shape[-1].get_value() == inf if isinstance(arr.shape[-1], DelayedParameter) else arr.shape[-1] == inf ) if is_infinite_value: # introspect array access index to determine dimensions of the array @@ -73,9 +69,7 @@ def _(source, arr: lang.Array): # TODO: support multiple logic fns if needed assert len(logic_fns) == 1, "Only one logic function is supported" access_indices = get_array_access_indices(arr, logic_fns[0]) - assert len(access_indices) == len( - arr.shape - ), "Access indices and shape must have the same dimensions" + assert len(access_indices) == len(arr.shape), "Access indices and shape must have the same dimensions" idx = source.get_indices().index(access_indices[-1]) # initialize the array with the new shape @@ -100,9 +94,7 @@ def _emit_module(module_to_emit, target, mode, output_dir, name): working_dir = os.path.join(output_dir, "_tmp") proj = accc.AcceraProject(output_dir=working_dir, library_name=name) - proj.module_file_sets = [ - accc.ModuleFileSet(name=name, common_module_dir=working_dir) - ] + proj.module_file_sets = [accc.ModuleFileSet(name=name, common_module_dir=working_dir)] module_to_emit.Save(proj.module_file_sets[0].generated_mlir_filepath) proj.generate_and_emit( @@ -117,9 +109,7 @@ def _emit_module(module_to_emit, target, mode, output_dir, name): # Complete the HAT file with information we have stored at this layer hat_file = hat.HATFile.Deserialize(header_path) - hat_file.dependencies.link_target = os.path.basename( - proj.module_file_sets[0].object_filepath - ) + hat_file.dependencies.link_target = os.path.basename(proj.module_file_sets[0].object_filepath) hat_file.Serialize(header_path) # copy HAT package files into output directory @@ -128,6 +118,7 @@ def _emit_module(module_to_emit, target, mode, output_dir, name): class SetActiveModule: + def __init__(self, module): self.module = module @@ -150,28 +141,20 @@ class Format(Flag): MLIR = auto() MLIR_VERBOSE = auto() SOURCE = auto() - DEFAULT = auto() # HAT_DYNAMIC on HOST target, HAT_STATIC otherwise - HAT_DYNAMIC = ( - HAT_PACKAGE | DYNAMIC_LIBRARY - ) #: HAT package format, dynamically linked. - HAT_STATIC = ( - HAT_PACKAGE | STATIC_LIBRARY - ) #: HAT package format, statically linked + DEFAULT = auto() # HAT_DYNAMIC on HOST target, HAT_STATIC otherwise + HAT_DYNAMIC = (HAT_PACKAGE | DYNAMIC_LIBRARY) #: HAT package format, dynamically linked. + HAT_STATIC = (HAT_PACKAGE | STATIC_LIBRARY) #: HAT package format, statically linked HAT_SOURCE = HAT_PACKAGE | SOURCE - MLIR_DYNAMIC = ( - HAT_DYNAMIC | MLIR - ) #: MLIR (debugging) package format, dynamically linked. - MLIR_STATIC = ( - HAT_STATIC | MLIR - ) #: MLIR (debugging) package format, statically linked. + MLIR_DYNAMIC = (HAT_DYNAMIC | MLIR) #: MLIR (debugging) package format, dynamically linked. + MLIR_STATIC = (HAT_STATIC | MLIR) #: MLIR (debugging) package format, statically linked. MLIR_SOURCE = HAT_SOURCE | MLIR class Mode(Enum): - RELEASE = "Release" #: Release (maximally optimized). - DEBUG = "Debug" #: Debug mode (automatically tests logical equivalence). + RELEASE = "Release" #: Release (maximally optimized). + DEBUG = "Debug" #: Debug mode (automatically tests logical equivalence). class _Options(Flag): - NONE = auto() # (enable auto unroll | low precision fp ops) + NONE = auto() # (enable auto unroll | low precision fp ops) DISABLE_AUTO_UNROLL = auto() HIGH_PRECISION_FLOATING_POINT_OPS = auto() @@ -185,21 +168,15 @@ def __init__(self): self._description = {} self._dynamic_dependencies = set() - def _create_gpu_utility_module( - self, compiler_options, target, mode, output_dir, name="AcceraGPUUtilities" - ): + def _create_gpu_utility_module(self, compiler_options, target, mode, output_dir, name="AcceraGPUUtilities"): gpu_utility_module = _lang_python._Module(name=name, options=compiler_options) with SetActiveModule(gpu_utility_module): gpu_init_fn = _lang_python._DeclareFunction("AcceraGPUInitialize") gpu_deinit_fn = _lang_python._DeclareFunction("AcceraGPUDeInitialize") - gpu_init_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI( - True - ).addTag("rc_gpu_init") - gpu_deinit_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI( - True - ).addTag("rc_gpu_deinit") + gpu_init_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI(True).addTag("rc_gpu_init") + gpu_deinit_fn.public(True).decorated(False).headerDecl(True).rawPointerAPI(True).addTag("rc_gpu_deinit") # No common initialization / de-initialization at this layer, however lowering passes may add steps def empty_func(args): @@ -211,8 +188,7 @@ def empty_func(args): return _emit_module(gpu_utility_module, target, mode, output_dir, name) def _create_mapping_of_heuristic_parameters_with_possible_values( - self, - source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable] + self, source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable] ): parameter_dict = {} heuristic_parameters = source._get_heuristic_parameters() @@ -226,9 +202,7 @@ def _create_mapping_of_heuristic_parameters_with_possible_values( def add( self, - source: Union[ - "accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable - ], + source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable], args: List[Union["accera.Dimension", "accera.Array"]] = None, base_name: str = "", parameters: Union[dict, List[dict]] = {}, @@ -255,8 +229,11 @@ def add( # Note: this does not prevent TEMP arrays from being passed as an argument to a function, but they cannot be the # api-defining arguments for the function for idx, arg in enumerate(args): - if isinstance(arg, (lang.Array, _lang_python._lang.Scalar, _lang_python._lang.Dimension)) and arg.role == _lang_python.Role.TEMP: - raise ValueError(f"Error in package.add() for function {base_name}: args includes TEMP array at positions {idx}") + if isinstance(arg, (lang.Array, _lang_python._lang.Scalar, + _lang_python._lang.Dimension)) and arg.role == _lang_python.Role.TEMP: + raise ValueError( + f"Error in package.add() for function {base_name}: args includes TEMP array at positions {idx}" + ) heuristic_parameters_dict = {} if isinstance(source, lang.Plan): @@ -267,7 +244,7 @@ def add( product_parameter_grid = [] # Create a list of delayed parameter for each possible value separately using `get_parameters_from_grid` - if heuristic_parameters_dict: + if heuristic_parameters_dict: product_parameter_grid = get_parameters_from_grid(heuristic_parameters_dict) # TODO: Add functions for product parameter grid in next PR instead of adding fns separately for @@ -275,17 +252,15 @@ def add( if parameters and not isinstance(parameters, dict): return [self._add_function(source, args, base_name, p, function_opts, auxiliary) for p in parameters] elif product_parameter_grid and not isinstance(product_parameter_grid, dict): - return [self._add_function(source, args, base_name, p, function_opts, auxiliary) for p in product_parameter_grid] + return [ + self._add_function(source, args, base_name, p, function_opts, auxiliary) for p in product_parameter_grid + ] else: - return self._add_function( - source, args, base_name, parameters, function_opts, auxiliary - ) + return self._add_function(source, args, base_name, parameters, function_opts, auxiliary) def _add_function( self, - source: Union[ - "accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable - ], + source: Union["accera.Nest", "accera.Schedule", "accera.Plan", "accera.Function", Callable], args: List[Union["accera.Dimension", "accera.Array"]] = None, base_name: str = "", parameters: dict = {}, @@ -315,14 +290,14 @@ def _add_function( else: if isinstance(value, tuple) or isinstance(value, list): if all(isinstance(v, LoopIndex) for v in value): - param_value_dict[delayed_param._name] = str( - [x._name for x in value] - ) + param_value_dict[delayed_param._name] = str([x._name for x in value]) else: raise ValueError("Invalid value of parameters") else: param_value_dict[delayed_param._name] = str(value) - auxiliary_metadata["accera"] = {"parameters": param_value_dict} + auxiliary_metadata["accera"] = { + "parameters": param_value_dict + } def validate_target(target: Target): # can't use set because targets are mutable (therefore unhashable) @@ -347,20 +322,14 @@ def get_function_name(target: Target): base_name or token_hex(4), target, auxiliary_metadata["accera"], - ] - + [ - (a.role, a.element_type, a.shape, a.layout) - if isinstance(a, lang.Array) else (a.name, a.type) - if isinstance(a, _lang_python._lang.Dimension) else None - for a in args - ], + ] + [(a.role, a.element_type, a.shape, a.layout) if isinstance(a, lang.Array) else + (a.name, a.type) if isinstance(a, _lang_python._lang.Dimension) else None + for a in args], ) ) ).encode("utf-8") - ) - .digest() - .hex()[:16] - ) # truncate + ).digest().hex()[:16] + ) # truncate # Function names must begin with an _ or alphabetical character return f"{base_name}_{suffix}" if base_name else f"_{suffix}" @@ -381,13 +350,15 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): arg_size_refs = [] for arg in args: if isinstance(arg, lang.Array): - arr_dim_mappings = [args.index(dim) if isinstance(dim, _lang_python._lang.Dimension) else SENTINEL_VALUE for dim in arg.shape] + arr_dim_mappings = [ + args.index(dim) if isinstance(dim, _lang_python._lang.Dimension) else SENTINEL_VALUE + for dim in arg.shape + ] arg_size_refs.append(arr_dim_mappings) else: arg_size_refs.append([SENTINEL_VALUE]) return arg_size_refs - # Resolve any undefined argument shapes based on the source usage pattern for arr in args: if isinstance(arr, lang.Array): @@ -400,12 +371,13 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): if isinstance(source, lang.Plan): self._dynamic_dependencies.update(source._dynamic_dependencies) - source = source._create_function( - args, **function_opts - ) + source = source._create_function(args, **function_opts) # fall-through - arg_names = [arg.name if isinstance(arg, lang.Array) or isinstance(arg, _lang_python._lang.Dimension) else "" for arg in args] + arg_names = [ + arg.name if isinstance(arg, lang.Array) or isinstance(arg, _lang_python._lang.Dimension) else "" + for arg in args + ] arg_sizes = [arg._size_str if isinstance(arg, lang.Array) else "" for arg in args] if isinstance(source, lang.Function): @@ -414,7 +386,7 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): # due to the fall-through, we only need to validate here validate_target(source.target) - native_array_dim_args = [arg._get_native_array() if isinstance(arg, lang.Array) else arg for arg in args ] + native_array_dim_args = [arg._get_native_array() if isinstance(arg, lang.Array) else arg for arg in args] source.name = get_function_name(source.target) source.base_name = base_name @@ -426,7 +398,7 @@ def compute_arg_size_references(args, SENTINEL_VALUE=-1): source.arg_sizes = arg_sizes source.requested_args = args self._fns[source.name] = source - return source # for composability + return source # for composability elif isinstance(source, Callable): @@ -454,7 +426,7 @@ def wrapper_fn(args): ) self._fns[name] = wrapped_func - return wrapped_func # for composability + return wrapped_func # for composability else: raise ValueError("Invalid type for source") @@ -467,9 +439,7 @@ def _add_functions_to_module(self, module, fail_on_error=False): wrapped_func._emit() except Exception as e: to_pop.append(name) - logging.error( - f"Compiler error when trying to build function {name}" - ) + logging.error(f"Compiler error when trying to build function {name}") logging.error(e) if fail_on_error: raise @@ -492,11 +462,8 @@ def _add_debug_utilities(self, tolerance): # only add if there are actually arguments to debug return add_debugging_functions( self, - { - name: fn_and_args - for name, fn_and_args in fns_to_add.items() - if fn_and_args[1] - }, + {name: fn_and_args + for name, fn_and_args in fns_to_add.items() if fn_and_args[1]}, atol=tolerance, ) @@ -511,10 +478,10 @@ def _generate_target_options(self, platform: Platform, mode: Mode = Mode.RELEASE host_target_device = _lang_python._GetTargetDeviceFromName("host") if platform in [ - Package.Platform.HOST, - Package.Platform.LINUX, - Package.Platform.MACOS, - Package.Platform.WINDOWS, + Package.Platform.HOST, + Package.Platform.LINUX, + Package.Platform.MACOS, + Package.Platform.WINDOWS, ]: target_device = _lang_python._GetTargetDeviceFromName(platform.value) else: @@ -536,15 +503,19 @@ def _generate_target_options(self, platform: Platform, mode: Mode = Mode.RELEASE elif target.architecture == Target.Architecture.X86_64: target_device.architecture = "x86_64" - if "AVX512" in target.extensions: - target_device.device_name = "avx512" - target_device.cpu = "skylake-avx512" - # TODO: make this functionality less hidden - avx512_feat_str = ",".join( - [f"+{feature.lower()}" for feature in target.extensions] - ) + if not target_device.is_macOS(): # macOS does not support these instructions + if "AVX512" in target.extensions: + target_device.device_name = "avx512" + target_device.cpu = "skylake-avx512" + # TODO: make this functionality less hidden + avx512_feat_str = ",".join([f"+{feature.lower()}" for feature in target.extensions]) - target_device.features = avx512_feat_str + target_device.features = avx512_feat_str + + elif "AVX2" in target.extensions: + target_device.device_name = "avx2" + target_device.cpu = "skylake" + target_device.features = "+avx2" elif target.architecture == Target.Architecture.X86: target_device.architecture = "x86" @@ -554,21 +525,14 @@ def _generate_target_options(self, platform: Platform, mode: Mode = Mode.RELEASE compiler_options = _lang_python.CompilerOptions() compiler_options.target_device = target_device compiler_options.debug = mode == Package.Mode.DEBUG - compiler_options.gpu_only = ( - target.category == Target.Category.GPU and target.runtime != Runtime.VULKAN - ) + compiler_options.gpu_only = (target.category == Target.Category.GPU and target.runtime != Runtime.VULKAN) BuildConfig.obj_extension = ".obj" if target_device.is_windows() else ".o" - libs = list( - filter( - None, - [ - get_library_reference(dep, platform) - for dep in self._dynamic_dependencies - ], - ) - ) + libs = list(filter( + None, + [get_library_reference(dep, platform) for dep in self._dynamic_dependencies], + )) return target, target_device, compiler_options, libs def _make_accc_options(self, options: _Options): @@ -627,9 +591,9 @@ def build( format_is_default = bool( format & Package.Format.DEFAULT - ) # store it as a boolean because we're going to turn off the actual flag + ) # store it as a boolean because we're going to turn off the actual flag if format_is_default: - format &= ~Package.Format.DEFAULT # Turn off "DEFAULT" + format &= ~Package.Format.DEFAULT # Turn off "DEFAULT" if target.runtime in [Target.Runtime.CUDA, Target.Runtime.ROCM]: format |= Package.Format.HAT_SOURCE @@ -640,9 +604,7 @@ def build( dynamic_link = bool(format & Package.Format.DYNAMIC_LIBRARY) if cross_compile and dynamic_link: - raise ValueError( - "Package.Format.DYNAMIC_LIBRARY is not supported when cross-compiling" - ) + raise ValueError("Package.Format.DYNAMIC_LIBRARY is not supported when cross-compiling") output_dir = output_dir or os.getcwd() working_dir = os.path.join(output_dir, "_tmp") @@ -662,44 +624,22 @@ def build( # Emit the package module if format & Package.Format.SOURCE: - output_type = ( - accc.ModuleOutputType.CUDA - if compiler_options.gpu_only - else accc.ModuleOutputType.CPP - ) + output_type = (accc.ModuleOutputType.CUDA if compiler_options.gpu_only else accc.ModuleOutputType.CPP) else: output_type = accc.ModuleOutputType.OBJECT # Emit the supporting modules supporting_hats = [] - if ( - not compiler_options.gpu_only - and output_type == accc.ModuleOutputType.OBJECT - ): + if (not compiler_options.gpu_only and output_type == accc.ModuleOutputType.OBJECT): supporting_hats.append( - Package._emit_default_module( - compiler_options, target, mode, output_dir, f"{name}_Globals" - ) + Package._emit_default_module(compiler_options, target, mode, output_dir, f"{name}_Globals") ) - if any( - fn.target.category == Target.Category.GPU - and fn.target.runtime == Target.Runtime.VULKAN - for fn in self._fns.values() - ): - supporting_hats.append( - self._create_gpu_utility_module( - compiler_options, target, mode, output_dir - ) - ) + if any(fn.target.category == Target.Category.GPU and fn.target.runtime == Target.Runtime.VULKAN + for fn in self._fns.values()): + supporting_hats.append(self._create_gpu_utility_module(compiler_options, target, mode, output_dir)) - proj = accc.AcceraProject( - output_dir=working_dir, library_name=name, output_type=output_type - ) - proj.module_file_sets = [ - accc.ModuleFileSet( - name=name, common_module_dir=working_dir, output_type=output_type - ) - ] + proj = accc.AcceraProject(output_dir=working_dir, library_name=name, output_type=output_type) + proj.module_file_sets = [accc.ModuleFileSet(name=name, common_module_dir=working_dir, output_type=output_type)] package_module.Save(proj.module_file_sets[0].generated_mlir_filepath) # Enable dumping of IR passes based on build format @@ -737,40 +677,28 @@ def build( # Complete the HAT file with information we have stored at this layer hat_file: hat.HATFile = hat.HATFile.Deserialize(header_path) - if format & ( - Package.Format.DYNAMIC_LIBRARY | Package.Format.STATIC_LIBRARY - ): - hat_file.dependencies.link_target = os.path.basename( - proj.module_file_sets[0].object_filepath - ) + if format & (Package.Format.DYNAMIC_LIBRARY | Package.Format.STATIC_LIBRARY): + hat_file.dependencies.link_target = os.path.basename(proj.module_file_sets[0].object_filepath) supporting_hats = map(hat.HATFile.Deserialize, supporting_hats) supporting_objs = [] supporting_decls = [] for support in supporting_hats: path = os.path - dependency_path = path.abspath( - path.join(output_dir, support.dependencies.link_target) - ) + dependency_path = path.abspath(path.join(output_dir, support.dependencies.link_target)) # Collect the supporting modules as dependencies - supporting_objs.append( - hat.LibraryReference(target_file=dependency_path) - ) + supporting_objs.append(hat.LibraryReference(target_file=dependency_path)) # Collecting the supporting code decls supporting_decls.append(support.declaration.code) # Merge the function maps - hat_file._function_table.function_map.update( - support._function_table.function_map - ) + hat_file._function_table.function_map.update(support._function_table.function_map) decl_code = hat_file.declaration.code hat_file.dependencies.dynamic = dynamic_dependencies + supporting_objs - hat_file.declaration.code = decl_code._new( - "\n".join(map(str, ["", decl_code] + supporting_decls)) - ) + hat_file.declaration.code = decl_code._new("\n".join(map(str, ["", decl_code] + supporting_decls))) for fn_name in self._fns: fn: lang.Function = self._fns[fn_name] @@ -779,28 +707,19 @@ def build( hat_func = hat_file.function_map.get(fn_name) if hat_func is None: - raise ValueError( - f"Couldn't find header-declared function {fn_name} in emitted HAT file" - ) + raise ValueError(f"Couldn't find header-declared function {fn_name} in emitted HAT file") hat_func.auxiliary = fn.auxiliary - if ( - fn.target.category == Target.Category.GPU - and fn.target.runtime != Target.Runtime.VULKAN - ): + if (fn.target.category == Target.Category.GPU and fn.target.runtime != Target.Runtime.VULKAN): # TODO: Remove this when the header is emitted as part of the compilation gpu_source = proj.module_file_sets[0].translated_source_filepath gpu_device_func = fn_name + "__gpu__" with open(gpu_source) as gpu_source_f: - s = re.search( - gpu_device_func + _R_GPU_LAUNCH, gpu_source_f.read() - ) + s = re.search(gpu_device_func + _R_GPU_LAUNCH, gpu_source_f.read()) if not s: raise RuntimeError("Couldn't parse emitted source code") - launch_parameters = list( - map(int, [s[n] for n in range(1, 7)]) - ) + launch_parameters = list(map(int, [s[n] for n in range(1, 7)])) dynamic_shared_mem_bytes = int(s[7]) gpu_source = os.path.split(gpu_source)[1] @@ -847,11 +766,11 @@ def build( if not cross_compile and (format & Package.Format.STATIC_LIBRARY): lib_hat_path = f"{path_root}_lib{extension}" hat.create_static_package(header_path, lib_hat_path) - + lib_hat_file = hat_file.Deserialize(lib_hat_path) lib_hat_file.dependencies.auxiliary["static"] = lib_hat_file.dependencies.link_target lib_hat_file.Serialize() - + shutil.move(lib_hat_path, header_path) if dynamic_link: @@ -863,7 +782,7 @@ def build( dyn_hat_file.Serialize() shutil.move(dyn_hat_path, header_path) - + # TODO: plumb cross-compilation of static libs return proj.module_file_sets @@ -892,9 +811,7 @@ def add_description( self._description["auxiliary"].update(other) # remove any keys marked None - keys_to_remove = [ - k for k, v in self._description["auxiliary"].items() if v is None - ] + keys_to_remove = [k for k, v in self._description["auxiliary"].items() if v is None] for k in keys_to_remove: del self._description["auxiliary"][k] diff --git a/accera/python/accera/Targets.py b/accera/python/accera/Targets.py index 2df3b721..ce54551e 100644 --- a/accera/python/accera/Targets.py +++ b/accera/python/accera/Targets.py @@ -6,13 +6,15 @@ import copy import cpuinfo import re + from typing import List, Union from dataclasses import dataclass, field, fields from enum import Enum, auto from ._lang_python import ScalarType, _GetKnownDeviceNames from ._lang_python._lang import ( - BLOCK_X, BLOCK_Y, BLOCK_Z, THREAD_X, THREAD_Y, THREAD_Z, WARP_X, WARP_Y, _MemorySpace, MMAShape, _ExecutionRuntime as Runtime + BLOCK_X, BLOCK_Y, BLOCK_Z, THREAD_X, THREAD_Y, THREAD_Z, WARP_X, WARP_Y, _MemorySpace, MMAShape, _ExecutionRuntime + as Runtime ) @@ -768,14 +770,14 @@ def supports( def mma_shape_to_tuple(self, mma_shape: MMAShape): return { - MMAShape.M64xN64xK1_B4 : (64, 64, 1), - MMAShape.M64xN64xK1_B2 : (64, 64, 1), - MMAShape.M32xN32xK2_B1 : (32, 32, 2), - MMAShape.M16xN16xK4_B1 : (16, 16, 4), - MMAShape.M64xN64xK4_B4 : (64, 64, 4), - MMAShape.M64xN64xK4_B2 : (64, 64, 4), - MMAShape.M32xN32xK8_B1 : (32, 32, 8), - MMAShape.M16xN16xK16_B1 : (16, 16, 16), + MMAShape.M64xN64xK1_B4: (64, 64, 1), + MMAShape.M64xN64xK1_B2: (64, 64, 1), + MMAShape.M32xN32xK2_B1: (32, 32, 2), + MMAShape.M16xN16xK4_B1: (16, 16, 4), + MMAShape.M64xN64xK4_B4: (64, 64, 4), + MMAShape.M64xN64xK4_B2: (64, 64, 4), + MMAShape.M32xN32xK8_B1: (32, 32, 8), + MMAShape.M16xN16xK16_B1: (16, 16, 16), MMAShape.M64xN64xK2_B4: (64, 64, 2), MMAShape.M64xN64xK2_B2: (64, 64, 2), MMAShape.M32xN32xK4_B1: (32, 32, 4), @@ -791,6 +793,7 @@ def compute_tensor_splits(self, mma_shape: MMAShape, num_total_passes: int = 1): return tuple(mutable_tensor_splits) +# yapf: disable MI100_TENSORCORE_INFO = TensorCoreInformation([ TensorCoreInformationEntry(shape=MMAShape.M16xN16xK4_B1, inType=ScalarType.float32, outType=ScalarType.float32), # maps to the 16x16x4 mfma instruction TensorCoreInformationEntry(shape=MMAShape.M32xN32xK2_B1, inType=ScalarType.float32, outType=ScalarType.float32), # maps to the 32x32x2 mfma instruction @@ -916,6 +919,14 @@ def __post_init__(self): device_name = \ self.family.lower() if self.family else (self.name.lower() if self.name else self._device_name) + + # FIXUP: determine the device name if it matches with certain known extensions + # revisit using self.family for the device_name? + if "AVX512" in self.extensions: + device_name = "avx512" + elif "AVX2" in self.extensions: + device_name = "avx2" + if device_name in _GetKnownDeviceNames(): self._device_name = device_name @@ -1045,7 +1056,7 @@ def __init__( num_threads: int = None, num_cores: int = None, vector_bytes: int = 0, - vector_registers: int = None, + vector_registers: int = 0, frequency_GHz: float = None, tensor_core_info: TensorCoreInformation = None, turbo_frequency_GHz: float = None, @@ -1060,15 +1071,25 @@ def __init__( known_name = self._try_get_known_name(known_name) if known_name == "HOST": + if not extensions: + # infer extensions from cpu information + cpu_info = cpuinfo.get_cpu_info() + extensions = [ + ext for ext in + ["MMX", "SSE", "SSE2", "SSE3", "SSSE3", "SSE4", "SSE4.1", "SSE4.2", "AVX", "AVX2", "FMA3"] + if "flags" in cpu_info and ext.lower() in cpu_info["flags"] + ] + + if "AVX2" in extensions: + vector_bytes = 32 # There are 32-bytes per full SIMD register + vector_registers = 16 # There are 16 YMM registers super().__init__( category=category or Target.Category.CPU, architecture=Target.Architecture["HOST"], - vector_bytes=32, # There are 32-bytes per full SIMD register - vector_registers=16, # There are 16 YMM registers - extensions=[ - "MMX", "SSE", "SSE2", "SSE3", "SSSE3", "SSE4", "SSE4.1", "SSE4.2", "AVX", "AVX2", "FMA3" - ] + vector_bytes=vector_bytes, + vector_registers=vector_registers, + extensions=extensions ) else: diff --git a/accera/python/accera/test/dsl_tests.py b/accera/python/accera/test/dsl_tests.py index 72eda320..e4ff0126 100644 --- a/accera/python/accera/test/dsl_tests.py +++ b/accera/python/accera/test/dsl_tests.py @@ -27,10 +27,9 @@ from accera import ScalarType, Array, Function, Nest, Target, Package, algorithms, cast, AllocateFlags, Role from accera.test import verifiers -from accera.test.test_utils import expectedFailure, FailedReason +from accera.test.test_utils import expectedFailure, FailedReason, avx2_cpu, get_avx_platform from accera._lang_python._lang import Dimension, EnterProfileRegion, ExitProfileRegion, PrintProfileResults - INTERNAL_FUNCTION_OPTS = { "no_inline_into": True, "public": False @@ -38,6 +37,7 @@ TEST_MODE = Package.Mode.DEBUG if DEV_MODE else Package.Mode.RELEASE TEST_FORMAT = Package.Format.MLIR_DYNAMIC if DEV_MODE else Package.Format.HAT_DYNAMIC +TEST_FORMAT_XCOMPILE = Package.Format.MLIR_STATIC TEST_PACKAGE_DIR = "test_acccgen" # Groups of types commonly used for tests @@ -75,7 +75,13 @@ def _verify_nest(self, nest, args: Tuple[Array], package_name, correctness_check # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir, _quiet=quiet) + package.build( + package_name, + format=TEST_FORMAT, + mode=_get_test_mode(correctness_check_values), + output_dir=output_dir, + _quiet=quiet + ) if correctness_check_values: v.check_correctness( function.name, @@ -373,6 +379,7 @@ def _(): ) def test_dynamic_temp_array(self) -> None: + def make_test_fn(package, A, B, C, N): T = Array(role=Role.TEMP, element_type=A.element_type, shape=A.shape) @@ -538,7 +545,9 @@ def _(): "pre": (A_test, B_test, C_test), "post": (A_test, B_expected, C_expected), } - self._verify_nest(plan, (A, B, C), "test_array_vectorize_cast", correctness_check_values=correctness_check_values) + self._verify_nest( + plan, (A, B, C), "test_array_vectorize_cast", correctness_check_values=correctness_check_values + ) def test_interleaved_vectorize_cast(self) -> None: shape = (64, 32, 8, 2) @@ -643,7 +652,7 @@ def test_reinterpret_cast(self) -> None: def reinterpret_arr_as_int16(array: Array): # Assumes array is f32 - num_elements = reduce(lambda x, y: x*y, array.shape, 1) + num_elements = reduce(lambda x, y: x * y, array.shape, 1) arr_mb = array._get_memory_buffer() self.assertEqual(arr_mb.shape, [num_elements * 4]) self.assertEqual(arr_mb.element_type, ScalarType.uint8) @@ -659,7 +668,7 @@ def reinterpret_arr_as_int16(array: Array): def reinterpret_arr_as_int32(array: Array): # Assumes array is f32 - num_elements = reduce(lambda x, y: x*y, array.shape, 1) + num_elements = reduce(lambda x, y: x * y, array.shape, 1) arr_mb = array._get_memory_buffer() self.assertEqual(arr_mb.shape, [num_elements * 4]) self.assertEqual(arr_mb.element_type, ScalarType.uint8) @@ -675,7 +684,7 @@ def reinterpret_arr_as_int32(array: Array): def simple_reinterpret_arr_as_int32(array: Array): # Assumes array is f32 - num_elements = reduce(lambda x, y: x*y, array.shape, 1) + num_elements = reduce(lambda x, y: x * y, array.shape, 1) arr_as_int32 = array._reinterpret_cast(ScalarType.int32) self.assertEqual(arr_as_int32.shape, array.shape) self.assertEqual(arr_as_int32.element_type, ScalarType.int32) @@ -718,18 +727,14 @@ def main(array): output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir): package.build( - package_name, - format=TEST_FORMAT, - mode=Package.Mode.RELEASE, - output_dir=output_dir, - _quiet=False + package_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, output_dir=output_dir, _quiet=False ) def test_heap_alloc_reinterpret_cast(self) -> None: package = Package() - temp = Array(role=Role.TEMP, element_type=ScalarType.int32, shape=(32,), flags=AllocateFlags.HEAP) - output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(32,)) + temp = Array(role=Role.TEMP, element_type=ScalarType.int32, shape=(32, ), flags=AllocateFlags.HEAP) + output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(32, )) nest = Nest(shape=output.shape) i, = nest.get_indices() @@ -739,18 +744,14 @@ def _(): temp_mb = temp._get_memory_buffer() cast_temp = temp_mb._reinterpret_cast(ScalarType.float32) output[i] = cast_temp[i] - + package.add(nest, args=(output, )) package_name = "test_heap_alloc_reinterpret_cast" output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir): package.build( - package_name, - format=TEST_FORMAT, - mode=Package.Mode.RELEASE, - output_dir=output_dir, - _quiet=False + package_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, output_dir=output_dir, _quiet=False ) def test_reinterpret_cast_partially_dynamic_shape(self) -> None: @@ -769,18 +770,12 @@ def test_reinterpret_cast_partially_dynamic_shape(self) -> None: def _(): float_A = A._reinterpret_cast(ScalarType.float32) B[indices] = float_A[indices] - + package.add(nest, args=(M, N, A, B), base_name=test_name) output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name with verifiers.VerifyPackage(self, test_name, output_dir): - package.build( - test_name, - format=TEST_FORMAT, - mode=Package.Mode.RELEASE, - output_dir=output_dir, - _quiet=False - ) + package.build(test_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, output_dir=output_dir, _quiet=False) def test_subarray(self) -> None: package = Package() @@ -885,7 +880,9 @@ def _verify_helper(self, package, test_name, function_name=None, correctness_che output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name with verifiers.VerifyPackage(self, test_name, output_dir) as v: shutil.rmtree(output_dir, ignore_errors=True) - package.build(test_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + test_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if function_name and correctness_check_values: v.check_correctness( function_name, @@ -1523,11 +1520,13 @@ def test_output_array_range_node2(self) -> None: output_start = Scalar(type=ScalarType.float32, role=Role.TEMP) nest1 = Nest((1, )) + @nest1.iteration_logic def _(): outputDim.set(cast(floor((limit - start) / delta), ScalarType.int64)) nest2 = Nest((1, )) + @nest2.iteration_logic def _(): output_start.set(start) @@ -1549,7 +1548,7 @@ def _(): package = Package() # BUGBUG: dim args ordered first due to issue with Debug mode package.add(nest1, args=(start, limit, delta, outputDim), base_name=f"range_get_size") - ini_start_fn = package.add(nest2, args=(start,), base_name=f"ini_start") + ini_start_fn = package.add(nest2, args=(start, ), base_name=f"ini_start") get_result_fn = package.add(nest3, args=(inputDim, output, delta), base_name=f"get_result") nest4 = Nest((1, )) @@ -1692,7 +1691,16 @@ def _create_nest(self, shape: Tuple[int], type=ScalarType.float32) -> Tuple: return Nest(shape=(M, N, S)), A, B, C - def _build_nest(self, nest, args: Tuple[Array], package_name, correctness_check_values=None, quiet=True) -> None: + def _build_nest( + self, + nest, + args: Tuple[Array], + package_name, + correctness_check_values=None, + quiet=True, + file_check_fn=None, + platform=Package.Platform.HOST + ) -> None: # helper function to build a nest so that we can focus on the logic function # create a HAT package and add the nest to it package = Package() @@ -1701,13 +1709,22 @@ def _build_nest(self, nest, args: Tuple[Array], package_name, correctness_check_ # build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir, _quiet=quiet) + package.build( + package_name, + format=TEST_FORMAT if correctness_check_values else TEST_FORMAT_XCOMPILE, + mode=_get_test_mode(correctness_check_values), + output_dir=output_dir, + platform=platform, + _quiet=quiet + ) if correctness_check_values: v.check_correctness( function.name, before=correctness_check_values["pre"], after=correctness_check_values["post"], ) + if file_check_fn: + file_check_fn(v) def test_signed_types(self) -> None: for t in [ScalarType.int16, ScalarType.int32, ScalarType.int64] + FLOAT_TYPES: @@ -1939,7 +1956,6 @@ def _(): self._build_nest(nest, [A, B], "test_round_intrinsic", correctness_check_values=correctness_check_values) - @expectedFailure(FailedReason.INVALID, "x86 round intrinsic not supported on MacOS", sys.platform == "darwin") def test_round_intrinsic_vectorized(self) -> None: from accera import round as accround @@ -1962,7 +1978,7 @@ def _(): j: 8 }) sched.reorder(i, j, ii, jj) - plan = sched.create_plan() + plan = sched.create_plan(Target("Intel 6700")) plan.vectorize(ii) A_test = np.random.uniform(low=-1000.0, high=1000.0, size=A.shape).astype(np.float32) @@ -1978,10 +1994,24 @@ def _(): correctness_check_values = { "pre": [A_test, B_test], "post": [A_test, B_ref] - } + } if avx2_cpu() else None + + package_name = "test_round_intrinsic_vectorized" + + def file_check_fn(v): + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check_label("llvm.func @{{[a-z_]*}}" + package_name) + checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.cvt.ps2dq.256"(%{{[0-9]+}}) : (vector<8xf32>) -> vector<8xi32>', 4 + ) + checker.run() self._build_nest( - plan, [A, B], "test_round_intrinsic_vectorized", correctness_check_values=correctness_check_values + plan, [A, B], + package_name, + correctness_check_values=correctness_check_values, + file_check_fn=file_check_fn, + platform=get_avx_platform() ) # TODO : fix this test - it appears to abort on just the linux buddy build machine @@ -2018,7 +2048,6 @@ def _(): # self._build_nest(nest, [A, B], "test_remainderf_intrinsic_rounding", correctness_check_values=correctness_check_values) - @expectedFailure(FailedReason.INVALID, "x86 max min intrinsics not supported on MacOS", sys.platform == "darwin") def test_vectorized_max_min(self) -> None: from accera import max, min @@ -2052,7 +2081,7 @@ def _(): j: 8 }) sched.reorder(i, j, ii, jj) - plan = sched.create_plan() + plan = sched.create_plan(Target("Intel 6700")) plan.vectorize(ii) function = package.add(plan, args=(A, B, C_max, C_min), base_name=fn_name) @@ -2064,18 +2093,20 @@ def _(): C_max_ref = np.maximum(A_test, B_test) C_min_ref = np.minimum(A_test, B_test) - correctness_check_values[fn_name] = { - "pre": [A_test, B_test, C_max_test, C_min_test], - "post": [A_test, B_test, C_max_ref, C_min_ref] - } + if avx2_cpu(): + correctness_check_values[fn_name] = { + "pre": [A_test, B_test, C_max_test, C_min_test], + "post": [A_test, B_test, C_max_ref, C_min_ref] + } # build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: package.build( package_name, - format=TEST_FORMAT | Package.Format.MLIR_VERBOSE, + format=TEST_FORMAT if avx2_cpu() else TEST_FORMAT_XCOMPILE, mode=Package.Mode.RELEASE, + platform=get_avx_platform(), output_dir=output_dir ) for fn_name in func_names: @@ -2086,7 +2117,22 @@ def _(): after=correctness_check_values[fn_name]["post"], ) - @expectedFailure(FailedReason.INVALID, "x86 max min intrinsics not supported on MacOS", sys.platform == "darwin") + max_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + max_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + max_checker.run() + + min_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + min_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.min.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + min_checker.run() + def test_vectorized_single_max_min_block(self) -> None: # In this test we're trying to find the single max and single min value of a 2-D array. # To vectorize this, we'll want to compute several maxs and mins in paralle and then reduce them @@ -2134,7 +2180,7 @@ def _(): io_A_min_cache[inner_i, inner_j] = min(io_A_min_cache[inner_i, inner_j], A[i, j]) inner_sched = inner_nest.create_schedule() - inner_plan = inner_sched.create_plan() + inner_plan = inner_sched.create_plan(Target("Intel 6700")) inner_plan.vectorize(inner_i) inner_fn = package.add( inner_plan, @@ -2158,7 +2204,7 @@ def _(): outer_j: N_tile }) outer_sched.reorder(outer_i, outer_j, outer_ii, outer_iii, outer_jj) - outer_plan = outer_sched.create_plan() + outer_plan = outer_sched.create_plan(Target("Intel 6700")) outer_plan._erase_loops([outer_iii, outer_jj]) outer_fn = package.add( outer_plan, @@ -2178,7 +2224,10 @@ def _(): arr[indices] = outer_arr[indices] return package.add( - zero_nest, args=(outer_arr, arr), base_name=base_name, function_opts=INTERNAL_FUNCTION_OPTS + zero_nest.create_plan(Target("Intel 6700")), + args=(outer_arr, arr), + base_name=base_name, + function_opts=INTERNAL_FUNCTION_OPTS ) zero_max_cache_fn = _make_init_fn(package, A, io_A_max_cache, "max_cache_zeroing") @@ -2201,7 +2250,10 @@ def _(): outer_arr[0] = min(outer_arr[0], cache[indices]) return package.add( - reduce_nest, args=(cache, outer_arr), base_name=base_name, function_opts=INTERNAL_FUNCTION_OPTS + reduce_nest.create_plan(Target("Intel 6700")), + args=(cache, outer_arr), + base_name=base_name, + function_opts=INTERNAL_FUNCTION_OPTS ) reduce_max_cache_fn = _make_cache_reduce_fn(package, io_A_max_cache, A_max, "max_cache_reduce", True) @@ -2219,7 +2271,9 @@ def _(): reduce_max_cache_fn(A_max_cache, A_max) reduce_min_cache_fn(A_min_cache, A_min) - function = package.add(top_nest, args=(A, A_max, A_min), base_name=fn_name) + function = package.add( + top_nest.create_plan(Target("Intel 6700")), args=(A, A_max, A_min), base_name=fn_name + ) A_test = np.random.random(A.shape).astype(np.float32) A_max_test = np.random.random(A_max.shape).astype(np.float32) @@ -2228,18 +2282,20 @@ def _(): A_max_ref = np.max(A_test).reshape((1, )) A_min_ref = np.min(A_test).reshape((1, )) - correctness_check_values[fn_name] = { - "pre": [A_test, A_max_test, A_min_test], - "post": [A_test, A_max_ref, A_min_ref] - } + if avx2_cpu(): + correctness_check_values[fn_name] = { + "pre": [A_test, A_max_test, A_min_test], + "post": [A_test, A_max_ref, A_min_ref] + } # build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: package.build( package_name, - format=TEST_FORMAT | Package.Format.MLIR_VERBOSE, + format=TEST_FORMAT if avx2_cpu() else TEST_FORMAT_XCOMPILE, mode=Package.Mode.RELEASE, + platform=get_avx_platform(), output_dir=output_dir ) for fn_name in func_names: @@ -2250,6 +2306,22 @@ def _(): after=correctness_check_values[fn_name]["post"], ) + max_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + max_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + max_checker.run() + + min_checker = v.file_checker(f"{package_name}_llvm.mlir") + max_checker.check_label(f'llvm.func @{function.name}') + min_checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.min.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 4 + ) + min_checker.run() + def test_intrinsics_float(self) -> None: from accera import ( abs, @@ -2361,7 +2433,9 @@ def _verify_schedule(self, schedule, args: Tuple[Array], package_name, correctne # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -4148,7 +4222,9 @@ def _verify_plan(self, plan, args: Tuple[Array], package_name, correctness_check # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -4351,13 +4427,21 @@ def _(): class DSLTest_07PlansVectorizationParallelization(unittest.TestCase): - def _verify_plan(self, plan, args: Tuple[int], package_name, correctness_check_values=None) -> None: + def _verify_plan(self, plan, args: Tuple[int], package_name, correctness_check_values=None, check_parallelization=False) -> None: package = Package() function = package.add(plan, args, base_name="vectorization_parallelization_test") output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) + + if check_parallelization: + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check_label("omp.parallel") + checker.run() + if correctness_check_values: v.check_correctness( function.name, @@ -4622,6 +4706,7 @@ def _(): [A, B, C], f"test_parallelize_i_{policy}", correctness_check_values, + check_parallelization=True ) # parallelizing middle index @@ -4632,6 +4717,7 @@ def _(): [A, B, C], f"test_parallelize_ii_{policy}", correctness_check_values, + check_parallelization=True ) try: @@ -4643,6 +4729,7 @@ def _(): [A, B, C], f"test_parallelize_i_ii_j_{policy}", correctness_check_values, + check_parallelization=True ) # partial collapsed inner indices @@ -4653,6 +4740,7 @@ def _(): [A, B, C], f"test_parallelize_ii_j_{policy}", correctness_check_values, + check_parallelization=True ) except: # BUGBUG: partial collapsed + dynamic is broken in mlir-translate since LLVM 14 @@ -4679,7 +4767,9 @@ def _verify_package(self, plan, args, package_name, correctness_check_values) -> output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -5053,7 +5143,9 @@ def _(): # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build( + package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir + ) if correctness_check_values: v.check_correctness( function.name, @@ -6464,6 +6556,7 @@ def test_autoplan(self) -> None: class DSLTest_12Profiling(unittest.TestCase): + def _verify_func( self, package, function, package_name, correctness_check_values, quiet=True, mode=TEST_MODE ) -> None: @@ -6471,7 +6564,9 @@ def _verify_func( # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=mode, output_dir=output_dir, _quiet=quiet, profile=True) + package.build( + package_name, format=TEST_FORMAT, mode=mode, output_dir=output_dir, _quiet=quiet, profile=True + ) if correctness_check_values: v.check_correctness( function.name, @@ -6540,7 +6635,7 @@ def _tile_logic(): EnterProfileRegion("pack_b_fn") pack_b_fn(B, B_temp, j, k) ExitProfileRegion("pack_b_fn") - + EnterProfileRegion("matmul_fn") matmul_fn(A, B, C, B_temp, i, j, k) ExitProfileRegion("matmul_fn") diff --git a/accera/python/accera/test/smoke_tests.py b/accera/python/accera/test/smoke_tests.py index 582ec90f..c2d5d067 100644 --- a/accera/python/accera/test/smoke_tests.py +++ b/accera/python/accera/test/smoke_tests.py @@ -8,11 +8,9 @@ import logging import os import pathlib -import platform import shutil import sys import unittest -from itertools import product from typing import Callable, List import numpy as np @@ -54,13 +52,13 @@ from accera._lang_python._lang import Dimension, _MemorySpace, _If, as_index from accera._lang_python._lang._gpu import Barrier from accera.samples import MatrixMultiplication -from accera.Targets import KNOWN_DEVICES +from accera.test.test_utils import expectedFailure, FailedReason from accera.test import verifiers -from accera.test.test_utils import FailedReason, expectedFailure +from accera.test.test_utils import FailedReason, expectedFailure, avx2_cpu, avx512_cpu, get_avx_platform from accera import ( - AUTO, AllocateFlags, Array, Constants, Nest, Package, Role, Scalar, ScalarType, Target, cast, MMAShape, create_dimensions, - create_parameters, fuse + AUTO, AllocateFlags, Array, Constants, Nest, Package, Role, Scalar, ScalarType, Target, cast, MMAShape, + create_dimensions, create_parameters, fuse ) from accera import min as accmin from accera import abs as accabs @@ -73,10 +71,84 @@ # TODO: Remove all @expectedFailure decorators as implementation converges with spec +def make_asm_sequence_filechecker(v: verifiers.VerifyPackage, file_name: str, interleave_loc_tags: bool = True): + + class FileCheckerWrapper: + + def __init__(self, file_checker): + self.file_checker = file_checker + + # Forward all other calls to the file_checker + def __getattr__(self, item): + return getattr(self.file_checker, item) + + def scan_to_next_loc(self): + self.file_checker.check(".loc") + + def check_label(self, check_str: str, check_loc: bool = interleave_loc_tags): + self.file_checker.check_label(check_str) + if check_loc: + self.file_checker.check_next(".loc") + + def check_next(self, check_str: str, check_loc: bool = interleave_loc_tags): + self.file_checker.check_next(check_str) + if check_loc: + self.file_checker.check_next(".loc") + + def check(self, check_str: str, check_loc: bool = interleave_loc_tags): + self.file_checker.check(check_str) + if check_loc: + self.file_checker.check_next(".loc") + + def get_simple_register_regex(self): + # Matches %rbp, %rsi, %r8, %r14, etc. + # Doesn't match xmm, ymm, zmm registers + return r"%([a-z]{3}|r[0-9]+)" + + def get_vector_register_regex(self, required_prefix: str = None): + # Matches xmm, ymm, zmm registers + suffix = r"mm[0-9]+" + if required_prefix is not None: + return r"%" + required_prefix + suffix + return r"%(x|y|z)" + suffix + + def get_register_regex(self): + # Matches any register name prefixed with % + simple_register_regex = self.get_simple_register_regex() + vector_register_regex = self.get_vector_register_regex() + return r"(%s|%s)" % (simple_register_regex, vector_register_regex) + + def get_memory_position_regex(self): + optional_offset = r"(-?[0-9]+)?" + # Examples: + # - any register name in isolation: %rbp + register_regex = self.get_register_regex() + + # - any register name in parentheses without an offset: (%rbp) + # - any register offset by a constant: -64(%rbp), 32(%rsi) + no_offset_register_regex = r"\(%s\)" % register_regex + possibly_offset_register = optional_offset + no_offset_register_regex + + # - or offset by a scaled value in another register: (%rbp,%rsi,8) + # - or offset by a constant and offset by a scaled value in another register: -1028(%rbp,%rsi,8) + scaled_offset_regex = r"\(%s,%s,[0-9]+\)" % (register_regex, register_regex) + possibly_constant_offset_scaled_offset_regex = optional_offset + scaled_offset_regex + + # Matches any of the above + return r"(%s|%s|%s)" % ( + register_regex, possibly_offset_register, possibly_constant_offset_scaled_offset_regex + ) + + return FileCheckerWrapper(v.file_checker(file_name)) + + class SmokeTest(unittest.TestCase): PACKAGE_FORMAT = Package.Format.MLIR_DYNAMIC if DEV_MODE else Package.Format.HAT_DYNAMIC PACKAGE_MODE = Package.Mode.RELEASE + def get_xcompile_package_format(self, matcher_fn: Callable[[], bool] = avx2_cpu) -> Package.Format: + return self.PACKAGE_FORMAT if matcher_fn() else Package.Format.MLIR_STATIC + def test_full_fusion_trivial(self) -> None: A = Array(role=Role.INPUT, shape=(16, 16)) B = Array(role=Role.INPUT, shape=(16, 16)) @@ -626,7 +698,7 @@ def _(): v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) def _test_fast_exp_mlas(self, func_level_precision: bool): - from accera import fast_exp, fast_exp_mlas + from accera import fast_exp_mlas M = 64 N = 64 @@ -665,33 +737,31 @@ def _(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: package.build( package_name, - format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + format=self.get_xcompile_package_format(), mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR, + platform=get_avx_platform(), _opts=pkg_opt, _quiet=False ) + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.check( + '%{{[0-9]+}} = "llvm.intr.fmuladd"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.run() - v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) + if avx2_cpu(): + v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_mlas_w_func_level_precision(self): self._test_fast_exp_mlas(True) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_mlas_w_pkg_level_precision(self): self._test_fast_exp_mlas(False) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_mlas_with_3_vectors(self): from accera import fast_exp_mlas @@ -731,20 +801,30 @@ def _(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: package.build( package_name, - format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + format=self.get_xcompile_package_format(), mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR, + platform=get_avx_platform(), _opts=Package._Options.HIGH_PRECISION_FLOATING_POINT_OPS, _quiet=False ) - v.check_correctness(function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref)) + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check_count( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 3 + ) + checker.check_count( + '%{{[0-9]+}} = "llvm.intr.fmuladd"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>', + 3 + ) + checker.run() + if avx2_cpu(): + v.check_correctness( + function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref) + ) - @expectedFailure( - FailedReason.INVALID, "avx2 instructions not supported on MacOS arm64", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_fast_exp_sum(self): from accera import fast_exp_mlas @@ -1051,14 +1131,27 @@ def call_fn_masked_vec_fast_exp_sum(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: package.build( package_name, - format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + format=self.get_xcompile_package_format(), mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR, + platform=get_avx_platform(), _opts=Package._Options.HIGH_PRECISION_FLOATING_POINT_OPS, _quiet=False ) - v.check_correctness(function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref)) + checker = v.file_checker(f"{package_name}_llvm.mlir") + checker.check( + '%{{[0-9]+}} = "accintr.x86.avx.max.ps.256"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.check( + '%{{[0-9]+}} = "llvm.intr.fmuladd"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>' + ) + checker.run() + + if avx2_cpu(): + v.check_correctness( + function.name, before=(In_test, Max_test, Out_test), after=(In_test, Max_test, Out_ref) + ) def test_emittime_cache_mlas_matmul(self) -> None: from accera.samples.OfflineCacheMatrixMultiplication import \ @@ -1085,7 +1178,8 @@ def test_emittime_cache_mlas_matmul(self) -> None: package.build(package_name, output_dir=output_dir, mode=self.PACKAGE_MODE, format=self.PACKAGE_FORMAT) # check build and correctness - v.check_correctness(function.name, before=(A_test, B_test, C_test), after=(A_test, B_test, C_ref)) + if avx2_cpu(): + v.check_correctness(function.name, before=(A_test, B_test, C_test), after=(A_test, B_test, C_ref)) def test_runtime_init_cache_mlas_matmul(self) -> None: from accera.samples.OfflineCacheMatrixMultiplication import \ @@ -2443,7 +2537,10 @@ def _(): # TODO : Bug - this could occur during buffer packing at an app initialization-time and/or runtime, # but the underlying issue is probably related to the next bug with static split sizes - @expectedFailure(FailedReason.BUG, "Multiple _split_dimension's of a dynamically sized dimension with a dynamic size is not working.") + @expectedFailure( + FailedReason.BUG, + "Multiple _split_dimension's of a dynamically sized dimension with a dynamic size is not working." + ) def test_multiple_dynamic_split_dim_dynamic_size(self) -> None: test_name = "test_multiple_dynamic_split_dim_dynamic_size" @@ -2490,7 +2587,10 @@ def _(): ) # TODO : Bug - this could occur during buffer packing in an app compile-time, initialization-time or runtime - @expectedFailure(FailedReason.BUG, "Multiple _split_dimension's of a dynamically sized buffer into statically sized dimensions is not working.") + @expectedFailure( + FailedReason.BUG, + "Multiple _split_dimension's of a dynamically sized buffer into statically sized dimensions is not working." + ) def test_multiple_dynamic_split_dim_static_size(self) -> None: test_name = "test_multiple_dynamic_split_dim_static_size" @@ -2508,7 +2608,7 @@ def test_multiple_dynamic_split_dim_static_size(self) -> None: Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(MN, )) Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N)) - nest = Nest(shape=(M, N_0, N_1, N_2)) # (M, 4, 5, 6) + nest = Nest(shape=(M, N_0, N_1, N_2)) # (M, 4, 5, 6) i, j_0, j_1, j_2 = nest.get_indices() @nest.iteration_logic @@ -2591,7 +2691,7 @@ def _(): after=(test_MN, test_M, test_N, test_input, test_output_ref) ) - # This test uses all static sizes to make sure the fix for dynamic size (test_dynamic_split_dim_static_size) + # This test uses all static sizes to make sure the fix for dynamic size (test_dynamic_split_dim_static_size) # won't regress the static size case. def test_dynamic_split_dim_all_static(self) -> None: test_name = "test_dynamic_split_dim_all_static" @@ -4612,10 +4712,6 @@ def _(): after=correctness_check_values["post"], ) - @expectedFailure( - FailedReason.INVALID, "generated x86_64 lib not readable by MacOS arm64 build tools", sys.platform == "darwin" - and platform.machine() == "arm64" - ) def test_int16_matmul_vpmaddwd_16_element_avx512(self): test_name = "test_int16_matmul_vpmaddwd_16_element_avx512" M = 240 @@ -4652,7 +4748,7 @@ def _(): schedule.reorder(i, j, k, ii, jj, kk, kkk, iii, jjj, jjjj, kkkk) # The Intel 8351N is a known Xeon Platinum with AVX-512 support - target = KNOWN_DEVICES[Target.Category.CPU]["Intel 8351N"] + target = Target("Intel 8351N") plan = schedule.create_plan(target) plan.cache(A, index=ii, element_type=ScalarType.int16, vectorize=False) plan.cache(B, index=jjjj, trigger_index=jj, layout=Array.Layout.LAST_MAJOR, vectorize=False) @@ -4660,14 +4756,101 @@ def _(): plan.vectorize(jjjj) package = Package() - function = package.add(plan, args=(A, B, C), base_name=test_name) + package.add(plan, args=(A, B, C), base_name=test_name) output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name # build the HAT package with verifiers.VerifyPackage(self, test_name, output_dir) as v: - package.build(test_name, format=Package.Format.DEFAULT, mode=Package.Mode.RELEASE, output_dir=output_dir) - # Don't check correctness as we've set a target that we may not be running the tests on + package.build( + test_name, + format=self.get_xcompile_package_format(avx512_cpu), + mode=Package.Mode.RELEASE, + output_dir=output_dir, + platform=get_avx_platform() + ) + checker = v.file_checker(f"{test_name}_llvm.mlir") + checker.check( + '%{{[0-9]+}} = "accintr.x86.avx512.pmaddw.d.512"(%{{[0-9]+}}, %{{[0-9]+}}) : (vector<32xi16>, vector<32xi16>) -> vector<16xi32>' + ) + checker.run() + + asm_checker = make_asm_sequence_filechecker(v, f"{test_name}.s", interleave_loc_tags=True) + + memory_pos_regex = asm_checker.get_memory_position_regex() + zmm_register_regex = asm_checker.get_vector_register_regex("z") + + # This is computing a 6x32 region using vpmaddwd and vpaddd instructions on zmm registers + # one vpmaddwd followed by vpaddd computes a single 1x16 output for 2 elements of k + store_reg = [[f"store_reg_{row_offset}_{col_offset}" for col_offset in range(0, 32, 16)] + for row_offset in range(0, 6, 1)] + + # 6 different broadcast values from A + broadcast_reg = [f"broadcast_reg_{row_offset}" for row_offset in range(0, 6, 1)] + + # Load 4 different vectors of B data, one for each of (k,j) = (0,0), (0,16), (2,0), (2,16) + load_reg = [[f"load_reg_{k_offset}_{col_offset}" for col_offset in range(0, 32, 16)] + for k_offset in range(0, 4, 2)] + + asm_checker.check_label( + "{{.*}}LBB0_14:", check_loc=False + ) # check_loc=False because there may be comments on the following lines before the next instruction + # Don't use python f-strings because FileCheck uses {{}} as delimiters that would collide with {} delimiters in the f-str. This could be worked around, but %-strs make it simpler + + # Broadcast A[0,0:2] + asm_checker.check( + "vpbroadcastd {{%s}}, [[%s:%s]]" % (memory_pos_regex, broadcast_reg[0], zmm_register_regex) + ) # vpbroadcastd -1028(%rbp,%rsi,8), %zmm5 + + # Load 4 different vectors of B data, one for each of (k,j) = (0,0), (0,16), (2,0), (2,16) + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[0][0], zmm_register_regex), check_loc=False + ) # vmovdqu64 -192(%rbx), %zmm13 + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[0][1], zmm_register_regex), check_loc=False + ) # vmovdqu64 -128(%rbx), %zmm14 + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[1][0], zmm_register_regex), check_loc=False + ) # vmovdqu64 -64(%rbx), %zmm15 + asm_checker.check_next( + "vmovdqu64 {{%s}}, [[%s:%s]]" % (memory_pos_regex, load_reg[1][1], zmm_register_regex), check_loc=False + ) # vmovdqu64 (%rbx), %zmm8 + asm_checker.check_next( + ".loc", check_loc=False + ) # We could have set check_loc=True on the previous line, but we write it this way for better clarity + + def match_vpmaddwd_vpaddd_block(broadcast_reg, load0, load1, store0, store1): + asm_checker.check_next( + "vpmaddwd [[%s]], [[%s]], [[temp_reg0:%s]]" % (load0, broadcast_reg, zmm_register_regex) + ) + asm_checker.check_next( + "vpaddd [[temp_reg0]], {{%s}}, [[%s:%s]]" % (zmm_register_regex, store0, zmm_register_regex) + ) # We don't really care what the original source register was, so don't capture it + asm_checker.check_next( + "vpmaddwd [[%s]], [[%s]], [[temp_reg1:%s]]" % (load1, broadcast_reg, zmm_register_regex) + ) + asm_checker.check_next( + "vpaddd [[temp_reg1]], {{%s}}, [[%s:%s]]" % (zmm_register_regex, store1, zmm_register_regex), + check_loc=False + ) + + for k_idx in range(2): + for row_idx in range(6): + if not (row_idx == 0 and k_idx == 0): + # First broadcast occurs before the loads, so we need to check it separately + asm_checker.check( + "vpbroadcastd {{%s}}, [[%s:%s]]" % + (memory_pos_regex, broadcast_reg[row_idx], zmm_register_regex) + ) # vpbroadcastd -1028(%rbp,%rsi,8), %zmm5 + match_vpmaddwd_vpaddd_block( + broadcast_reg[row_idx], load_reg[k_idx][0], load_reg[k_idx][1], store_reg[row_idx][0], + store_reg[row_idx][1] + ) + + # Jump back to the top of the compute loop + asm_checker.check_label("jne {{.*}}LBB0_14", check_loc=False) + + asm_checker.run() def test_int16_matmul_vpmaddwd_16_element_host(self): test_name = "test_int16_matmul_vpmaddwd_16_element_host" diff --git a/accera/python/accera/test/test_utils.py b/accera/python/accera/test/test_utils.py index 8a3b274b..4e22bea8 100644 --- a/accera/python/accera/test/test_utils.py +++ b/accera/python/accera/test/test_utils.py @@ -5,11 +5,13 @@ from enum import Enum from typing import Callable +import cpuinfo import unittest +import sys import numpy as np -from accera import ScalarType +from accera import Package, ScalarType class FailedReason(Enum): @@ -24,6 +26,7 @@ def expectedFailure(reason: FailedReason, msg: str, condition: bool = True) -> C "Extends the unittest.expectedFailure decorator to print failure details and takes an optional condition" def _decorator(func): + @unittest.expectedFailure def _wrapper(x): print(f"\n{reason.value}: {msg}") @@ -38,8 +41,23 @@ def _wrapper(x): return _decorator +class SkipReason(Enum): + BUG = "Bug" + TARGET_MISMATCH = "Test doesn't execute for target" + + +def skipTest(reason: SkipReason, msg: str, condition: bool = True) -> Callable: + "Extends the unittest.expectedFailure decorator to print failure details and takes an optional condition" + + def _decorator(func): + + return unittest.skipIf(condition, f"\n{reason.value}: {msg}")(func) + + return _decorator + + def get_type_str(datatype: ScalarType): - return datatype.name + return datatype.name def accera_to_np_type(datatype: ScalarType): @@ -62,3 +80,17 @@ def accera_to_np_type(datatype: ScalarType): return np.int8 return None + + +def avx2_cpu(): + cpu_info = cpuinfo.get_cpu_info() + return "flags" in cpu_info and "avx2" in cpu_info["flags"] + + +def avx512_cpu(): + cpu_info = cpuinfo.get_cpu_info() + return "flags" in cpu_info and "avx512" in cpu_info["flags"] + + +def get_avx_platform(): + return Package.Platform.HOST if sys.platform != "darwin" else Package.Platform.WINDOWS \ No newline at end of file diff --git a/accera/python/accera/test/unit_tests.py b/accera/python/accera/test/unit_tests.py index 14f3dc7b..122b32dc 100644 --- a/accera/python/accera/test/unit_tests.py +++ b/accera/python/accera/test/unit_tests.py @@ -172,7 +172,7 @@ def test_module(self) -> None: self.assertIsNotNone(module) with ModuleScope(module): module.Print() - module.Verify() + # module.Verify() # BUGBUG: ModuleOp::verify is private as of LLVM 15 header_filename = f"test_module_{time.time()}.hat" module.WriteHeader(header_filename) diff --git a/accera/python/lib/src/PackagingTypes.cpp b/accera/python/lib/src/PackagingTypes.cpp index 9f98cee8..29ae4b1d 100644 --- a/accera/python/lib/src/PackagingTypes.cpp +++ b/accera/python/lib/src/PackagingTypes.cpp @@ -116,7 +116,6 @@ ARM: fp16, neon, vfp3, d16, vfp4, hwdiv-arm, hwdiv "_flags"_a = value::AllocateFlags::None) .def("Print", &value::MLIRContext::print, "Prints the module") .def("Save", &value::MLIRContext::save, "filename"_a) - .def("Verify", &value::MLIRContext::verify) .def("WriteHeader", &value::MLIRContext::writeHeader, "filename"_a = std::nullopt) .def("SetMetadata", &value::MLIRContext::setMetadata) .def("GetFullMetadata", &value::MLIRContext::getFullMetadata) diff --git a/accera/python/llvm/setup.cfg b/accera/python/llvm/setup.cfg index af4bd8bf..6ec56db4 100644 --- a/accera/python/llvm/setup.cfg +++ b/accera/python/llvm/setup.cfg @@ -6,7 +6,10 @@ name = accera-llvm # Accera micro versions are at least 2 digits: # LLVM 15.0.7 -> 15.0.700 # LLVM 15.0.7-1 -> be 15.0.701 -version = 14.0.602 +# Note that the micro versions must start with a non-zero digit, else setup tools will +# normalize the version by removing the leading zeros +# Note: keep version in sync with Accera/setup.cfg +version = 15.0.101 author = Microsoft Research AI Compilers Team author_email = mlodev@microsoft.com summary = Accera LLVM Binaries diff --git a/accera/transforms/CMakeLists.txt b/accera/transforms/CMakeLists.txt index eca1aab3..e9036114 100644 --- a/accera/transforms/CMakeLists.txt +++ b/accera/transforms/CMakeLists.txt @@ -157,9 +157,11 @@ target_link_libraries( MLIRROCDLIR MLIRROCDLToLLVMIRTranslation MLIRStandardToLLVM - MLIRSCFToStandard + MLIRSCFToControlFlow + MLIRControlFlowToLLVM MLIRAffineToStandard MLIRAffineTransforms + MLIRAffineUtils MLIRLinalgToLLVM MLIRLinalgTransforms MLIRTargetLLVMIRExport diff --git a/accera/transforms/include/util/RangeValueUtilities.h b/accera/transforms/include/util/RangeValueUtilities.h index cc06fd1a..1c5fd8d7 100644 --- a/accera/transforms/include/util/RangeValueUtilities.h +++ b/accera/transforms/include/util/RangeValueUtilities.h @@ -79,6 +79,9 @@ class RangeValueAnalysis RangeValue resolveRangeValue(mlir::AffineApplyOp op); RangeValue resolveRangeValue(mlir::scf::ForOp op); RangeValue resolveRangeValue(mlir::Operation* op); + + void addSCFParallelOp(mlir::scf::ParallelOp op); + void addAffineParallelOp(mlir::AffineParallelOp op); }; } // namespace accera::ir::util diff --git a/accera/transforms/include/value/ValueToLLVMLoweringPass.h b/accera/transforms/include/value/ValueToLLVMLoweringPass.h index 62a73dc4..f6f8417f 100644 --- a/accera/transforms/include/value/ValueToLLVMLoweringPass.h +++ b/accera/transforms/include/value/ValueToLLVMLoweringPass.h @@ -34,9 +34,9 @@ class RewritePatternSet; namespace accera::transforms::value { -void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); +void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo); void populateGlobalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); -void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); +void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo); void populateValueToLLVMMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); void populateReshapeOpToLLVMMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns); @@ -49,5 +49,6 @@ std::unique_ptr> createValueToLLVMPass(bool unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, + accera::value::TargetDevice deviceInfo = {}, const IntraPassSnapshotOptions& options = {}); } // namespace accera::transforms::value diff --git a/accera/transforms/src/AcceraPasses.cpp b/accera/transforms/src/AcceraPasses.cpp index e075383d..55785f5c 100644 --- a/accera/transforms/src/AcceraPasses.cpp +++ b/accera/transforms/src/AcceraPasses.cpp @@ -133,6 +133,20 @@ void simplifyAndLowerAffine(PassManagerAdaptor& pmAdaptor) pmAdaptor.addPass(createLowerAffinePass()); } +bool addGPUPasses(PassManagerAdaptor& pmAdaptor, const accera::value::ExecutionRuntime execRuntime, const AcceraPassPipelineOptions& options) +{ + auto gpuPass = createAcceraToGPUPass(execRuntime); + if (gpuPass) + { + pmAdaptor.addPass(createGPUSimplificationPass()); + pmAdaptor.addPass(value::createBarrierOptPass(options.writeBarrierGraph.getValue(), options.barrierGraphFilename.getValue())); + pmAdaptor.addPass(std::move(gpuPass)); + return true; + } + + return false; +} + void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOptions& options) { ir::InitializeAccera(); @@ -173,7 +187,6 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti funcOpPM.addPass(createLoopInvariantCodeMotionPass()); funcOpPM.addPass(createCSEPass()); - pmAdaptor.addPass(createConvertSCFToOpenMPPass()); pmAdaptor.addPass(value::createValueToStdPass(options.enableProfile)); pmAdaptor.addPass(value::createRangeValueOptimizePass()); pmAdaptor.addPass(createCanonicalizerPass()); @@ -183,24 +196,28 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti { // The spirv lowering doesn't generate affine dialect ops, and the SPIRV dialect doesn't play nicely with them, so lower the affine ops before running the GPU lowering simplifyAndLowerAffine(pmAdaptor); + pmAdaptor.addPass(createGpuKernelOutliningPass()); + addGPUPasses(pmAdaptor, execRuntime, options); } - - pmAdaptor.addPass(createGpuKernelOutliningPass()); - auto gpuPass = createAcceraToGPUPass(execRuntime); - if (gpuPass) + else { - pmAdaptor.addPass(createGPUSimplificationPass()); - pmAdaptor.addPass(value::createBarrierOptPass(options.writeBarrierGraph.getValue(), options.barrierGraphFilename.getValue())); - pmAdaptor.addPass(std::move(gpuPass)); - } + pmAdaptor.addPass(createGpuKernelOutliningPass()); + const auto isGPU = addGPUPasses(pmAdaptor, execRuntime, options); - if (execRuntime != accera::value::ExecutionRuntime::VULKAN) - { // lowering to runtimes other than SPIRV generates affine dialect ops so optimize and lower those now simplifyAndLowerAffine(pmAdaptor); - if (execRuntime == accera::value::ExecutionRuntime::ROCM) + + if (isGPU) + { + if (execRuntime == accera::value::ExecutionRuntime::ROCM) + { + pmAdaptor.addPass(createGPUToROCDLPass()); + } + } + else { - pmAdaptor.addPass(createGPUToROCDLPass()); + // Convert to OMP when in non-GPU scenarios + pmAdaptor.addPass(createConvertSCFToOpenMPPass()); } } @@ -233,7 +250,7 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti funcOpPM.addPass(createConvertVectorToSCFPass( VectorTransferToSCFOptions{} /*.setLowerPermutationMaps(true) .setLowerTensors(true).setUnroll(true) */)); - pmAdaptor.addPass(createLowerToCFGPass()); + pmAdaptor.addPass(createConvertSCFToCFPass()); if (execRuntime != accera::value::ExecutionRuntime::VULKAN) { @@ -255,6 +272,7 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti /* indexBitwidth = */ kDeriveIndexBitwidthFromDataLayout, /* useAlignedAlloc = */ true, /* dataLayout = */ llvm::DataLayout(accera::value::GetTargetDevice(options.target).dataLayout), + /* deviceInfo = */ accera::value::GetTargetDevice(options.target), { options.dumpIntraPassIR.getValue(), options.basename + "ValueToLLVM_Subpasses" })); pmAdaptor.addPass(createCanonicalizerPass()); pmAdaptor.addPass(LLVM::createLegalizeForExportPass()); diff --git a/accera/transforms/src/affine/AffineLoopNormalize.cpp b/accera/transforms/src/affine/AffineLoopNormalize.cpp index d29a9539..2b905d82 100644 --- a/accera/transforms/src/affine/AffineLoopNormalize.cpp +++ b/accera/transforms/src/affine/AffineLoopNormalize.cpp @@ -81,7 +81,7 @@ struct AcceraAffineLoopNormalizePass : public accera::transforms::AcceraAffineLo // See \mlir\lib\Dialect\Affine\Transforms\AffineLoopNormalize.cpp op->walk([](mlir::AffineForOp affineForOp) { workaroundModifyAffineForOp(affineForOp); - mlir::normalizeAffineFor(affineForOp); + (void)mlir::normalizeAffineFor(affineForOp); }); } }; diff --git a/accera/transforms/src/affine/AffineSimplifications.cpp b/accera/transforms/src/affine/AffineSimplifications.cpp index 64d249bd..9da3dac4 100644 --- a/accera/transforms/src/affine/AffineSimplifications.cpp +++ b/accera/transforms/src/affine/AffineSimplifications.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include diff --git a/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp b/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp index dd125b54..5cf3ac80 100644 --- a/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp +++ b/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp @@ -705,14 +705,12 @@ std::tuple CreateCacheLoopnestHelper( { // Determine how much of the nest can be vectorized and set the vectorization info on those loops budget /= innermostLoopRange.NumIterations(); - int numVectorizedLoops = 1; for (size_t loopCounter = 1; loopCounter < loopnestInfo.fullySplitRanges.size(); ++loopCounter) { size_t loopIdx = loopnestInfo.fullySplitRanges.size() - loopCounter - 1; // Vectorize loops from the innermost to the outermost as long as we still have vector registers to work with auto loopRange = loopnestInfo.fullySplitRanges[loopIdx]; auto loopUnrollFactor = std::min(budget, loopRange.NumIterations()); InPlaceUnrollInfo inPlaceUnrollInfo{ loopUnrollFactor }; - numVectorizedLoops++; SetInPlaceUnrollInfo(cacheNestSchedule, cacheNestScheduleOrder[loopIdx], inPlaceUnrollInfo); budget /= loopUnrollFactor; if (budget <= 1) // if there is only 1 in-place op unroll left in the budget then we're done vectorizing diff --git a/accera/transforms/src/gpu/AcceraToGPUPass.cpp b/accera/transforms/src/gpu/AcceraToGPUPass.cpp index 4a9d36f5..3f91bb09 100644 --- a/accera/transforms/src/gpu/AcceraToGPUPass.cpp +++ b/accera/transforms/src/gpu/AcceraToGPUPass.cpp @@ -457,7 +457,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern return failure(); } - newOp = rewriter.create(loc, newOp, rewriter.getIndexType()); + newOp = rewriter.create(loc, rewriter.getIndexType(), newOp); rewriter.replaceOp(op, { newOp }); return success(); diff --git a/accera/transforms/src/util/MathUtilities.cpp b/accera/transforms/src/util/MathUtilities.cpp index 56f4f898..e4fb2942 100644 --- a/accera/transforms/src/util/MathUtilities.cpp +++ b/accera/transforms/src/util/MathUtilities.cpp @@ -6,7 +6,7 @@ #include "util/MathUtilities.h" #include -#include +#include namespace accera::transforms { @@ -24,13 +24,13 @@ mlir::Value SaturateValue(mlir::PatternRewriter& rewriter, mlir::Value value, in if (auto vectorType = resultType.dyn_cast()) { - minConst = rewriter.create(loc, minConst, vectorType); - maxConst = rewriter.create(loc, maxConst, vectorType); + minConst = rewriter.create(loc, minConst, vectorType); + maxConst = rewriter.create(loc, maxConst, vectorType); } auto maxCmp = rewriter.create(loc, mlir::arith::CmpIPredicate::sgt, value, minConst); - auto temp = rewriter.create(loc, maxCmp, value, minConst); + auto temp = rewriter.create(loc, maxCmp, value, minConst); auto minCmp = rewriter.create(loc, mlir::arith::CmpIPredicate::slt, temp, maxConst); - auto result = rewriter.create(loc, minCmp, temp, maxConst); + auto result = rewriter.create(loc, minCmp, temp, maxConst); return result; } diff --git a/accera/transforms/src/util/RangeValueUtilities.cpp b/accera/transforms/src/util/RangeValueUtilities.cpp index d8b25b89..de565cff 100644 --- a/accera/transforms/src/util/RangeValueUtilities.cpp +++ b/accera/transforms/src/util/RangeValueUtilities.cpp @@ -68,6 +68,25 @@ RangeValue resolveGridDimRange(Operation* op, gpu::Dimension dimId) return RangeValue(); } +mlir::APInt toAPInt(int64_t val) +{ + return mlir::APInt(RangeValue::maxBitWidth, val, true); +} + +RangeValue resolveConstantForLoopRange(mlir::APInt lb, mlir::APInt ub, mlir::APInt step) +{ + auto range = ub - lb; + auto remainder = range.srem(step); + auto largestInductionVarValue = (remainder.sgt(0)) ? (ub - remainder) : (ub - step); + + return RangeValue(lb, largestInductionVarValue); +} + +RangeValue resolveConstantForLoopRange(int64_t lb, int64_t ub, int64_t step) +{ + return resolveConstantForLoopRange(toAPInt(lb), toAPInt(ub), toAPInt(step)); +} + } // namespace namespace accera::ir::util @@ -172,8 +191,86 @@ RangeValue RangeValueAnalysis::getRange(Value value) const return it->second; } +void RangeValueAnalysis::addSCFParallelOp(mlir::scf::ParallelOp op) +{ + mlir::scf::ParallelOpAdaptor adaptor{ op }; + auto ivs = op.getInductionVars(); + auto lbs = adaptor.getLowerBound(); + auto ubs = adaptor.getUpperBound(); + auto steps = adaptor.getStep(); + assert(ivs.size() == lbs.size()); + assert(ivs.size() == ubs.size()); + assert(ivs.size() == steps.size()); + unsigned numDims = ivs.size(); + + for (unsigned dimIdx = 0; dimIdx < numDims; ++dimIdx) + { + auto lb = lbs[dimIdx]; + auto ub = ubs[dimIdx]; + auto step = steps[dimIdx]; + auto lbConst = lb.getDefiningOp(); + auto ubConst = ub.getDefiningOp(); + auto stepConst = step.getDefiningOp(); + if (lbConst && ubConst && stepConst) + { + auto range = resolveConstantForLoopRange(lbConst.value(), ubConst.value(), stepConst.value()); + _rangeMap.insert({ ivs[dimIdx], range }); + } + else + { + _rangeMap.insert({ ivs[dimIdx], RangeValue() }); + } + } +} + +void RangeValueAnalysis::addAffineParallelOp(mlir::AffineParallelOp op) +{ + auto constantRangesOpt = op.getConstantRanges(); + if (!constantRangesOpt.hasValue()) + { + return; + } + + unsigned numDims = op.getNumDims(); + auto ivs = op.getIVs(); + auto constantRanges = constantRangesOpt.getValue(); + auto lbValueMap = op.getLowerBoundsValueMap(); + auto steps = op.getSteps(); + + assert(numDims == ivs.size()); + assert(numDims == constantRanges.size()); + assert(numDims == lbValueMap.getNumResults()); + assert(numDims == steps.size()); + + for (unsigned dimIdx = 0; dimIdx < numDims; ++dimIdx) + { + auto expr = lbValueMap.getResult(dimIdx); + if (expr.isa()) + { + auto lb = expr.cast().getValue(); + auto range = resolveConstantForLoopRange(lb, lb + constantRanges[dimIdx], steps[dimIdx]); + _rangeMap.insert({ ivs[dimIdx], range }); + } + else + { + _rangeMap.insert({ ivs[dimIdx], RangeValue() }); + } + } +} + RangeValue RangeValueAnalysis::addOperation(mlir::Operation* op) { + // Special case AffineParallelOp and scf::ParallelOp as they can have multiple results, + // however we care more about their possibly-multiple IVs which may have static ranges + if (isa(op) || isa(op)) + { + mlir::TypeSwitch(op) + .Case([&](scf::ParallelOp op) { addSCFParallelOp(op); }) + .Case([&](AffineParallelOp op) { addAffineParallelOp(op); }); + // Possibly no single range to return for the op itself + return RangeValue(); + } + if (op->getNumResults() > 1) { // Only operations with 0 or 1 results can have their ranges tracked successfully currently @@ -415,14 +512,11 @@ RangeValue RangeValueAnalysis::resolveRangeValue(AffineForOp op) auto ub = op.getConstantUpperBound(); auto step = op.getStep(); - auto range = ub - lb; - auto remainder = range % step; - auto largestInductionVarValue = (remainder > 0) ? (ub - remainder) : (ub - step); - - return RangeValue(lb, largestInductionVarValue); + return resolveConstantForLoopRange(lb, ub, step); } return RangeValue(); } + RangeValue RangeValueAnalysis::resolveRangeValue(scf::ForOp op) { assert(op.getNumInductionVars() == 1); @@ -442,14 +536,11 @@ RangeValue RangeValueAnalysis::resolveRangeValue(scf::ForOp op) auto ub = upperBound.range.getUpper(); auto step = stepSize.range.getLower(); - auto range = ub - lb; - auto remainder = range.srem(step); - auto largestInductionVarValue = (remainder.sgt(0)) ? (ub - remainder) : (ub - step); - - return RangeValue(lb, largestInductionVarValue); + return resolveConstantForLoopRange(lb, ub, step); } return RangeValue(); } + RangeValue RangeValueAnalysis::resolveRangeValue(mlir::Operation* op) { return mlir::TypeSwitch(op) diff --git a/accera/transforms/src/value/ValueSimplifyPass.cpp b/accera/transforms/src/value/ValueSimplifyPass.cpp index b67ea6a3..b7722df7 100644 --- a/accera/transforms/src/value/ValueSimplifyPass.cpp +++ b/accera/transforms/src/value/ValueSimplifyPass.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include @@ -125,7 +124,7 @@ struct ValueSliceSimplifyPattern : public OpRewritePattern auto indexShape = index.getType().cast().getShape(); if (indexShape.size() == 0 || indexShape.size() == 1) { - resolvedOffsets[dim] = rewriter.create(loc, rewriter.create(loc, index), indexType); + resolvedOffsets[dim] = rewriter.create(loc, indexType, rewriter.create(loc, index)); } else { @@ -344,7 +343,7 @@ LogicalResult CopyOpLowering::matchAndRewrite( if (outputMemRef.getElementType().isInteger(64)) // this should really be target dependent... { (void)rewriter.create(loc, - rewriter.create(loc, input, rewriter.getIntegerType(64)), + rewriter.create(loc, rewriter.getIntegerType(64), input), output, std::vector(outputMemRef.getRank(), zero)); } diff --git a/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp b/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp index b7db858c..70dff286 100644 --- a/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp +++ b/accera/transforms/src/value/ValueToLLVMLoweringPass.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -28,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -66,6 +66,15 @@ using namespace accera::transforms::value; namespace { +// This is ported from Linux code time.h +// #define CLOCK_REALTIME 0 // Identifier for system-wide realtime clock. +// #define CLOCK_MONOTONIC 1 // Monotonic system-wide clock. +enum class ClockID +{ + ACCERA_CLOCK_REALTIME = 0, + ACCERA_CLOCK_MONOTONIC = 1, +}; + // TODO: Refactor this class and find a better place for this helper class class LLVMTypeConverterDynMem : public mlir::LLVMTypeConverter { @@ -344,6 +353,13 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern { using ValueLLVMOpConversionPattern::ValueLLVMOpConversionPattern; + accera::value::TargetDevice deviceInfo; + + GetTimeOpLowering(LLVMTypeConverter& converter, mlir::MLIRContext* context, accera::value::TargetDevice deviceInfo) : + ValueLLVMOpConversionPattern(converter, context), + deviceInfo(deviceInfo) + {} + LogicalResult matchAndRewrite( GetTimeOp op, OpAdaptor adaptor, @@ -371,10 +387,11 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern static FlatSymbolRefAttr getOrInsertClockGetTime(PatternRewriter& rewriter, ModuleOp module, - LLVM::LLVMDialect* llvmDialect) + LLVM::LLVMDialect* llvmDialect, + size_t numBits) { auto* context = module.getContext(); - auto llvmFnType = getGetTimeFunctionType(context); + auto llvmFnType = getGetTimeFunctionType(context, numBits); return getOrInsertLibraryFunction(rewriter, "clock_gettime", llvmFnType, module, llvmDialect); } @@ -394,13 +411,13 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern return LLVM::LLVMFunctionType::get(boolTy, { argTy }, /*isVarArg=*/false); } - static Type getGetTimeFunctionType(mlir::MLIRContext* context) + static Type getGetTimeFunctionType(mlir::MLIRContext* context, size_t numBits) { // Create a function type for clock_gettime, the signature is: // int clock_gettime(clockid_t clockid, struct timespec *tp); - auto returnTy = getIntType(context); - auto llvmClockIdTy = getClockIdType(context); - auto llvmTimespecTy = getTimeSpecType(context); + auto returnTy = getIntType(context, numBits); + auto llvmClockIdTy = getClockIdType(context, numBits); + auto llvmTimespecTy = getTimeSpecType(context, numBits); auto llvmTimespecPtrTy = LLVM::LLVMPointerType::get(llvmTimespecTy); return LLVM::LLVMFunctionType::get(returnTy, { llvmClockIdTy, llvmTimespecPtrTy }, /*isVarArg=*/false); } @@ -410,29 +427,27 @@ struct GetTimeOpLowering : public ValueLLVMOpConversionPattern return IntegerType::get(context, 64); } - static Type getClockIdType(mlir::MLIRContext* context) + static Type getClockIdType(mlir::MLIRContext* context, size_t numBits) { - return getIntType(context); + return getIntType(context, numBits); } - static Type getTimeSpecType(mlir::MLIRContext* context) + static Type getTimeSpecType(mlir::MLIRContext* context, size_t numBits) { // struct timespec { // time_t tv_sec; /* seconds */ // long tv_nsec; /* nanoseconds */ // }; - auto llvmIntTy = getIntType(context); + auto llvmIntTy = getIntType(context, numBits); auto llvmTimespecTy = LLVM::LLVMStructType::getLiteral(context, { llvmIntTy, llvmIntTy }, /* isPacked */ true); return llvmTimespecTy; } - static Type getIntType(mlir::MLIRContext* context) + static Type getIntType(mlir::MLIRContext* context, size_t numBits) { auto llvmI32Ty = IntegerType::get(context, 32); auto llvmI64Ty = IntegerType::get(context, 64); - const int hostBitSize = 64; // TODO:: FIXME :: This assumes that the host is always 64bit - // Should query the target hardware - auto llvmIntTy = hostBitSize == 32 ? llvmI32Ty : llvmI64Ty; + auto llvmIntTy = numBits == 32 ? llvmI32Ty : llvmI64Ty; return llvmIntTy; } }; @@ -757,7 +772,7 @@ struct RoundOpLowering : public ValueLLVMOpConversionPattern // Create arithmetic dialect cast ops with the expectation that other arithmetic dialect ops are getting lowered as part of this pass auto signlessOutputType = util::ToSignlessMLIRType(rewriter, op.getType()); - mlir::Value roundedSIVal = rewriter.create(op.getLoc(), roundedFPVal, signlessOutputType); + mlir::Value roundedSIVal = rewriter.create(op.getLoc(), signlessOutputType, roundedFPVal); rewriter.replaceOpWithNewOp(op, op.getType(), roundedSIVal); } return success(); @@ -766,8 +781,9 @@ struct RoundOpLowering : public ValueLLVMOpConversionPattern struct ValueToLLVMLoweringPass : public ConvertValueToLLVMBase { - ValueToLLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, const IntraPassSnapshotOptions& snapshotteroptions = {}) : - _intrapassSnapshotter(snapshotteroptions) + ValueToLLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, accera::value::TargetDevice deviceInfo = {}, const IntraPassSnapshotOptions& snapshotteroptions = {}) : + _intrapassSnapshotter(snapshotteroptions), + deviceInfo(deviceInfo) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; @@ -781,6 +797,7 @@ struct ValueToLLVMLoweringPass : public ConvertValueToLLVMBase(timerRegionTypeAttr.getInt()) == TimerRegionType::enterRegion; } - // TODO: get check `TargetDeviceInfo` for the OS instead -#ifdef WIN32 - if (isEnterRegionTimer) { - auto boolTy = IntegerType::get(context, 8); - auto argTy = getPerformanceCounterType(context); - LLVMTypeConverter llvmTypeConverter(context); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - - auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); - Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); - - Value perfFreq = rewriter.create(loc, perfFreqPtr); - Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); - - auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); - Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); - - [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); - [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); - - Value perfCount = rewriter.create(loc, perfCountPtr); - Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); - - Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); - return result; - } - else + if (this->deviceInfo.IsWindows()) { - auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); - auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); + if (isEnterRegionTimer) { + auto boolTy = IntegerType::get(context, 8); + auto argTy = getPerformanceCounterType(context); + LLVMTypeConverter llvmTypeConverter(context); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + + auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); + Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); - auto boolTy = IntegerType::get(context, 8); - auto argTy = getPerformanceCounterType(context); - LLVMTypeConverter llvmTypeConverter(context); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - - Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); + Value perfFreq = rewriter.create(loc, perfFreqPtr); + Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); - Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); - auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); + auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); + Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); - [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); - [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); + [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); + [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); - Value perfCount = rewriter.create(loc, perfCountPtr); - Value perfFreq = rewriter.create(loc, perfFreqPtr); + Value perfCount = rewriter.create(loc, perfCountPtr); + Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); - Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); - Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); - Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); - return result; - } -#else - if (isEnterRegionTimer) - { - auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect); - - auto llvmTimespecTy = getTimeSpecType(context); - auto clockIdTy = getClockIdType(context); - auto intTy = getIntType(context); - - // Get a symbol reference to the gettime function, inserting it if necessary. - LLVMTypeConverter llvmTypeConverter(context); - Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(CLOCK_REALTIME)); - - Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); - Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); - Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); - - std::vector args{ clockId, timespecPtr }; - auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context) }, clockGetTimeFn, args); - [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); - - Value secondsIntVal = rewriter.create(loc, secondsPtr); - Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); - Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); - Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); - Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); - Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); - Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); - return totalSecondsDoubleVal; + Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); + return result; + } + else + { + auto queryPerfCounterFn = getOrInsertQueryPerfCounter(rewriter, parentModule, llvmDialect); + auto queryPerfFrequencyFn = getOrInsertQueryPerfFrequency(rewriter, parentModule, llvmDialect); + + auto boolTy = IntegerType::get(context, 8); + auto argTy = getPerformanceCounterType(context); + LLVMTypeConverter llvmTypeConverter(context); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value perfCountPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getCounterCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfCounterFn, ValueRange{ perfCountPtr }); + + Value perfFreqPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(argTy), one); + auto getFreqCall = rewriter.create(loc, std::vector{ boolTy }, queryPerfFrequencyFn, ValueRange{ perfFreqPtr }); + [[maybe_unused]] auto getCountResult = getCounterCall.getResult(0); + [[maybe_unused]] auto getFreqResult = getFreqCall.getResult(0); + + Value perfCount = rewriter.create(loc, perfCountPtr); + Value perfFreq = rewriter.create(loc, perfFreqPtr); + + Value ticksDoubleVal = rewriter.create(loc, doubleTy, perfCount); + Value freqDoubleVal = rewriter.create(loc, doubleTy, perfFreq); + Value result = rewriter.create(loc, doubleTy, ticksDoubleVal, freqDoubleVal); + return result; + } } else { - auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect); - - auto llvmTimespecTy = getTimeSpecType(context); - auto clockIdTy = getClockIdType(context); - auto intTy = getIntType(context); - - // Get a symbol reference to the gettime function, inserting it if necessary. - LLVMTypeConverter llvmTypeConverter(context); - Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); - Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(CLOCK_REALTIME)); - - Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); - - std::vector args{ clockId, timespecPtr }; - auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context) }, clockGetTimeFn, args); - [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); - - Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); - Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); - - Value secondsIntVal = rewriter.create(loc, secondsPtr); - Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); - Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); - Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); - Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); - Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); - Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); - return totalSecondsDoubleVal; + if (isEnterRegionTimer) + { + auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect, this->deviceInfo.numBits); + + auto llvmTimespecTy = getTimeSpecType(context, this->deviceInfo.numBits); + auto clockIdTy = getClockIdType(context, this->deviceInfo.numBits); + auto intTy = getIntType(context, this->deviceInfo.numBits); + + // Get a symbol reference to the gettime function, inserting it if necessary. + LLVMTypeConverter llvmTypeConverter(context); + Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); + + Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(int64_t(ClockID::ACCERA_CLOCK_REALTIME))); + + Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); + Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); + Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); + + std::vector args{ clockId, timespecPtr }; + auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context, this->deviceInfo.numBits) }, clockGetTimeFn, args); + [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); + + Value secondsIntVal = rewriter.create(loc, secondsPtr); + Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); + Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); + Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); + Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); + Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); + Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); + return totalSecondsDoubleVal; + } + else + { + auto clockGetTimeFn = getOrInsertClockGetTime(rewriter, parentModule, llvmDialect, this->deviceInfo.numBits); + + auto llvmTimespecTy = getTimeSpecType(context, this->deviceInfo.numBits); + auto clockIdTy = getClockIdType(context, this->deviceInfo.numBits); + auto intTy = getIntType(context, this->deviceInfo.numBits); + + // Get a symbol reference to the gettime function, inserting it if necessary. + LLVMTypeConverter llvmTypeConverter(context); + Value zero = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + Value zero32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 0)); + Value one = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getIndexType()), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value one32 = rewriter.create(loc, llvmTypeConverter.convertType(rewriter.getI32Type()), rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); + Value clockId = rewriter.create(loc, clockIdTy, rewriter.getI64IntegerAttr(int64_t(ClockID::ACCERA_CLOCK_REALTIME))); + + Value timespecPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(llvmTimespecTy), one); + + std::vector args{ clockId, timespecPtr }; + auto getTimeCall = rewriter.create(loc, std::vector{ getIntType(context, this->deviceInfo.numBits) }, clockGetTimeFn, args); + [[maybe_unused]] auto getTimeResult = getTimeCall.getResult(0); + + Value secondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, zero32 }); + Value nanosecondsPtr = rewriter.create(loc, LLVM::LLVMPointerType::get(intTy), timespecPtr, ValueRange{ zero, one32 }); + + Value secondsIntVal = rewriter.create(loc, secondsPtr); + Value nanosecondsIntVal = rewriter.create(loc, nanosecondsPtr); + Value secondsDoubleVal = rewriter.create(loc, doubleTy, secondsIntVal); + Value nanosecondsDoubleVal = rewriter.create(loc, doubleTy, nanosecondsIntVal); + Value divisor = rewriter.create(loc, doubleTy, rewriter.getF64FloatAttr(1.0e9)); + Value nanoseconds = rewriter.create(loc, doubleTy, nanosecondsDoubleVal, divisor); + Value totalSecondsDoubleVal = rewriter.create(loc, doubleTy, secondsDoubleVal, nanoseconds); + return totalSecondsDoubleVal; + } } -#endif } LogicalResult GetTimeOpLowering::matchAndRewrite( @@ -1856,7 +1872,7 @@ void ValueToLLVMLoweringPass::runOnModule() intermediateTarget.addLegalDialect(); RewritePatternSet patterns(&getContext()); - populateValueToLLVMNonMemPatterns(llvmTypeConverter, patterns); + populateValueToLLVMNonMemPatterns(llvmTypeConverter, patterns, this->deviceInfo); populateLinalgToLLVMConversionPatterns(llvmTypeConverter, patterns); @@ -1891,6 +1907,7 @@ void ValueToLLVMLoweringPass::runOnModule() populateStdToLLVMConversionPatterns(llvmTypeConverter, patterns); arith::populateArithmeticToLLVMConversionPatterns(llvmTypeConverter, patterns); arith::populateArithmeticExpandOpsPatterns(patterns); + cf::populateControlFlowToLLVMConversionPatterns(llvmTypeConverter, patterns); // Subset of LowerVectorToLLVMPass patterns vector::populateVectorToVectorCanonicalizationPatterns(patterns); @@ -1947,7 +1964,7 @@ void populateGlobalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConver patterns.insert(typeConverter, context); } -void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns) +void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo) { mlir::MLIRContext* context = patterns.getContext(); @@ -1957,19 +1974,20 @@ void populateLocalValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConvert BitcastOpLowering, CallOpLowering, PrintFOpLowering, - GetTimeOpLowering, RangeOpLowering, VpmaddwdOpLowering, VmaxpsOpLowering, VminpsOpLowering, RoundOpLowering, MemrefAllocOpLowering>(typeConverter, context); + + patterns.insert(typeConverter, context, deviceInfo); } -void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns) +void populateValueToLLVMNonMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns, accera::value::TargetDevice deviceInfo) { populateGlobalValueToLLVMNonMemPatterns(typeConverter, patterns); - populateLocalValueToLLVMNonMemPatterns(typeConverter, patterns); + populateLocalValueToLLVMNonMemPatterns(typeConverter, patterns, deviceInfo); } void populateValueToLLVMMemPatterns(mlir::LLVMTypeConverter& typeConverter, mlir::RewritePatternSet& patterns) @@ -2009,9 +2027,10 @@ std::unique_ptr> createValueToLLVMPass(bool unsigned indexBitwidth, bool useAlignedAlloc, llvm::DataLayout dataLayout, + accera::value::TargetDevice deviceInfo /* = {} */, const IntraPassSnapshotOptions& options /* = {} */) { - return std::make_unique(useBasePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, dataLayout, options); + return std::make_unique(useBasePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, dataLayout, deviceInfo, options); } std::unique_ptr> createValueToLLVMPass() diff --git a/accera/transforms/src/value/ValueToStandardLoweringPass.cpp b/accera/transforms/src/value/ValueToStandardLoweringPass.cpp index 690df492..813ee62c 100644 --- a/accera/transforms/src/value/ValueToStandardLoweringPass.cpp +++ b/accera/transforms/src/value/ValueToStandardLoweringPass.cpp @@ -580,7 +580,7 @@ struct CastOpLowering : public OpRewritePattern #define CAST_FROM_TO_WITH_OP_IF(testFromType, testToType, castOp, conditional) \ if (fromType && toType && fromElementType.isa() && toElementType.isa() && conditional) \ { \ - mlir::Value castValue = rewriter.create(op.getLoc(), signlessFromValue, signlessToType); \ + mlir::Value castValue = rewriter.create(op.getLoc(), signlessToType, signlessFromValue); \ if (toType.isIntOrIndex()) \ { \ rewriter.replaceOpWithNewOp(op, toType, castValue); \ @@ -652,14 +652,14 @@ struct CastOpLowering : public OpRewritePattern auto i64IntermediateType = accera::ir::util::CloneTypeWithNewElementType(op.source().getType(), rewriter.getI64Type()); if (fromElementType.isa() && toElementType.isa()) { - auto int64Value = rewriter.create(loc, op.source(), i64IntermediateType); // index->int64 - rewriter.replaceOpWithNewOp(op, int64Value, toElementType); // int64->fp + auto int64Value = rewriter.create(loc, i64IntermediateType, op.source()); // index->int64 + rewriter.replaceOpWithNewOp(op, toElementType, int64Value); // int64->fp return success(); } if (fromElementType.isa() && toElementType.isa()) { - auto int64Value = rewriter.create(loc, op.source(), i64IntermediateType); // fp->int64 - rewriter.replaceOpWithNewOp(op, int64Value, toElementType); // int64->index + auto int64Value = rewriter.create(loc, i64IntermediateType, op.source()); // fp->int64 + rewriter.replaceOpWithNewOp(op, toElementType, int64Value); // int64->index return success(); } @@ -1238,7 +1238,7 @@ LogicalResult CopyOpLowering::matchAndRewrite( if (outputMemRef.getElementType().isInteger(64)) // this should really be target dependent... { (void)rewriter.create(loc, - rewriter.create(loc, input, rewriter.getIntegerType(64)), + rewriter.create(loc, rewriter.getIntegerType(64), input), output, std::vector(outputMemRef.getRank(), zero)); } @@ -1358,7 +1358,7 @@ LogicalResult LoadOpLowering::matchAndRewrite( else { resolvedIndices.push_back( - rewriter.create(loc, rewriter.create(loc, index), indexType)); + rewriter.create(loc, indexType, rewriter.create(loc, index))); } } @@ -1388,8 +1388,8 @@ LogicalResult StoreOpLowering::matchAndRewrite( { resolvedIndices.push_back( rewriter.create(loc, - rewriter.create(loc, index), - indexType)); + indexType, + rewriter.create(loc, index))); } } @@ -1524,8 +1524,8 @@ LogicalResult OffsetOpLowering::matchAndRewrite( { resolvedOffsets.push_back( rewriter.create(loc, - rewriter.create(loc, index), - indexType)); + indexType, + rewriter.create(loc, index))); } else { @@ -1611,7 +1611,7 @@ LogicalResult SliceOpLowering::matchAndRewrite( auto indexShape = index.getType().cast().getShape(); if (indexShape.size() == 0 || indexShape.size() == 1) { - index = rewriter.create(loc, rewriter.create(loc, index), indexType); + index = rewriter.create(loc, indexType, rewriter.create(loc, index)); } else { @@ -1743,8 +1743,8 @@ LogicalResult ReorderOpLowering::matchAndRewrite( auto resultType = op.getType(); // cast to a value with type `memref` (via `memref<* x elem_type>`) - mlir::Value ptr = rewriter.create(loc, source, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace())); - auto result = rewriter.create(loc, ptr, resultType); + mlir::Value ptr = rewriter.create(loc, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace()), source); + auto result = rewriter.create(loc, resultType, ptr); rewriter.replaceOp(op, { result }); return success(); @@ -1762,8 +1762,8 @@ LogicalResult ReshapeOpLowering::matchAndRewrite( auto resultType = op.getType(); // cast to a value with type `memref` (via `memref<* x elem_type>`) - mlir::Value ptr = rewriter.create(loc, source, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace())); - auto result = rewriter.create(loc, ptr, resultType); + mlir::Value ptr = rewriter.create(loc, mlir::UnrankedMemRefType::get(elemTy, sourceType.getMemorySpace()), source); + auto result = rewriter.create(loc, resultType, ptr); rewriter.replaceOp(op, { result }); return success(); @@ -1932,11 +1932,11 @@ LogicalResult ReduceOpLowering::matchAndRewrite( vectorSize = vectorType.getShape()[0]; } auto stepValue = isParallelReduction ? vectorSize : 1; - auto oldInputValue = op.getInputValueVar(); auto oldInductionValue = op.getInductionValue(); auto oldTerminator = op.getBody()->getTerminator(); auto oldYieldValue = oldTerminator->getOperand(0); // TODO: add "get result value" helper to ReduceOp + bool isFloatType = initialValueType.isa(); // Check for trivial reductions of the form bin_op(arg1, arg2) if (isHorizontalReduction) @@ -1959,40 +1959,36 @@ LogicalResult ReduceOpLowering::matchAndRewrite( { using accera::ir::value::BinaryOpPredicate; auto pred = binOp.predicate(); - std::string opName = ""; + mlir::vector::CombiningKind kind; switch (pred) { case BinaryOpPredicate::ADD: - opName = "add"; + kind = mlir::vector::CombiningKind::ADD; break; case BinaryOpPredicate::MUL: - opName = "mul"; + kind = mlir::vector::CombiningKind::MUL; break; case BinaryOpPredicate::SUB: - opName = "sub"; - break; + [[fallthrough]]; default: - break; + llvm_unreachable("Unsupported binary op predicate for vector reduction"); } - if (!opName.empty()) + // We can use the init value for floating-point add and mul + if (isFloatType && (pred == BinaryOpPredicate::ADD || pred == BinaryOpPredicate::MUL)) { - // We can use the init value for floating-point add and mul - if (initialValueType.isa() && (pred == BinaryOpPredicate::ADD || pred == BinaryOpPredicate::MUL)) - { - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr(opName), op.input(), op.initArg()); - rewriter.replaceOp(op, { result }); - } - else - { - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr(opName), op.input(), llvm::None); - rewriter.replaceOp(op, { result }); - } - return success(); + auto result = rewriter.create(loc, kind, op.input(), op.initArg()); + rewriter.replaceOp(op, { result }); } + else + { + auto result = rewriter.create(loc, kind, op.input()); + rewriter.replaceOp(op, { result }); + } + return success(); } } - else if (auto selectOp = dyn_cast(yieldValueOp)) + else if (auto selectOp = dyn_cast(yieldValueOp)) { // Look for sequences like: // @@ -2022,13 +2018,13 @@ LogicalResult ReduceOpLowering::matchAndRewrite( [[fallthrough]]; case ValueCmpOpPredicate::LE: // min - result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("min"), op.input(), llvm::None); + result = rewriter.create(loc, isFloatType ? mlir::vector::CombiningKind::MINF : mlir::vector::CombiningKind::MINSI, op.input()); break; case ValueCmpOpPredicate::GT: [[fallthrough]]; case ValueCmpOpPredicate::GE: // max - result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("max"), op.input(), llvm::None); + result = rewriter.create(loc, isFloatType ? mlir::vector::CombiningKind::MAXF : mlir::vector::CombiningKind::MAXSI, op.input()); break; } @@ -2149,8 +2145,8 @@ LogicalResult ReferenceGlobalOpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, - getGlobalOpValue, - op.getType()); + op.getType(), + getGlobalOpValue); return success(); } @@ -2538,11 +2534,11 @@ LogicalResult ReduceMaxOpLowering::matchAndRewrite( return op.emitError("Can only reduce a rank-1 memref"); } + auto elementType = memRefType.getElementType(); mlir::Value memrefToCast = input; mlir::Value loadedVector = nullptr; if (!memRefType.getLayout().isIdentity()) { - auto elementType = memRefType.getElementType(); auto vectorType = mlir::VectorType::get(memRefType.getShape(), elementType); auto zero = rewriter.create(loc, 0); loadedVector = rewriter.create(loc, vectorType, memrefToCast, mlir::ValueRange{ zero }); @@ -2552,7 +2548,8 @@ LogicalResult ReduceMaxOpLowering::matchAndRewrite( auto castMemRefVector = rewriter.create(loc, memrefToCast); loadedVector = rewriter.create(loc, castMemRefVector, llvm::None); } - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("max"), loadedVector, llvm::None); + auto kind = elementType.isa() ? mlir::vector::CombiningKind::MAXF : mlir::vector::CombiningKind::MAXSI; + auto result = rewriter.create(loc, kind, loadedVector); rewriter.replaceOp(op, { result }); return success(); } @@ -2633,7 +2630,7 @@ LogicalResult ExitProfileRegionOpLowering::matchAndRewrite(ExitProfileRegionOp o rewriter.create(loc, totalTime, totalTimeRef); mlir::Value prevCount = rewriter.create(loc, countRef); - auto one = rewriter.create(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); mlir::Value newCount = rewriter.create(loc, vir::BinaryOpPredicate::ADD, prevCount, one); rewriter.create(loc, newCount, countRef); @@ -2701,7 +2698,7 @@ LogicalResult ReduceSumOpLowering::matchAndRewrite( auto castMemRefVector = rewriter.create(loc, memrefToCast); loadedVector = rewriter.create(loc, castMemRefVector, llvm::None); } - auto result = rewriter.create(loc, op.result().getType(), rewriter.getStringAttr("add"), loadedVector, llvm::None); + auto result = rewriter.create(loc, mlir::vector::CombiningKind::ADD, loadedVector); rewriter.replaceOp(op, { result }); return success(); } @@ -2724,7 +2721,7 @@ LogicalResult PrintOpLowering::matchAndRewrite( auto printElement = [&](mlir::Value el) { if (elementType.isF32()) { - el = rewriter.create(loc, el, rewriter.getF64Type()); + el = rewriter.create(loc, rewriter.getF64Type(), el); } rewriter.create(loc, formatStr, ValueRange{ el }, toStderr); }; diff --git a/accera/transforms/src/vectorization/VectorizationUtil.cpp b/accera/transforms/src/vectorization/VectorizationUtil.cpp index 94b4322b..b8306937 100644 --- a/accera/transforms/src/vectorization/VectorizationUtil.cpp +++ b/accera/transforms/src/vectorization/VectorizationUtil.cpp @@ -140,7 +140,7 @@ bool CanVectorizeOp(mlir::Operation* op, .Case([](mlir::memref::StoreOp) { return true; }) .Case([](mlir::AffineLoadOp) { return true; }) .Case([](mlir::AffineStoreOp) { return true; }) - .Case([](mlir::SelectOp) { return true; }) + .Case([](mlir::arith::SelectOp) { return true; }) .Case([](mlir::arith::ShLIOp) { return true; }) .Case([](mlir::arith::FPToSIOp) { return true; }) .Case([](mlir::arith::ExtSIOp) { return true; }) @@ -810,7 +810,7 @@ std::optional VectorizeAffineApplyOp(mlir::PatternRewriter& rewrit } std::optional VectorizeSelectOp(mlir::PatternRewriter& rewriter, - mlir::SelectOp op, + mlir::arith::SelectOp op, const VectorizedOpMap& vectorizedOps, std::vector& laneMappings, mlir::Value inductionVar, @@ -830,7 +830,7 @@ std::optional VectorizeSelectOp(mlir::PatternRewriter& rewrite } auto loc = op.getLoc(); - auto result = rewriter.create(loc, cond->GetVectorResult(), trueVal->GetVectorResult(), falseVal->GetVectorResult()); + auto result = rewriter.create(loc, cond->GetVectorResult(), trueVal->GetVectorResult(), falseVal->GetVectorResult()); return result; } @@ -1193,7 +1193,7 @@ std::optional VectorizeOp(mlir::PatternRewriter& rewriter, .Case([&](mlir::AffineApplyOp affineApplyOp) { return VectorizeAffineApplyOp(rewriter, affineApplyOp, vectorizedOps, laneMappings, inductionVar, step, vectorSize); }) - .Case([&](mlir::SelectOp selectOp) { + .Case([&](mlir::arith::SelectOp selectOp) { return VectorizeSelectOp(rewriter, selectOp, vectorizedOps, laneMappings, inductionVar, step, vectorSize); }) .Case([&](mlir::arith::ShLIOp shiftLeftOp) { @@ -2185,7 +2185,7 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, } else { - reducedVal = rewriter.create(binOp.getLoc(), storeElementType, mlir::vector::stringifyEnum(reductionKind), vectorValToReduce, mlir::ValueRange{} /* optional accumulate values */); + reducedVal = rewriter.create(binOp.getLoc(), reductionKind, vectorValToReduce); } auto scalarValThatWasReduced = lhsLoadIsLoopSequential ? lhsVal : rhsVal; diff --git a/accera/value/src/Cache.cpp b/accera/value/src/Cache.cpp index 49ee22fd..ccf381e7 100644 --- a/accera/value/src/Cache.cpp +++ b/accera/value/src/Cache.cpp @@ -763,8 +763,8 @@ namespace value auto localScopeGlobalRef = builder.create(GetLocation(), globalScopeGlobalRef.getGlobal()); // Re-view the packed buffer memref to match the function argument // TODO : update the function arguments to match the packed shape - mlir::Value shapelessMemref = builder.create(loc, localScopeGlobalRef, mlir::UnrankedMemRefType::get(GetElementType(), GetInputType().getMemorySpace())); - auto reshapedMemref = builder.create(loc, shapelessMemref, GetInputType()); + mlir::Value shapelessMemref = builder.create(loc, mlir::UnrankedMemRefType::get(GetElementType(), GetInputType().getMemorySpace()), localScopeGlobalRef); + auto reshapedMemref = builder.create(loc, GetInputType(), shapelessMemref); constantInjectedArgs.insert(constantInjectedArgs.begin() + targetArgIdx, reshapedMemref); auto launchFuncOp = builder.create(GetLocation(), scheduleFuncOp, constantInjectedArgs); diff --git a/accera/value/src/MLIREmitterContext.cpp b/accera/value/src/MLIREmitterContext.cpp index b14b3a9d..04044aaf 100644 --- a/accera/value/src/MLIREmitterContext.cpp +++ b/accera/value/src/MLIREmitterContext.cpp @@ -281,7 +281,7 @@ mlir::FunctionType ToMLIRType(mlir::OpBuilder& builder, const FunctionDeclaratio if (type.isa()) { auto loc = builder.getUnknownLoc(); - return builder.create(loc, v, mlir::IndexType::get(v.getContext())); + return builder.create(loc, mlir::IndexType::get(v.getContext()), v); } // Index types fall through @@ -355,7 +355,7 @@ std::vector ToMLIRValue(mlir::OpBuilder& builder, std::vector(loc, mlirValue, indexType); + return builder.create(loc, indexType, mlirValue); } } @@ -825,7 +825,9 @@ void MLIRContext::print() const void MLIRContext::verify() const { - (void)_impl->_mlirModule.verify(); + // BUGBUG: ModuleOp::verify() is private as of LLVM 15 + // (void)_impl->_mlirModule.verify(); + throw utilities::LogicException(utilities::LogicExceptionErrors::notImplemented, "verify() is not implemented."); } mlir::OwningOpRef MLIRContext::cloneModule() const @@ -918,8 +920,8 @@ Scalar CreateGPUIndexOp(mlir::OpBuilder& builder, accera::ir::value::Processor i auto loc = builder.getUnknownLoc(); return Wrap( builder.create(loc, - accera::ir::util::GetGPUIndex(idxType, builder, loc), - builder.getI64Type())); + builder.getI64Type(), + accera::ir::util::GetGPUIndex(idxType, builder, loc))); } template @@ -2097,7 +2099,7 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) mlir::MemRefType::Builder outputTypeBuilder(inputMemrefType); outputTypeBuilder.setElementType(outputMlirElemType); mlir::MemRefType outputMemRefType = outputTypeBuilder; - auto returnVal = builder.create(loc, inputMlir, outputMemRefType); + auto returnVal = builder.create(loc, outputMemRefType, inputMlir); return Wrap(returnVal, input.GetLayout()); } @@ -2254,7 +2256,7 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) auto outputMemRefType = mlir::MemRefType::get({ numElementsInOutput }, outputMlirElemType, composedMap); // Fetch a new memref type after normalizing the old memref to have an identity map layout. outputMemRefType = normalizeMemRefType(outputMemRefType, builder, composedMap.getNumSymbols() /* ?? No idea if this is correct */); - auto returnVal = builder.create(loc, inputMlir, outputMemRefType); + auto returnVal = builder.create(loc, outputMemRefType, inputMlir); return Wrap(returnVal, MemoryLayout{ numElementsInOutput }); } @@ -2262,7 +2264,7 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) { // Case 3 // auto outputMemRefType = mlir::UnrankedMemRefType::get(outputMlirElemType, inputUnrankedMemrefType.getMemorySpace()); - // auto returnVal = builder.create(loc, inputMlir, outputMemRefType); + // auto returnVal = builder.create(loc, outputMemRefType, inputMlir); // TODO: This is going to assert because returnVal is of type UnrankedMemRefType and MemoryLayout doesn't support unranked // return Wrap(returnVal); @@ -2786,8 +2788,8 @@ void MLIRContext::PrintRawMemoryImpl(ViewAdapter value) auto identityLayout = mlir::MemRefLayoutAttrInterface{}; // cast to a value with type `memref` (via `memref<* x elem_type>`) - mlir::Value ptr = builder.create(loc, mem, mlir::UnrankedMemRefType::get(elemTy, memType.getMemorySpace())); - mlir::Value mlirValue = builder.create(loc, ptr, mlir::MemRefType::get({ size }, elemTy, identityLayout, memType.getMemorySpace())); + mlir::Value ptr = builder.create(loc, mlir::UnrankedMemRefType::get(elemTy, memType.getMemorySpace()), mem); + mlir::Value mlirValue = builder.create(loc, mlir::MemRefType::get({ size }, elemTy, identityLayout, memType.getMemorySpace()), ptr); [[maybe_unused]] auto op = builder.create(loc, mlirValue, /*toStderr=*/false); } @@ -3141,8 +3143,8 @@ void FillResource(ViewAdapter resourceView, Scalar fillValue) auto loc = b.getUnknownLoc(); mlir::Value memrefCasted = b.create( loc, - res, - castType); + castType, + res); auto DeclareFn = [&](const std::string& name, mlir::FunctionType fnTy) -> ir::value::ValueFuncOp { auto mod = res.getParentRegion()->getParentOfType(); @@ -3208,8 +3210,8 @@ void PrintMemref(ViewAdapter memView) } mlir::Value memrefCasted = b.create( loc, - mem, - mlir::UnrankedMemRefType::get(elemTy, shapedType.getMemorySpace())); + mlir::UnrankedMemRefType::get(elemTy, shapedType.getMemorySpace()), + mem); auto DeclareFn = [&](const std::string& name, mlir::FunctionType fnTy) -> ir::value::ValueFuncOp { auto mod = mem.getParentRegion()->getParentOfType(); diff --git a/accera/value/src/ScalarOperations.cpp b/accera/value/src/ScalarOperations.cpp index 1b8929e5..034a226e 100644 --- a/accera/value/src/ScalarOperations.cpp +++ b/accera/value/src/ScalarOperations.cpp @@ -251,7 +251,7 @@ namespace value Scalar Select(Scalar cmp, Scalar a, Scalar b) { std::tie(a, b) = Scalar::MakeTypeCompatible(a, b); - return ScalarOpBuilder(cmp, a, b); + return ScalarOpBuilder(cmp, a, b); } Scalar Sin(Scalar s) diff --git a/accera/value/src/TargetDevice.cpp b/accera/value/src/TargetDevice.cpp index 1a46ad67..c4767318 100644 --- a/accera/value/src/TargetDevice.cpp +++ b/accera/value/src/TargetDevice.cpp @@ -92,6 +92,12 @@ namespace value targetDevice.triple = c_windowsTriple; targetDevice.dataLayout = c_windowsDataLayout; } }, + { "avx2", [](TargetDevice& targetDevice) { + targetDevice.architecture = "x86_64"; + targetDevice.cpu = "skylake"; + targetDevice.numBits = 64; + targetDevice.features = "+avx2"; + } }, { "avx512", [](TargetDevice& targetDevice) { targetDevice.architecture = "x86_64"; targetDevice.cpu = "skylake-avx512"; diff --git a/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch b/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch index b3a31422..d913eadd 100644 --- a/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch +++ b/external/llvm/0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch @@ -1,7 +1,7 @@ -From 08caef4cc7d6e8bc79185d0775da02552eddea9d Mon Sep 17 00:00:00 2001 +From c0e0b6c647aaa8a9c8f8167ef54f4846f25f827b Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Tue, 17 May 2022 15:16:57 +0800 -Subject: [PATCH] Merged PR 2213: [mlir] Plumb OpenMP dialect attributes +Subject: [PATCH 1/6] Merged PR 2213: [mlir] Plumb OpenMP dialect attributes through affine and scf lowering * Updated AffineToSCF and SCFToOpenMP to support OMP dialect attributes for num_threads, schedule_val, proc_bind, and collapse @@ -43,7 +43,7 @@ index 05d7637d52d7..8295b01f8fcd 100644 + #endif // MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td -index 505e9cb22a0a..d124be11cb94 100644 +index ddeb698fb2a2..6a74eeb217bd 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -124,7 +124,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments, @@ -54,13 +54,13 @@ index 505e9cb22a0a..d124be11cb94 100644 + OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "Value":$num_threads, "ClauseProcBindKindAttr":$proc_bind)> ]; - let parser = [{ return parseParallelOp(parser, result); }]; - let printer = [{ return printParallelOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp -index 8ff1134f4b7b..888a6ed20fdc 100644 +index 7c91af4c49f0..0992fc0c1f3a 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp -@@ -385,6 +385,11 @@ public: +@@ -177,6 +177,11 @@ public: SmallVector upperBoundTuple; SmallVector lowerBoundTuple; SmallVector identityVals; @@ -72,7 +72,7 @@ index 8ff1134f4b7b..888a6ed20fdc 100644 // Emit IR computing the lower and upper bound by expanding the map // expression. lowerBoundTuple.reserve(op.getNumDims()); -@@ -418,6 +423,7 @@ public: +@@ -210,6 +215,7 @@ public: rewriter.eraseBlock(parOp.getBody()); rewriter.inlineRegionBefore(op.region(), parOp.getRegion(), parOp.getRegion().end()); @@ -80,7 +80,7 @@ index 8ff1134f4b7b..888a6ed20fdc 100644 rewriter.replaceOp(op, parOp.getResults()); return success(); } -@@ -467,6 +473,7 @@ public: +@@ -259,6 +265,7 @@ public: reduceOp.getReductionOperator().front().getArgument(1)); rewriter.create(loc, reductionResult); } @@ -89,10 +89,10 @@ index 8ff1134f4b7b..888a6ed20fdc 100644 return success(); } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp -index e472af257776..526909ca4224 100644 +index a9e7759aa75e..30e0e3e9ad16 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp -@@ -363,8 +363,12 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -364,8 +364,12 @@ struct ParallelOpLowering : public OpRewritePattern { loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); @@ -107,7 +107,7 @@ index e472af257776..526909ca4224 100644 for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || init.getType().isa()) && -@@ -389,7 +393,19 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -390,7 +394,19 @@ struct ParallelOpLowering : public OpRewritePattern { } // Create the parallel wrapper. @@ -128,7 +128,7 @@ index e472af257776..526909ca4224 100644 { OpBuilder::InsertionGuard guard(rewriter); rewriter.createBlock(&ompParallel.region()); -@@ -405,9 +421,20 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -406,9 +422,20 @@ struct ParallelOpLowering : public OpRewritePattern { } // Replace the loop. @@ -150,7 +150,7 @@ index e472af257776..526909ca4224 100644 rewriter.create(loc); rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), -@@ -420,15 +447,21 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -421,15 +448,21 @@ struct ParallelOpLowering : public OpRewritePattern { } // Load loop results. @@ -179,7 +179,7 @@ index e472af257776..526909ca4224 100644 return success(); } }; -@@ -437,7 +470,7 @@ struct ParallelOpLowering : public OpRewritePattern { +@@ -438,7 +471,7 @@ struct ParallelOpLowering : public OpRewritePattern { static LogicalResult applyPatterns(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addIllegalOp(); @@ -189,7 +189,7 @@ index e472af257776..526909ca4224 100644 RewritePatternSet patterns(module.getContext()); patterns.add(module.getContext()); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp -index 46a2bf3019e0..d9104bf41a93 100644 +index 4ff38e2b455a..a4b6fd78e7f9 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -73,6 +73,16 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, @@ -210,5 +210,5 @@ index 46a2bf3019e0..d9104bf41a93 100644 // Parser and printer for Operand and type list //===----------------------------------------------------------------------===// -- -2.32.1 (Apple Git-133) +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch b/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch index 81d187e7..70ca7374 100644 --- a/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch +++ b/external/llvm/0002-Merged-PR-2237-Improved-codegen-of-vpmaddwd-instruct.patch @@ -1,7 +1,7 @@ -From 6172c5a07d0e4ca745aeff45298d2995aa685c42 Mon Sep 17 00:00:00 2001 +From b2cecf54139212bbea1f29337c51e47fca81be5c Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Mon, 14 Feb 2022 16:18:47 +0800 -Subject: [PATCH] From 97c4232342b7e8b802c4557368bc016264679930 Mon Sep 17 +Subject: [PATCH 2/6] From 97c4232342b7e8b802c4557368bc016264679930 Mon Sep 17 00:00:00 2001 From: Chuck Jacobs Date: Wed, 6 Oct 2021 16:40:38 +0000 Subject: Merged PR 2237: Improved codegen of vpmaddwd instruction @@ -12,7 +12,7 @@ This PR adds another codegen path that lowers certain "multiply-like" operations 1 file changed, 574 insertions(+) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp -index 6f1fe8195595..f7bff9795548 100644 +index 4c622568f8d0..7ee8b9be2154 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -39,6 +39,7 @@ @@ -23,7 +23,7 @@ index 6f1fe8195595..f7bff9795548 100644 #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/WinEHFuncInfo.h" #include "llvm/IR/CallingConv.h" -@@ -49340,6 +49341,350 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, +@@ -49524,6 +49525,350 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, return DAG.getNode(Opc, DL, VT, LHS, RHS); } @@ -374,7 +374,7 @@ index 6f1fe8195595..f7bff9795548 100644 // Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes // from one vector with signed bytes from another vector, adds together // adjacent pairs of 16-bit products, and saturates the result before -@@ -52432,6 +52777,233 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, +@@ -52546,6 +52891,233 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, PMADDBuilder); } @@ -608,7 +608,7 @@ index 6f1fe8195595..f7bff9795548 100644 // ADD(VPMADDWD(X,Y),VPMADDWD(Z,W)) -> VPMADDWD(SHUFFLE(X,Z), SHUFFLE(Y,W)) // If upper element in each pair of both VPMADDWD are zero then we can merge // the operand elements and use the implicit add of VPMADDWD. -@@ -52554,6 +53126,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, +@@ -52668,6 +53240,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, return MAdd; if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, DL, VT, Subtarget)) return MAdd; @@ -618,5 +618,5 @@ index 6f1fe8195595..f7bff9795548 100644 return MAdd; -- -2.30.0.windows.1 +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0003-Fix-bad-merge.patch b/external/llvm/0003-Fix-bad-merge.patch index e21a29db..87e80c4a 100644 --- a/external/llvm/0003-Fix-bad-merge.patch +++ b/external/llvm/0003-Fix-bad-merge.patch @@ -1,7 +1,7 @@ -From 2dbd177abcf79fe183c989824863aafe5b38e1cd Mon Sep 17 00:00:00 2001 +From 5cf0dde409f5f937e22f9a6c81db8368494a63cb Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Mon, 14 Feb 2022 16:32:01 +0800 -Subject: [PATCH] From 7f9a254015c977405957fb5b2b6e2a1895f0ca69 Mon Sep 17 +Subject: [PATCH 3/6] From 7f9a254015c977405957fb5b2b6e2a1895f0ca69 Mon Sep 17 00:00:00 2001 From: Kern Handa Date: Wed, 6 Oct 2021 11:09:31 -0700 Subject: Fix bad merge @@ -32,5 +32,5 @@ index 79a062fd9735..377080b25c36 100644 } -- -2.30.0.windows.1 +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch b/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch index e6fe7654..b82ee1e9 100644 --- a/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch +++ b/external/llvm/0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch @@ -1,36 +1,34 @@ -From bc61595ab6d6f5eb6fbf3564cf00f8839250d2d7 Mon Sep 17 00:00:00 2001 +From c865623e33d24660ecc474529192914d0f87b48f Mon Sep 17 00:00:00 2001 From: Lisa Ong Date: Mon, 23 May 2022 12:39:55 +0800 -Subject: [PATCH] Lower memref.copy to memcpy when layouts canonicalize to +Subject: [PATCH 4/6] Lower memref.copy to memcpy when layouts canonicalize to identity layouts. A memref.cast won't work because it gets folded into memref.copy during op canonicalization. --- - mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 9 +++++++-- - 1 file changed, 7 insertions(+), 2 deletions(-) + mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 7 +++++++ + 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp -index 288c252b81bb..bec7513f7986 100644 +index 56413c415590..4402485545ad 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp -@@ -933,10 +933,15 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { - auto srcType = op.source().getType().cast(); - auto targetType = op.target().getType().cast(); - -+ // Memref casts get folded away during CopyOp::fold, so we have to replace -+ // the operand with its canonicalized identity form, if they are equivalent -+ auto cannedSrcType = canonicalizeStridedLayout(srcType.cast()); -+ auto cannedTargetType = canonicalizeStridedLayout(targetType.cast()); +@@ -944,8 +944,15 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { + // We can use memcpy for memrefs if they have an identity layout or are + // contiguous with an arbitrary offset. Ignore empty memrefs, which is a + // special case handled by memrefCopy. + - if (srcType.hasRank() && -- srcType.cast().getLayout().isIdentity() && -+ (srcType.cast().getLayout().isIdentity() || cannedSrcType.getLayout().isIdentity()) && - targetType.hasRank() && -- targetType.cast().getLayout().isIdentity()) -+ (targetType.cast().getLayout().isIdentity() || cannedTargetType.getLayout().isIdentity())) - return lowerToMemCopyIntrinsic(op, adaptor, rewriter); - - return lowerToMemCopyFunctionCall(op, adaptor, rewriter); ++ // Memref casts get folded away during CopyOp::fold, so we have to replace ++ // the operand with its canonicalized identity form, if they are ++ // equivalent ++ auto cannedType = canonicalizeStridedLayout(memrefType); ++ + return memrefType && + (memrefType.getLayout().isIdentity() || ++ cannedType.getLayout().isIdentity() || + (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && + isStaticShapeAndContiguousRowMajor(memrefType))); + }; -- -2.32.1 (Apple Git-133) +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch b/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch index 2067d5d3..04d43dce 100644 --- a/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch +++ b/external/llvm/0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch @@ -1,14 +1,16 @@ -From 39f0a4c97f5c89d7fa815118a3230091172bc795 Mon Sep 17 00:00:00 2001 -From: Charles Jacobs -Date: Mon, 15 Aug 2022 16:00:43 -0700 -Subject: [PATCH] Fix issue where passed-in op-printing flags were ignored +From bb702476975ebfc65d445f8da7d6064a81c09666 Mon Sep 17 00:00:00 2001 +From: Chuck Jacobs +Date: Wed, 24 Aug 2022 04:13:49 +0000 +Subject: [PATCH 5/6] Merged PR 2823: Fix op-printing bug in + LocationSnapshotPass +This PR fixes an issue where the printing flags passed in to the LocationSnapshotPass were being ignored. --- mlir/lib/Transforms/LocationSnapshot.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp -index f23a3eee1511..c3c323284b18 100644 +index a042d07335bb..808f2ad2a67c 100644 --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -133,7 +133,7 @@ struct LocationSnapshotPass @@ -21,5 +23,5 @@ index f23a3eee1511..c3c323284b18 100644 } -- -2.32.1 (Apple Git-133) +2.37.1 (Apple Git-137.1) diff --git a/external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch b/external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch new file mode 100644 index 00000000..207470e3 --- /dev/null +++ b/external/llvm/0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch @@ -0,0 +1,317 @@ +From f8fa2aac2352539b963c03336a69403e5b40fc29 Mon Sep 17 00:00:00 2001 +From: Chuck Jacobs +Date: Tue, 25 Oct 2022 02:38:08 +0000 +Subject: [PATCH 6/6] Merged PR 2919: More flexible code generation for + vpmaddwd instruction + +This change increases the number of patterns that will generate a `vpmaddwd` instruction: + +- It now handles the case where `opt` linearizes the sum-of-product instructions +- It handles some cases where one of the operands can be interpreted as 2 adjacent values broadcast to fill a vector register (this is the case with the "A" matrix in matrix-matrix multiply) +--- + llvm/lib/Target/X86/X86ISelLowering.cpp | 206 +++++++++++++++++++++--- + 1 file changed, 187 insertions(+), 19 deletions(-) + +diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp +index 7ee8b9be2154..d4ba64d64f87 100644 +--- a/llvm/lib/Target/X86/X86ISelLowering.cpp ++++ b/llvm/lib/Target/X86/X86ISelLowering.cpp +@@ -31,6 +31,7 @@ + #include "llvm/Analysis/ObjCARCUtil.h" + #include "llvm/Analysis/ProfileSummaryInfo.h" + #include "llvm/Analysis/VectorUtils.h" ++#include "llvm/CodeGen/ISDOpcodes.h" + #include "llvm/CodeGen/IntrinsicLowering.h" + #include "llvm/CodeGen/MachineFrameInfo.h" + #include "llvm/CodeGen/MachineFunction.h" +@@ -49527,7 +49528,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, + + // Helper function for PMADDUBSW / PMADDWD + // TODO: need to do the convert-load-to-broadcast-splat fixup in the function that takes the a0/a1 pair +-static SDValue getExtMulOperand(SDValue node, int index, EVT VT, const SDLoc &DL, ++static SDValue getPMADDOperand(SDValue node, int index, EVT VT, const SDLoc &DL, + SelectionDAG &DAG) { + unsigned NumElems = VT.getVectorNumElements(); + auto op = node.getOperand(index); +@@ -49634,14 +49635,38 @@ static SDValue getExtMulOperand(SDValue node, int index, EVT VT, const SDLoc &DL + // TODO: reduce to a single global load + } + ++ auto bv = DAG.getBuildVector(argVT, DL, elements); ++ auto replacement = DAG.getNode(opcode, DL, VT, bv); ++ return replacement; ++ } // end if (arg.getOpcode() == ISD::INSERT_VECTOR_ELT) ++ else if (arg.getOpcode() == ISD::VECTOR_SHUFFLE) { ++ // "broadcast-A" case: arg is vector_shuffle<...> X ++ SmallVector elements; ++ auto shuffleNode = dyn_cast(arg); ++ assert(shuffleNode); ++ auto mask = shuffleNode->getMask(); ++ auto val = shuffleNode->getOperand(0); ++ if (val.getOpcode() == ISD::EXTRACT_SUBVECTOR) { ++ val = val->getOperand(0); ++ } ++ // TODO: see if val is an extract_subvector op, and if so return its arg ++ ++ MVT elemMT = argVT.getSimpleVT().getVectorElementType(); ++ for (auto idx: mask) ++ { ++ SDValue elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, elemMT, val, ++ DAG.getIntPtrConstant(idx, DL)); ++ elements.push_back(elt); ++ } ++ + auto bv = DAG.getBuildVector(argVT, DL, elements); + auto replacement = DAG.getNode(opcode, DL, VT, bv); + return replacement; + } + return op; +- } ++ } // end if (opcode == ISD::ZERO_EXTEND || ...) + +- // look for (shuf<0,0,0,0...> (insert_elt (extend ...) 0) undef) pattern and translate it ++ // TODO: look for (shuf<0,0,0,0...> (insert_elt (extend ...) 0) undef) pattern and translate it + // into (extend (shuf<0,0,0,0...> (insert_elt ... 0) undef)) + // or better yet: (extend (splat ...)) + if (opcode != ISD::VECTOR_SHUFFLE) +@@ -49673,7 +49698,7 @@ static SDValue getExtMulOperand(SDValue node, int index, EVT VT, const SDLoc &DL + SDValue buildVec = DAG.getSplatBuildVector(splatVT, DL, val); + SDValue newExt = isSigned ? DAG.getSExtOrTrunc(buildVec, DL, resultVT) : DAG.getZExtOrTrunc(buildVec, DL, resultVT); + return newExt; +-} // end of getExtMulOperand ++} // end of getPMADDOperand + + static void replaceWithInterleavedShuffles(ArrayRef> interleavedOps, const std::vector>& masks, SDValue val, SelectionDAG &DAG, const SDLoc &DL) { + const unsigned dotProdSize = interleavedOps.size(); +@@ -52769,10 +52794,19 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + !isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + +- SDValue N00 = N0.getOperand(0); +- SDValue N01 = N0.getOperand(1); +- SDValue N10 = N1.getOperand(0); +- SDValue N11 = N1.getOperand(1); ++ // Nxx naming scheme: N, where phase == 0 means "even", and operand is the operand index of the MUL operation ++ // So, the "even" left-hand operand is N00, and the "odd" left-hand operand is N10 ++ ++ // "evens" ++ SDValue N00 = getPMADDOperand(N0, 0, VT, DL, DAG); ++ SDValue N01 = getPMADDOperand(N0, 1, VT, DL, DAG); ++ ++ // "odds" ++ SDValue N10 = getPMADDOperand(N1, 0, VT, DL, DAG); ++ SDValue N11 = getPMADDOperand(N1, 1, VT, DL, DAG); ++ ++ if (!N00 || !N01 || !N10 || !N11) ++ return SDValue(); + + // All inputs need to be sign extends. + // TODO: Support ZERO_EXTEND from known positive? +@@ -52801,6 +52835,8 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + ++ // TODO: verify N00, N01, N10, and N11 have the same # of operands ++ + // For each element, we need to ensure we have an odd element from one vector + // multiplied by the odd element of another vector and the even element from + // one of the same vectors being multiplied by the even element from the +@@ -52808,6 +52844,8 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + // is being performed: + // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] + SDValue In0, In1; ++ bool prevIdxN00WasZero = true; ++ bool prevIdxN01WasZero = true; + for (unsigned i = 0; i != N00.getNumOperands(); ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); +@@ -52834,9 +52872,21 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } +- // N0 indices be the even element. N1 indices must be the next odd element. +- if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || +- IdxN01 != 2 * i || IdxN11 != 2 * i + 1) ++ ++ // N0 indices must be sequential even elements. ++ // TODO: also allow even indices to be element 0 broadcasted to fill the array ++ // ... which means Idx can be 0 if prev Idx was zero ++ if (IdxN00 != 2 * i && !(IdxN00 == 0 && prevIdxN00WasZero) || ++ IdxN01 != 2 * i && !(IdxN01 == 0 && prevIdxN01WasZero)) ++ { ++ return SDValue(); ++ } ++ ++ prevIdxN00WasZero = IdxN00 == 0; ++ prevIdxN01WasZero = IdxN01 == 0; ++ ++ // N1 indices must be the next (odd) element after the corresponding N0 indcex ++ if (IdxN10 != IdxN00 + 1 || IdxN11 != IdxN01 + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); +@@ -52864,6 +52914,41 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + return SDValue(); + } + ++ // #### ++ if (prevIdxN00WasZero) // a splat, so just use the original build_vector ++ { ++ unsigned NumElems = VT.getVectorNumElements(); ++ SmallVector elements(2*NumElems); ++ EVT argVT = In0.getValueType(); ++ MVT elemMT = argVT.getSimpleVT().getVectorElementType(); ++ for(unsigned idx = 0; idx < 2*NumElems; ++idx) ++ { ++ SDValue elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, elemMT, In0, ++ DAG.getIntPtrConstant(idx%2, DL)); ++ elements[idx] = elt; ++ } ++ ++ // In0 = ++ In0 = DAG.getBuildVector(argVT, DL, elements); ++ } ++ ++ if (prevIdxN01WasZero) // a splat, so just use the original build_vector ++ { ++ unsigned NumElems = VT.getVectorNumElements(); ++ SmallVector elements(2*NumElems); ++ EVT argVT = In1.getValueType(); ++ MVT elemMT = argVT.getSimpleVT().getVectorElementType(); ++ for(unsigned idx = 0; idx < 2*NumElems; ++idx) ++ { ++ SDValue elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, elemMT, In1, ++ DAG.getIntPtrConstant(idx%2, DL)); ++ elements[idx] = elt; ++ } ++ ++ // In1 = ++ In1 = DAG.getBuildVector(argVT, DL, elements); ++ } ++ + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + EVT OpVT = Ops[0].getValueType(); +@@ -52967,14 +53052,14 @@ static SDValue matchPMADDWD_3(SelectionDAG &DAG, SDValue N0, SDValue N1, + return SDValue(); + + SDValue p000, p001, p010, p011, p100, p101, p110, p111; +- if (!(p000 = getExtMulOperand(ssatArg0.getOperand(0), 0, VT, DL, DAG))) return SDValue(); +- if (!(p001 = getExtMulOperand(ssatArg0.getOperand(0), 1, VT, DL, DAG))) return SDValue(); +- if (!(p010 = getExtMulOperand(ssatArg0.getOperand(1), 0, VT, DL, DAG))) return SDValue(); +- if (!(p011 = getExtMulOperand(ssatArg0.getOperand(1), 1, VT, DL, DAG))) return SDValue(); +- if (!(p100 = getExtMulOperand(ssatArg1.getOperand(0), 0, VT, DL, DAG))) return SDValue(); +- if (!(p101 = getExtMulOperand(ssatArg1.getOperand(0), 1, VT, DL, DAG))) return SDValue(); +- if (!(p110 = getExtMulOperand(ssatArg1.getOperand(1), 0, VT, DL, DAG))) return SDValue(); +- if (!(p111 = getExtMulOperand(ssatArg1.getOperand(1), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p000 = getPMADDOperand(ssatArg0.getOperand(0), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p001 = getPMADDOperand(ssatArg0.getOperand(0), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p010 = getPMADDOperand(ssatArg0.getOperand(1), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p011 = getPMADDOperand(ssatArg0.getOperand(1), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p100 = getPMADDOperand(ssatArg1.getOperand(0), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p101 = getPMADDOperand(ssatArg1.getOperand(0), 1, VT, DL, DAG))) return SDValue(); ++ if (!(p110 = getPMADDOperand(ssatArg1.getOperand(1), 0, VT, DL, DAG))) return SDValue(); ++ if (!(p111 = getPMADDOperand(ssatArg1.getOperand(1), 1, VT, DL, DAG))) return SDValue(); + + SDValue zextArgs[4] = {p000, p010, p100, p110}; + SDValue sextArgs[4] = {p001, p011, p101, p111}; +@@ -53162,6 +53247,87 @@ static SDValue combineAddOfPMADDWD(SelectionDAG &DAG, SDValue N0, SDValue N1, + return DAG.getNode(X86ISD::VPMADDWD, DL, VT, LHS, RHS); + } + ++// Attempt to turn this pattern: ++// ++// (add ++// (add X, ++// (mul (sext (build_vector)), (sext (build_vector))), ++// (mul (sext (build_vector)), (sext (build_vector)))) ++// ++// into: ++// ++// (add X, ++// (add (mul (sext (build_vector)), (sext (build_vector))), ++// (mul (sext (build_vector)), (sext (build_vector))))) ++// ++// or, (X + a) + b --> X + (a + b), where a and b are (mul (sext (build_vector))) ++// ++// So that the inner add can be turned into a PMADDWD ++static SDValue rebalancePotentialPMADDWD(SelectionDAG &DAG, SDValue N0, SDValue N1, ++ const SDLoc &DL, EVT VT, ++ const X86Subtarget &Subtarget) { ++ if (!Subtarget.hasSSE2()) ++ return SDValue(); ++ ++ if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 || ++ VT.getVectorNumElements() < 4 || ++ !isPowerOf2_32(VT.getVectorNumElements())) ++ return SDValue(); ++ ++ // normalize N0 and N1 so N0 == (X + a) and N1 == b ++ if (N0.getOpcode() != ISD::ADD) ++ std::swap(N0, N1); ++ if (N0.getOpcode() != ISD::ADD || N1.getOpcode() != ISD::MUL) ++ return SDValue(); ++ ++ // function to verify v is a valid argument for an add that gets converted to PMADDWD ++ auto isValidPMADDWDArg = [](SDValue v) { ++ // get "a" and "b" operands ++ SDValue v0 = v.getOperand(0); ++ SDValue v1 = v.getOperand(1); ++ ++ // All inputs need to be sign extends. ++ // TODO: Support ZERO_EXTEND from known positive? ++ if (v0.getOpcode() != ISD::SIGN_EXTEND || ++ v1.getOpcode() != ISD::SIGN_EXTEND ) ++ return false; ++ ++ // Peek through the extends. ++ v0 = v0.getOperand(0); ++ v1 = v1.getOperand(0); ++ ++ // Must be extending from vXi16. ++ EVT InVT = v0.getValueType(); ++ if (InVT.getVectorElementType() != MVT::i16 || v1.getValueType() != InVT) ++ return false; ++ ++ // All inputs should be build_vectors. ++ // TODO: also allow broadcast-A pattern ++ if ((v0.getOpcode() != ISD::BUILD_VECTOR && v0.getOpcode() != ISD::VECTOR_SHUFFLE) || ++ (v1.getOpcode() != ISD::BUILD_VECTOR && v1.getOpcode() != ISD::VECTOR_SHUFFLE)) ++ return false; ++ ++ return true; ++ }; ++ ++ // normalize N00 and N01 so N00 == X and N01 == b ++ SDValue N00 = N0.getOperand(0); ++ SDValue N01 = N0.getOperand(1); ++ if (isValidPMADDWDArg(N00)) ++ std::swap(N00, N01); ++ // now N00 = X, N01 = a, and N1 = b ++ ++ // verify everything is correct (that is, a and b ) ++ if (!isValidPMADDWDArg(N01) || !isValidPMADDWDArg(N1)) ++ return SDValue(); ++ ++ // TODO: just turn this directly into X + vpmaddwd ++ ++ // return new expression X + (a + b) == N00 + (N01 + N1) ++ auto pmaddTerm = DAG.getNode(ISD::ADD, DL, VT, N01, N1); ++ return DAG.getNode(ISD::ADD, DL, VT, N00, pmaddTerm); ++} ++ + /// CMOV of constants requires materializing constant operands in registers. + /// Try to fold those constants into an 'add' instruction to reduce instruction + /// count. We do this with CMOV rather the generic 'select' because there are +@@ -53244,6 +53410,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, + return MAdd; + if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT)) + return MAdd; ++ if (SDValue MAdd = rebalancePotentialPMADDWD(DAG, Op0, Op1, DL, VT, Subtarget)) ++ return MAdd; + + // Try to synthesize horizontal adds from adds of shuffles. + if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget)) +-- +2.37.1 (Apple Git-137.1) + diff --git a/external/llvm/portfile.cmake b/external/llvm/portfile.cmake index f5faba93..4da7e637 100644 --- a/external/llvm/portfile.cmake +++ b/external/llvm/portfile.cmake @@ -1,5 +1,6 @@ # Builds LLVM for features needed by Accera -set(LLVM_VERSION llvmorg-14.0.6) +set(LLVM_VERSION 24a37a396a9bd6b73b05b4eafce8b87e7a748cf9) +set(LLVM_FRIENDLY_VERSION 15.0.0-rc1) set(VCPKG_BUILD_TYPE release) if((DEFINED ENV{LLVM_BUILD_TYPE}) AND ("$ENV{LLVM_BUILD_TYPE}" STREQUAL "debug")) @@ -20,7 +21,7 @@ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO llvm/llvm-project REF ${LLVM_VERSION} - SHA512 d64f97754c24f32deb5f284ebbd486b3a467978b7463d622f50d5237fff91108616137b4394f1d1ce836efa59bf7bec675b6dee257a79b241c15be52d4697460 + SHA512 5fdee8487afac16033a6d2cea720dedb5f05e00f20d761307805f0a6e1fad22d6c3ce45d89112cbe70aec1a4dfe8ae72a90d45e14bc67539fbe3dc948a316d92 HEAD_REF main PATCHES 0001-Merged-PR-2213-mlir-Plumb-OpenMP-dialect-attributes-.patch @@ -28,9 +29,8 @@ vcpkg_from_github( 0003-Fix-bad-merge.patch 0004-Lower-memref.copy-to-memcpy-when-layouts-canonicaliz.patch 0005-Fix-issue-where-passed-in-op-printing-flags-were-ign.patch - 0006-Merged-PR-2822-Fix-lowering-of-MemrefCastOp-to-the-L.patch - 0007-More-flexible-code-generation-for-vpmaddwd-instructi.patch - 0008-fix-vcpkg-install-paths.patch # cf. https://github.com/microsoft/vcpkg/blob/master/ports/llvm + 0006-Merged-PR-2919-More-flexible-code-generation-for-vpm.patch + 0007-fix-vcpkg-install-paths.patch # cf. https://github.com/microsoft/vcpkg/blob/master/ports/llvm ) vcpkg_find_acquire_program(PYTHON3) @@ -64,7 +64,7 @@ vcpkg_configure_cmake( -DLLVM_INSTALL_UTILS=ON # FileCheck "-DLLVM_ENABLE_PROJECTS=mlir;lld" "-DLLVM_TARGETS_TO_BUILD=host;X86;ARM;NVPTX;AMDGPU" - -DPACKAGE_VERSION=${LLVM_VERSION} + -DPACKAGE_VERSION=${LLVM_FRIENDLY_VERSION} # Force TableGen to be built with optimization. This will significantly improve build time. # cf. https://github.com/microsoft/vcpkg/blob/master/ports/llvm -DLLVM_OPTIMIZED_TABLEGEN=ON diff --git a/external/llvm/vcpkg.json b/external/llvm/vcpkg.json index 2b9172f0..36e65b52 100644 --- a/external/llvm/vcpkg.json +++ b/external/llvm/vcpkg.json @@ -1,6 +1,6 @@ { "name": "accera-llvm", - "version-string": "14.0.6-2", + "version-string": "15.0.0", "description": "LLVM Compiler Infrastructure for Accera.", "homepage": "https://llvm.org", "supports": "!uwp" diff --git a/requirements.txt b/requirements.txt index b68bbbdb..01cf7244 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ conan<2.0.0 lit packaging pytest -hatlib==0.0.38 # keep setup.cfg in sync; prefer == (avoids picking up breaking changes) +hatlib==0.0.39 # keep setup.cfg in sync; prefer == (avoids picking up breaking changes) varname py-cpuinfo termcolor diff --git a/setup.cfg b/setup.cfg index 1197abc7..92c94d3a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,13 +35,13 @@ install_requires = termcolor py-cpuinfo varname - hatlib==0.0.38 + hatlib==0.0.39 numpy pyyaml tomlkit>=0.11.1, <0.11.5 accera-compilers accera-gpu - accera-llvm==14.0.602 + accera-llvm==15.0.101 # keep in sync with accera/python/llvm/setup.cfg package_dir = accera = accera/python/accera accera.hat = accera/hat/scripts