From 95221f9eb59699c2bd97c09701ac8f4b7d141e2f Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Tue, 4 Feb 2025 11:25:14 +0100 Subject: [PATCH] OpenXLA-specific changes --- BUILD | 891 +++++++++++++++++ include/triton/Conversion/MLIRTypes.h | 13 +- include/triton/Dialect/Triton/IR/TritonOps.td | 7 +- lib/Analysis/Allocation.cpp | 6 + lib/Analysis/AxisInfo.cpp | 2 +- lib/Analysis/Utility.cpp | 9 +- .../TritonToTritonGPU/TritonGPUConversion.cpp | 12 + lib/Dialect/Triton/IR/Ops.cpp | 2 +- lib/Dialect/TritonGPU/IR/Ops.cpp | 13 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 50 +- .../Transforms/OptimizeDotOperands.cpp | 1 + .../Pipeliner/MatmulLoopPipeline.cpp | 9 +- .../Pipeliner/PipeliningUtility.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 26 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 36 +- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 5 +- .../Transforms/FenceInsertion.cpp | 3 +- python/BUILD | 79 ++ python/test/regression/BUILD | 26 + python/test/unit/BUILD | 183 ++++ python/test/unit/language/test_core.py | 19 + python/triton/_C/include | 2 +- python/triton/backends/__init__.py | 7 +- python/triton/language/core.py | 8 +- python/triton/runtime/build.py | 37 - test/BUILD | 70 ++ test/Conversion/tritongpu_to_llvm.mlir | 20 + test/TritonGPU/accelerate-matmul.mlir | 18 + test/TritonGPU/canonicalize.mlir | 17 + test/TritonGPU/dot-operands.mlir | 5 +- test/TritonGPU/prefetch.mlir | 20 + third_party/amd/BUILD | 266 ++++++ .../TritonAMDGPUToLLVM/BufferOpsEmitter.cpp | 6 +- .../ElementwiseOpToLLVM.cpp | 16 +- .../AccelerateAMDMatmul.cpp | 2 +- .../lib/TritonAMDGPUTransforms/MfmaGroup.cpp | 15 +- third_party/f2reduce/BUILD | 31 + third_party/nvidia/BUILD | 316 ++++++ third_party/nvidia/backend/BUILD | 30 + third_party/nvidia/backend/cuda_utils.cc | 897 ++++++++++++++++++ third_party/nvidia/backend/driver.c | 421 -------- third_party/nvidia/backend/driver.py | 494 ++-------- .../include/Dialect/NVGPU/IR/NVGPUOps.td | 9 + third_party/nvidia/language/cuda/BUILD | 13 + .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 35 +- .../SharedToDotOperandMMAv2OrV3.cpp | 1 + .../DecomposeUnsupportedConversions.cpp | 100 -- .../lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp | 1 + .../DotOpToLLVM/MMAv2.cpp | 16 +- .../DotOpToLLVM/MMAv5.cpp | 12 +- .../DotOpToLLVM/WGMMA.cpp | 7 +- .../ElementwiseOpToLLVM.cpp | 145 ++- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 1 + .../TensorPtrOpsToLLVM.cpp | 1 + .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 1 + .../lib/TritonNVIDIAGPUToLLVM/Utility.h | 2 - third_party/proton/BUILD | 130 +++ third_party/proton/proton/_C/include | 2 +- unittest/BUILD | 144 +++ 59 files changed, 3630 insertions(+), 1082 deletions(-) create mode 100644 BUILD create mode 100644 python/BUILD create mode 100644 python/test/regression/BUILD create mode 100644 python/test/unit/BUILD delete mode 100644 python/triton/runtime/build.py create mode 100644 test/BUILD create mode 100644 third_party/amd/BUILD create mode 100644 third_party/f2reduce/BUILD create mode 100644 third_party/nvidia/BUILD create mode 100644 third_party/nvidia/backend/BUILD create mode 100644 third_party/nvidia/backend/cuda_utils.cc delete mode 100644 third_party/nvidia/backend/driver.c create mode 100644 third_party/nvidia/language/cuda/BUILD delete mode 100644 third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp create mode 100644 third_party/proton/BUILD create mode 100644 unittest/BUILD diff --git a/BUILD b/BUILD new file mode 100644 index 0000000000000..a84af7503682e --- /dev/null +++ b/BUILD @@ -0,0 +1,891 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/AttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_op_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-op-interface-decls"], + "include/triton/Dialect/Triton/IR/OpInterfaces.h.inc", + ), + ( + ["--gen-op-interface-defs"], + "include/triton/Dialect/Triton/IR/OpInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOpInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + # There are so many interdependencies between Dialect and Analysis that we're just compiling + # everything in a single unit. + "lib/Analysis/*.cpp", + ]) + [ + "include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h", # Avoid circular dependency. + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", # Avoid circular dependency. + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + # There are so many interdependencies between Dialect and Analysis that we're just compiling + # everything in a single unit. + "include/triton/Analysis/*.h", + ]) + [ + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + # What is this lone header doing rooted under Conversion? Best to add it to Dialect, but + # it would be better if upstream moved it there. + "include/triton/Conversion/MLIRTypes.h", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_gpu_type_interfaces_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_op_interfaces_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h", + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-logical-op-parentheses", + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + "-Wno-string-conversion", + ], + }), + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]) + [ + "@triton//test:lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp", + ], + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":TritonTools", + ":triton_gpu_transforms_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index afa1aa989e6ec..96d60cdc0636c 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -28,15 +28,16 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } inline bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128() || - type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + type.isBF16() || + llvm::isa(type); } inline bool isFloat8(Type type) { - return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + return llvm::isa(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 6d72beaca6680..ba359f044bc56 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1105,7 +1105,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI MutableOperandRange getArgOperandsMutable() { return getOperandsMutable(); } - + Attribute removeArgAttrsAttr() { return nullptr; } + Attribute removeResAttrsAttr() { return nullptr; } + ArrayAttr getArgAttrsAttr() { return nullptr; } + ArrayAttr getResAttrsAttr() { return nullptr; } + void setArgAttrsAttr(ArrayAttr) { return; } + void setResAttrsAttr(ArrayAttr) { return; } }]; let assemblyFormat = [{ diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 14563810e2614..9e45ebe6aa005 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, std::tie(scratchConfig.inVec, scratchConfig.outVec) = getScratchCvtInOutVecLengths(srcTy, dstTy); + // We can't write a longer vector than the shape of shared memory. + // This shape might be smaller than the tensor shape in case we decided to + // do the conversion in multiple iterations. + unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]]; + scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim); + scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim); // No padding is required if the tensor is 1-D, or if all dimensions except // the first accessed dimension have a size of 1. diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index ee7c8f1ceb995..5db9d01a262f6 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n lhsDivisibility = 1; } - return std::max(1, lhsDivisibility / (1 << shift)); + return std::max(1, lhsDivisibility / (int64_t(1) << shift)); } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 01c2aef7d431e..56891c49100be 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -750,14 +750,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || + (llvm::isa(aElemTy) || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && + (llvm::isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -778,8 +778,9 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + bool isFP8 = + llvm::isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 06e75ee18d599..035991a5d44f2 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining arguments that have been converted to a new type. + // We use this to rewrite triton_xla.sparse_dot in a separate pass after + // 'convert-triton-to-tritongpu'. + return builder.create(loc, tensorType, + inputs); llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); return {}; @@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining uses of values that have been converted to a new type. + // We use this to rewrite triton_xla.sparse_dot in a separate pass after + // 'convert-triton-to-tritongpu'. + return builder.create(loc, tensorType, + inputs); llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); return {}; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index d8ed0492ce91e..3178481c5a0ad 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -899,7 +899,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index b831699bcb261..9e6837dfc6fa3 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -151,6 +151,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); @@ -213,13 +218,13 @@ struct CanonicalizeConvertFromConvert // heuristic to accommodate fused attention. auto srcType = op.getSrc().getType(); auto dstType = op.getType(); - if (mlir::isa(dstType.getEncoding()) && - mlir::isa(srcType.getEncoding())) + if (mlir::isa_and_nonnull(dstType.getEncoding()) && + mlir::isa_and_nonnull(srcType.getEncoding())) return failure(); // for hopper MMAv3 - if (mlir::isa(dstType.getEncoding()) && - mlir::isa(srcType.getEncoding()) && + if (mlir::isa_and_nonnull(dstType.getEncoding()) && + mlir::isa_and_nonnull(srcType.getEncoding()) && llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { return dot->hasTrait(); })) { diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 6d7632b1b7884..c66c9f4ae232c 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -21,8 +21,6 @@ namespace mlir { namespace triton { namespace gpu { -namespace { - // Get the highest version supported for the hardware and the dot. static int getMMAVersionSafe(int computeCapability, DotOp op) { // List supported mma version in order of preference. @@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - int numWarps) { +SmallVector +warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) @@ -112,10 +110,10 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, +warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); + mlir::getForwardSlice(dotOp->getResult(0), &slices); // Contains a chained dot. We prefer to assign warps to one axis // to facilitate use cases like flash attention, allowing reductions within // the same warp. @@ -170,11 +168,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } SmallVector -getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, +getWarpsPerTile(Operation* dotOp, const ArrayRef shape, int version, int numWarps, const SmallVector &instrShape) { switch (version) { case 2: @@ -188,6 +201,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, } static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } + return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || @@ -207,7 +230,7 @@ static bool bwdFilter(Operation *op) { // result, kwidth can be the bitwidth of the lower precision primitive. // Conversely, in the downcasting scenario, no reordering is performed, // making it directory use the lower precision primitive. -static int computeOrigBitWidth(Value x) { +int computeOrigBitWidth(Value x) { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; SetVector slice; @@ -227,6 +250,9 @@ static int computeOrigBitWidth(Value x) { } return origBitWidth; } +// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity +// extension. +namespace { class BlockedToMMA : public mlir::OpRewritePattern { int computeCapability; @@ -632,7 +658,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = + llvm::isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || @@ -1018,6 +1045,11 @@ class TritonGPUAccelerateMatmulPass } }; +Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, + int opIdx, bool allowTranspose) { + return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); +} + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5b268b154241e..f6e3c0fd8c6c3 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -285,6 +285,7 @@ class HoistLayoutConversion : public OpRewritePattern { if (!foundInitializer) return failure(); + rewriter.setInsertionPointAfter(src); SmallVector newOperands; for (auto operand : src->getOperands()) { // We checked earlier that all operands are ranked tensors. diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 47ed1b232423e..06a5952d845a0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -132,6 +132,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, Value zero = builder.createWithStage( forOp.getLoc(), stage, clusterId, 0, 32); + // Replace the load with insert/extract slice. builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); @@ -527,7 +528,8 @@ assignMemoryLayouts(scf::ForOp &forOp, bool isTMALoad = isa(op); - loadsToPipeline.insert(&op); + // TODO: b/381421713 - Uncomment this once pipelining is fixed. + // loadsToPipeline.insert(&op); LoadInfo loadInfo; for (auto use : users) { if (use->hasTrait()) { @@ -566,6 +568,11 @@ assignMemoryLayouts(scf::ForOp &forOp, getBlockedEncoding(loadOp, axisInfoAnalysis); } } + + // TODO: b/381421713 - Remove this once pipelining is fixed. + if (!loadInfo.sharedEncoding) continue; + loadsToPipeline.insert(&op); + loadToInfo[&op] = loadInfo; } // Make sure all loads in loadsToPipeline are in loadToInfo. diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 6530abd3ab650..121395ea253bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -255,7 +255,7 @@ mlir::triton::maybeGetStageCluster(Operation *op) { } std::pair mlir::triton::getStageCluster(Operation *op) { auto res = maybeGetStageCluster(op); - assert(res.has_value() || "Operation is missing stage & cluster attribute"); + assert(res.has_value() && "Operation is missing stage & cluster attribute"); return *res; } diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index f7795ebbf7ba7..da9245b783c5f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, // opIdx: 0 => a, 1 => b auto type = cast(v.getType()); SmallVector shape{type.getShape().begin(), type.getShape().end()}; - SmallVector offset{0, 0}; + SmallVector offset(shape.size(), 0); Type elementType = type.getElementType(); // k => (prefetchWidth, k - prefetchWidth) @@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMutableMemory(), type.getAllocShape()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } + // Propagation through ExpandDims is currently not supported. This blindly + // replaces the encoding with dot encoding & but ExpandDims requires a + // SliceEncoding. This could be rewritten to support it somehow, but I + // don't think it's trivial & it's currently crashing. + if (isa(op)) { + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { foundConvertFromShared = true; diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 21b8e059ca298..e7e568beb624d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,9 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || - eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || - eltType.isF32()) { + if (llvm::isa(eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); @@ -1004,18 +1004,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; - auto dotOpEnc = dyn_cast( - cast(user->getResult(0).getType()) - .getEncoding()); - if (!dotOpEnc) + auto enc = + cast(user->getResult(0).getType()).getEncoding(); + if (isa(enc)) { + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), cast(enc), + srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false); + } else if (enc.getAbstractAttribute().getName().str() == + "triton.gpu.sparse_dot_meta_encoding") { + auto srcTy = cast(val.getType()); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1, + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding())); + } else { return std::nullopt; - auto srcTy = cast(val.getType()); - auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, - bitWidth, /*needTrans=*/false); + } } // Check that the shared encodings needed by the users are compatible. if (attr != nullptr && attr != tempAttr) { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index a171d89339967..baaec513bfb44 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,8 +77,9 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || - aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool isFP8 = + llvm::isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fdb189..1e4e6a0a95e73 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -44,7 +44,8 @@ struct FenceInsertionPass return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (!isa(op)) + if (!isa(op) && + op->getName().getStringRef() != "triton_xla.sparse_dot") return WalkResult::advance(); OpBuilder builder(op); auto a = op->getOperand(0); diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 0000000000000..247b8cda2103b --- /dev/null +++ b/python/BUILD @@ -0,0 +1,79 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Instrumentation", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 0000000000000..a88f4eeae1f85 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 0000000000000..72a7c6f33c0b7 --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,183 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_blaze_config_cuda": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test + "test_debug.py", # TODO(b/374733875): fix failing test. Also see b/374733872. + "language/test_compile_only.py", # TODO(b/394497996): enable test, when CUDA version in g3 supports Blackwell +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + "language/test_mxfp.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + [ + "language/test_core.py", + "language/test_pipeliner.py", # TODO(b/362458006): fix failing test + "cuda/test_experimental_tma.py", # TODO(b/362458006): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "cuda/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + "language/test_mxfp.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 0c4ee23cfcbf3..c8bdf5f29ffc5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4194,6 +4194,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1b..8a5dba6c4b560 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b2..f9bab523bf6ce 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3ad797f527beb..ddcde1f6e61c9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -794,7 +794,7 @@ def __str__(self) -> str: @builtin def __add__(self, other, _builder=None): - return add(self, other, sanitize_overflow=True, _builder=_builder) + return add(self, other, sanitize_overflow=False, _builder=_builder) @builtin def __radd__(self, other, _builder=None): @@ -810,7 +810,7 @@ def __rsub__(self, other, _builder=None): @builtin def __mul__(self, other, _builder=None): - return mul(self, other, sanitize_overflow=True, _builder=_builder) + return mul(self, other, sanitize_overflow=False, _builder=_builder) @builtin def __rmul__(self, other, _builder=None): @@ -2177,7 +2177,7 @@ def where(condition, x, y, _builder=None): @builtin -def add(x, y, sanitize_overflow: constexpr = True, _builder=None): +def add(x, y, sanitize_overflow: constexpr = False, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) return semantic.add(x, y, sanitize_overflow, _builder) @@ -2191,7 +2191,7 @@ def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): @builtin -def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): +def mul(x, y, sanitize_overflow: constexpr = False, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) return semantic.mul(x, y, sanitize_overflow, _builder) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py deleted file mode 100644 index 1b76548d43a77..0000000000000 --- a/python/triton/runtime/build.py +++ /dev/null @@ -1,37 +0,0 @@ -import sysconfig -import os -import shutil -import subprocess - - -def _build(name, src, srcdir, library_dirs, include_dirs, libraries): - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) - # try to avoid setuptools if possible - cc = os.environ.get("CC") - if cc is None: - # TODO: support more things here. - clang = shutil.which("clang") - gcc = shutil.which("gcc") - cc = gcc if gcc is not None else clang - if cc is None: - raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) - include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] - # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] - cc_cmd += [f'-l{lib}' for lib in libraries] - cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] - subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) - return so diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 0000000000000..5e6bece13f7bf --- /dev/null +++ b/test/BUILD @@ -0,0 +1,70 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:non_prod"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "@llvm-project//llvm:opt", +# "@llvm-project//mlir:mlir-translate", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# # broken, offending change reverted in +# # https://github.com/triton-lang/triton/commit/3ed479f2f91a1d94dacb547115d357f5ce3219d8 +# "Conversion/reduce_to_llvm.mlir", +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/amd/amd-instruction-sched.mlir", # AMD-specific, broken with -debug-only. +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +exports_files(srcs = [ + "lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp", +]) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 18fa950efd8f0..53eef45374799 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1269,6 +1269,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // ----- +// Regression test for https://github.com/triton-lang/triton/issues/5745 +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], warp = [[1, 0], [2, 0], [4, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 2]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [1, 0]], warp = [[2, 0], [4, 0], [0, 1]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: linear_layout_with_multiple_iterations + tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) { + %cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1> + // CHECK: inline_asm{{.*}}st.shared.v2 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK: nvvm.barrier0 + // CHECK: inline_asm{{.*}}st.shared.v2 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 417b8db398f92..1d028cc0fd774 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -350,3 +350,21 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- tt.return %result : tensor<128x128xf32, #blocked> } } + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #ttg.blocked +// CHECK-DAG: #mma = #ttg.nvidia_mma<{versionMajor = 3 +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = ttg.local_alloc + // CHECK: %[[RHS_CVT:.*]] = ttg.convert_layout {{.*}} #ttg.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = ttg.local_alloc %[[RHS_CVT]] + // CHECK: ttng.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 73ae6abe361db..67d6202fd693f 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -178,3 +178,20 @@ tt.func @infer_trans(%arg0: tensor<32x32xf32, #linear>) -> tensor<32x32xf32, #bl } } + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #ttg.blocked +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> !ttg.memdesc<256x32xf32, #shared1, #smem> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = ttg.convert_layout %in : tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = ttg.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared1, #smem> + // CHECK: %[[ALLOC:.*]] = ttg.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !ttg.memdesc<256x32xf32, #shared1, #smem> + } +} // end module + diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index c18b2d222ccca..afdb0ce68bfea 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -340,6 +340,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @propagate_dot_op_to_constant_above_for() // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + // CHECK: tt.elementwise_inline_asm + // CHECK: scf.for + // CHECK: tt.dot tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> @@ -347,8 +350,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 %c128_i32 = arith.constant 128 : i32 + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 208516b3bfabd..481e982cd8cbf 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -244,3 +244,23 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt tt.return %loop#4 : tensor<128x128xf32, #C> } } // end module + + // ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, %arg_rhs: !ttg.memdesc<512x32xf32, #shared, #smem, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, !ttg.memdesc<512x32xf32, #shared, #smem, mutable>) : i32 { + %lhs_ll = ttg.local_load %lhs : !ttg.memdesc<16x512xf32, #shared, #smem, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = ttg.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = ttg.local_load %rhs : !ttg.memdesc<512x32xf32, #shared, #smem, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = ttg.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, !ttg.memdesc<512x32xf32, #shared, #smem, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 0000000000000..6364535899427 --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,266 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/backends/gpu/codegen/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob( + [ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ], + exclude = [ + "lib/TritonAMDGPUTransforms/MfmaGroup.cpp", # Avoid circular dependency. + ], + ) + [ + # Work around dependencies on private headers. + "lib/TritonAMDGPUToLLVM/SchedInstructions.h", + "lib/TritonAMDGPUToLLVM/TargetInfo.h", + "lib/TritonAMDGPUToLLVM/Utility.h", + ], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":TritonAMDGPU", + ":TritonAMDGPUToLLVM", + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPU", + srcs = glob([ + "lib/Dialect/TritonAMDGPU/**/*.h", + "lib/Dialect/TritonAMDGPU/**/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/TritonAMDGPU/**/*.h", + ]), + includes = [ + "..", + "include", + ], + deps = [ + ":triton_amdgpu_attr_def_inc_gen", + ":triton_amdgpu_dialect_inc_gen", + ":triton_amdgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TensorDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + # TritonAMDGPUToLLVM and TritonAMDGPUDialectToLLVM have interdependencies, easiest way to + # deal with circular dependencies is to just compile both in a single unit. + "lib/TritonAMDGPUDialectToLLVM/**/*.h", + "lib/TritonAMDGPUDialectToLLVM/**/*.cpp", + ]) + [ + "include/TritonAMDGPUTransforms/MfmaGroup.h", # Avoid circular dependency. + "lib/TritonAMDGPUTransforms/MfmaGroup.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable + ["-Wno-implicit-fallthrough"], + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPU", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:AMDGPUDialect", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "@triton//third_party/proton:TritonProtonToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_ops_inc_gen", + tbl_outs = [ + ( + [ + "--gen-llvmir-conversions", + ], + "include/Dialect/TritonAMDGPU/IR/OpsConversions.inc", + ), + ( + [ + "--gen-op-decls", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.h.inc", + ), + ( + [ + "--gen-op-defs", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_dialect_inc_gen", + tbl_outs = [ + ( + [ + "--gen-dialect-decls", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.h.inc", + ), + ( + [ + "--gen-dialect-defs", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_attr_def_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index 6d79bd7aae08c..5eaad3d350064 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -1,9 +1,9 @@ -#include "PatternTritonGPUOpToLLVM.h" -#include "TargetInfo.h" -#include "Utility.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" #include "triton/Dialect/Triton/IR/Dialect.h" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 35a2e1a34bcda..4ff658ab0b7a1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1019,17 +1019,17 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() || - srcElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E4M3FNUZ() || - srcElementType.isFloat8E5M2FNUZ() || - dstElementType.isFloat8E5M2FNUZ()) { + if (llvm::isa(srcElementType) || + llvm::isa(dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = - srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (dstElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E5M2FNUZ())); + srcElementType.isF32() && + !(isaFamily == AMD::ISAFamily::CDNA3 && + (llvm::isa( + dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 7ea13142a76c2..005089aaf7ac0 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,7 +416,7 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 4979ee005b9f0..f96e70e727e2c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -20,19 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) { return MfmaTypeId::I8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Fp8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Bf8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Fp8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Bf8TyId; } - if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp16TyId; } llvm_unreachable("Unsupported input argument type."); diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 0000000000000..93829539e1b97 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 0000000000000..e0cf10e7c2492 --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,316 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/backends/gpu:__subpackages__", + # "//third_party/tensorflow/compiler/xla/pjrt:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-return-type", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h", + "lib/TritonNVIDIAGPUToLLVM/TargetInfo.h", + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + "lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "//:triton_gpu_attr_inc_gen", + "@triton//third_party/proton:TritonProtonToLLVM", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 0000000000000..a5b34aa5c29b9 --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc new file mode 100644 index 0000000000000..2c9defd219243 --- /dev/null +++ b/third_party/nvidia/backend/cuda_utils.cc @@ -0,0 +1,897 @@ +#define PY_SSIZE_T_CLEAN +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cuda.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace { + +struct UniquePyObjectDeleter { + void operator()(PyObject* obj) { Py_DECREF(obj); } +}; +// A unique_ptr for PyObjects that automatically calls Py_DECREF once it goes +// out of scope. +using UniquePyObjectPtr = std::unique_ptr; + +// Raise a python exception if the CUDA result code is not CUDA_SUCCESS. +// Can be called even on threads that do not hold Python's Global Interpreter +// Lock (GIL), as the function will acquire one if needed. +inline bool gpuAssert(CUresult code, const char* file, int line) { + if (code == CUDA_SUCCESS) + return true; + const char* error = nullptr; + cuGetErrorString(code, &error); + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyErr_Format(PyExc_RuntimeError, "Triton Error [CUDA]: %s", error); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +#define CUDA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +using cuLaunchKernelEx_t = CUresult (*)(const CUlaunchConfig* config, + CUfunction f, void** kernelParams, + void** extra); + +// Dynamically load the handle to cuLaunchKernelEx. +cuLaunchKernelEx_t getLaunchKernelExHandle() { + // Open the shared library + void* handle = dlopen("libcuda.so.1", RTLD_LAZY); + if (!handle) { + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); + return nullptr; + } + // Clear any existing error + dlerror(); + auto cuLaunchKernelExHandle = + reinterpret_cast(dlsym(handle, "cuLaunchKernelEx")); + // Check for errors + if (const char* dlsym_error = dlerror()) { + PyErr_Format(PyExc_RuntimeError, + "Failed to retrieve cuLaunchKernelEx from libcuda.so: %s", + dlsym_error); + return nullptr; + } + return cuLaunchKernelExHandle; +} + +// Configuration with all the information necessary to launch a compiled +// Triton kernel using the CUDA driver API. +struct TritonLaunchConfig { + // Represents CUDA's 3D ID structure of grids and clusters + struct Dim { + int x; + int y; + int z; + constexpr int size() const { return x * y * z; } + }; + Dim grid; // Number of clusters per grid + Dim cluster; // Number of blocks per cluster + int num_warps; // number of warps per block + int shared_memory; // Size of shared memory in bytes to allocate + CUstream stream; // CUDA Stream on which to launch the kernel + CUfunction function; // Pointer to the kernel to launch + void** params; // Parameters to pass to the kernel +}; + +// Launch a CUDA kernel with the given parameters. Raises a Python exception +// if the kernel launch fails. +PyObject* launchKernel(const TritonLaunchConfig& config) { + // Launching the kernel might take a while and does not use Python APIs, so + // we can release the Global Interpreter Lock so other threads can use Python + // APIs if needed. + Py_BEGIN_ALLOW_THREADS; + const auto& grid = config.grid; + const auto& cluster = config.cluster; + if (grid.size() == 0) { + PyEval_RestoreThread(_save); + Py_RETURN_NONE; + } + if (cluster.size() == 1) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuLaunchKernel( + config.function, grid.x, grid.y, grid.z, 32 * config.num_warps, 1, 1, + config.shared_memory, config.stream, config.params, 0)); + } else { + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = cluster.x; + launchAttr[0].value.clusterDim.y = cluster.y; + launchAttr[0].value.clusterDim.z = cluster.z; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + CUlaunchConfig cu_config; + cu_config.gridDimX = grid.x * cluster.x; + cu_config.gridDimY = grid.y * cluster.y; + cu_config.gridDimZ = grid.z * cluster.z; + cu_config.blockDimX = 32 * config.num_warps; + cu_config.blockDimY = 1; + cu_config.blockDimZ = 1; + cu_config.sharedMemBytes = config.shared_memory; + cu_config.hStream = config.stream; + cu_config.attrs = launchAttr; + cu_config.numAttrs = 2; + // cuLaunchKernelEx was added in CUDA 12, so load it dynamically to be + // able to link on CUDA 11 and earlier. + static cuLaunchKernelEx_t cuLaunchKernelExHandle = + getLaunchKernelExHandle(); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuLaunchKernelExHandle(&cu_config, config.function, config.params, 0)); + } + Py_END_ALLOW_THREADS; + Py_RETURN_NONE; +} + +// Interface used by various PyObject extractors to extract obj into a memory +// location pointed by ptr. Returns true if extraction succeeded, and false +// otherwise. +using ExtractorType = bool (*)(PyObject* obj, void* ptr); + +// Extract a CUDA device pointer from a pointer-like PyObject obj, and store +// it to the memory location pointed by ptr. +bool extractPointer(PyObject* obj, void* ptr) { + auto dev_ptr = static_cast(ptr); + if (obj == Py_None) { + *dev_ptr = static_cast(0); // valid nullptr + return true; + } + if (PyLong_Check(obj)) { + *dev_ptr = PyLong_AsUnsignedLongLong(obj); + return true; + } + UniquePyObjectPtr ret(PyObject_CallMethod(obj, "data_ptr", nullptr)); + if (!ret.get()) { + PyErr_Format(PyExc_TypeError, + "Pointer argument must be either uint64 or have data_ptr " + "method, but got %R", + obj); + return false; + } + if (!PyLong_Check(ret.get())) { + PyErr_SetString(PyExc_TypeError, + "data_ptr method of Pointer object must return 64-bit int"); + return false; + } + *dev_ptr = PyLong_AsUnsignedLongLong(ret.get()); + if (PyErr_Occurred()) { + return false; + } + if (*dev_ptr == 0) { + return true; // valid nullptr + } + CUresult status = cuPointerGetAttribute( + dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, *dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) { + PyErr_Format(PyExc_ValueError, + "Pointer argument cannot be accessed from Triton " + "(cpu tensor?)"); + return false; + } else if (status != CUDA_SUCCESS) { + CUDA_CHECK(status); + return false; + } + return true; +} + +// For a given type T, maps to the Python API with signature `U (*)(PyObject*)` +// that can extract values of that type from a PyObject. Note that the return +// type U is not guaranteed to be the same as T, but it can be explicitly casted +// to T. +template +constexpr auto kValueFunction = nullptr; +template +constexpr auto + kValueFunction && + std::is_signed_v && sizeof(T) <= 4>> = + PyLong_AsLong; +template <> +constexpr auto kValueFunction = PyLong_AsLongLong; +template +constexpr auto kValueFunction< + T, std::enable_if_t && std::is_unsigned_v && + sizeof(T) <= 4>> = PyLong_AsUnsignedLong; +template <> +constexpr auto kValueFunction = PyLong_AsUnsignedLongLong; +template +constexpr auto + kValueFunction>> = + PyFloat_AsDouble; + +// Extract a value of type T from obj and store it into memory pointed by ptr. +// Returns true if extraction succeeded, and false otherwise. +template +bool extractValue(PyObject* obj, void* ptr) { + *static_cast(ptr) = static_cast(kValueFunction(obj)); + return PyErr_Occurred() == nullptr; +} + +// Contains information necessary for extracting a certain type from a PyObject. +struct ExtractionInfo { + // Prefixes of types reprs supported by the extractor. + llvm::SmallVector supported_type_repr_prefixes; + std::size_t size; // Size required by the extracted value. + ExtractorType extractor; // Function to call to extract the value. + + // Builds an ExtractionInfo for a given type T and a list of type reprs that + // are backed by that type. + template + static ExtractionInfo build( + std::initializer_list supported_type_reprs, + ExtractorType extractor = extractValue) { + return {supported_type_reprs, sizeof(T), extractor}; + } + + // Checks if the extractor supports extracting a given type repr. + bool supports(llvm::StringRef type_repr) const { + return llvm::any_of( + supported_type_repr_prefixes, + [&](llvm::StringRef prefix) { return type_repr.starts_with(prefix); }); + } +}; + +// Array of supported extractors +const ExtractionInfo kExtractionInfos[]{ + ExtractionInfo::build({"'i8'"}), + ExtractionInfo::build({"'i16'"}), + ExtractionInfo::build({"'i1'", "'i32'"}), + ExtractionInfo::build({"'i64'"}), + ExtractionInfo::build({"'u8'"}), + ExtractionInfo::build({"'u16'"}), + ExtractionInfo::build({"'u1'", "'u32'"}), + ExtractionInfo::build({"'u64'"}), + ExtractionInfo::build({"'fp16'", "'bf16'", "'fp32'", "'f32'"}), + ExtractionInfo::build({"'fp64'"}), + // Note: types are e.g. '*fp32', so no closing quote is intentional. + ExtractionInfo::build({"'*"}, extractPointer), + ExtractionInfo{ + {"None", "'none'"}, 0, nullptr}, // Represent constexprs as None +}; + +// Finds an extractor that supports a given type_repr in the extractor list. +// Returns nullopt if no such extractor is found. +std::optional findExtractor(llvm::StringRef type_repr) { + constexpr std::size_t kNumExtractors = std::size(kExtractionInfos); + static_assert(kNumExtractors < std::numeric_limits::max(), + "Not enough bits in a byte to store the extractor index"); + for (const auto& [idx, info] : llvm::enumerate(kExtractionInfos)) { + if (info.supports(type_repr)) return idx; + } + return std::nullopt; +} + +PyDoc_STRVAR(buildSignatureMetadata__doc__, + R"(buildSignatureMetadata(signature_iterator) -> bytes + +Build a metadata object describing the signature of a kernel. + +This can then be passed as the signature_metadata parameter to the launch() +function. + +:param signature: list of types describing the signature of a kernel, + specialized parameters should be represented with None +:type signature: sequence or iterable +:return: an opaque metadata object which can then be passed to launch() +:rtype: bytes +)"); +PyObject* buildSignatureMetadata(PyObject* self, PyObject* args) { + PyObject* signature = nullptr; + if (!PyArg_ParseTuple(args, "O", &signature)) { + return nullptr; + } + if (!PyIter_Check(signature)) { + PyErr_Format(PyExc_TypeError, + "expected signature to be an iterable, got %R", signature); + return nullptr; + } + + llvm::SmallVector signature_metadata; + while (UniquePyObjectPtr obj_type{PyIter_Next(signature)}) { + UniquePyObjectPtr repr(PyObject_Repr(obj_type.get())); + if (!repr) { + return nullptr; + } + UniquePyObjectPtr repr_str( + PyUnicode_AsEncodedString(repr.get(), "utf-8", "~E~")); + if (!repr_str) { + return nullptr; + } + const char* repr_bytes = PyBytes_AsString(repr_str.get()); + if (!repr_bytes) { + return nullptr; + } + std::optional extractor_idx = findExtractor(repr_bytes); + if (!extractor_idx.has_value()) { + PyErr_Format(PyExc_TypeError, + "unexpected type %R in kernel signature, dir: %R", + obj_type.get(), PyObject_Dir(obj_type.get())); + return nullptr; + } + signature_metadata.push_back(extractor_idx.value()); + } + if (PyErr_Occurred()) { + return nullptr; + } + + return PyBytes_FromStringAndSize(signature_metadata.data(), + signature_metadata.size()); +} + +// Launch a Python callable hook with metadata passed as parameters. +bool launchHook(PyObject* hook, PyObject* metadata) { + if (hook == Py_None) { + return true; + } + UniquePyObjectPtr args(Py_BuildValue("(O)", metadata)); + if (!args) { + return false; + } + UniquePyObjectPtr ret(PyObject_CallObject(hook, args.get())); + return static_cast(ret); +} + +static void ensureCudaContext() { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +PyDoc_STRVAR( + launch__doc__, + R"(launch(gridDimX, gridDimY, gridDimZ, stream, kernel, packed_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, kernel_arg_types, global_scratch, kernel_args) + +Launch a kernel on an Nvidia GPU. + +:param gridDimX: X dimension of the grid +:type gridDimX: signed integer +:param gridDimY: Y dimension of the grid +:type gridDimY: signed integer +:param gridDimZ: Z dimension of the grid +:type gridDimZ: signed integer +:param stream: CUDA Stream to launch on +:type stream: unsigned long integer (pointer) +:param kernel: CUDA kernel to launch +:type kernel: unsigned long integer (pointer) +:param packed_metadata: Kernel metadata, including in sequence: + number of warps, number of CTAs, required bytes of shared memory, + cluster dimensions x, y, and z +:type packed_metadata: 6-tuple +:param hook_args: arguments to pass to the enter and exit hooks +:type hook_args: object +:param launch_enter_hook: hook to call just before launching the kernel +:type launch_enter_hook: callable +:param launch_exit_hook: hook to call just after launching the kernel +:type launch_exit_hook: callable +:param signature_metadata: matadata built from build_signature_metadata +:type signature_metadata: bytes +:param global_scratch: pointer to global scratch memory +:type global_scratch: pointer +:param kernel_args: kernel parameters +:type kernel_args: tuple + +:raises RuntimeError: on kernel launch failure +)"); +PyObject* launch(PyObject* self, PyObject* args) { + ensureCudaContext(); + TritonLaunchConfig config{}; + auto& grid = config.grid; + auto& cluster = config.cluster; + // PyObject* kernel_metadata = nullptr; + PyObject* hook_args = nullptr; + PyObject* launch_enter_hook = nullptr; + PyObject* launch_exit_hook = nullptr; + PyBytesObject* signature_metadata_bytes = nullptr; + PyObject* kernel_args = nullptr; + PyObject* global_scratch = nullptr; + int num_ctas = 0; + if (!PyArg_ParseTuple(args, "iiiKK(iiiiii)OOOSOO", &grid.x, &grid.y, &grid.z, + &config.stream, &config.function, &config.num_warps, + &num_ctas, &config.shared_memory, &cluster.x, + &cluster.y, &cluster.z, &hook_args, &launch_enter_hook, + &launch_exit_hook, &signature_metadata_bytes, + &global_scratch, &kernel_args)) { + return nullptr; + } + if (num_ctas != cluster.size()) { + PyErr_Format( + PyExc_ValueError, + "Expected cluster dimensions (%d, %d, %d) to have a total size of %d", + cluster.x, cluster.y, cluster.z, num_ctas); + return nullptr; + } + llvm::ArrayRef signature_metadata( + PyBytes_AS_STRING(signature_metadata_bytes), + PyBytes_GET_SIZE(signature_metadata_bytes)); + UniquePyObjectPtr fast_kernel_args(PySequence_Fast( + kernel_args, "Expected kernel_args to be a sequence or iterable")); + if (!fast_kernel_args) { + return nullptr; + } + llvm::ArrayRef kernel_args_data( + PySequence_Fast_ITEMS(fast_kernel_args.get()), + PySequence_Fast_GET_SIZE(fast_kernel_args.get())); + + if (signature_metadata.size() != kernel_args_data.size()) { + PyErr_Format(PyExc_TypeError, + "Expected kernel to have %d parameters, but got %d", + signature_metadata.size(), kernel_args_data.size()); + return nullptr; + } + + // +1 for the global scratch pointer. + std::size_t num_params = signature_metadata.size() + 1; + // Use alloca to set up kernel parameters on the stack and avoid dynamic + // memory allocations. + config.params = static_cast(alloca(num_params * sizeof(void*))); + // This loop has to stay in the same function that owns params, since we are + // using alloca to allocate pointers to it on the stack of the function. + std::size_t params_idx = 0; + for (const auto& [converter_idx, arg] : + llvm::zip(signature_metadata, kernel_args_data)) { + if (converter_idx >= std::size(kExtractionInfos)) { + PyErr_SetString(PyExc_ValueError, "corrupted signature metadata"); + return nullptr; + } + const ExtractionInfo& extraction_info = kExtractionInfos[converter_idx]; + if (extraction_info.size == 0) { + continue; // skip adding constexpr parameters + } + config.params[params_idx] = alloca(extraction_info.size); + if (!extraction_info.extractor(arg, config.params[params_idx])) { + return nullptr; + } + ++params_idx; + } + config.params[params_idx] = alloca(sizeof(void*)); + if (!extractPointer(global_scratch, config.params[params_idx])) { + return nullptr; + } + + if (!launchHook(launch_enter_hook, hook_args)) { + return nullptr; + } + + if(!launchKernel(config)) { + return nullptr; + } + + if (!launchHook(launch_exit_hook, hook_args)) { + return nullptr; + } + + Py_RETURN_NONE; +} + +} // namespace + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK_AND_RETURN_NULL( + cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + CUdevice device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + CUfunction fun; + CUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + CUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGet(&device, 0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); + } + + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared > 49152 && shared_optin > 49152) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + +#if CUDA_VERSION >= 12000 +typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); +#endif + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so.1"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ + } + +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); + +#if CUDA_VERSION >= 12000 +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); +#endif + +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, + maxActiveClusters = -1; + int shared = 0; + CUfunction func; + + if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, + &clusterDimY, &clusterDimZ)) { + return NULL; + } + + // Let each SM have one block + int maxActiveBlocks = 1; + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); + Py_END_ALLOW_THREADS; + + CUlaunchAttribute launchAttr[1]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + CUlaunchConfig config; + config.gridDimX = clusterDimX; + config.gridDimY = maxActiveBlocks * clusterDimY; + config.gridDimZ = clusterDimZ; + config.blockDimX = 128; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared; + config.hStream = 0; + config.numAttrs = 1; + config.attrs = launchAttr; + + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, + getCuOccupancyMaxActiveClustersHandle); + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); + Py_END_ALLOW_THREADS; + return PyLong_FromLong(maxActiveClusters); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + // Ensure we have an active context. + CUcontext ctx = NULL; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); + if (!ctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); + } + + // We can't set the fifo size after running a kernel that calls printf. This + // is true even if the set() call is a nop and the new size is the same as the + // old size. + // + // This is unfriendly, so check if the old size matches the new size, and skip + // the set() call if so. + size_t oldSize = 0; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); + if (oldSize != size) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); + } + + Py_END_ALLOW_THREADS; + Py_INCREF(Py_None); + return Py_None; +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else + unsigned long long global_address; + uint64_t dim; + uint32_t tensorDim; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { + return NULL; + } + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t boxDim[1] = {tensorDim}; + uint32_t elementStrides[1] = {1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + return NULL; + } + assert((elementSize * tensorDim) >= 32 && "block size too small."); + int rank = 1; + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +#endif +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else + unsigned long long global_address; + uint64_t dims[2]; + uint32_t tensorDims[2]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], + &tensorDims[1], &tensorDims[0], &elementSize, + &desc_address)) { + return NULL; + } + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + } + int rank = 2; + // Swizzling should be picked in codegen but since we need to set it on the + // descriptor we rely on a convention between this function and codegen. + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + assert(false && "block size too small."); + } + // The bounding box inner dimension must be less than or equal to the swizzle + // size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy operations. + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +#endif +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, + "Python interface for cuOccupancyMaxActiveClusters function"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, + {"build_signature_metadata", buildSignatureMetadata, METH_VARARGS, + buildSignatureMetadata__doc__}, + {"launch", launch, METH_VARARGS, launch__doc__}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cuda_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c deleted file mode 100644 index 12deb0d1e7a30..0000000000000 --- a/third_party/nvidia/backend/driver.c +++ /dev/null @@ -1,421 +0,0 @@ -#include "cuda.h" -#include -#include -#define PY_SSIZE_T_CLEAN -#include - -// Raises a Python exception and returns false if code is not CUDA_SUCCESS. -static bool gpuAssert(CUresult code, const char *file, int line) { - if (code == CUDA_SUCCESS) - return true; - - const char *prefix = "Triton Error [CUDA]: "; - const char *str; - cuGetErrorString(code, &str); - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - return false; -} - -// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. -#define CUDA_CHECK_AND_RETURN_NULL(ans) \ - do { \ - if (!gpuAssert((ans), __FILE__, __LINE__)) \ - return NULL; \ - } while (0) - -// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. -#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ - do { \ - if (!gpuAssert((ans), __FILE__, __LINE__)) { \ - PyEval_RestoreThread(_save); \ - return NULL; \ - } \ - } while (0) - -// Used to check if functions exist in old CUDA driver versions. -#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ - do { \ - if ((funcPointer) == NULL) { \ - (funcPointer) = (initializerFunction)(); \ - if ((funcPointer) == NULL) { \ - return NULL; \ - } \ - } \ - } while (0) - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - // Get device handle - CUdevice device; - cuDeviceGet(&device, device_id); - - // create a struct to hold device properties - int max_shared_mem; - int max_num_regs; - int multiprocessor_count; - int warp_size; - int sm_clock_rate; - int mem_clock_rate; - int mem_bus_width; - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); - CUDA_CHECK_AND_RETURN_NULL( - cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); - - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", - max_shared_mem, "max_num_regs", max_num_regs, - "multiprocessor_count", multiprocessor_count, "warpSize", - warp_size, "sm_clock_rate", sm_clock_rate, - "mem_clock_rate", mem_clock_rate, "mem_bus_width", - mem_bus_width); -} - -static PyObject *loadBinary(PyObject *self, PyObject *args) { - const char *name; - const char *data; - Py_ssize_t data_size; - int shared; - int device; - if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, - &device)) { - return NULL; - } - CUfunction fun; - CUmodule mod; - int32_t n_regs = 0; - int32_t n_spills = 0; - // create driver handles - CUcontext pctx = 0; - - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); - if (!pctx) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); - } - - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuModuleGetFunction(&fun, mod, name)); - // get allocated registers and spilled registers from the function - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); - n_spills /= 4; - // set dynamic shared memory if necessary - int shared_optin; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( - &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - if (shared > 49152 && shared_optin > 49152) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); - int shared_total, shared_static; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( - &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, - device)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( - &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_optin - shared_static)); - } - Py_END_ALLOW_THREADS; - - if (PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, - n_spills); -} - -typedef CUresult (*cuOccupancyMaxActiveClusters_t)( - int *numClusters, CUfunction func, const CUlaunchConfig *config); - -typedef CUresult (*cuTensorMapEncodeTiled_t)( - CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, - const cuuint64_t *globalStrides, const cuuint32_t *boxDim, - const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill); - -#define defineGetFunctionHandle(name, symbolName) \ - static symbolName##_t name() { \ - /* Open the shared library */ \ - void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ - if (!libHandle) { \ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ - return NULL; \ - } \ - /* Clear any existing error */ \ - dlerror(); \ - symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ - /* Check for errors */ \ - const char *err = dlerror(); \ - if (err) { \ - PyErr_SetString(PyExc_RuntimeError, \ - "Failed to retrieve " #symbolName " from libcuda.so.1"); \ - dlclose(libHandle); \ - return NULL; \ - } \ - return funcHandle; \ - } - -defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, - cuOccupancyMaxActiveClusters); - -defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, - cuTensorMapEncodeTiled); - -static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { - int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, - maxActiveClusters = -1; - int shared = 0; - CUfunction func; - - if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, - &clusterDimY, &clusterDimZ)) { - return NULL; - } - - // Let each SM have one block - int maxActiveBlocks = 1; - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( - func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); - Py_END_ALLOW_THREADS; - - CUlaunchAttribute launchAttr[1]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - CUlaunchConfig config; - config.gridDimX = clusterDimX; - config.gridDimY = maxActiveBlocks * clusterDimY; - config.gridDimZ = clusterDimZ; - config.blockDimX = 128; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared; - config.hStream = 0; - config.numAttrs = 1; - config.attrs = launchAttr; - - static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, - getCuOccupancyMaxActiveClustersHandle); - - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( - func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); - Py_END_ALLOW_THREADS; - return PyLong_FromLong(maxActiveClusters); -} - -static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { - long size; - if (!PyArg_ParseTuple(args, "l", &size)) { - return NULL; - } - if (size < 0) { - PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); - return NULL; - } - - Py_BEGIN_ALLOW_THREADS; - - // Ensure we have an active context. - CUcontext ctx = NULL; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); - if (!ctx) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); - } - - // We can't set the fifo size after running a kernel that calls printf. This - // is true even if the set() call is a nop and the new size is the same as the - // old size. - // - // This is unfriendly, so check if the old size matches the new size, and skip - // the set() call if so. - size_t oldSize = 0; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); - if (oldSize != size) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); - } - - Py_END_ALLOW_THREADS; - Py_INCREF(Py_None); - return Py_None; -} - -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. -static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dim; - uint32_t tensorDim; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, - &elementSize, &desc_address)) { - return NULL; - } - uint64_t dims[1] = {dim}; - uint64_t globalStrides[1] = {dim * elementSize}; - uint32_t boxDim[1] = {tensorDim}; - uint32_t elementStrides[1] = {1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - return NULL; - } - assert((elementSize * tensorDim) >= 32 && "block size too small."); - int rank = 1; - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; -} - -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. -static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dims[2]; - uint32_t tensorDims[2]; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], - &tensorDims[1], &tensorDims[0], &elementSize, - &desc_address)) { - return NULL; - } - uint64_t globalStrides[2] = {dims[0] * elementSize, - dims[0] * dims[1] * elementSize}; - uint32_t elementStrides[2] = {1, 1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - } - int rank = 2; - // Swizzling should be picked in codegen but since we need to set it on the - // descriptor we rely on a convention between this function and codegen. - CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; - if (contigDimSizeInByte >= 128) { - swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - } else if (contigDimSizeInByte >= 64) { - swizzle = CU_TENSOR_MAP_SWIZZLE_64B; - } else if (contigDimSizeInByte >= 32) { - swizzle = CU_TENSOR_MAP_SWIZZLE_32B; - } else { - assert(false && "block size too small."); - } - // The bounding box inner dimension must be less than or equal to the swizzle - // size. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // We clamp the block size and the codegen will emit multiple copy operations. - if (contigDimSizeInByte > 128) { - tensorDims[0] = 128 / elementSize; - } - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBinary, METH_VARARGS, - "Load provided cubin into CUDA driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, - "Python interface for cuOccupancyMaxActiveClusters function"}, - {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, - "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " - "controls how many bytes can be streamed from kernels before data starts " - "being dropped. This inherits all the limitations of this call; in " - "particular it's an error to change this value after launching any kernel " - "that calls printf()."}, - {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, - {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, - - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cuda_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - - PyModule_AddFunctions(m, ModuleMethods); - - return m; -} diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index d088ec0927daf..2e98e52a377be 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,20 +1,15 @@ +from collections.abc import Callable import functools import os -import sysconfig -import hashlib import subprocess -import tempfile -from pathlib import Path -from triton.runtime.build import _build -from triton.runtime.cache import get_cache_manager from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from ._C import cuda_utils dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] libdevice_dir = os.path.join(dirname, "lib") -libraries = ['cuda'] @functools.lru_cache() @@ -47,26 +42,6 @@ def library_dirs(): return [libdevice_dir, *libcuda_dirs()] -def compile_module_from_src(src, name): - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] - cache_path = cache.get_file(f"{name}.{ext}") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - # ------------------------ # Utils # ------------------------ @@ -80,13 +55,12 @@ def __new__(cls): return cls.instance def __init__(self): - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters - self.set_printf_fifo_size = mod.set_printf_fifo_size - self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor - self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + self.load_binary = cuda_utils.load_binary + self.get_device_properties = cuda_utils.get_device_properties + self.cuOccupancyMaxActiveClusters = cuda_utils.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = cuda_utils.set_printf_fifo_size + self.fill_1d_tma_descriptor = cuda_utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor = cuda_utils.fill_2d_tma_descriptor # ------------------------ @@ -95,7 +69,7 @@ def __init__(self): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "CUdeviceptr" return { "i1": "int32_t", @@ -117,386 +91,82 @@ def ty_to_cpp(ty): }[ty] -def make_launcher(constants, signature): +def flatten_tuples(xs): + """Recursively flattens tuple elements in xs.""" + for x in xs: + if isinstance(x, tuple): + yield from flatten_tuples(x) + else: + yield x + + +def make_launcher(constants : dict[int, str], signature : dict[int, any]) -> Callable[..., None]: + # Here, signature can look like: + # {'_0': 'i32', + # 'Ptrs': (), + # '_1': 'constexpr', + # 'values': '[*f32, constexpr]', + # 'out_tuple': 'constexpr'} + # We want to remove the constexprs, flatten the tuples, and remove any more + # constexprs. If we remove them all at the end, we won't be able to remove + # entire tuples that are a single constexpr. If we remove them before + # flattening, we will miss mixed-tuples. So we do it twice. def _serialize_signature(sig): if isinstance(sig, tuple): return ','.join(map(_serialize_signature, sig)) return sig - - def _extracted_type(ty): - if isinstance(ty, tuple): - val = ','.join(map(_extracted_type, ty)) - return f"[{val}]" - if ty[0] == '*': - return "PyObject*" - if ty in ("constexpr", "nvTmaDesc"): - return "PyObject*" - return ty_to_cpp(ty) - - def format_of(ty): - if isinstance(ty, tuple): - val = ''.join(map(format_of, ty)) - return f"({val})" - if ty[0] == '*': - return "O" - if ty in ("constexpr", "nvTmaDesc"): - return "O" - return { - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "L", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty_to_cpp(ty)] - - args_format = ''.join([format_of(ty) for ty in signature.values()]) - format = "iiiKKpOOOOO" + args_format + + # Remember & remove all the constexpr before flattening. + constant_indices_before_flattening = {i for i, [k, v] in enumerate(signature.items()) if v == 'constexpr'} + # constant_indices_before_flattening = [2, 4] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} + # signature = {'_0': 'i32', 'Ptrs': (), 'values': '[*f32, constexpr]'} + + # Flatten. signature = ','.join(map(_serialize_signature, signature.values())) + # signature = 'i32,,*f32,constexpr' signature = list(filter(bool, signature.split(','))) - signature = {i: s for i, s in enumerate(signature)} - args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") - internal_args_list = [] - for i, ty in signature.items(): - if ty[0] == "*": - internal_args_list.append(f"ptr_info{i}.dev_ptr") - elif ty == "nvTmaDesc": - # Note: we have to dereference the pointer - internal_args_list.append(f"*tma_ptr{i}") - elif ty != "constexpr": - internal_args_list.append(f"_arg{i}") - params = range(len(signature)) - - # generate glue code - newline = '\n ' - ptr_decls = [ - f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" - for i, ty in signature.items() - if ty[0] == "*" - ] - tma_decls = [ - f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items() - if ty == "nvTmaDesc" - ] - params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] - params.append("&global_scratch") - src = f""" -#include \"cuda.h\" -#include -#include -#include - -static inline void gpuAssert(CUresult code, const char *file, int line) -{{ - if (code != CUDA_SUCCESS) - {{ - const char* prefix = "Triton Error [CUDA]: "; - const char* str; - cuGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - }} -}} - -#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - -typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); - -static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ - // Open the shared library - void* handle = dlopen("libcuda.so.1", RTLD_LAZY); - if (!handle) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); - return NULL; - }} - // Clear any existing error - dlerror(); - cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); - return NULL; - }} - return cuLaunchKernelExHandle; -}} - -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(params)} }}; - if (gridX*gridY*gridZ > 0) {{ - if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{ - CUlaunchAttribute launchAttr[1]; - CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; - launchAttr[0] = coopAttr; - - CUlaunchConfig config; - config.gridDimX = gridX; - config.gridDimY = gridY; - config.gridDimZ = gridZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = 1; - - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - - }} else {{ - CUlaunchAttribute launchAttr[3]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; - - unsigned numAttrs = 2; - if (0 != launch_cooperative_grid) {{ - CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; - launchAttr[2] = coopAttr; - numAttrs = 3; - }} - - CUlaunchConfig config; - config.gridDimX = gridX * clusterDimX; - config.gridDimY = gridY * clusterDimY; - config.gridDimZ = gridZ * clusterDimZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = numAttrs; - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - }} - }} -}} - -typedef struct _DevicePtrInfo {{ - CUdeviceptr dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) - return ptr_info; - uint64_t dev_ptr; - int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == CUDA_ERROR_INVALID_VALUE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} else if (status != CUDA_SUCCESS) {{ - CUDA_CHECK(status); // Catch any other cuda API errors - ptr_info.valid = false; - }} - ptr_info.dev_ptr = dev_ptr; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ - if (sizeof(CUtensorMap*) != 8) {{ - PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); - return NULL; - }} - - PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); - if (!method_handle) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); - return NULL; - }} - - PyObject *empty_tuple = PyTuple_New(0); - if (!empty_tuple) {{ - Py_DECREF(method_handle); - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(method_handle); - if (!method_ret) {{ - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - - if (!PyLong_Check(method_ret)) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); - Py_DECREF(method_ret); - return NULL; - }} - - uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); - Py_DECREF(method_ret); - if (!ptr_as_uint) {{ - PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); - return NULL; - }} - if (ptr_as_uint % 64 != 0) {{ - PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); - return NULL; - }} - - return (CUtensorMap*)(ptr_as_uint); -}} - -static void ensureCudaContext() {{ - CUcontext pctx; - CUDA_CHECK(cuCtxGetCurrent(&pctx)); - if (!pctx) {{ - // Ensure device context. - CUdevice device; - CUDA_CHECK(cuDeviceGet(&device, 0)); - CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK(cuCtxSetCurrent(pctx)); - }} -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes - ensureCudaContext(); - - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int launch_cooperative_grid; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *global_scratch_obj = NULL; - {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, - &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook{args_list})) {{ - return NULL; - }} - - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); - return NULL; - }} - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - CUdeviceptr global_scratch = 0; - if (global_scratch_obj != Py_None) {{ - DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1); - if (!global_scratch_info.valid) {{ - return NULL; - }} - global_scratch = global_scratch_info.dev_ptr; - }} - - // raise exception asap - {newline.join(ptr_decls)} - {newline.join(tma_decls)} - Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); - Py_END_ALLOW_THREADS; - if (PyErr_Occurred()) {{ - return NULL; - }} - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - - }} - - Py_RETURN_NONE; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - return src + # signature = ['i32', '*f32', 'constexpr'] + + # Remove any constexprs after flattening. + constant_indices_after_flattening = {i for i, s in enumerate(signature) if s == 'constexpr'} + # constant_indices_after_flattening = [2] + signature = {i: s for i, s in enumerate(signature) if s != 'constexpr'} + # signature = {0: 'i32', 1: '*f32'} + + signature_metadata = cuda_utils.build_signature_metadata( + ty for ty in signature.values()) + + def wrapper(grid_dim_x: int, grid_dim_y: int, grid_dim_z: int, + stream: int, kernel: int, global_scratch: any, + packed_metadata: tuple[int, int, int, int, int, int], + hook_args: any, + launch_enter_hook: Callable[..., None], + launch_exit_hook: Callable[..., None], + *args: any) -> None: + # Given the example above, args would look something like: + # args = [8, (), 5, (3, 4), (2, 2, 2)] + # constant_indices_before_flattening = [2, 4] + # Remove constantexprs before flattening: + non_const_args = [arg + for idx, arg in enumerate(args) + if idx not in constant_indices_before_flattening + ] + # non_const_args = [8, (), (3, 4)] + non_const_args = flatten_tuples(non_const_args) + # non_const_args = [8, 3, 4] + # constant_indices_after_flattening = [2] + non_const_args = [arg + for idx, arg in enumerate(non_const_args) + if idx not in constant_indices_after_flattening + ] + # non_const_args = [8, 3] + cuda_utils.launch(grid_dim_x, grid_dim_y, grid_dim_z, stream, kernel, + packed_metadata, hook_args, launch_enter_hook, + launch_exit_hook, signature_metadata, global_scratch, + non_const_args) + return wrapper class CudaLauncher(object): @@ -506,9 +176,7 @@ def __init__(self, src, metadata): arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x constants = {arg_idx(idx): value for idx, value in constants.items()} signature = {idx: value for idx, value in src.signature.items()} - src = make_launcher(constants, signature) - mod = compile_module_from_src(src, "__triton_launcher") - self.launch = mod.launch + self.launch = make_launcher(constants, signature) self.global_scratch_size = metadata.global_scratch_size self.global_scratch_align = metadata.global_scratch_align self.launch_cooperative_grid = metadata.launch_cooperative_grid @@ -520,7 +188,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args): global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream) else: global_scratch = None - self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args) + self.launch(gridX, gridY, gridZ, stream, function, global_scratch, *args) class CudaDriver(GPUDriver): @@ -551,7 +219,7 @@ def is_active(): import torch return torch.cuda.is_available() and (torch.version.hip is None) except ImportError: - return False + return True def get_benchmarker(self): from triton.testing import do_bench diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 458913dba5953..e31317b05d000 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -113,6 +113,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } +def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> { + let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, + WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); + let results = (outs LLVM_AnyStruct:$res); + let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)"; +} + def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { let arguments = (ins BoolAttr:$bCluster); let assemblyFormat = "attr-dict"; diff --git a/third_party/nvidia/language/cuda/BUILD b/third_party/nvidia/language/cuda/BUILD new file mode 100644 index 0000000000000..55e6ec8795c1a --- /dev/null +++ b/third_party/nvidia/language/cuda/BUILD @@ -0,0 +1,13 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 7a1f518c8b8ff..00c6fb2ea5dae 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -10,6 +10,7 @@ #include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h" #include "llvm/Support/ErrorHandling.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir; using namespace mlir::triton; @@ -438,10 +439,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 14691634c7908..3c2ef60f40775 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -2,6 +2,7 @@ #include "Utility.h" #include "mlir/Support/LLVM.h" #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp deleted file mode 100644 index d9c4909c421cf..0000000000000 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include "TritonNVIDIAGPUToLLVM/Passes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Utility.h" -#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -using namespace mlir; - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_DECOMPOSEUNSUPPORTEDNVIDIACONVERSIONS -#include "TritonNVIDIAGPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -namespace { - -using namespace mlir; -using namespace triton; -using namespace triton::gpu; - -// Loading from Hopper shared memory layout to dot operand is not supported. We -// need to break it down and use a different shared layout. This would mostly -// happen when TMAs are used with MMAV2 and will cause poor performance. -class DecomposeLocalLoadToDotOperand - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(triton::gpu::LocalLoadOp op, - PatternRewriter &rewriter) const override { - - auto dstDotOp = dyn_cast( - op.getType().getEncoding()); - MemDescType srcType = op.getSrc().getType(); - auto sharedEncoding = dyn_cast(srcType.getEncoding()); - if (!dstDotOp || !sharedEncoding || !sharedEncoding.getHasLeadingOffset()) - return failure(); - RankedTensorType type = op.getType(); - auto parentEnc = dstDotOp.getParent(); - int numWarps = triton::gpu::getNumWarpsPerCTA(parentEnc); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( - op->getParentOfType()); - int numCTAs = triton::gpu::getNumCTAs(parentEnc); - auto blockEncoding = getDefaultBlockedEncoding( - op.getContext(), type.getShape(), numWarps, threadsPerWarp, numCTAs); - auto tmpType = RankedTensorType::get(type.getShape(), type.getElementType(), - blockEncoding); - Value load = - rewriter.create(op.getLoc(), tmpType, op.getSrc()); - auto newSharedDescTy = MemDescType::get( - type.getShape(), type.getElementType(), - triton::gpu::SharedEncodingAttr::get( - op.getContext(), dstDotOp, type.getShape(), - triton::gpu::getOrder(parentEnc), - triton::gpu::getCTALayout(parentEnc), type.getElementType()), - srcType.getMemorySpace()); - auto tmp = rewriter.create( - op.getLoc(), newSharedDescTy, load); - auto newConvert = - rewriter.create(op.getLoc(), type, tmp); - rewriter.replaceOp(op, newConvert); - return success(); - } -}; - -struct DecomposeUnsupportedConversions - : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< - DecomposeUnsupportedConversions> { - void runOnOperation() override { - // FIXME [Dot LL] - // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after - // we have enabled the new layout conversion for all the cases. - auto nvidiaShortCutFn = [&](RankedTensorType srcTy, - RankedTensorType dstTy) { return true; }; - ModuleOp mod = getOperation(); - triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - nvidiaShortCutFn); - triton::gpu::decomposeBlockedToDotLayoutConversion(mod); - - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - if (mlir::applyPatternsGreedily(mod, std::move(patterns)).failed()) { - signalPassFailure(); - } - } -}; -} // namespace - -namespace mlir::triton::NVIDIA { - -std::unique_ptr> -createDecomposeUnsupportedConversionsPass() { - return std::make_unique(); -} - -} // namespace mlir::triton::NVIDIA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index 4c58405740d53..5f476bc6601a0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -2,6 +2,7 @@ #include "Utility.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir; using namespace mlir::triton; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c5ec00097d93a..e6b388c77f628 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E5M2()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FN()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E5M2()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E4M3FN()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index c347cad9880c8..705fa5371f023 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -3,6 +3,7 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" using namespace mlir; @@ -60,7 +61,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 }; inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB, Type scaleAType, Type scaleBType) { if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) { - if (scaleAType.isFloat8E4M3FN() && scaleBType.isFloat8E4M3FN()) { + if (llvm::isa(scaleAType) && + llvm::isa(scaleBType)) { return mxfpKind::mxf4nvf4; } return mxfpKind::mxf4; @@ -100,10 +102,11 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, return 1; if (type.isF32()) return 2; - if (type.isFloat8E4M3FN()) + if (llvm::isa(type)) return 0; - if (type.isFloat8E5M2()) + if (llvm::isa(type)) return 1; + llvm_unreachable("Unsupported type."); }; static_assert(sizeof(TCGen5InstructionDescriptor) == 4, @@ -224,7 +227,8 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, opcode += "f16"; else if (srcElementTy.isF32()) opcode += "tf32"; - else if (srcElementTy.isFloat8E4M3FN() || srcElementTy.isFloat8E5M2()) + else if (llvm::isa(srcElementTy) || + llvm::isa(srcElementTy)) opcode += "f8f6f4"; else assert(0 && "Unsupported type."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 7450bc3f4e1b4..e76d9df2d7f73 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -24,6 +24,7 @@ #include "MMAHelpers.h" #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir; using namespace mlir::triton; @@ -59,9 +60,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; - } else if (aTy.isFloat8E5M2()) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FN()) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); @@ -91,7 +92,7 @@ int64_t getSwizzlingFromLayout(const SharedEncodingAttr &layout, return swizzlingByteWidth; } -static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, +Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, int64_t swizzling, uint32_t stride) { static_assert(sizeof(SMEMDescriptor) == 8, "Descriptor size should be 64 bits."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index d489d0a1b1f43..97698f75a30da 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -285,7 +285,8 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) -> SmallVector { int numElements = v.size(); - assert(numElements == 4 || numElements == 2 && "invalid vector size"); + assert(numElements == 8 || numElements == 4 || + numElements == 2 && "invalid vector size"); auto ctx = rewriter.getContext(); int inBitwidth = inType.getIntOrFloatBitWidth(); @@ -407,10 +408,10 @@ struct FpToFpOpConversion ptx = "cvt.rz.f16.f32"; break; default: - llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 " - "conversion: " - << stringifyRoundingMode(rounding) << "\n"; - llvm_unreachable(""); + llvm::report_fatal_error( + "WARNING: unsupported rounding mode for f32->f16 " + "conversion: " + stringifyRoundingMode(rounding) + + "\n"); } auto &cvt = *builder.create(ptx.str()); auto res = builder.newOperand("=h"); @@ -466,12 +467,12 @@ struct FpToFpOpConversion llvm::errs() << "\n"; llvm::report_fatal_error("Unsupported rounding mode for conversion."); } - if (computeCapability < 89 && - (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { - llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " - "compute capability >= 89" - << "\n"; - llvm_unreachable(""); + if (computeCapability < 89 && (llvm::isa(srcTy) || + llvm::isa(dstTy))) { + llvm::report_fatal_error( + "Conversion from/to f8e4m3nv is only supported on " + "compute capability >= 89" + "\n"); } auto convDesc = srcMap.lookup(key); return {makeConverterFromPtx( @@ -489,16 +490,17 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (llvm::isa( + dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); // For now only RTNE is supported for conversions from fp16 to fp8 if (!srcElementType.isF32() && roundingMode.value() != RoundingMode::RTNE) { - llvm::errs() << "Unsupported rounding mode for conversion to fp8: " - << stringifyRoundingMode(roundingMode.value()) << "\n"; - llvm_unreachable(""); + llvm::report_fatal_error( + "Unsupported rounding mode for conversion to fp8: " + + stringifyRoundingMode(roundingMode.value()) + "\n"); } } @@ -526,8 +528,9 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || - dstElementType.isFloat8E5M2())) || + (!(computeCapability >= 90 && + (llvm::isa( + dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; @@ -692,6 +695,114 @@ struct SIToFPOpConversion : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), computeCapability(computeCapability) {} + LogicalResult matchAndRewrite( + arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(matchAndRewriteInt4ToBf16Conversion(op, rewriter))) { + return success(); + } + return Base::matchAndRewrite(op, adaptor, rewriter); + } + + // Matches subgraph of convert 8xi4 to 8xbf16 and rewrites it to inline PTX. + LogicalResult matchAndRewriteInt4ToBf16Conversion( + arith::SIToFPOp op, ConversionPatternRewriter &rewriter) const { + if (computeCapability < 90) return failure(); + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (!inElemTy.isInteger(8) || !outElemTy.isBF16()) return failure(); + FailureOr unpack = matchInt4Unpack(op.getIn()); + if (failed(unpack)) return failure(); + + Location loc = op.getLoc(); + Value src = rewriter.getRemappedValue(unpack.value()); + auto structTy = dyn_cast(src.getType()); + if (!structTy || structTy.getBody().size() % 4 != 0) return failure(); + auto isInt8 = [](Type type) { return type.isInteger(8); }; + if (!all_of(structTy.getBody(), isInt8)) return failure(); + + const LLVMTypeConverter *typeConverter = getTypeConverter(); + assert(inElemTy == typeConverter->convertType(inElemTy)); + assert(outElemTy == typeConverter->convertType(outElemTy)); + + const std::string S4_to_Bf16_sm90 = R"({ + .reg .b32 r<4>, mi, mf; + mov.b32 mi, 0x43404340 - 0x00080008; + mov.b32 mf, 0x43404340; + // Shift 4-bit inputs to 16-bit boundary. + shr.u32 r1, $4, 4; + shr.u32 r2, $4, 8; + shr.u32 r3, $4, 12; + // Sign-extend from 4 bits is equivalent to (x ^ 0x8) - 0x8. + lop3.b32 r0, $4, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r1, r1, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r2, r2, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r3, r3, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + // Interger-add magic number (minus bias from sign-extend above). + add.s16x2 r0, r0, mi; + add.s16x2 r1, r1, mi; + add.s16x2 r2, r2, mi; + add.s16x2 r3, r3, mi; + // Float-subtract magic number. + sub.bf16x2 r0, r0, mf; + sub.bf16x2 r1, r1, mf; + sub.bf16x2 r2, r2, mf; + sub.bf16x2 r3, r3, mf; + // Shuffle results into correct order. + prmt.b32 $0, r1, r0, 0x5410; + prmt.b32 $1, r3, r2, 0x5410; + prmt.b32 $2, r1, r0, 0x7632; + prmt.b32 $3, r3, r2, 0x7632; + })"; + + SmallVector resultVals; + SmallVector unpackedVals = unpackLLElements(loc, src, rewriter); + auto cvtFunc = makeConverterFromPtx(S4_to_Bf16_sm90, inElemTy, outElemTy); + for (ValueRange operands = unpackedVals; !operands.empty(); + operands = operands.drop_front(4)) { + SmallVector inVals = { + operands[0], operands[1], operands[2], operands[3], + // Repeat operands so that cvtFunc produces 8 outputs. + operands[0], operands[1], operands[2], operands[3]}; + auto outVals = cvtFunc(loc, rewriter, inVals); + assert(inVals.size() == outVals.size()); + resultVals.append(outVals.begin(), outVals.end()); + } + + resultVals = maybeDeduplicate(op, resultVals); + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, op.getType()); + rewriter.replaceOp(op, view); + + return success(); + } + + // Returns the source if value is the result of an 2xi4 -> 2xi8 unpack + // sequence. + static FailureOr matchInt4Unpack(Value value) { + auto reshape = value.getDefiningOp(); + if (!reshape) return failure(); + auto join = reshape.getSrc().getDefiningOp(); + if (!join) return failure(); + auto shrHi = join.getLhs().getDefiningOp(); + if (!shrHi || !isConst4(shrHi.getRhs())) return failure(); + auto shrLo = join.getRhs().getDefiningOp(); + if (!shrLo || !isConst4(shrLo.getRhs())) return failure(); + auto shlLo = shrLo.getLhs().getDefiningOp(); + if (!shlLo || !isConst4(shlLo.getRhs())) return failure(); + if (shrHi.getLhs() != shlLo.getLhs()) return failure(); + return shrHi.getLhs(); + } + + // Returns true if the value is equal to 4. + static bool isConst4(Value v) { + auto constOp = v.getDefiningOp(); + if (!constOp) return false; + auto attr = mlir::dyn_cast(constOp.getValue()); + if (!attr || !attr.isSplat()) return false; + return attr.getSplatValue().getLimitedValue() == 4; + }; + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 7c4a9e5b92dff..08c7346bb4a1d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -6,6 +6,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "llvm/Support/MathExtras.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp index c117eb176431c..61ab6ca7901d5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -23,6 +23,7 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir; using namespace mlir::triton; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 044d420b1cb32..fb731dd1fda22 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -2,6 +2,7 @@ #include "Dialect/NVGPU/IR/Dialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace mlir { namespace LLVM { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index 37241f76791b1..d41a28eeb462d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -3,8 +3,6 @@ #include "nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" - #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" diff --git a/third_party/proton/BUILD b/third_party/proton/BUILD new file mode 100644 index 0000000000000..783718497934a --- /dev/null +++ b/third_party/proton/BUILD @@ -0,0 +1,130 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +td_library( + name = "td_files", + srcs = glob(["dialect/include/Dialect/Proton/IR/*.td"]), + includes = ["dialect/include"], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "//:td_files", + ], +) + +gentbl_cc_library( + name = "proton_ir_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "proton_ir_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "dialect/include/Dialect/Proton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "dialect/include/Dialect/Proton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "proton_ir_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "dialect/include/Dialect/Proton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "dialect/include/Dialect/Proton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "dialect/include/Dialect/Proton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "dialect/include/Dialect/Proton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonOps.td", + deps = ["td_files"], +) + +cc_library( + name = "ProtonIRDialect", + srcs = glob([ + "dialect/lib/Dialect/Proton/IR/*.cpp", + ]), + hdrs = glob([ + "dialect/include/Dialect/Proton/IR/*.h", + ]), + includes = [ + "..", # because proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc + "dialect/include", + ], + deps = [ + ":proton_ir_attr_inc_gen", + ":proton_ir_dialect_inc_gen", + ":proton_ir_ops_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonProtonToLLVM", + srcs = glob([ + "dialect/lib/TritonProtonToLLVM/*.cpp", + ]), + hdrs = glob([ + "dialect/include/TritonProtonToLLVM/*.h", + ]), + includes = [ + ], + deps = [ + ":ProtonIRDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bdc..4400934bdf78a 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 0000000000000..4cbadcfa4655b --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +)