Skip to content

Commit

Permalink
compiler: add switch for static_cast vs reinterpret_cast
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 30, 2025
1 parent 5972958 commit dd4d9cc
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
8 changes: 5 additions & 3 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def __init__(self):

fields = {'cc', 'ld'}
default_cpp = False
_cxxstd = 'c++14'
_cstd = 'c99'

def __init__(self, **kwargs):
_name = kwargs.pop('name', self.__class__.__name__)
Expand Down Expand Up @@ -256,7 +258,7 @@ def version(self):

@property
def std(self):
return 'c++14' if self._cpp else 'c99'
return self._cxxstd if self._cpp else self._cstd

def get_version(self):
result, stdout, stderr = call_capture_output((self.cc, "--version"))
Expand Down Expand Up @@ -497,7 +499,7 @@ def __init_finalize__(self, **kwargs):
language = kwargs.pop('language', configuration['language'])
platform = kwargs.pop('platform', configuration['platform'])

if platform is NvidiaDevice:
if isinstance(platform, NvidiaDevice):
self.cflags.remove(f'-std={self.std}')
# Add flags for OpenMP offloading
if language in ['C', 'openmp']:
Expand Down Expand Up @@ -565,7 +567,7 @@ def __init_finalize__(self, **kwargs):
if not configuration['safe-math']:
self.cflags.append('-ffast-math')

if platform is NvidiaDevice:
if isinstance(platform, NvidiaDevice):
self.cflags.remove(f'-std={self.std}')
elif platform is AMDGPUX:
self.cflags.remove(f'-std={self.std}')
Expand Down
3 changes: 2 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,8 @@ def parse_kwargs(**kwargs):
kwargs['compiler'] = configuration['compiler'].__new_with__()

# Make sure compiler and language are compatible
if kwargs['compiler']._cpp and kwargs['language'] in ['C', 'openmp']:
if compiler is not None and kwargs['compiler']._cpp and \
kwargs['language'] in ['C', 'openmp']:
kwargs['language'] = 'CXX' if kwargs['language'] == 'C' else 'CXXopenmp'
if 'CXX' in kwargs['language'] and not kwargs['compiler']._cpp:
kwargs['compiler'] = kwargs['compiler'].__new_with__(cpp=True)
Expand Down
3 changes: 2 additions & 1 deletion devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,6 @@ def _print_Cast(self, expr):
tstr = self._print(expr._C_ctype)
if 'void' in tstr:
return super()._print_Cast(expr)
cast = f'static_cast<{tstr}{self._print(expr.stars)}>'
caster = 'reinterpret_cast' if expr.reinterpret else 'static_cast'
cast = f'{caster}<{tstr}{self._print(expr.stars)}>'
return self._print_UnaryOp(expr, op=cast, parenthesize=True)
4 changes: 2 additions & 2 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ def place_devptr(self, iet, **kwargs):

dpf = List(body=[
self.lang.mapper['map-serial-present'](hp, tdp),
Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp)))
Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp, reinterpret=True)))
])

ffp = FieldFromPointer(f._C_field_dmap, f._C_symbol)
ctdp = cast_mapper((hp.dtype, '*'))(tdp)
ctdp = cast_mapper((hp.dtype, '*'))(tdp, reinterpret=True)
cast = DummyExpr(ffp, ctdp)

ret = Return(ctdp)
Expand Down
9 changes: 7 additions & 2 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ class Cast(UnaryOp):
"""

__rargs__ = ('base', )
__rkwargs__ = ('dtype', 'stars')
__rkwargs__ = ('dtype', 'stars', 'reinterpret')

def __new__(cls, base, dtype=None, stars=None, **kwargs):
def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs):
try:
if issubclass(dtype, np.generic) and sympify(base).is_Number:
base = sympify(dtype(base))
Expand All @@ -397,6 +397,7 @@ def __new__(cls, base, dtype=None, stars=None, **kwargs):
obj = super().__new__(cls, base)
obj._stars = stars or ''
obj._dtype = dtype
obj._reinterpret = reinterpret
return obj

def _hashable_content(self):
Expand All @@ -412,6 +413,10 @@ def stars(self):
def dtype(self):
return self._dtype

@property
def reinterpret(self):
return self._reinterpret

@property
def _C_ctype(self):
ctype = ctypes_vector_mapper.get(self.dtype, self.dtype)
Expand Down

0 comments on commit dd4d9cc

Please sign in to comment.