Skip to content

Commit

Permalink
types: Fix dtype of FFP and edit sympy_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Jan 7, 2025
1 parent 71e7eda commit edea3b7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
4 changes: 4 additions & 0 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def __str__(self):
def field(self):
return self.call

@property
def dtype(self):
return self.field.dtype

__repr__ = __str__


Expand Down
9 changes: 4 additions & 5 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ def sympy_dtype(expr, base=None):
Infer the dtype of the expression.
"""
dtypes = {base} - {None}
for i in expr.free_symbols:
try:
dtypes.add(i.dtype)
except AttributeError:
pass
for i in expr.args:
dtype = getattr(i, 'dtype', None)
if dtype:
dtypes.add(dtype)
return infer_dtype(dtypes)
17 changes: 14 additions & 3 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
CallFromPointer, Cast, DefFunction, FieldFromPointer,
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
ReservedWord, ListInitializer, ccode, uxreplace,
retrieve_derivatives)
retrieve_derivatives, sympy_dtype)
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
Symbol as dSymbol)
Symbol as dSymbol, CompositeObject)
from devito.types.basic import AbstractSymbol


Expand Down Expand Up @@ -248,6 +248,17 @@ def test_field_from_pointer():
# Free symbols
assert ffp1.free_symbols == {s}

# Test dtype
f = dSymbol('f')
pfields = [(f._C_name, f._C_ctype)]
struct = CompositeObject('s1', 'myStruct', pfields)
ffp4 = FieldFromPointer(f, struct)
assert str(ffp4) == 's1->f'
assert ffp4.dtype == f.dtype
expr = 1/ffp4
dtype = sympy_dtype(expr)
assert dtype == f.dtype


def test_field_from_composite():
s = Symbol('s')
Expand Down Expand Up @@ -292,7 +303,7 @@ def test_extended_sympy_arithmetic():
# noncommutative
o = Object(name='o', dtype=c_void_p)
bar = FieldFromPointer('bar', o)
assert ccode(-1 + bar) == '-1 + o->bar'
assert ccode(-1 + bar) == 'o->bar - 1'


def test_integer_abs():
Expand Down

0 comments on commit edea3b7

Please sign in to comment.