Skip to content

Commit

Permalink
compiler: fix dtype for mpi routines
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 17, 2025
1 parent dcd0bd4 commit d217047
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 20 deletions.
6 changes: 3 additions & 3 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_to_ctype,
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, dtype_alloc_ctype,
flatten, generator, is_integer, split)
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
CompositeObject, CustomDimension)
Expand Down Expand Up @@ -1204,8 +1204,8 @@ def _arg_defaults(self, allocator, alias, args=None):
entry.sizes = (c_int*len(shape))(*shape)

# Allocate the send/recv buffers
size = reduce(mul, shape)*dtype_len(self.target.dtype)
ctype = dtype_to_ctype(f.dtype)
ctype, c_scale = dtype_alloc_ctype(f.dtype)
size = int(reduce(mul, shape) * c_scale) * dtype_len(self.target.dtype)
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype)
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype)

Expand Down
4 changes: 2 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def __setstate__(self, state):
self._lib.name = soname

self._allocator = default_allocator(
'%s.%s.%s' % (self._compiler.name, self._language, self._platform)
'%s.%s.%s' % (self._compiler.__class__.name, self._language, self._platform)
)


Expand Down Expand Up @@ -1404,7 +1404,7 @@ def parse_kwargs(**kwargs):

# `allocator`
kwargs['allocator'] = default_allocator(
'%s.%s.%s' % (kwargs['compiler'].name,
'%s.%s.%s' % (kwargs['compiler'].__class__.__name__,
kwargs['language'],
kwargs['platform'])
)
Expand Down
3 changes: 2 additions & 1 deletion devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

from sympy import S
import numpy as np

from devito.finite_differences import IndexDerivative
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
Expand Down Expand Up @@ -157,7 +158,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
# NOTE: created before recurring so that we ultimately get a sound ordering
try:
s = reusables.pop()
assert s.dtype is dtype
assert np.can_cast(s.dtype, dtype)
except KeyError:
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=dtype)
Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
import numpy as np

from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit
from devito.symbolics.extended_sympy import ReservedWord, Cast, ValueLimit
from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa
int2, int3, int4, ctypes_vector_mapper)

Expand Down Expand Up @@ -64,7 +64,7 @@ class CustomType(ReservedWord):
cls = type(v.upper(), (Cast,), {'_base_typ': v})
globals()[cls.__name__] = cls

clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls})
clsp = type('%sP' % v.upper(), (Cast,), {'base': cls})
globals()[clsp.__name__] = clsp


Expand All @@ -75,7 +75,7 @@ def no_dtype(kwargs):
def cast_mapper(arg):
try:
assert len(arg) == 2 and arg[1] == '*'
return lambda v, **kw: CastStar(v, dtype=arg[0], **no_dtype(kw))
return lambda v, **kw: Cast(v, dtype=arg[0], stars=arg[1], **no_dtype(kw))
except TypeError:
return lambda v, **kw: Cast(v, dtype=arg, **no_dtype(kw))

Expand Down
8 changes: 0 additions & 8 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,14 +768,6 @@ def __str__(self):
__repr__ = __str__


# *** Casting

class CastStar:

def __new__(cls, base, dtype=None, ase=''):
return Cast(base, dtype=dtype, stars='*')


# Some other utility objects
Null = Macro('NULL')

Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def normalize_args(args):
for k, v in args.items():
try:
retval[k] = sympify(v, strict=True)
except SympifyError:
except (TypeError, SympifyError):
continue

return retval
Expand Down
6 changes: 5 additions & 1 deletion devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# NOTE: the following is inspired by pyopencl.cltypes

mapper = {
"half": np.float16,
"int": np.int32,
"float": np.float32,
"double": np.float64
Expand Down Expand Up @@ -189,7 +190,8 @@ def dtype_to_mpitype(dtype):
np.int32: 'MPI_INT',
np.float32: 'MPI_FLOAT',
np.int64: 'MPI_LONG',
np.float64: 'MPI_DOUBLE'
np.float64: 'MPI_DOUBLE',
np.float16: 'MPI_UNSIGNED_SHORT'
}[dtype]


Expand Down Expand Up @@ -222,6 +224,8 @@ class c_restrict_void_p(ctypes.c_void_p):

ctypes_vector_mapper = {}
for base_name, base_dtype in mapper.items():
if base_dtype is np.float16:
continue
base_ctype = dtype_to_ctype(base_dtype)

for count in counts:
Expand Down
6 changes: 5 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,21 @@ def _halo_exchange(self):
# Gather send data
data = self._data_in_region(OWNED, d, i)
sendbuf = np.ascontiguousarray(data)
if self.dtype == np.float16:
sendbuf = sendbuf.view(np.uint16)

# Setup recv buffer
shape = self._data_in_region(HALO, d, i.flip()).shape
recvbuf = np.ndarray(shape=shape, dtype=self.dtype)
if self.dtype == np.float16:
recvbuf = recvbuf.view(np.uint16)

# Communication
comm.Sendrecv(sendbuf, dest=dest, recvbuf=recvbuf, source=source)

# Scatter received data
if recvbuf is not None and source != MPI.PROC_NULL:
self._data_in_region(HALO, d, i.flip())[:] = recvbuf
self._data_in_region(HALO, d, i.flip())[:] = recvbuf.view(self.dtype)

self._is_halo_dirty = False

Expand Down

0 comments on commit d217047

Please sign in to comment.