From dd4d9ccc7fd05239ed3894b88bc0a6fb2bcddca0 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 Jan 2025 10:40:21 -0500 Subject: [PATCH] compiler: add switch for static_cast vs reinterpret_cast --- devito/arch/compiler.py | 8 +++++--- devito/operator/operator.py | 3 ++- devito/passes/iet/languages/CXX.py | 3 ++- devito/passes/iet/languages/openacc.py | 4 ++-- devito/symbolics/extended_sympy.py | 9 +++++++-- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 01cf9c3261..66b3c880d4 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -181,6 +181,8 @@ def __init__(self): fields = {'cc', 'ld'} default_cpp = False + _cxxstd = 'c++14' + _cstd = 'c99' def __init__(self, **kwargs): _name = kwargs.pop('name', self.__class__.__name__) @@ -256,7 +258,7 @@ def version(self): @property def std(self): - return 'c++14' if self._cpp else 'c99' + return self._cxxstd if self._cpp else self._cstd def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) @@ -497,7 +499,7 @@ def __init_finalize__(self, **kwargs): language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) - if platform is NvidiaDevice: + if isinstance(platform, NvidiaDevice): self.cflags.remove(f'-std={self.std}') # Add flags for OpenMP offloading if language in ['C', 'openmp']: @@ -565,7 +567,7 @@ def __init_finalize__(self, **kwargs): if not configuration['safe-math']: self.cflags.append('-ffast-math') - if platform is NvidiaDevice: + if isinstance(platform, NvidiaDevice): self.cflags.remove(f'-std={self.std}') elif platform is AMDGPUX: self.cflags.remove(f'-std={self.std}') diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 2dd6eace02..0dc74d2c3e 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1410,7 +1410,8 @@ def parse_kwargs(**kwargs): kwargs['compiler'] = configuration['compiler'].__new_with__() # Make sure compiler and language are compatible - if kwargs['compiler']._cpp and kwargs['language'] in ['C', 'openmp']: + if compiler is not None and kwargs['compiler']._cpp and \ + kwargs['language'] in ['C', 'openmp']: kwargs['language'] = 'CXX' if kwargs['language'] == 'C' else 'CXXopenmp' if 'CXX' in kwargs['language'] and not kwargs['compiler']._cpp: kwargs['compiler'] = kwargs['compiler'].__new_with__(cpp=True) diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 9089962b49..5554308a0c 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -91,5 +91,6 @@ def _print_Cast(self, expr): tstr = self._print(expr._C_ctype) if 'void' in tstr: return super()._print_Cast(expr) - cast = f'static_cast<{tstr}{self._print(expr.stars)}>' + caster = 'reinterpret_cast' if expr.reinterpret else 'static_cast' + cast = f'{caster}<{tstr}{self._print(expr.stars)}>' return self._print_UnaryOp(expr, op=cast, parenthesize=True) diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 25d8e6e478..5d80428eb6 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -236,11 +236,11 @@ def place_devptr(self, iet, **kwargs): dpf = List(body=[ self.lang.mapper['map-serial-present'](hp, tdp), - Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp))) + Block(body=DummyExpr(tdp, cast_mapper(tdp.dtype)(hp, reinterpret=True))) ]) ffp = FieldFromPointer(f._C_field_dmap, f._C_symbol) - ctdp = cast_mapper((hp.dtype, '*'))(tdp) + ctdp = cast_mapper((hp.dtype, '*'))(tdp, reinterpret=True) cast = DummyExpr(ffp, ctdp) ret = Return(ctdp) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index cb1d835c08..8281902e78 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -384,9 +384,9 @@ class Cast(UnaryOp): """ __rargs__ = ('base', ) - __rkwargs__ = ('dtype', 'stars') + __rkwargs__ = ('dtype', 'stars', 'reinterpret') - def __new__(cls, base, dtype=None, stars=None, **kwargs): + def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs): try: if issubclass(dtype, np.generic) and sympify(base).is_Number: base = sympify(dtype(base)) @@ -397,6 +397,7 @@ def __new__(cls, base, dtype=None, stars=None, **kwargs): obj = super().__new__(cls, base) obj._stars = stars or '' obj._dtype = dtype + obj._reinterpret = reinterpret return obj def _hashable_content(self): @@ -412,6 +413,10 @@ def stars(self): def dtype(self): return self._dtype + @property + def reinterpret(self): + return self._reinterpret + @property def _C_ctype(self): ctype = ctypes_vector_mapper.get(self.dtype, self.dtype)