From 49a284f9870cca32c8c2a019a0cfdd14055e1402 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 17 Jan 2025 15:47:07 -0500 Subject: [PATCH] compiler: fix dtype for mpi routines --- devito/mpi/routines.py | 6 +++--- devito/operator/operator.py | 4 ++-- devito/passes/clusters/derivatives.py | 3 ++- devito/symbolics/extended_dtypes.py | 6 +++--- devito/symbolics/extended_sympy.py | 10 +--------- devito/symbolics/manipulation.py | 2 +- devito/symbolics/printer.py | 3 ++- devito/tools/dtypes_lowering.py | 6 +++++- devito/types/array.py | 2 ++ devito/types/dense.py | 6 +++++- tests/test_pickle.py | 4 ++-- 11 files changed, 28 insertions(+), 24 deletions(-) diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 8da418bfde..67158a621c 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -16,7 +16,7 @@ from devito.mpi import MPI from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, IndexedPointer, Macro, cast_mapper, subs_op_args) -from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype, +from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype, flatten, generator, is_integer, split) from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) @@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None): entry.sizes = (c_int*len(shape))(*shape) # Allocate the send/recv buffers - size = reduce(mul, shape)*dtype_len(self.target.dtype) - ctype = dtype_to_ctype(f.dtype) + ctype, c_scale = dtype_alloc_ctype(f.dtype) + size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype) entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype) entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index c4ce4dc9aa..38566e6a52 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1118,7 +1118,7 @@ def __setstate__(self, state): self._lib.name = soname self._allocator = default_allocator( - '%s.%s.%s' % (self._compiler.name, self._language, self._platform) + '%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform) ) @@ -1404,7 +1404,7 @@ def parse_kwargs(**kwargs): # `allocator` kwargs['allocator'] = default_allocator( - '%s.%s.%s' % (kwargs['compiler'].name, + '%s.%s.%s' % (kwargs['compiler'].__class__.__name__, kwargs['language'], kwargs['platform']) ) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index f8f339aa1e..5af92a3208 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -1,6 +1,7 @@ from functools import singledispatch from sympy import S +import numpy as np from devito.finite_differences import IndexDerivative from devito.ir import Backward, Forward, Interval, IterationSpace, Queue @@ -157,7 +158,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # NOTE: created before recurring so that we ultimately get a sound ordering try: s = reusables.pop() - assert s.dtype is dtype + assert np.can_cast(s.dtype, dtype) except KeyError: name = sregistry.make_name(prefix='r') s = Symbol(name=name, dtype=dtype) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0789c7b947..089b454b72 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,7 +1,7 @@ import ctypes import numpy as np -from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.symbolics.extended_sympy import ReservedWord, Cast, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4, ctypes_vector_mapper) @@ -64,7 +64,7 @@ class CustomType(ReservedWord): cls = type(v.upper(), (Cast,), {'_base_typ': v}) globals()[cls.__name__] = cls - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + clsp = type('%sP' % v.upper(), (Cast,), {'base': cls}) globals()[clsp.__name__] = clsp @@ -75,7 +75,7 @@ def no_dtype(kwargs): def cast_mapper(arg): try: assert len(arg) == 2 and arg[1] == '*' - return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw)) + return lambda v, **kw: Cast(v, dtype=arg[0], stars=arg[1], **no_dtype(kw)) except TypeError: return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw)) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 1b01c6db3f..9ea9611f63 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -768,14 +768,6 @@ def __str__(self): __repr__ = __str__ -# *** Casting - -class CastStar: - - def __new__(cls, base, dtype=None, ase=''): - return Cast(base, dtype=dtype, stars='*') - - # Some other utility objects Null = Macro('NULL') @@ -789,7 +781,7 @@ def __new__(cls, intype, stars=None, **kwargs): stars = stars or '' argument = Keyword(f'{intype}{stars}') newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs) - newobj.intype = intype + newobj.intype = Cast.__process_dtype__(intype) newobj.stars = stars return newobj diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index f5992ac8be..80389ead08 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -393,7 +393,7 @@ def normalize_args(args): for k, v in args.items(): try: retval[k] = sympify(v, strict=True) - except SympifyError: + except (TypeError, SympifyError): continue return retval diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index ffb91b7cf5..6ec07e05c3 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -213,7 +213,8 @@ def _print_Abs(self, expr): # Unary function, single argument arg = expr.args[0] # AOMPCC errors with abs, always use fabs - if isinstance(self.compiler, AOMPCompiler): + if isinstance(self.compiler, AOMPCompiler) and \ + not np.issubdtype(self._prec(expr), np.integer): return "fabs(%s)" % self._print(arg) func = f'{self.func_prefix(arg, abs=True)}abs{self.func_literal(arg)}' return f"{self._ns}{func}({self._print(arg)})" diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index a6ce289324..d8e2b0723b 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -21,6 +21,7 @@ # NOTE: the following is inspired by pyopencl.cltypes mapper = { + "half": np.float16, "int": np.int32, "float": np.float32, "double": np.float64 @@ -189,7 +190,8 @@ def dtype_to_mpitype(dtype): np.int32: 'MPI_INT', np.float32: 'MPI_FLOAT', np.int64: 'MPI_LONG', - np.float64: 'MPI_DOUBLE' + np.float64: 'MPI_DOUBLE', + np.float16: 'MPI_UNSIGNED_SHORT' }[dtype] @@ -222,6 +224,8 @@ class c_restrict_void_p(ctypes.c_void_p): ctypes_vector_mapper = {} for base_name, base_dtype in mapper.items(): + if base_dtype is np.float16: + continue base_ctype = dtype_to_ctype(base_dtype) for count in counts: diff --git a/devito/types/array.py b/devito/types/array.py index 62c1b62f49..cdf66db315 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -204,6 +204,8 @@ class ArrayMapped(Array): (_C_field_dmap, c_void_p), (_C_field_size, c_uint64)]})) + _C_typedata = 'struct ' + _C_structname + class ArrayObject(ArrayBasic): diff --git a/devito/types/dense.py b/devito/types/dense.py index b05beb656c..6676284848 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -792,17 +792,21 @@ def _halo_exchange(self): # Gather send data data = self._data_in_region(OWNED, d, i) sendbuf = np.ascontiguousarray(data) + if self.dtype == np.float16: + sendbuf = sendbuf.view(np.uint16) # Setup recv buffer shape = self._data_in_region(HALO, d, i.flip()).shape recvbuf = np.ndarray(shape=shape, dtype=self.dtype) + if self.dtype == np.float16: + recvbuf = recvbuf.view(np.uint16) # Communication comm.Sendrecv(sendbuf, dest=dest, recvbuf=recvbuf, source=source) # Scatter received data if recvbuf is not None and source != MPI.PROC_NULL: - self._data_in_region(HALO, d, i.flip())[:] = recvbuf + self._data_in_region(HALO, d, i.flip())[:] = recvbuf.view(self.dtype) self._is_halo_dirty = False diff --git a/tests/test_pickle.py b/tests/test_pickle.py index ef47e917fb..fc33e98965 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -567,8 +567,8 @@ def test_equation(self, pickle): eq = Eq(f, f+1, implicit_dims=xs) - pkl_eq = pickle0.dumps(eq) - new_eq = pickle0.loads(pkl_eq) + pkl_eq = pickle.dumps(eq) + new_eq = pickle.loads(pkl_eq) assert new_eq.lhs.name == f.name assert str(new_eq.rhs) == 'f(x) + 1'