Skip to content

Commit

Permalink
compiler: Check sympy_type returns a floating point type
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Jan 10, 2025
1 parent 71e7eda commit 457d716
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def compiler(self):
return self._settings['compiler']

def single_prec(self, expr=None):
# Extract the dtype of the expression
dtype = sympy_dtype(expr) if expr is not None else self.dtype
# Check that the dtype is a floating point type
dtype = dtype if np.issubdtype(dtype, np.floating) else self.dtype
return dtype in [np.float32, np.float16]

def parenthesize(self, item, level, strict=False):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def test_cos_vs_cosf():

# Doesn't make much sense, but it's legal
c = dSymbol('c', dtype=np.int32)
assert ccode(cos(c)) == "cos(c)"
assert ccode(cos(c)) == "cosf(c)"


def test_intdiv():
Expand Down

0 comments on commit 457d716

Please sign in to comment.