diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 8da418bfde..67158a621c 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -16,7 +16,7 @@ from devito.mpi import MPI from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, IndexedPointer, Macro, cast_mapper, subs_op_args) -from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype, +from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype, flatten, generator, is_integer, split) from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, CompositeObject, CustomDimension) @@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None): entry.sizes = (c_int*len(shape))(*shape) # Allocate the send/recv buffers - size = reduce(mul, shape)*dtype_len(self.target.dtype) - ctype = dtype_to_ctype(f.dtype) + ctype, c_scale = dtype_alloc_ctype(f.dtype) + size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype) entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype) entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index c4ce4dc9aa..38566e6a52 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1118,7 +1118,7 @@ def __setstate__(self, state): self._lib.name = soname self._allocator = default_allocator( - '%s.%s.%s' % (self._compiler.name, self._language, self._platform) + '%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform) ) @@ -1404,7 +1404,7 @@ def parse_kwargs(**kwargs): # `allocator` kwargs['allocator'] = default_allocator( - '%s.%s.%s' % (kwargs['compiler'].name, + '%s.%s.%s' % (kwargs['compiler'].__class__.__name__, kwargs['language'], kwargs['platform']) ) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index f8f339aa1e..5af92a3208 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -1,6 +1,7 @@ from functools import singledispatch from sympy import S +import numpy as np from devito.finite_differences import IndexDerivative from devito.ir import Backward, Forward, Interval, IterationSpace, Queue @@ -157,7 +158,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # NOTE: created before recurring so that we ultimately get a sound ordering try: s = reusables.pop() - assert s.dtype is dtype + assert np.can_cast(s.dtype, dtype) except KeyError: name = sregistry.make_name(prefix='r') s = Symbol(name=name, dtype=dtype) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 0789c7b947..089b454b72 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -1,7 +1,7 @@ import ctypes import numpy as np -from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.symbolics.extended_sympy import ReservedWord, Cast, ValueLimit from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4, ctypes_vector_mapper) @@ -64,7 +64,7 @@ class CustomType(ReservedWord): cls = type(v.upper(), (Cast,), {'_base_typ': v}) globals()[cls.__name__] = cls - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + clsp = type('%sP' % v.upper(), (Cast,), {'base': cls}) globals()[clsp.__name__] = clsp @@ -75,7 +75,7 @@ def no_dtype(kwargs): def cast_mapper(arg): try: assert len(arg) == 2 and arg[1] == '*' - return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw)) + return lambda v, **kw: Cast(v, dtype=arg[0], stars=arg[1], **no_dtype(kw)) except TypeError: return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw)) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 1b01c6db3f..eb466604e7 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -768,14 +768,6 @@ def __str__(self): __repr__ = __str__ -# *** Casting - -class CastStar: - - def __new__(cls, base, dtype=None, ase=''): - return Cast(base, dtype=dtype, stars='*') - - # Some other utility objects Null = Macro('NULL') diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 6ca746adcc..36f702a03b 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -393,7 +393,7 @@ def normalize_args(args): for k, v in args.items(): try: retval[k] = sympify(v, strict=True) - except SympifyError: + except (TypeError, SympifyError): continue return retval diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index a6ce289324..d8e2b0723b 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -21,6 +21,7 @@ # NOTE: the following is inspired by pyopencl.cltypes mapper = { + "half": np.float16, "int": np.int32, "float": np.float32, "double": np.float64 @@ -189,7 +190,8 @@ def dtype_to_mpitype(dtype): np.int32: 'MPI_INT', np.float32: 'MPI_FLOAT', np.int64: 'MPI_LONG', - np.float64: 'MPI_DOUBLE' + np.float64: 'MPI_DOUBLE', + np.float16: 'MPI_UNSIGNED_SHORT' }[dtype] @@ -222,6 +224,8 @@ class c_restrict_void_p(ctypes.c_void_p): ctypes_vector_mapper = {} for base_name, base_dtype in mapper.items(): + if base_dtype is np.float16: + continue base_ctype = dtype_to_ctype(base_dtype) for count in counts: diff --git a/devito/types/dense.py b/devito/types/dense.py index b05beb656c..6676284848 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -792,17 +792,21 @@ def _halo_exchange(self): # Gather send data data = self._data_in_region(OWNED, d, i) sendbuf = np.ascontiguousarray(data) + if self.dtype == np.float16: + sendbuf = sendbuf.view(np.uint16) # Setup recv buffer shape = self._data_in_region(HALO, d, i.flip()).shape recvbuf = np.ndarray(shape=shape, dtype=self.dtype) + if self.dtype == np.float16: + recvbuf = recvbuf.view(np.uint16) # Communication comm.Sendrecv(sendbuf, dest=dest, recvbuf=recvbuf, source=source) # Scatter received data if recvbuf is not None and source != MPI.PROC_NULL: - self._data_in_region(HALO, d, i.flip())[:] = recvbuf + self._data_in_region(HALO, d, i.flip())[:] = recvbuf.view(self.dtype) self._is_halo_dirty = False