Skip to content

Commit

Permalink
fix: pass immutable arg_id_to_dtype to InKernelCallables
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Jan 31, 2025
1 parent 771c0bc commit 2f4a982
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 82 deletions.
8 changes: 5 additions & 3 deletions examples/python/call-external.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from constantdict import constantdict

import loopy as lp
from loopy.diagnostic import LoopyError
Expand Down Expand Up @@ -30,9 +31,10 @@ def with_types(self, arg_id_to_dtype, callables_table):
"types")

return (self.copy(name_in_target=name_in_target,
arg_id_to_dtype={0: vec_dtype,
1: vec_dtype,
-1: vec_dtype}),
arg_id_to_dtype=constantdict({
0: vec_dtype,
1: vec_dtype,
-1: vec_dtype})),
callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
Expand Down
23 changes: 13 additions & 10 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,10 @@ def with_types(self, arg_id_to_dtype, callables_table):
"the function %s." % (self.name))

def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate()
new_arg_id_to_descr[-1] = ValueArgDescriptor()

arg_id_to_descr[-1] = ValueArgDescriptor()
return (self.copy(arg_id_to_descr=arg_id_to_descr),
return (self.copy(arg_id_to_descr=new_arg_id_to_descr.finish()),
clbl_inf_ctx)

def get_hw_axes_sizes(self, arg_id_to_arg, space, callables_table):
Expand Down Expand Up @@ -782,14 +783,15 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
# arg_id_to_descr expressions provided are from the caller's namespace,
# need to register

new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate()
kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel)

kw_to_callee_idx = {arg.name: i
for i, arg in enumerate(self.subkernel.args)}

new_args = self.subkernel.args[:]

for arg_id, descr in arg_id_to_descr.items():
for arg_id, descr in new_arg_id_to_descr.items():
if isinstance(arg_id, int):
arg_id = pos_to_kw[arg_id]

Expand Down Expand Up @@ -837,20 +839,20 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
for arg in subkernel.args:
kw = arg.name
if isinstance(arg, ArrayBase):
arg_id_to_descr[kw] = (
new_arg_id_to_descr[kw] = (
ArrayArgDescriptor(shape=arg.shape,
dim_tags=arg.dim_tags,
address_space=arg.address_space))
else:
assert isinstance(arg, ValueArg)
arg_id_to_descr[kw] = ValueArgDescriptor()
new_arg_id_to_descr[kw] = ValueArgDescriptor()

arg_id_to_descr[kw_to_pos[kw]] = arg_id_to_descr[kw]
new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw]

# }}}

return (self.copy(subkernel=subkernel,
arg_id_to_descr=constantdict(arg_id_to_descr)),
arg_id_to_descr=new_arg_id_to_descr.finish()),
clbl_inf_ctx)

def with_added_arg(self, arg_dtype, arg_descr):
Expand All @@ -868,6 +870,7 @@ def with_added_arg(self, arg_dtype, arg_descr):
arg_id_to_dtype = {}
else:
arg_id_to_dtype = dict(self.arg_id_to_dtype)

if self.arg_id_to_descr is None:
arg_id_to_descr = {}
else:
Expand All @@ -879,8 +882,8 @@ def with_added_arg(self, arg_dtype, arg_descr):
arg_id_to_descr[kw_to_pos[var_name]] = arg_descr

return (self.copy(subkernel=subknl,
arg_id_to_dtype=arg_id_to_dtype,
arg_id_to_descr=arg_id_to_descr),
arg_id_to_dtype=constantdict(arg_id_to_dtype),
arg_id_to_descr=constantdict(arg_id_to_descr)),
var_name)

else:
Expand All @@ -902,7 +905,7 @@ def with_packing_for_args(self):
address_space=AddressSpace.GLOBAL)

return self.copy(subkernel=self.subkernel,
arg_id_to_descr=arg_id_to_descr)
arg_id_to_descr=constantdict(arg_id_to_descr))

def get_used_hw_axes(self, callables_table):
gsize, lsize = self.subkernel.get_grid_size_upper_bounds(callables_table,
Expand Down
14 changes: 8 additions & 6 deletions loopy/library/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand All @@ -38,21 +39,22 @@

class MakeTupleCallable(ScalarCallable):
def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = arg_id_to_dtype.copy()
new_arg_id_to_dtype = constantdict(arg_id_to_dtype).mutate()
for i in range(len(arg_id_to_dtype)):
if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None:
new_arg_id_to_dtype[-i-1] = new_arg_id_to_dtype[i]

return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
name_in_target="loopy_make_tuple"), callables_table)
return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype.finish(),
name_in_target="loopy_make_tuple"),
callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
from loopy.kernel.function_interface import ValueArgDescriptor
new_arg_id_to_descr = {(id, ValueArgDescriptor()):
(-id-1, ValueArgDescriptor()) for id in arg_id_to_descr.keys()}
(-id-1, ValueArgDescriptor()) for id in arg_id_to_descr}

return (
self.copy(arg_id_to_descr=new_arg_id_to_descr),
self.copy(arg_id_to_descr=constantdict(new_arg_id_to_descr)),
callables_table)


Expand All @@ -63,7 +65,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
if dtype is not None}
new_arg_id_to_dtype[-1] = NumpyType(np.int32)

return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype),
return (self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype)),
callables_table)

def emit_call(self, expression_to_code_mapper, expression, target):
Expand Down
19 changes: 12 additions & 7 deletions loopy/library/random123.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict
from mako.template import Template

from pymbolic.typing import not_none
Expand Down Expand Up @@ -221,27 +222,31 @@ def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = {-1: ctr_dtype, -2: ctr_dtype, 0: ctr_dtype, 1:
key_dtype}
return (
self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
name_in_target=fn+"_gen"),
self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=fn+"_gen"),
callables_table)

elif name == fn + "_f32":
new_arg_id_to_dtype = {-1: target.vector_dtype(NumpyType(np.float32),
rng_variant.width),
-2: ctr_dtype, 0: ctr_dtype, 1:
key_dtype}
return self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
name_in_target=name), callables_table
return (
self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=name),
callables_table)

elif name == fn + "_f64":
new_arg_id_to_dtype = {-1: target.vector_dtype(NumpyType(np.float64),
rng_variant.width),
-2: ctr_dtype, 0: ctr_dtype, 1:
key_dtype}
return self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
name_in_target=name), callables_table
return (
self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype),
name_in_target=name),
callables_table)

return (self.copy(arg_id_to_dtype=arg_id_to_dtype),
return (self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)),
callables_table)

def generate_preambles(self, target):
Expand Down
15 changes: 10 additions & 5 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict

from pymbolic import var
from pymbolic.primitives import expr_dataclass
Expand Down Expand Up @@ -580,21 +581,25 @@ def with_types(self, arg_id_to_dtype, callables_table):
index_dtype = arg_id_to_dtype[1]
result_dtypes = self.name.reduction_op.result_dtypes(scalar_dtype, # pylint: disable=no-member
index_dtype)
new_arg_id_to_dtype = arg_id_to_dtype.copy()

new_arg_id_to_dtype = constantdict(arg_id_to_dtype).mutate()
new_arg_id_to_dtype[-1] = result_dtypes[0]
new_arg_id_to_dtype[-2] = result_dtypes[1]
name_in_target = self.name.reduction_op.prefix(scalar_dtype, # pylint: disable=no-member
index_dtype) + "_op"

return self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
name_in_target=name_in_target), callables_table
return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype.finish(),
name_in_target=name_in_target),
callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
from loopy.kernel.function_interface import ValueArgDescriptor
new_arg_id_to_descr = arg_id_to_descr.copy()

new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate()
new_arg_id_to_descr[-1] = ValueArgDescriptor()

return (
self.copy(arg_id_to_descr=arg_id_to_descr),
self.copy(arg_id_to_descr=new_arg_id_to_descr.finish()),
callables_table)


Expand Down
35 changes: 19 additions & 16 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import TYPE_CHECKING, Any, Sequence, cast

import numpy as np
from constantdict import constantdict

import pymbolic.primitives as p
from cgen import (
Expand Down Expand Up @@ -530,7 +531,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
# the types provided aren't mature enough to specialize the
# callable
return (
self.copy(arg_id_to_dtype=arg_id_to_dtype),
self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)),
callables_table)

dtype = arg_id_to_dtype[0].numpy_dtype
Expand Down Expand Up @@ -563,9 +564,9 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=name,
arg_id_to_dtype={
arg_id_to_dtype=constantdict({
0: NumpyType(dtype),
-1: NumpyType(result_dtype)}),
-1: NumpyType(result_dtype)})),
callables_table)

# binary functions
Expand All @@ -580,7 +581,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
# the types provided aren't mature enough to specialize the
# callable
return (
self.copy(arg_id_to_dtype=arg_id_to_dtype),
self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)),
callables_table)

dtype = np.result_type(*[
Expand All @@ -607,7 +608,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
dtype = NumpyType(dtype)
return (
self.copy(name_in_target=name,
arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}),
arg_id_to_dtype=constantdict({-1: dtype, 0: dtype, 1: dtype})),
callables_table)
elif name in ["max", "min"]:

Expand All @@ -620,7 +621,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
# the types provided aren't resolved enough to specialize the
# callable
return (
self.copy(arg_id_to_dtype=arg_id_to_dtype),
self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)),
callables_table)

dtype = np.result_type(*[
Expand All @@ -632,9 +633,10 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=f"lpy_{name}_{dtype.name}",
arg_id_to_dtype={-1: NumpyType(dtype),
0: NumpyType(dtype),
1: NumpyType(dtype)}),
arg_id_to_dtype=constantdict({
-1: NumpyType(dtype),
0: NumpyType(dtype),
1: NumpyType(dtype)})),
callables_table)
elif name == "isnan":
for id in arg_id_to_dtype:
Expand All @@ -645,7 +647,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
# the types provided aren't mature enough to specialize the
# callable
return (
self.copy(arg_id_to_dtype=arg_id_to_dtype),
self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)),
callables_table)

dtype = arg_id_to_dtype[0].numpy_dtype
Expand All @@ -662,9 +664,9 @@ def with_types(self, arg_id_to_dtype, callables_table):
return (
self.copy(
name_in_target=name,
arg_id_to_dtype={
arg_id_to_dtype=constantdict({
0: NumpyType(dtype),
-1: NumpyType(np.int32)}),
-1: NumpyType(np.int32)})),
callables_table)

def generate_preambles(self, target):
Expand Down Expand Up @@ -713,7 +715,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
# the types provided aren't mature enough to specialize the
# callable
return (
self.copy(arg_id_to_dtype=arg_id_to_dtype),
self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)),
callables_table)

if not arg_id_to_dtype[0].is_integral():
Expand All @@ -738,9 +740,10 @@ def with_types(self, arg_id_to_dtype, callables_table):

return (
self.copy(name_in_target=name_in_target,
arg_id_to_dtype={-1: arg_id_to_dtype[1],
0: NumpyType(np.int32),
1: arg_id_to_dtype[1]}),
arg_id_to_dtype=constantdict({
-1: arg_id_to_dtype[1],
0: NumpyType(np.int32),
1: arg_id_to_dtype[1]})),
callables_table)
else:
raise NotImplementedError(f"with_types for '{name}'")
Expand Down
Loading

0 comments on commit 2f4a982

Please sign in to comment.