From edea3b73c22f0d2bae997fc109ce915da9508a82 Mon Sep 17 00:00:00 2001
From: ZoeLeibowitz <zoeleibowitz12@gmail.com>
Date: Tue, 7 Jan 2025 14:02:02 +0000
Subject: [PATCH] types: Fix dtype of FFP and edit sympy_dtype

---
 devito/symbolics/extended_sympy.py |  4 ++++
 devito/symbolics/inspection.py     |  9 ++++-----
 tests/test_symbolics.py            | 17 ++++++++++++++---
 3 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py
index 4087bbc72c..a31fc9986e 100644
--- a/devito/symbolics/extended_sympy.py
+++ b/devito/symbolics/extended_sympy.py
@@ -252,6 +252,10 @@ def __str__(self):
     def field(self):
         return self.call
 
+    @property
+    def dtype(self):
+        return self.field.dtype
+
     __repr__ = __str__
 
 
diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py
index 437d48fff0..411faee26c 100644
--- a/devito/symbolics/inspection.py
+++ b/devito/symbolics/inspection.py
@@ -295,9 +295,8 @@ def sympy_dtype(expr, base=None):
     Infer the dtype of the expression.
     """
     dtypes = {base} - {None}
-    for i in expr.free_symbols:
-        try:
-            dtypes.add(i.dtype)
-        except AttributeError:
-            pass
+    for i in expr.args:
+        dtype = getattr(i, 'dtype', None)
+        if dtype:
+            dtypes.add(dtype)
     return infer_dtype(dtypes)
diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py
index 7beb0c0b97..61fb0daef0 100644
--- a/tests/test_symbolics.py
+++ b/tests/test_symbolics.py
@@ -13,10 +13,10 @@
                               CallFromPointer, Cast, DefFunction, FieldFromPointer,
                               INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
                               ReservedWord, ListInitializer, ccode, uxreplace,
-                              retrieve_derivatives)
+                              retrieve_derivatives, sympy_dtype)
 from devito.tools import as_tuple
 from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
-                          Symbol as dSymbol)
+                          Symbol as dSymbol, CompositeObject)
 from devito.types.basic import AbstractSymbol
 
 
@@ -248,6 +248,17 @@ def test_field_from_pointer():
     # Free symbols
     assert ffp1.free_symbols == {s}
 
+    # Test dtype
+    f = dSymbol('f')
+    pfields = [(f._C_name, f._C_ctype)]
+    struct = CompositeObject('s1', 'myStruct', pfields)
+    ffp4 = FieldFromPointer(f, struct)
+    assert str(ffp4) == 's1->f'
+    assert ffp4.dtype == f.dtype
+    expr = 1/ffp4
+    dtype = sympy_dtype(expr)
+    assert dtype == f.dtype
+
 
 def test_field_from_composite():
     s = Symbol('s')
@@ -292,7 +303,7 @@ def test_extended_sympy_arithmetic():
     # noncommutative
     o = Object(name='o', dtype=c_void_p)
     bar = FieldFromPointer('bar', o)
-    assert ccode(-1 + bar) == '-1 + o->bar'
+    assert ccode(-1 + bar) == 'o->bar - 1'
 
 
 def test_integer_abs():