From 2f4a982b819510d164b1f90cc8e8a69e987116d7 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Wed, 29 Jan 2025 15:51:22 +0200 Subject: [PATCH] fix: pass immutable arg_id_to_dtype to InKernelCallables --- examples/python/call-external.py | 8 ++-- loopy/kernel/function_interface.py | 23 ++++++----- loopy/library/function.py | 14 ++++--- loopy/library/random123.py | 19 ++++++---- loopy/library/reduction.py | 15 +++++--- loopy/target/c/__init__.py | 35 +++++++++-------- loopy/target/opencl.py | 61 ++++++++++++++++++------------ loopy/target/pyopencl.py | 13 ++++--- test/library_for_test.py | 7 ++-- test/testlib.py | 9 +++-- 10 files changed, 122 insertions(+), 82 deletions(-) diff --git a/examples/python/call-external.py b/examples/python/call-external.py index ad5615c7b..a2c19b855 100644 --- a/examples/python/call-external.py +++ b/examples/python/call-external.py @@ -1,4 +1,5 @@ import numpy as np +from constantdict import constantdict import loopy as lp from loopy.diagnostic import LoopyError @@ -30,9 +31,10 @@ def with_types(self, arg_id_to_dtype, callables_table): "types") return (self.copy(name_in_target=name_in_target, - arg_id_to_dtype={0: vec_dtype, - 1: vec_dtype, - -1: vec_dtype}), + arg_id_to_dtype=constantdict({ + 0: vec_dtype, + 1: vec_dtype, + -1: vec_dtype})), callables_table) def with_descrs(self, arg_id_to_descr, callables_table): diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 799e5d91b..c0e67ffcc 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -565,9 +565,10 @@ def with_types(self, arg_id_to_dtype, callables_table): "the function %s." % (self.name)) def with_descrs(self, arg_id_to_descr, clbl_inf_ctx): + new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate() + new_arg_id_to_descr[-1] = ValueArgDescriptor() - arg_id_to_descr[-1] = ValueArgDescriptor() - return (self.copy(arg_id_to_descr=arg_id_to_descr), + return (self.copy(arg_id_to_descr=new_arg_id_to_descr.finish()), clbl_inf_ctx) def get_hw_axes_sizes(self, arg_id_to_arg, space, callables_table): @@ -782,6 +783,7 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx): # arg_id_to_descr expressions provided are from the caller's namespace, # need to register + new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate() kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel) kw_to_callee_idx = {arg.name: i @@ -789,7 +791,7 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx): new_args = self.subkernel.args[:] - for arg_id, descr in arg_id_to_descr.items(): + for arg_id, descr in new_arg_id_to_descr.items(): if isinstance(arg_id, int): arg_id = pos_to_kw[arg_id] @@ -837,20 +839,20 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx): for arg in subkernel.args: kw = arg.name if isinstance(arg, ArrayBase): - arg_id_to_descr[kw] = ( + new_arg_id_to_descr[kw] = ( ArrayArgDescriptor(shape=arg.shape, dim_tags=arg.dim_tags, address_space=arg.address_space)) else: assert isinstance(arg, ValueArg) - arg_id_to_descr[kw] = ValueArgDescriptor() + new_arg_id_to_descr[kw] = ValueArgDescriptor() - arg_id_to_descr[kw_to_pos[kw]] = arg_id_to_descr[kw] + new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw] # }}} return (self.copy(subkernel=subkernel, - arg_id_to_descr=constantdict(arg_id_to_descr)), + arg_id_to_descr=new_arg_id_to_descr.finish()), clbl_inf_ctx) def with_added_arg(self, arg_dtype, arg_descr): @@ -868,6 +870,7 @@ def with_added_arg(self, arg_dtype, arg_descr): arg_id_to_dtype = {} else: arg_id_to_dtype = dict(self.arg_id_to_dtype) + if self.arg_id_to_descr is None: arg_id_to_descr = {} else: @@ -879,8 +882,8 @@ def with_added_arg(self, arg_dtype, arg_descr): arg_id_to_descr[kw_to_pos[var_name]] = arg_descr return (self.copy(subkernel=subknl, - arg_id_to_dtype=arg_id_to_dtype, - arg_id_to_descr=arg_id_to_descr), + arg_id_to_dtype=constantdict(arg_id_to_dtype), + arg_id_to_descr=constantdict(arg_id_to_descr)), var_name) else: @@ -902,7 +905,7 @@ def with_packing_for_args(self): address_space=AddressSpace.GLOBAL) return self.copy(subkernel=self.subkernel, - arg_id_to_descr=arg_id_to_descr) + arg_id_to_descr=constantdict(arg_id_to_descr)) def get_used_hw_axes(self, callables_table): gsize, lsize = self.subkernel.get_grid_size_upper_bounds(callables_table, diff --git a/loopy/library/function.py b/loopy/library/function.py index 8b61ad41a..9840c2571 100644 --- a/loopy/library/function.py +++ b/loopy/library/function.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING import numpy as np +from constantdict import constantdict from loopy.diagnostic import LoopyError from loopy.kernel.function_interface import ScalarCallable @@ -38,21 +39,22 @@ class MakeTupleCallable(ScalarCallable): def with_types(self, arg_id_to_dtype, callables_table): - new_arg_id_to_dtype = arg_id_to_dtype.copy() + new_arg_id_to_dtype = constantdict(arg_id_to_dtype).mutate() for i in range(len(arg_id_to_dtype)): if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None: new_arg_id_to_dtype[-i-1] = new_arg_id_to_dtype[i] - return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype, - name_in_target="loopy_make_tuple"), callables_table) + return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype.finish(), + name_in_target="loopy_make_tuple"), + callables_table) def with_descrs(self, arg_id_to_descr, callables_table): from loopy.kernel.function_interface import ValueArgDescriptor new_arg_id_to_descr = {(id, ValueArgDescriptor()): - (-id-1, ValueArgDescriptor()) for id in arg_id_to_descr.keys()} + (-id-1, ValueArgDescriptor()) for id in arg_id_to_descr} return ( - self.copy(arg_id_to_descr=new_arg_id_to_descr), + self.copy(arg_id_to_descr=constantdict(new_arg_id_to_descr)), callables_table) @@ -63,7 +65,7 @@ def with_types(self, arg_id_to_dtype, callables_table): if dtype is not None} new_arg_id_to_dtype[-1] = NumpyType(np.int32) - return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype), + return (self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype)), callables_table) def emit_call(self, expression_to_code_mapper, expression, target): diff --git a/loopy/library/random123.py b/loopy/library/random123.py index f65fa7600..85c0839af 100644 --- a/loopy/library/random123.py +++ b/loopy/library/random123.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING import numpy as np +from constantdict import constantdict from mako.template import Template from pymbolic.typing import not_none @@ -221,8 +222,8 @@ def with_types(self, arg_id_to_dtype, callables_table): new_arg_id_to_dtype = {-1: ctr_dtype, -2: ctr_dtype, 0: ctr_dtype, 1: key_dtype} return ( - self.copy(arg_id_to_dtype=new_arg_id_to_dtype, - name_in_target=fn+"_gen"), + self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype), + name_in_target=fn+"_gen"), callables_table) elif name == fn + "_f32": @@ -230,18 +231,22 @@ def with_types(self, arg_id_to_dtype, callables_table): rng_variant.width), -2: ctr_dtype, 0: ctr_dtype, 1: key_dtype} - return self.copy(arg_id_to_dtype=new_arg_id_to_dtype, - name_in_target=name), callables_table + return ( + self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype), + name_in_target=name), + callables_table) elif name == fn + "_f64": new_arg_id_to_dtype = {-1: target.vector_dtype(NumpyType(np.float64), rng_variant.width), -2: ctr_dtype, 0: ctr_dtype, 1: key_dtype} - return self.copy(arg_id_to_dtype=new_arg_id_to_dtype, - name_in_target=name), callables_table + return ( + self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype), + name_in_target=name), + callables_table) - return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + return (self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) def generate_preambles(self, target): diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 6ddc3fb86..d27dadee2 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING import numpy as np +from constantdict import constantdict from pymbolic import var from pymbolic.primitives import expr_dataclass @@ -580,21 +581,25 @@ def with_types(self, arg_id_to_dtype, callables_table): index_dtype = arg_id_to_dtype[1] result_dtypes = self.name.reduction_op.result_dtypes(scalar_dtype, # pylint: disable=no-member index_dtype) - new_arg_id_to_dtype = arg_id_to_dtype.copy() + + new_arg_id_to_dtype = constantdict(arg_id_to_dtype).mutate() new_arg_id_to_dtype[-1] = result_dtypes[0] new_arg_id_to_dtype[-2] = result_dtypes[1] name_in_target = self.name.reduction_op.prefix(scalar_dtype, # pylint: disable=no-member index_dtype) + "_op" - return self.copy(arg_id_to_dtype=new_arg_id_to_dtype, - name_in_target=name_in_target), callables_table + return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype.finish(), + name_in_target=name_in_target), + callables_table) def with_descrs(self, arg_id_to_descr, callables_table): from loopy.kernel.function_interface import ValueArgDescriptor - new_arg_id_to_descr = arg_id_to_descr.copy() + + new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate() new_arg_id_to_descr[-1] = ValueArgDescriptor() + return ( - self.copy(arg_id_to_descr=arg_id_to_descr), + self.copy(arg_id_to_descr=new_arg_id_to_descr.finish()), callables_table) diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index 06cc208ed..c170fb323 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, Sequence, cast import numpy as np +from constantdict import constantdict import pymbolic.primitives as p from cgen import ( @@ -530,7 +531,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0].numpy_dtype @@ -563,9 +564,9 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=name, - arg_id_to_dtype={ + arg_id_to_dtype=constantdict({ 0: NumpyType(dtype), - -1: NumpyType(result_dtype)}), + -1: NumpyType(result_dtype)})), callables_table) # binary functions @@ -580,7 +581,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = np.result_type(*[ @@ -607,7 +608,7 @@ def with_types(self, arg_id_to_dtype, callables_table): dtype = NumpyType(dtype) return ( self.copy(name_in_target=name, - arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}), + arg_id_to_dtype=constantdict({-1: dtype, 0: dtype, 1: dtype})), callables_table) elif name in ["max", "min"]: @@ -620,7 +621,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't resolved enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = np.result_type(*[ @@ -632,9 +633,10 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=f"lpy_{name}_{dtype.name}", - arg_id_to_dtype={-1: NumpyType(dtype), - 0: NumpyType(dtype), - 1: NumpyType(dtype)}), + arg_id_to_dtype=constantdict({ + -1: NumpyType(dtype), + 0: NumpyType(dtype), + 1: NumpyType(dtype)})), callables_table) elif name == "isnan": for id in arg_id_to_dtype: @@ -645,7 +647,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0].numpy_dtype @@ -662,9 +664,9 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy( name_in_target=name, - arg_id_to_dtype={ + arg_id_to_dtype=constantdict({ 0: NumpyType(dtype), - -1: NumpyType(np.int32)}), + -1: NumpyType(np.int32)})), callables_table) def generate_preambles(self, target): @@ -713,7 +715,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) if not arg_id_to_dtype[0].is_integral(): @@ -738,9 +740,10 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=name_in_target, - arg_id_to_dtype={-1: arg_id_to_dtype[1], - 0: NumpyType(np.int32), - 1: arg_id_to_dtype[1]}), + arg_id_to_dtype=constantdict({ + -1: arg_id_to_dtype[1], + 0: NumpyType(np.int32), + 1: arg_id_to_dtype[1]})), callables_table) else: raise NotImplementedError(f"with_types for '{name}'") diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index d14dd9e30..3fe951c4e 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Literal, Sequence import numpy as np +from constantdict import constantdict from pymbolic import var from pytools import memoize_method @@ -208,7 +209,7 @@ def with_types(self, arg_id_to_dtype, callables_table): if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0].numpy_dtype @@ -217,9 +218,10 @@ def with_types(self, arg_id_to_dtype, callables_table): # OpenCL C 2.2, Section 6.13.3: abs returns *u*gentype from loopy.types import to_unsigned_dtype return (self.copy(name_in_target=name, - arg_id_to_dtype={ + arg_id_to_dtype=constantdict({ 0: NumpyType(dtype), - -1: NumpyType(to_unsigned_dtype(dtype))}), + -1: NumpyType(to_unsigned_dtype(dtype)) + })), callables_table) elif dtype.kind == "f": name = "fabs" @@ -237,7 +239,7 @@ def with_types(self, arg_id_to_dtype, callables_table): if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0] @@ -251,8 +253,10 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=name, - arg_id_to_dtype={0: NumpyType(dtype), -1: - NumpyType(dtype)}), + arg_id_to_dtype=constantdict({ + 0: NumpyType(dtype), + -1: NumpyType(dtype) + })), callables_table) # }}} @@ -270,7 +274,7 @@ def with_types(self, arg_id_to_dtype, callables_table): if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype or ( arg_id_to_dtype[0] is None or arg_id_to_dtype[1] is None): return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = np.result_type(*[ @@ -283,7 +287,9 @@ def with_types(self, arg_id_to_dtype, callables_table): dtype = NumpyType(dtype) return ( self.copy(name_in_target=name, - arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}), + arg_id_to_dtype=constantdict({ + -1: dtype, 0: dtype, 1: dtype + })), callables_table) elif name in ["max", "min"]: @@ -292,7 +298,7 @@ def with_types(self, arg_id_to_dtype, callables_table): raise LoopyError("%s can take only 2 arguments." % name) if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype: return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) common_dtype = np.result_type(*[ dtype.numpy_dtype for id, dtype in arg_id_to_dtype.items() @@ -305,7 +311,9 @@ def with_types(self, arg_id_to_dtype, callables_table): dtype = NumpyType(common_dtype) return ( self.copy(name_in_target=name, - arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype}), + arg_id_to_dtype=constantdict({ + -1: dtype, 0: dtype, 1: dtype + })), callables_table) else: # Unsupported type. @@ -322,14 +330,15 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0] scalar_dtype, _offset, _field_name = dtype.numpy_dtype.fields["s0"] return ( - self.copy(name_in_target=name, arg_id_to_dtype={-1: - NumpyType(scalar_dtype), 0: dtype, 1: dtype}), + self.copy(name_in_target=name, arg_id_to_dtype=constantdict({ + -1: NumpyType(scalar_dtype), 0: dtype, 1: dtype + })), callables_table) elif name == "pow": @@ -352,8 +361,11 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=name, - arg_id_to_dtype={-1: result_dtype, - 0: common_dtype, 1: common_dtype}), + arg_id_to_dtype=constantdict({ + -1: result_dtype, + 0: common_dtype, + 1: common_dtype + })), callables_table) elif name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: @@ -368,7 +380,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = np.result_type(*[ @@ -379,8 +391,9 @@ def with_types(self, arg_id_to_dtype, callables_table): raise LoopyError("%s does not support complex numbers" % name) - updated_arg_id_to_dtype = {id: NumpyType(dtype) for id in range(-1, - num_args)} + updated_arg_id_to_dtype = constantdict({ + id: NumpyType(dtype) for id in range(-1, num_args) + }) return ( self.copy(name_in_target=name, @@ -400,23 +413,23 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) - updated_arg_id_to_dtype = {id: NumpyType(dtype) for id in - range(count)} + updated_arg_id_to_dtype = {id: NumpyType(dtype) for id in range(count)} updated_arg_id_to_dtype[-1] = OpenCLTarget().vector_dtype( NumpyType(dtype), count) return ( - self.copy(name_in_target="(%s%d) " % (base_tp_name, count), - arg_id_to_dtype=updated_arg_id_to_dtype), + self.copy( + name_in_target="(%s%d) " % (base_tp_name, count), + arg_id_to_dtype=constantdict(updated_arg_id_to_dtype)), callables_table) # does not satisfy any of the conditions needed for specialization. # hence just returning a copy of the callable. return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 9add453d7..ae923f1fe 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -30,6 +30,7 @@ from warnings import warn import numpy as np +from constantdict import constantdict import pymbolic.primitives as p from cgen import ( @@ -98,7 +99,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0] @@ -114,8 +115,10 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=f"{tpname}_{name}", - arg_id_to_dtype={0: dtype, -1: NumpyType( - np.dtype(dtype.numpy_dtype.type(0).real))}), + arg_id_to_dtype=constantdict({ + 0: dtype, + -1: NumpyType(np.dtype(dtype.numpy_dtype.type(0).real)) + })), callables_table) if name in ["real", "imag", "conj"]: @@ -124,7 +127,7 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy( name_in_target=f"_lpy_{name}_{tpname}", - arg_id_to_dtype={0: dtype, -1: dtype}), + arg_id_to_dtype=constantdict({0: dtype, -1: dtype})), callables_table) if name in ["sqrt", "exp", "log", @@ -142,7 +145,7 @@ def with_types(self, arg_id_to_dtype, callables_table): return ( self.copy(name_in_target=f"{tpname}_{name}", - arg_id_to_dtype={0: dtype, -1: dtype}), + arg_id_to_dtype=constantdict({0: dtype, -1: dtype})), callables_table) # fall back to pure OpenCL for real-valued arguments diff --git a/test/library_for_test.py b/test/library_for_test.py index 47bca082a..2b24e1595 100644 --- a/test/library_for_test.py +++ b/test/library_for_test.py @@ -1,4 +1,5 @@ import numpy as np +from constantdict import constantdict import loopy as lp @@ -8,7 +9,7 @@ def with_types(self, arg_id_to_dtype, callables): if len(arg_id_to_dtype) != 0: raise RuntimeError("'f' cannot take any inputs.") - return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + return (self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype), name_in_target="f"), callables) @@ -16,7 +17,7 @@ def with_descrs(self, arg_id_to_descr, callables): if len(arg_id_to_descr) != 0: raise RuntimeError("'f' cannot take any inputs.") - return (self.copy(arg_id_to_descr=arg_id_to_descr), + return (self.copy(arg_id_to_descr=constantdict(arg_id_to_descr)), callables) def generate_preambles(self, target): @@ -39,7 +40,7 @@ def with_types(self, arg_id_to_dtype, callables): if input_dtype.numpy_dtype != np.float32: raise RuntimeError("'f' only supports f32.") - return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + return (self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype), name_in_target="f"), callables) diff --git a/test/testlib.py b/test/testlib.py index f8a491ac1..a7179d796 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -1,4 +1,5 @@ import numpy as np +from constantdict import constantdict import loopy as lp @@ -28,7 +29,7 @@ def with_types(self, arg_id_to_dtype, callables_table): # the types provided aren't mature enough to specialize the # callable return ( - self.copy(arg_id_to_dtype=arg_id_to_dtype), + self.copy(arg_id_to_dtype=constantdict(arg_id_to_dtype)), callables_table) dtype = arg_id_to_dtype[0].numpy_dtype @@ -48,8 +49,10 @@ def with_types(self, arg_id_to_dtype, callables_table): from loopy.types import NumpyType return ( self.copy(name_in_target=name_in_target, - arg_id_to_dtype={0: NumpyType(dtype), -1: - NumpyType(dtype)}), + arg_id_to_dtype=constantdict({ + 0: NumpyType(dtype), + -1: NumpyType(dtype) + })), callables_table)