Skip to content

Commit

Permalink
[GPU] Add basic GPU support and example (#381)
Browse files Browse the repository at this point in the history
Co-authored-by: SForeKeeper <[email protected]>
  • Loading branch information
matrix72c and SForeKeeper authored Oct 12, 2024
1 parent 2d8c05a commit 22bb0fa
Show file tree
Hide file tree
Showing 25 changed files with 2,602 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/BuddyGPU/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
log.mlir
log.ll
log.s
matmul-cubin.mlir
40 changes: 40 additions & 0 deletions examples/BuddyGPU/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Buddy GPU Example
This example demonstrates how to use the Buddy GPU to run a simple single-kernel program.

## Matmul
The example program is a simple matrix multiplication kernel. The linalg definition is in the `matmul.mlir` file.
A transform sequence is in `transform.mlir` to optimize this kernel and prepare it for execution on the GPU.
The `matmul-cubin.mlir` provides a lowered file, in case the pipeline is not working.

Run the following command to compile and run the program:
```
make buddy-gpu-matmul
python run-module-gpu.py --source matmul.mlir --target matmul-cubin.mlir --llvm_dir ../../llvm
```

The result should be:
```
[[502.9141 499.7761 511.35623 ... 500.9083 505.25574 511.03818]
[499.57034 494.8066 506.427 ... 492.7868 497.22513 509.95612]
[511.2017 516.017 513.631 ... 515.5991 515.6389 521.8318 ]
...
[496.2721 496.3155 506.08054 ... 502.36798 505.94202 516.3577 ]
[512.06866 505.80127 518.81934 ... 510.64966 510.10333 531.85364]
[501.23514 500.17123 505.71808 ... 496.4447 500.5735 514.4204 ]]
[[503.26013 500.11093 511.70193 ... 501.24622 505.60373 511.38376]
[499.89877 495.13043 506.762 ... 493.1151 497.5555 510.29483]
[511.54883 516.35547 513.9717 ... 515.944 515.9865 522.1828 ]
...
[496.59937 496.63785 506.41483 ... 502.70337 506.27927 516.6994 ]
[512.4154 506.1411 519.17175 ... 510.9929 510.45322 532.2152 ]
[501.57388 500.5093 506.06213 ... 496.7807 500.91638 514.77124]]
MLIR equal to NumPy? True
```

As the tensorcore doesn't support fp32 computation, the operands are converted to tf32, hence the result is not exactly the same as the PyTorch result.

### Profiling
You need to install nsight compute first.
```
ncu -o profile-result --set full python run-module-gpu.py --source matmul.mlir --target matmul-cubin.mlir --llvm_dir ../../llvm
```
14 changes: 14 additions & 0 deletions examples/BuddyGPU/makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
#!/bin/bash
BUDDY_OPT := ../../build/bin/buddy-opt
MLIR_OPT := ../../llvm/build/bin/mlir-opt
MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate
MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner
LLC := ../../llvm/build/bin/llc

buddy-gpu-matmul-lower:
@${BUDDY_OPT} matmul.mlir \
-transform-preload-library="transform-library-paths=transform.mlir" \
-transform-interpreter="entry-point=codegen" \
-o log.mlir

buddy-gpu-matmul:
@${BUDDY_OPT} matmul.mlir -transform-preload-library="transform-library-paths=transform.mlir" -transform-interpreter="entry-point=codegen" | \
${BUDDY_OPT} --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | \
${BUDDY_OPT} -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -linalg-bufferize -convert-linalg-to-affine-loops -affine-loop-fusion -affine-parallelize -lower-affine -canonicalize -func-bufferize -arith-bufferize -tensor-bufferize -buffer-deallocation -finalizing-bufferize -canonicalize | \
${BUDDY_OPT} -gpu-launch-sink-index-computations -canonicalize -legalize-shmem-outlining -canonicalize | \
${BUDDY_OPT} -convert-memcpy-to-gpu -gpu-async-region -canonicalize | \
${BUDDY_OPT} -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm -convert-arith-to-llvm --convert-vector-to-llvm -convert-gpu-to-nvvm='has-redux=1' | \
${BUDDY_OPT} -llvm-request-c-wrappers -canonicalize -cse -sccp | \
${MLIR_OPT} --test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin" -o matmul-cubin.mlir
147 changes: 147 additions & 0 deletions examples/BuddyGPU/run-module-gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# ===- run-module-gpu.py --------------------------------------------------===//
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===----------------------------------------------------------------------===//
#
# This file is a script to test whether the specified MLIR module on the GPU
# calculates the same result as NumPy.
#
# ===----------------------------------------------------------------------===//
import mlir.ir as ir
import mlir.dialects.func as func
import mlir.dialects.memref as memref
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir import runtime as rt
from mlir.ir import *
import numpy as np
import ctypes
import ml_dtypes
import argparse as ap


def to_numpy(element_type: str) -> np.dtype:
match element_type:
case "f16":
return np.float16
case "f32":
return np.float32
case "f64":
return np.float64
case "i8":
return np.int8
case "i16":
return np.int16
case "i32":
return np.int32
case "i64":
return np.int64
case "bf16":
return np.dtype("bfloat16")
case _:
raise ValueError(f"Unsupported type: {element_type}")


def new_ranked_memref_descriptor(nparray: np.ndarray):
if nparray.dtype == "bfloat16":
ctp = rt.F16
else:
ctp = rt.as_ctype(nparray.dtype)

if nparray.ndim == 0:
x = rt.make_zero_d_memref_descriptor(ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
return x

x = rt.make_nd_memref_descriptor(nparray.ndim, ctp)()
nbytes = nparray.nbytes
buffer = ctypes.create_string_buffer(nbytes)
ctypes.memmove(buffer, nparray.ctypes.data, nbytes)
x.allocated = ctypes.cast(buffer, ctypes.c_void_p).value
x.aligned = ctypes.cast(buffer, ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
x.shape = nparray.ctypes.shape

# Numpy uses byte quantities to express strides, MLIR OTOH uses the
# torch abstraction which specifies strides in terms of elements.
strides_ctype_t = ctypes.c_longlong * nparray.ndim
x.strides = strides_ctype_t(
*[x // nparray.itemsize for x in nparray.strides]
)
return x


def get_memref_descriptors(args: list[Type]):
memref_ptrs = []
for arg in args:
elem_type = to_numpy(str(arg.element_type))
np_arg = np.random.rand(*arg.shape).astype(elem_type)
memref_ptrs.append(
ctypes.pointer(ctypes.pointer(new_ranked_memref_descriptor(np_arg)))
)
return memref_ptrs


def test(source, target, llvm_dir):
with Context() as ctx:
file = open(source, "r")
module: Module = Module.parse(file.read())
funcOp: func.FuncOp = (
module.operation.regions[0].blocks[0].operations[0]
)
funcName = str(funcOp.name).replace('"', "")
assert isinstance(funcOp, func.FuncOp)
args_type: list[Type] = [arg.type for arg in funcOp.arguments]
res_type = funcOp.type.results

file = open(target, "r")
# newModule = lower_to_llvm_cpu(module)
newModule = Module.parse(file.read())
memref_ptrs = get_memref_descriptors(res_type + args_type)

engine = ExecutionEngine(
newModule,
shared_libs=[
"/usr/lib/libomp.so",
llvm_dir + "/build/lib/libmlir_c_runner_utils.so",
llvm_dir + "/build/lib/libmlir_async_runtime.so",
llvm_dir + "/build/lib/libmlir_runner_utils.so",
llvm_dir + "/build/lib/libmlir_cuda_runtime.so",
],
opt_level=3,
)
engine.invoke(funcName, *memref_ptrs)
out = rt.ranked_memref_to_numpy(memref_ptrs[0][0])
if str(res_type[0].element_type) == "bf16":
print("Running on BF16 mode, skipping numpy comparison.")
else:
print(out)
input1 = rt.ranked_memref_to_numpy(memref_ptrs[1][0])
input2 = rt.ranked_memref_to_numpy(memref_ptrs[2][0])
numpy_out = np.matmul(input1, input2)
print(numpy_out)
print(
f"MLIR equal to NumPy? {np.allclose(out, numpy_out,rtol=1e-03, atol=1e-03)}"
)


if __name__ == "__main__":
parser = ap.ArgumentParser()
parser.add_argument("--source", type=str, required=True)
parser.add_argument("--target", type=str, required=True)
parser.add_argument("--llvm_dir", type=str, required=True)
args = parser.parse_args()
test(args.source, args.target, args.llvm_dir)
Loading

0 comments on commit 22bb0fa

Please sign in to comment.