Skip to content

Commit

Permalink
interpreter: Initial WGPU interpreter. (#1346)
Browse files Browse the repository at this point in the history
Here's a first draft/example for a WebGPU-based interpreter for the GPU
dialect, including JIT-compilation through WGSL.
I only implement very specific cases of `gpu.func`, `gpu.launch_func`,
`gpu.alloc` and `gpu.memcpy` here.

This requires the addition of Vulkan drivers in the CI

---------

Co-authored-by: Ka Wing, Li <[email protected]>
Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
3 people authored Jul 28, 2023
1 parent 0f9f42e commit 7401a1a
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci-core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ jobs:

steps:
- uses: actions/checkout@v3

- name: Install native dependencies
uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: mesa-vulkan-drivers
version: 1.0

- name: Set up Python
uses: actions/setup-python@v4
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/ci-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ jobs:
- uses: actions/checkout@v3
with:
path: xdsl

- name: Install native dependencies
uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: mesa-vulkan-drivers
version: 1.0

- name: Python Setup
uses: actions/setup-python@v4
Expand Down
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ nbconvert>=7.7.2,<8.0.0
# pyright version has to be fixed with `==`. The CI parses this file
# and installs the according version for typechecking.
pyright==1.1.317
wgpu==0.9.4
85 changes: 85 additions & 0 deletions tests/interpreters/test_wgpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from io import StringIO

import pytest

from xdsl.dialects import arith, builtin, func, gpu, memref, printf
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.memref import MemrefFunctions
from xdsl.interpreters.printf import PrintfFunctions
from xdsl.ir import MLContext
from xdsl.parser import Parser

pytest.importorskip("wgpu", reason="wgpu is an optional dependency")

from xdsl.interpreters.experimental.wgpu import WGPUFunctions # noqa: E402


def test_init():
mlir_source = """
builtin.module attributes {gpu.container_module} {
"gpu.module"() ({
"gpu.func"() ({
^0(%arg : memref<4x4xindex>):
%0 = "arith.constant"() {"value" = 2 : index} : () -> index
%1 = "gpu.global_id"() {"dimension" = #gpu<dim x>} : () -> index
%2 = "gpu.global_id"() {"dimension" = #gpu<dim y>} : () -> index
%3 = "arith.constant"() {"value" = 4 : index} : () -> index
%4 = "arith.muli"(%1, %3) : (index, index) -> index
%5 = "arith.addi"(%4, %2) : (index, index) -> index
"memref.store"(%5, %arg, %1, %2) {"nontemporal" = false} : (index, memref<4x4xindex>, index, index) -> ()
"gpu.return"() : () -> ()
}) {"function_type" = (memref<4x4xindex>) -> (),
"gpu.kernel",
"sym_name" = "fill"
} : () -> ()
"gpu.func"() ({
^0(%arg : memref<4x4xindex>):
%0 = "arith.constant"() {"value" = 1 : index} : () -> index
%1 = "gpu.global_id"() {"dimension" = #gpu<dim x>} : () -> index
%2 = "gpu.global_id"() {"dimension" = #gpu<dim y>} : () -> index
%3 = "memref.load"(%arg, %1, %2) {"nontemporal" = false} : (memref<4x4xindex>, index, index) -> (index)
%4 = "arith.addi"(%3, %0) : (index, index) -> index
"memref.store"(%4, %arg, %1, %2) {"nontemporal" = false} : (index, memref<4x4xindex>, index, index) -> ()
"gpu.return"() : () -> ()
}) {"function_type" = (memref<4x4xindex>) -> (),
"gpu.kernel",
"sym_name" = "inc"
} : () -> ()
"gpu.module_end"() : () -> ()
}) {"sym_name" = "gpu"} : () -> ()
func.func @main() -> index {
%four = "arith.constant"() {"value" = 4 : index} : () -> index
%one = "arith.constant"() {"value" = 1 : index} : () -> index
%memref = "gpu.alloc"() {"alignment" = 0 : i64, "operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<4x4xindex>
"gpu.launch_func"(%four, %four, %one, %one, %one, %one, %memref) {"operand_segment_sizes" = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 1>, "kernel" = @gpu::@fill} : (index, index, index, index, index, index, memref<4x4xindex>) -> ()
"gpu.launch_func"(%four, %four, %one, %one, %one, %one, %memref) {"operand_segment_sizes" = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 1>, "kernel" = @gpu::@inc} : (index, index, index, index, index, index, memref<4x4xindex>) -> ()
%hmemref = "memref.alloc"() {"alignment" = 0 : i64, "operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<4x4xindex>
"gpu.memcpy"(%hmemref, %memref) {"operand_segment_sizes" = array<i32: 0, 1, 1>} : (memref<4x4xindex>, memref<4x4xindex>) -> ()
printf.print_format "Result : {}", %hmemref : memref<4x4xindex>
}
}
"""
context = MLContext()
context.register_dialect(arith.Arith)
context.register_dialect(memref.MemRef)
context.register_dialect(builtin.Builtin)
context.register_dialect(gpu.GPU)
context.register_dialect(func.Func)
context.register_dialect(printf.Printf)
parser = Parser(context, mlir_source)
module = parser.parse_module()

f = StringIO("")
interpreter = Interpreter(module, file=f)
interpreter.register_implementations(ArithFunctions())
interpreter.register_implementations(MemrefFunctions())
interpreter.register_implementations(WGPUFunctions())
interpreter.register_implementations(PrintfFunctions())
interpreter.call_op("main", ())
assert (
"Result : [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]"
in f.getvalue()
)
8 changes: 4 additions & 4 deletions xdsl/dialects/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ class MemcpyOp(IRDLOperation):
name = "gpu.memcpy"

asyncDependencies: VarOperand = var_operand_def(AsyncTokenType)
src: Operand = operand_def(memref.MemRefType)
dst: Operand = operand_def(memref.MemRefType)
src: Operand = operand_def(memref.MemRefType)

irdl_options = [AttrSizedOperandSegments()]

Expand All @@ -337,14 +337,14 @@ def __init__(
is_async: bool = False,
):
return super().__init__(
operands=[async_dependencies, source, destination],
operands=[async_dependencies, destination, source],
result_types=[[AsyncTokenType()] if is_async else []],
)

def verify_(self) -> None:
if self.src.type != self.dst.type:
raise VerifyException(
f"Expected {self.src.type}, got {self.dst.type}. gpu.memcpy source and "
f"Expected {self.dst.type}, got {self.src.type}. gpu.memcpy source and "
"destination types must match."
)

Expand Down Expand Up @@ -397,7 +397,7 @@ class FuncOp(IRDLOperation):
DenseArrayBase, attr_name="gpu.known_grid_size"
)

traits = frozenset([IsolatedFromAbove(), HasParent(ModuleOp)])
traits = frozenset([IsolatedFromAbove(), HasParent(ModuleOp), SymbolOpInterface()])

def __init__(
self,
Expand Down
213 changes: 213 additions & 0 deletions xdsl/interpreters/experimental/wgpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from io import StringIO
from typing import Any, Sequence, cast

import wgpu # pyright: ignore
import wgpu.utils # pyright: ignore

from xdsl.dialects import gpu
from xdsl.dialects.builtin import IndexType
from xdsl.dialects.memref import MemRefType
from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls
from xdsl.interpreters.experimental.wgsl_printer import WGSLPrinter
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.ir.core import Attribute, SSAValue
from xdsl.traits import SymbolTable
from xdsl.utils.hints import isa


@register_impls
class WGPUFunctions(InterpreterFunctions):
device = cast(wgpu.GPUDevice, wgpu.utils.get_default_device())
shader_modules: dict[gpu.FuncOp, wgpu.GPUShaderModule] = {}

def buffer_from_operand(self, interpreter: Interpreter, operand: SSAValue):
"""
Prepare a GPUBuffer from an SSA operand.
memrefs are excpected to be GPUBuffers at this point.
Still a helper function, because boilerplating will need to happen to forward
e.g. scalar parameters!
"""
if isa(operand.type, MemRefType[Attribute]):
value = interpreter.get_values((operand,))[0]
if not isinstance(value, wgpu.GPUBuffer):
raise ValueError(
f"gpu.launch_func memref operand expected to be GPU-allocated"
)
return value
raise NotImplementedError(f"{operand.type} not yet mapped to WGPU.")

def prepare_bindings(
self, interpreter: Interpreter, kernel_operands: Sequence[SSAValue]
):
"""
Boilerplate preparation for arguments bindings.
"""
layouts: list[dict[str, Any]] = []
bindings: list[dict[str, Any]] = []
for i, o in enumerate(kernel_operands):
buffer = WGPUFunctions.buffer_from_operand(self, interpreter, o)

layouts.append(
{
"binding": i,
"visibility": wgpu.ShaderStage.COMPUTE, # pyright: ignore
"buffer": {
"type": wgpu.BufferBindingType.storage # pyright: ignore
},
}
)
bindings.append(
{
"binding": i,
"resource": {"buffer": buffer, "offset": 0, "size": buffer.size},
}
)

return layouts, bindings

def compile_func(self, op: gpu.FuncOp):
"""
Compile a gpu.func if not already done.
"""
if op not in self.shader_modules:
wgsl_printer = WGSLPrinter()
wgsl_source = StringIO("")
wgsl_printer.print(op, wgsl_source)
self.shader_modules[op] = cast(
wgpu.GPUShaderModule,
self.device.create_shader_module( # pyright: ignore
code=wgsl_source.getvalue()
), # pyright: ignore
)

@impl(gpu.AllocOp)
def run_alloc(
self, interpreter: Interpreter, op: gpu.AllocOp, args: tuple[Any, ...]
):
"""
Allocate a GPUBuffer according to a gpu.alloc operation, return it as the memref
value.
"""
if args or op.asyncToken:
raise NotImplementedError(
"Only synchronous, known-sized gpu.alloc implemented yet."
)
memref_type = cast(MemRefType[Attribute], op.result.type)
match (memref_type.element_type):
case IndexType():
element_size = 4
case _:
raise NotImplementedError(
f"The element type {memref_type.element_type} for gpu.alloc is not implemented yet."
)
buffer = cast(
wgpu.GPUBuffer,
self.device.create_buffer( # pyright: ignore
size=memref_type.element_count() * element_size,
usage=wgpu.BufferUsage.STORAGE # pyright: ignore
| wgpu.BufferUsage.COPY_SRC, # pyright: ignore
),
)
return (buffer,)

@impl(gpu.MemcpyOp)
def run_memcpy(
self, interpreter: Interpreter, op: gpu.MemcpyOp, args: tuple[Any, ...]
) -> tuple[()]:
"""
Copy buffers according to the gpu.memcpy operation.
Only Device to Host copy is implemented here, to keep the first draft bearable.
"""
src, dst = interpreter.get_values((op.src, op.dst))
if not (isinstance(src, wgpu.GPUBuffer) and isinstance(dst, ShapedArray)):
raise NotImplementedError(
f"Only device to host copy is implemented for now. got {src} to {dst}"
)

# Get device/source view
memview = cast(
memoryview, self.device.queue.read_buffer(src) # pyright: ignore
)
dst_type = cast(MemRefType[Attribute], op.dst.type)
match (dst_type.element_type):
case IndexType():
format = "I"
case _:
raise NotImplementedError(
f"copy for element type {dst_type.element_type} not yet implemented."
)
memview = memview.cast(format, [i.value.data for i in dst_type.shape])
for index in dst.indices():
dst.store(index, memview.__getitem__(index)) # pyright: ignore
return ()

@impl(gpu.LaunchFuncOp)
def run_launch_func(
self, interpreter: Interpreter, op: gpu.LaunchFuncOp, args: tuple[Any, ...]
):
"""
Launch a GPU kernel through the WebGPU API.
"""
if op.asyncToken is not None or op.asyncDependencies:
raise NotImplementedError(
"The WGPU interpreter does not handle asynchronous GPU regions at the moment."
)

gridSize = interpreter.get_values((op.gridSizeX, op.gridSizeY, op.gridSizeZ))
blockSize = interpreter.get_values(
(op.blockSizeX, op.blockSizeY, op.blockSizeZ)
)
kernel_operands = op.kernelOperands

func = SymbolTable.lookup_symbol(op, op.kernel)
assert isinstance(func, gpu.FuncOp)
WGPUFunctions.compile_func(self, func)
shader_module = self.shader_modules[func]

# Compute the dispatch number
# If the func has a known block size, it's reflected in the compiled module
# Otherwise, it defaults to (1,1,1) currently and we have to take this
# into account
if func.known_block_size:
dispatch = gridSize
else:
dispatch = [a * b for a, b in zip(gridSize, blockSize)]

layouts, bindings = WGPUFunctions.prepare_bindings(
self, interpreter, kernel_operands
)

# All the boilerplate
device = self.device
# Put bindings together
bind_group_layout = device.create_bind_group_layout( # pyright: ignore
entries=layouts
)
pipeline_layout = device.create_pipeline_layout( # pyright: ignore
bind_group_layouts=[bind_group_layout]
)
bind_group = device.create_bind_group( # pyright: ignore
layout=bind_group_layout, entries=bindings # pyright: ignore
)

# Create and run the pipeline
compute_pipeline = device.create_compute_pipeline( # pyright: ignore
layout=pipeline_layout, # pyright: ignore
compute={"module": shader_module, "entry_point": func.sym_name.data},
)

command_encoder = device.create_command_encoder() # pyright: ignore
compute_pass = command_encoder.begin_compute_pass() # pyright: ignore
compute_pass.set_pipeline(compute_pipeline) # pyright: ignore
compute_pass.set_bind_group( # pyright: ignore
0, bind_group, [], 0, 0
) # last 2 elements not used
compute_pass.dispatch_workgroups(*dispatch) # x y z # pyright: ignore
compute_pass.end() # pyright: ignore
device.queue.submit([command_encoder.finish()]) # pyright: ignore

# gpu.launch_func has no return
return ()

0 comments on commit 7401a1a

Please sign in to comment.