-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
interpreter: Initial WGPU interpreter. (#1346)
Here's a first draft/example for a WebGPU-based interpreter for the GPU dialect, including JIT-compilation through WGSL. I only implement very specific cases of `gpu.func`, `gpu.launch_func`, `gpu.alloc` and `gpu.memcpy` here. This requires the addition of Vulkan drivers in the CI --------- Co-authored-by: Ka Wing, Li <[email protected]> Co-authored-by: Sasha Lopoukhine <[email protected]>
- Loading branch information
1 parent
0f9f42e
commit 7401a1a
Showing
6 changed files
with
315 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from io import StringIO | ||
|
||
import pytest | ||
|
||
from xdsl.dialects import arith, builtin, func, gpu, memref, printf | ||
from xdsl.interpreter import Interpreter | ||
from xdsl.interpreters.arith import ArithFunctions | ||
from xdsl.interpreters.memref import MemrefFunctions | ||
from xdsl.interpreters.printf import PrintfFunctions | ||
from xdsl.ir import MLContext | ||
from xdsl.parser import Parser | ||
|
||
pytest.importorskip("wgpu", reason="wgpu is an optional dependency") | ||
|
||
from xdsl.interpreters.experimental.wgpu import WGPUFunctions # noqa: E402 | ||
|
||
|
||
def test_init(): | ||
mlir_source = """ | ||
builtin.module attributes {gpu.container_module} { | ||
"gpu.module"() ({ | ||
"gpu.func"() ({ | ||
^0(%arg : memref<4x4xindex>): | ||
%0 = "arith.constant"() {"value" = 2 : index} : () -> index | ||
%1 = "gpu.global_id"() {"dimension" = #gpu<dim x>} : () -> index | ||
%2 = "gpu.global_id"() {"dimension" = #gpu<dim y>} : () -> index | ||
%3 = "arith.constant"() {"value" = 4 : index} : () -> index | ||
%4 = "arith.muli"(%1, %3) : (index, index) -> index | ||
%5 = "arith.addi"(%4, %2) : (index, index) -> index | ||
"memref.store"(%5, %arg, %1, %2) {"nontemporal" = false} : (index, memref<4x4xindex>, index, index) -> () | ||
"gpu.return"() : () -> () | ||
}) {"function_type" = (memref<4x4xindex>) -> (), | ||
"gpu.kernel", | ||
"sym_name" = "fill" | ||
} : () -> () | ||
"gpu.func"() ({ | ||
^0(%arg : memref<4x4xindex>): | ||
%0 = "arith.constant"() {"value" = 1 : index} : () -> index | ||
%1 = "gpu.global_id"() {"dimension" = #gpu<dim x>} : () -> index | ||
%2 = "gpu.global_id"() {"dimension" = #gpu<dim y>} : () -> index | ||
%3 = "memref.load"(%arg, %1, %2) {"nontemporal" = false} : (memref<4x4xindex>, index, index) -> (index) | ||
%4 = "arith.addi"(%3, %0) : (index, index) -> index | ||
"memref.store"(%4, %arg, %1, %2) {"nontemporal" = false} : (index, memref<4x4xindex>, index, index) -> () | ||
"gpu.return"() : () -> () | ||
}) {"function_type" = (memref<4x4xindex>) -> (), | ||
"gpu.kernel", | ||
"sym_name" = "inc" | ||
} : () -> () | ||
"gpu.module_end"() : () -> () | ||
}) {"sym_name" = "gpu"} : () -> () | ||
func.func @main() -> index { | ||
%four = "arith.constant"() {"value" = 4 : index} : () -> index | ||
%one = "arith.constant"() {"value" = 1 : index} : () -> index | ||
%memref = "gpu.alloc"() {"alignment" = 0 : i64, "operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<4x4xindex> | ||
"gpu.launch_func"(%four, %four, %one, %one, %one, %one, %memref) {"operand_segment_sizes" = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 1>, "kernel" = @gpu::@fill} : (index, index, index, index, index, index, memref<4x4xindex>) -> () | ||
"gpu.launch_func"(%four, %four, %one, %one, %one, %one, %memref) {"operand_segment_sizes" = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 1>, "kernel" = @gpu::@inc} : (index, index, index, index, index, index, memref<4x4xindex>) -> () | ||
%hmemref = "memref.alloc"() {"alignment" = 0 : i64, "operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<4x4xindex> | ||
"gpu.memcpy"(%hmemref, %memref) {"operand_segment_sizes" = array<i32: 0, 1, 1>} : (memref<4x4xindex>, memref<4x4xindex>) -> () | ||
printf.print_format "Result : {}", %hmemref : memref<4x4xindex> | ||
} | ||
} | ||
""" | ||
context = MLContext() | ||
context.register_dialect(arith.Arith) | ||
context.register_dialect(memref.MemRef) | ||
context.register_dialect(builtin.Builtin) | ||
context.register_dialect(gpu.GPU) | ||
context.register_dialect(func.Func) | ||
context.register_dialect(printf.Printf) | ||
parser = Parser(context, mlir_source) | ||
module = parser.parse_module() | ||
|
||
f = StringIO("") | ||
interpreter = Interpreter(module, file=f) | ||
interpreter.register_implementations(ArithFunctions()) | ||
interpreter.register_implementations(MemrefFunctions()) | ||
interpreter.register_implementations(WGPUFunctions()) | ||
interpreter.register_implementations(PrintfFunctions()) | ||
interpreter.call_op("main", ()) | ||
assert ( | ||
"Result : [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]" | ||
in f.getvalue() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
from io import StringIO | ||
from typing import Any, Sequence, cast | ||
|
||
import wgpu # pyright: ignore | ||
import wgpu.utils # pyright: ignore | ||
|
||
from xdsl.dialects import gpu | ||
from xdsl.dialects.builtin import IndexType | ||
from xdsl.dialects.memref import MemRefType | ||
from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls | ||
from xdsl.interpreters.experimental.wgsl_printer import WGSLPrinter | ||
from xdsl.interpreters.shaped_array import ShapedArray | ||
from xdsl.ir.core import Attribute, SSAValue | ||
from xdsl.traits import SymbolTable | ||
from xdsl.utils.hints import isa | ||
|
||
|
||
@register_impls | ||
class WGPUFunctions(InterpreterFunctions): | ||
device = cast(wgpu.GPUDevice, wgpu.utils.get_default_device()) | ||
shader_modules: dict[gpu.FuncOp, wgpu.GPUShaderModule] = {} | ||
|
||
def buffer_from_operand(self, interpreter: Interpreter, operand: SSAValue): | ||
""" | ||
Prepare a GPUBuffer from an SSA operand. | ||
memrefs are excpected to be GPUBuffers at this point. | ||
Still a helper function, because boilerplating will need to happen to forward | ||
e.g. scalar parameters! | ||
""" | ||
if isa(operand.type, MemRefType[Attribute]): | ||
value = interpreter.get_values((operand,))[0] | ||
if not isinstance(value, wgpu.GPUBuffer): | ||
raise ValueError( | ||
f"gpu.launch_func memref operand expected to be GPU-allocated" | ||
) | ||
return value | ||
raise NotImplementedError(f"{operand.type} not yet mapped to WGPU.") | ||
|
||
def prepare_bindings( | ||
self, interpreter: Interpreter, kernel_operands: Sequence[SSAValue] | ||
): | ||
""" | ||
Boilerplate preparation for arguments bindings. | ||
""" | ||
layouts: list[dict[str, Any]] = [] | ||
bindings: list[dict[str, Any]] = [] | ||
for i, o in enumerate(kernel_operands): | ||
buffer = WGPUFunctions.buffer_from_operand(self, interpreter, o) | ||
|
||
layouts.append( | ||
{ | ||
"binding": i, | ||
"visibility": wgpu.ShaderStage.COMPUTE, # pyright: ignore | ||
"buffer": { | ||
"type": wgpu.BufferBindingType.storage # pyright: ignore | ||
}, | ||
} | ||
) | ||
bindings.append( | ||
{ | ||
"binding": i, | ||
"resource": {"buffer": buffer, "offset": 0, "size": buffer.size}, | ||
} | ||
) | ||
|
||
return layouts, bindings | ||
|
||
def compile_func(self, op: gpu.FuncOp): | ||
""" | ||
Compile a gpu.func if not already done. | ||
""" | ||
if op not in self.shader_modules: | ||
wgsl_printer = WGSLPrinter() | ||
wgsl_source = StringIO("") | ||
wgsl_printer.print(op, wgsl_source) | ||
self.shader_modules[op] = cast( | ||
wgpu.GPUShaderModule, | ||
self.device.create_shader_module( # pyright: ignore | ||
code=wgsl_source.getvalue() | ||
), # pyright: ignore | ||
) | ||
|
||
@impl(gpu.AllocOp) | ||
def run_alloc( | ||
self, interpreter: Interpreter, op: gpu.AllocOp, args: tuple[Any, ...] | ||
): | ||
""" | ||
Allocate a GPUBuffer according to a gpu.alloc operation, return it as the memref | ||
value. | ||
""" | ||
if args or op.asyncToken: | ||
raise NotImplementedError( | ||
"Only synchronous, known-sized gpu.alloc implemented yet." | ||
) | ||
memref_type = cast(MemRefType[Attribute], op.result.type) | ||
match (memref_type.element_type): | ||
case IndexType(): | ||
element_size = 4 | ||
case _: | ||
raise NotImplementedError( | ||
f"The element type {memref_type.element_type} for gpu.alloc is not implemented yet." | ||
) | ||
buffer = cast( | ||
wgpu.GPUBuffer, | ||
self.device.create_buffer( # pyright: ignore | ||
size=memref_type.element_count() * element_size, | ||
usage=wgpu.BufferUsage.STORAGE # pyright: ignore | ||
| wgpu.BufferUsage.COPY_SRC, # pyright: ignore | ||
), | ||
) | ||
return (buffer,) | ||
|
||
@impl(gpu.MemcpyOp) | ||
def run_memcpy( | ||
self, interpreter: Interpreter, op: gpu.MemcpyOp, args: tuple[Any, ...] | ||
) -> tuple[()]: | ||
""" | ||
Copy buffers according to the gpu.memcpy operation. | ||
Only Device to Host copy is implemented here, to keep the first draft bearable. | ||
""" | ||
src, dst = interpreter.get_values((op.src, op.dst)) | ||
if not (isinstance(src, wgpu.GPUBuffer) and isinstance(dst, ShapedArray)): | ||
raise NotImplementedError( | ||
f"Only device to host copy is implemented for now. got {src} to {dst}" | ||
) | ||
|
||
# Get device/source view | ||
memview = cast( | ||
memoryview, self.device.queue.read_buffer(src) # pyright: ignore | ||
) | ||
dst_type = cast(MemRefType[Attribute], op.dst.type) | ||
match (dst_type.element_type): | ||
case IndexType(): | ||
format = "I" | ||
case _: | ||
raise NotImplementedError( | ||
f"copy for element type {dst_type.element_type} not yet implemented." | ||
) | ||
memview = memview.cast(format, [i.value.data for i in dst_type.shape]) | ||
for index in dst.indices(): | ||
dst.store(index, memview.__getitem__(index)) # pyright: ignore | ||
return () | ||
|
||
@impl(gpu.LaunchFuncOp) | ||
def run_launch_func( | ||
self, interpreter: Interpreter, op: gpu.LaunchFuncOp, args: tuple[Any, ...] | ||
): | ||
""" | ||
Launch a GPU kernel through the WebGPU API. | ||
""" | ||
if op.asyncToken is not None or op.asyncDependencies: | ||
raise NotImplementedError( | ||
"The WGPU interpreter does not handle asynchronous GPU regions at the moment." | ||
) | ||
|
||
gridSize = interpreter.get_values((op.gridSizeX, op.gridSizeY, op.gridSizeZ)) | ||
blockSize = interpreter.get_values( | ||
(op.blockSizeX, op.blockSizeY, op.blockSizeZ) | ||
) | ||
kernel_operands = op.kernelOperands | ||
|
||
func = SymbolTable.lookup_symbol(op, op.kernel) | ||
assert isinstance(func, gpu.FuncOp) | ||
WGPUFunctions.compile_func(self, func) | ||
shader_module = self.shader_modules[func] | ||
|
||
# Compute the dispatch number | ||
# If the func has a known block size, it's reflected in the compiled module | ||
# Otherwise, it defaults to (1,1,1) currently and we have to take this | ||
# into account | ||
if func.known_block_size: | ||
dispatch = gridSize | ||
else: | ||
dispatch = [a * b for a, b in zip(gridSize, blockSize)] | ||
|
||
layouts, bindings = WGPUFunctions.prepare_bindings( | ||
self, interpreter, kernel_operands | ||
) | ||
|
||
# All the boilerplate | ||
device = self.device | ||
# Put bindings together | ||
bind_group_layout = device.create_bind_group_layout( # pyright: ignore | ||
entries=layouts | ||
) | ||
pipeline_layout = device.create_pipeline_layout( # pyright: ignore | ||
bind_group_layouts=[bind_group_layout] | ||
) | ||
bind_group = device.create_bind_group( # pyright: ignore | ||
layout=bind_group_layout, entries=bindings # pyright: ignore | ||
) | ||
|
||
# Create and run the pipeline | ||
compute_pipeline = device.create_compute_pipeline( # pyright: ignore | ||
layout=pipeline_layout, # pyright: ignore | ||
compute={"module": shader_module, "entry_point": func.sym_name.data}, | ||
) | ||
|
||
command_encoder = device.create_command_encoder() # pyright: ignore | ||
compute_pass = command_encoder.begin_compute_pass() # pyright: ignore | ||
compute_pass.set_pipeline(compute_pipeline) # pyright: ignore | ||
compute_pass.set_bind_group( # pyright: ignore | ||
0, bind_group, [], 0, 0 | ||
) # last 2 elements not used | ||
compute_pass.dispatch_workgroups(*dispatch) # x y z # pyright: ignore | ||
compute_pass.end() # pyright: ignore | ||
device.queue.submit([command_encoder.finish()]) # pyright: ignore | ||
|
||
# gpu.launch_func has no return | ||
return () |