Skip to content

Commit

Permalink
symbolics: use std namespace for c++
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 16, 2025
1 parent 3cf2f60 commit 5b0eaff
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 71 deletions.
8 changes: 8 additions & 0 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
Expand Down Expand Up @@ -42,3 +44,9 @@ class CDevitoPrinter(_DevitoPrinterBase):
type_mappings = {**_DevitoPrinterBase.type_mappings,
c_complex: 'float _Complex',
c_double_complex: 'double _Complex'}

_func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c',
np.complex128: 'c'}

def _print_ImaginaryUnit(self, expr):
return '_Complex_I'
7 changes: 6 additions & 1 deletion devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter):

_default_settings = {**_DevitoPrinterBase._default_settings,
**CXX11CodePrinter._default_settings}
_ns = "std::"

# These cannot go through _print_xxx because they are classes not
# instances
type_mappings = {c_complex: 'std::complex<float>',
type_mappings = {**_DevitoPrinterBase.type_mappings,
c_complex: 'std::complex<float>',
c_double_complex: 'std::complex<float>',
**CXX11CodePrinter.type_mappings}

def _print_ImaginaryUnit(self, expr):
return f'1i{self.prec_literal(expr).lower()}'
8 changes: 5 additions & 3 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,17 @@ class SizeOf(DefFunction):
__rargs__ = ('intype', 'stars')

def __new__(cls, intype, stars=None, **kwargs):
newobj = super().__new__(cls, 'sizeof', arguments=[str(intype)], **kwargs)
stars = stars or ''
argument = Keyword(f'{intype}{stars}')
newobj = super().__new__(cls, 'sizeof', arguments=(argument,), **kwargs)
newobj.intype = intype
newobj.stars = stars or ''
newobj.stars = stars

return newobj

@property
def arguments(self):
return (self.intype,)
return self.args

@property
def args(self):
Expand Down
5 changes: 4 additions & 1 deletion devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def sympy_dtype(expr, base=None):
# Promote if we missed complex number, i.e f + I
is_im = np.issubdtype(dtype, np.complexfloating)
if expr.has(ImaginaryUnit) and not is_im:
dtype = np.promote_types(dtype, np.complex64).type
if dtype is None:
dtype = np.complex64
else:
dtype = np.promote_types(dtype, np.complex64).type

return dtype
102 changes: 37 additions & 65 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
__all__ = ['ccode']


_prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'}
_func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'}


class _DevitoPrinterBase(C99CodePrinter):

"""
Expand All @@ -37,6 +41,8 @@ class _DevitoPrinterBase(C99CodePrinter):
_default_settings = {'compiler': None, 'dtype': np.float32,
**C99CodePrinter._default_settings}

_func_prefix = {np.float32: 'f', np.float64: 'f'}

@property
def dtype(self):
try:
Expand All @@ -55,25 +61,27 @@ def doprint(self, expr, assign_to=None):
"""
return self._print(expr)

def single_prec(self, expr=None, with_f=False):
no_f = self.compiler._cpp and not with_f
if no_f and expr is not None:
return False
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return any(issubclass(dtype, d) for d in [np.float32, np.complex64])

def half_prec(self, expr=None, with_f=False):
no_f = self.compiler._cpp and not with_f
if no_f and expr is not None:
return False
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return issubclass(dtype, np.float16)

def complex_prec(self, expr=None):
if self.compiler._cpp:
return False
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return np.issubdtype(dtype, np.complexfloating)
def print_type(self, expr):
dtype = sympy_dtype(expr) if expr is not None else None
if dtype is None or np.issubdtype(dtype, np.integer):
real = any(isinstance(i, Float) for i in expr.atoms())
stype = np.float64 if real else np.int64
return np.result_type(dtype, stype).type
else:
return dtype or self.dtype

def prec_literal(self, expr):
return _prec_litterals.get(self.print_type(expr), '')

def func_literal(self, expr):
return _func_litterals.get(self.print_type(expr), '')

def func_prefix(self, expr, abs=False):
prefix = self._func_prefix.get(self.print_type(expr), '')
if abs:
return prefix
else:
return '' if prefix == 'f' else prefix

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
Expand Down Expand Up @@ -147,35 +155,33 @@ def _print_math_func(self, expr, nest=False, known=None):
if cname not in self._prec_funcs:
return super()._print_math_func(expr, nest=nest, known=known)

if self.single_prec(expr) or self.half_prec(expr):
cname = '%sf' % cname
if self.complex_prec(expr):
cname = 'c%s' % cname
cname = f'{self.func_prefix(expr)}{cname}{self.func_literal(expr)}'

if nest and len(expr.args) > 2:
args = ', '.join([self._print(expr.args[0]),
self._print_math_func(cls(*expr.args[1:]))])
else:
args = ', '.join([self._print(arg) for arg in expr.args])

return f'{cname}({args})'
return f'{self._ns}{cname}({args})'

def _print_Pow(self, expr):
# Completely reimplement `_print_Pow` from sympy, since it doesn't
# correctly handle precision
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
suffix = 'f' if self.single_prec(expr) else ''
suffix = self.func_literal(expr)
base = self._print(expr.base)
if equal_valued(expr.exp, -1):
return self._print_Float(Float(1.0)) + '/' + \
self.parenthesize(expr.base, PREC)
elif equal_valued(expr.exp, 0.5):
return f'sqrt{suffix}({self._print(expr.base)})'
return f'{self._ns}sqrt{suffix}({base})'
elif expr.exp == S.One/3 and self.standard != 'C89':
return f'cbrt{suffix}({self._print(expr.base)})'
return f'{self._ns}cbrt{suffix}({base})'
else:
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})'
return f'{self._ns}pow{suffix}({base}, {self._print(expr.exp)})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
Expand Down Expand Up @@ -203,18 +209,8 @@ def _print_Abs(self, expr):
# AOMPCC errors with abs, always use fabs
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
if has_integer_args(*expr.args[0].args):
func = "abs"
elif self.single_prec(expr):
func = "fabsf"
elif any([isinstance(a, Real) for a in expr.args[0].args]):
# The previous condition isn't sufficient to detect case with
# Python `float`s in that case, fall back to the "default"
func = "fabsf" if self.single_prec() else "fabs"
else:
func = "fabs"
return f"{func}({self._print(expr.args[0])})"
func = f'{self.func_prefix(expr, abs=True)}abs{self.func_literal(expr)}'
return f"{self._ns}{func}({self._print(expr.args[0])})"

def _print_Add(self, expr, order=None):
""""
Expand Down Expand Up @@ -264,21 +260,7 @@ def _print_Float(self, expr):
if 'e' not in rv:
rv = rv.rstrip('0') + "0"

if self.single_prec():
rv = '%sF' % rv
elif self.half_prec():
rv = '%sF16' % rv

return rv

def _print_ImaginaryUnit(self, expr):
if self.compiler._cpp:
if self.single_prec(with_f=True) or self.half_prec(with_f=True):
return '1if'
else:
return '1i'
else:
return '_Complex_I'
return f'{rv}{self.prec_literal(expr)}'

def _print_Differentiable(self, expr):
return "(%s)" % self._print(expr._expr)
Expand Down Expand Up @@ -329,16 +311,6 @@ def _print_UnaryOp(self, expr):
def _print_ComponentAccess(self, expr):
return "%s.%s" % (self._print(expr.base), expr.sindex)

def _print_TrigonometricFunction(self, expr):
func_name = str(expr.func)

if self.single_prec() or self.half_prec():
func_name = '%sf' % func_name
if self.complex_prec():
func_name = 'c%s' % func_name

return '%s(%s)' % (func_name, self._print(*expr.args))

def _print_DefFunction(self, expr):
arguments = [self._print(i) for i in expr.arguments]
if expr.template:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def test_extended_sympy_arithmetic():
def test_integer_abs():
i1 = Dimension(name="i1")
assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)"
assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)"
# .5 is a standard python Float, i.e an np.float64
assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1)"
assert ccode(
Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5))
) == "fabs(i1 - half)"
Expand Down

0 comments on commit 5b0eaff

Please sign in to comment.