From 5b0eaff1e293c284e345faaefc0da468c8ef153b Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 16 Jan 2025 13:57:35 -0500 Subject: [PATCH] symbolics: use std namespace for c++ --- devito/passes/iet/languages/C.py | 8 +++ devito/passes/iet/languages/CXX.py | 7 +- devito/symbolics/extended_sympy.py | 8 ++- devito/symbolics/inspection.py | 5 +- devito/symbolics/printer.py | 102 +++++++++++------------------ tests/test_symbolics.py | 3 +- 6 files changed, 62 insertions(+), 71 deletions(-) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 7efdaa44ff..4285a673e1 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,3 +1,5 @@ +import numpy as np + from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator @@ -42,3 +44,9 @@ class CDevitoPrinter(_DevitoPrinterBase): type_mappings = {**_DevitoPrinterBase.type_mappings, c_complex: 'float _Complex', c_double_complex: 'double _Complex'} + + _func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c', + np.complex128: 'c'} + + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index aa9e9118de..b261c89213 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -70,9 +70,14 @@ class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): _default_settings = {**_DevitoPrinterBase._default_settings, **CXX11CodePrinter._default_settings} + _ns = "std::" # These cannot go through _print_xxx because they are classes not # instances - type_mappings = {c_complex: 'std::complex', + type_mappings = {**_DevitoPrinterBase.type_mappings, + c_complex: 'std::complex', c_double_complex: 'std::complex', **CXX11CodePrinter.type_mappings} + + def _print_ImaginaryUnit(self, expr): + return f'1i{self.prec_literal(expr).lower()}' diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 2908c8f12c..127995d3c0 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -769,15 +769,17 @@ class SizeOf(DefFunction): __rargs__ = ('intype', 'stars') def __new__(cls, intype, stars=None, **kwargs): - newobj = super().__new__(cls, 'sizeof', arguments=[str(intype)], **kwargs) + stars = stars or '' + argument = Keyword(f'{intype}{stars}') + newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs) newobj.intype = intype - newobj.stars = stars or '' + newobj.stars = stars return newobj @property def arguments(self): - return (self.intype,) + return self.args @property def args(self): diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 11b95a16d3..e7f497a9d5 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -311,6 +311,9 @@ def sympy_dtype(expr, base=None): # Promote if we missed complex number, i.e f + I is_im = np.issubdtype(dtype, np.complexfloating) if expr.has(ImaginaryUnit) and not is_im: - dtype = np.promote_types(dtype, np.complex64).type + if dtype is None: + dtype = np.complex64 + else: + dtype = np.promote_types(dtype, np.complex64).type return dtype diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 91711c4f27..f5e4b1a6e4 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -24,6 +24,10 @@ __all__ = ['ccode'] +_prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} +_func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'} + + class _DevitoPrinterBase(C99CodePrinter): """ @@ -37,6 +41,8 @@ class _DevitoPrinterBase(C99CodePrinter): _default_settings = {'compiler': None, 'dtype': np.float32, **C99CodePrinter._default_settings} + _func_prefix = {np.float32: 'f', np.float64: 'f'} + @property def dtype(self): try: @@ -55,25 +61,27 @@ def doprint(self, expr, assign_to=None): """ return self._print(expr) - def single_prec(self, expr=None, with_f=False): - no_f = self.compiler._cpp and not with_f - if no_f and expr is not None: - return False - dtype = sympy_dtype(expr) if expr is not None else self.dtype - return any(issubclass(dtype, d) for d in [np.float32, np.complex64]) - - def half_prec(self, expr=None, with_f=False): - no_f = self.compiler._cpp and not with_f - if no_f and expr is not None: - return False - dtype = sympy_dtype(expr) if expr is not None else self.dtype - return issubclass(dtype, np.float16) - - def complex_prec(self, expr=None): - if self.compiler._cpp: - return False - dtype = sympy_dtype(expr) if expr is not None else self.dtype - return np.issubdtype(dtype, np.complexfloating) + def print_type(self, expr): + dtype = sympy_dtype(expr) if expr is not None else None + if dtype is None or np.issubdtype(dtype, np.integer): + real = any(isinstance(i, Float) for i in expr.atoms()) + stype = np.float64 if real else np.int64 + return np.result_type(dtype, stype).type + else: + return dtype or self.dtype + + def prec_literal(self, expr): + return _prec_litterals.get(self.print_type(expr), '') + + def func_literal(self, expr): + return _func_litterals.get(self.print_type(expr), '') + + def func_prefix(self, expr, abs=False): + prefix = self._func_prefix.get(self.print_type(expr), '') + if abs: + return prefix + else: + return '' if prefix == 'f' else prefix def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): @@ -147,10 +155,7 @@ def _print_math_func(self, expr, nest=False, known=None): if cname not in self._prec_funcs: return super()._print_math_func(expr, nest=nest, known=known) - if self.single_prec(expr) or self.half_prec(expr): - cname = '%sf' % cname - if self.complex_prec(expr): - cname = 'c%s' % cname + cname = f'{self.func_prefix(expr)}{cname}{self.func_literal(expr)}' if nest and len(expr.args) > 2: args = ', '.join([self._print(expr.args[0]), @@ -158,7 +163,7 @@ def _print_math_func(self, expr, nest=False, known=None): else: args = ', '.join([self._print(arg) for arg in expr.args]) - return f'{cname}({args})' + return f'{self._ns}{cname}({args})' def _print_Pow(self, expr): # Completely reimplement `_print_Pow` from sympy, since it doesn't @@ -166,16 +171,17 @@ def _print_Pow(self, expr): if "Pow" in self.known_functions: return self._print_Function(expr) PREC = precedence(expr) - suffix = 'f' if self.single_prec(expr) else '' + suffix = self.func_literal(expr) + base = self._print(expr.base) if equal_valued(expr.exp, -1): return self._print_Float(Float(1.0)) + '/' + \ self.parenthesize(expr.base, PREC) elif equal_valued(expr.exp, 0.5): - return f'sqrt{suffix}({self._print(expr.base)})' + return f'{self._ns}sqrt{suffix}({base})' elif expr.exp == S.One/3 and self.standard != 'C89': - return f'cbrt{suffix}({self._print(expr.base)})' + return f'{self._ns}cbrt{suffix}({base})' else: - return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})' + return f'{self._ns}pow{suffix}({base}, {self._print(expr.exp)})' def _print_Mod(self, expr): """Print a Mod as a C-like %-based operation.""" @@ -203,18 +209,8 @@ def _print_Abs(self, expr): # AOMPCC errors with abs, always use fabs if isinstance(self.compiler, AOMPCompiler): return "fabs(%s)" % self._print(expr.args[0]) - # Check if argument is an integer - if has_integer_args(*expr.args[0].args): - func = "abs" - elif self.single_prec(expr): - func = "fabsf" - elif any([isinstance(a, Real) for a in expr.args[0].args]): - # The previous condition isn't sufficient to detect case with - # Python `float`s in that case, fall back to the "default" - func = "fabsf" if self.single_prec() else "fabs" - else: - func = "fabs" - return f"{func}({self._print(expr.args[0])})" + func = f'{self.func_prefix(expr, abs=True)}abs{self.func_literal(expr)}' + return f"{self._ns}{func}({self._print(expr.args[0])})" def _print_Add(self, expr, order=None): """" @@ -264,21 +260,7 @@ def _print_Float(self, expr): if 'e' not in rv: rv = rv.rstrip('0') + "0" - if self.single_prec(): - rv = '%sF' % rv - elif self.half_prec(): - rv = '%sF16' % rv - - return rv - - def _print_ImaginaryUnit(self, expr): - if self.compiler._cpp: - if self.single_prec(with_f=True) or self.half_prec(with_f=True): - return '1if' - else: - return '1i' - else: - return '_Complex_I' + return f'{rv}{self.prec_literal(expr)}' def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) @@ -329,16 +311,6 @@ def _print_UnaryOp(self, expr): def _print_ComponentAccess(self, expr): return "%s.%s" % (self._print(expr.base), expr.sindex) - def _print_TrigonometricFunction(self, expr): - func_name = str(expr.func) - - if self.single_prec() or self.half_prec(): - func_name = '%sf' % func_name - if self.complex_prec(): - func_name = 'c%s' % func_name - - return '%s(%s)' % (func_name, self._print(*expr.args)) - def _print_DefFunction(self, expr): arguments = [self._print(i) for i in expr.arguments] if expr.template: diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 7beb0c0b97..66a7b3b28c 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -298,7 +298,8 @@ def test_extended_sympy_arithmetic(): def test_integer_abs(): i1 = Dimension(name="i1") assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)" - assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)" + # .5 is a standard python Float, i.e an np.float64 + assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1)" assert ccode( Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5)) ) == "fabs(i1 - half)"