Skip to content

Commit

Permalink
backend: (JAX) add some syntax sugar for typed compiled functions (#3058
Browse files Browse the repository at this point in the history
)

PR 2/2, this one makes the compilation a bit nicer to use.

---------

Co-authored-by: Markus Böck <[email protected]>
  • Loading branch information
superlopuh and zero9178 authored Aug 20, 2024
1 parent 8fc70cd commit bc1529e
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 2 deletions.
147 changes: 147 additions & 0 deletions tests/backend/test_jax_executable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import jax
import pytest

Expand Down Expand Up @@ -34,6 +36,50 @@ def test_abs():
array(2, dtype=jax.numpy.int32)
]

@executable
def abs_tuple(a: jax.Array) -> tuple[jax.Array]: ...

assert abs_tuple(array(-2, dtype=jax.numpy.int32)) == (
array(2, dtype=jax.numpy.int32),
)
assert abs_tuple(array(0, dtype=jax.numpy.int32)) == (
array(0, dtype=jax.numpy.int32),
)
assert abs_tuple(array(2, dtype=jax.numpy.int32)) == (
array(2, dtype=jax.numpy.int32),
)

@executable
def abs_one(a: jax.Array) -> jax.Array: ...

assert abs_one(array(-2, dtype=jax.numpy.int32)) == array(2, dtype=jax.numpy.int32)
assert abs_one(array(0, dtype=jax.numpy.int32)) == array(0, dtype=jax.numpy.int32)
assert abs_one(array(2, dtype=jax.numpy.int32)) == array(2, dtype=jax.numpy.int32)


def test_add_sub():
TI32 = TensorType(i32, ())

main_op = func.FuncOp("main", ((TI32, TI32), (TI32, TI32)))
with ImplicitBuilder(main_op.body) as (arg0, arg1):
res0 = stablehlo.AddOp(arg0, arg1).result
res1 = stablehlo.SubtractOp(arg0, arg1).result
func.Return(res0, res1)

module = ModuleOp([main_op])

executable = JaxExecutable.compile(module)

def a(i: int) -> jax.Array:
return array(i, dtype=jax.numpy.int32)

assert executable.execute([a(-2), a(-3)]) == [a(-5), a(1)]

@executable
def add_sub_tuple(a: jax.Array, b: jax.Array) -> tuple[jax.Array, jax.Array]: ...

assert add_sub_tuple(a(-2), a(-3)) == (a(-5), a(1))


def test_no_main():
with pytest.raises(ValueError, match="No `main` function in module"):
Expand Down Expand Up @@ -65,3 +111,104 @@ class SymNameOp(IRDLOperation):

with pytest.raises(ValueError, match="`main` operation is not a `func.func`"):
JaxExecutable.compile(module)


def test_parameter_count_mismatch():
TI32 = TensorType(i32, ())

main_op = func.FuncOp("main", ((TI32,), (TI32,)))
with ImplicitBuilder(main_op.body) as (arg,):
res = stablehlo.AbsOp(arg).result
func.Return(res)

module = ModuleOp([main_op])
executable = JaxExecutable.compile(module)

with pytest.raises(
ValueError,
match="Number of parameters .* does not match the number of operand types",
):

@executable
def abs_two_params(a: jax.Array, b: jax.Array) -> jax.Array: ... # pyright: ignore[reportUnusedFunction]


def test_parameter_annotation():
TI32 = TensorType(i32, ())

main_op = func.FuncOp("main", ((TI32,), (TI32,)))
with ImplicitBuilder(main_op.body) as (arg,):
res = stablehlo.AbsOp(arg).result
func.Return(res)

module = ModuleOp([main_op])
executable = JaxExecutable.compile(module)

with pytest.raises(
NotImplementedError, match="Parameter .* is not annotated as jnp.ndarray"
):

@executable
def abs_wrong_annotation(a: int) -> jax.Array: ... # pyright: ignore[reportUnusedFunction]


def test_return_annotation_tuple_type():
TI32 = TensorType(i32, ())

main_op = func.FuncOp("main", ((TI32,), (TI32,)))
with ImplicitBuilder(main_op.body) as (arg,):
res = stablehlo.AbsOp(arg).result
func.Return(res)

module = ModuleOp([main_op])
executable = JaxExecutable.compile(module)

with pytest.raises(
NotImplementedError,
match=re.escape(
"Return annotation is must be jnp.ndarray or a tuple of jnp.ndarray, got tuple[int]."
),
):

@executable # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
def abs_wrong_tuple_type(a: jax.Array) -> tuple[int]: ... # pyright: ignore[reportUnusedFunction]


def test_return_annotation_single():
TI32 = TensorType(i32, ())

main_op = func.FuncOp("main", ((TI32,), (TI32,)))
with ImplicitBuilder(main_op.body) as (arg,):
res = stablehlo.AbsOp(arg).result
func.Return(res)

module = ModuleOp([main_op])
executable = JaxExecutable.compile(module)

with pytest.raises(
NotImplementedError,
match="Return annotation is must be jnp.ndarray or a tuple of jnp.ndarray",
):

@executable # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
def abs_wrong_single_type(a: jax.Array) -> int: ... # pyright: ignore[reportUnusedFunction]


def test_return_value_count_mismatch():
TI32 = TensorType(i32, ())

main_op = func.FuncOp("main", ((TI32,), (TI32, TI32)))
with ImplicitBuilder(main_op.body) as (arg,):
res = stablehlo.AbsOp(arg).result
func.Return(res, res)

module = ModuleOp([main_op])
executable = JaxExecutable.compile(module)

with pytest.raises(
ValueError,
match="Number of return values .* does not match the stub's return annotation",
):

@executable
def abs_return_count_mismatch(a: jax.Array) -> jax.Array: ... # pyright: ignore[reportUnusedFunction]
59 changes: 57 additions & 2 deletions xdsl/backend/jax_executable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence
from typing import Any
from collections.abc import Callable, Sequence
from inspect import signature
from typing import Any, ParamSpec, TypeVar, cast, get_args, get_origin

import jax.numpy as jnp
import numpy as np
Expand All @@ -14,6 +15,9 @@
from xdsl.dialects.func import FuncOp
from xdsl.traits import SymbolTable

P = ParamSpec("P")
R = TypeVar("R", bound=tuple[jnp.ndarray, ...] | jnp.ndarray)

# JAX DTypeLike is currently broken
# The np.dtype annotation in jax does not specify the generic parameter
DTypeLike = (
Expand Down Expand Up @@ -85,3 +89,54 @@ def compile(module: ModuleOp) -> "JaxExecutable":
client = xla_bridge.backends()["cpu"]
loaded = client.compile(bytecode)
return JaxExecutable(func_op.function_type, loaded)

def __call__(self, stub: Callable[P, R]) -> Callable[P, R]:
func_type = self.main_type
loaded = self.loaded_executable

operand_types = func_type.inputs.data
result_types = func_type.outputs.data

sig = signature(stub)

if len(sig.parameters) != len(operand_types):
raise ValueError(
f"Number of parameters ({len(sig.parameters)}) does not match the number of operand types ({len(operand_types)})"
)

# Check that all parameters are annotated as jnp.ndarray
for param in sig.parameters.values():
if param.annotation != jnp.ndarray:
raise NotImplementedError(
f"Parameter {param.name} is not annotated as jnp.ndarray"
)

# Check return annotation
sig_return = sig.return_annotation
sig_return_origin = get_origin(sig_return)

if sig_return_origin is not tuple:
if sig.return_annotation is not jnp.ndarray:
raise NotImplementedError(
f"Return annotation is must be jnp.ndarray or a tuple of jnp.ndarray, got {sig_return}."
)
if len(result_types) != 1:
raise ValueError(
f"Number of return values ({len(result_types)}) does not match the stub's return annotation"
)

def func(*args: P.args, **kwargs: P.kwargs) -> R:
result = loaded.execute(args)
return result[0]
else:
return_args = get_args(sig_return)
if not all(return_arg is jnp.ndarray for return_arg in return_args):
raise NotImplementedError(
f"Return annotation is must be jnp.ndarray or a tuple of jnp.ndarray, got {sig_return}."
)

def func(*args: P.args, **kwargs: P.kwargs) -> R:
result = loaded.execute(args)
return cast(R, tuple(result))

return func

0 comments on commit bc1529e

Please sign in to comment.