From 1bea726f9b2c22c09cf2bd068579337fa30b72f6 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 1 Aug 2024 15:50:13 +0100 Subject: [PATCH 1/3] dsl: Fix missing sympy assumptions during rebuilding --- devito/types/basic.py | 16 +++++++++++++--- tests/test_caching.py | 11 +++++++++++ tests/test_pickle.py | 18 +++++++++++++++++- tests/test_symbolics.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 4 deletions(-) diff --git a/devito/types/basic.py b/devito/types/basic.py index e2859ea07e..228af6fcd3 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -332,19 +332,29 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable): is_imaginary = False is_commutative = True - __rkwargs__ = ('name', 'dtype', 'is_const') + __rkwargs__ = ('name', 'dtype', 'is_const', 'assumptions0') @classmethod def _filter_assumptions(cls, **kwargs): """Extract sympy.Symbol-specific kwargs.""" + assumptions0 = kwargs.get('assumptions0', {}) + assumptions = {} - # pop predefined assumptions + # Pop predefined assumptions for key in ('real', 'imaginary', 'commutative'): kwargs.pop(key, None) - # extract sympy.Symbol-specific kwargs + assumptions0.pop(key, None) + + # Extract sympy.Symbol-specific kwargs for i in list(kwargs): if i in _assume_rules.defined_facts: assumptions[i] = kwargs.pop(i) + + # Extract any remaining unset assumptions + for k, v in assumptions0.items(): + if k not in assumptions: + assumptions[k] = v + return assumptions, kwargs def __new__(cls, *args, **kwargs): diff --git a/tests/test_caching.py b/tests/test_caching.py index f4346706ea..8dca69fa60 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -11,6 +11,7 @@ TensorFunction, TensorTimeFunction, VectorTimeFunction) from devito.types import (DeviceID, NThreadsBase, NPThreads, Object, LocalObject, Scalar, Symbol, ThreadID) +from devito.types.basic import AbstractSymbol @pytest.fixture @@ -44,6 +45,16 @@ class TestHashing: Test hashing of symbolic objects. """ + def test_abstractsymbol(self): + """Test that different Symbols have different hash values.""" + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s') + assert s0 is not s1 + assert hash(s0) == hash(s1) + + s2 = AbstractSymbol('s', nonnegative=True) + assert hash(s0) != hash(s2) + def test_constant(self): """Test that different Constants have different hash value.""" c0 = Constant(name='c') diff --git a/tests/test_pickle.py b/tests/test_pickle.py index bf1b859a75..bb1ddb4027 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -18,7 +18,7 @@ PointerArray, Lock, PThreadArray, SharedData, Timer, DeviceID, NPThreads, ThreadID, TempFunction, Indirection, FIndexed) -from devito.types.basic import BoundSymbol +from devito.types.basic import BoundSymbol, AbstractSymbol from devito.tools import EnrichedTuple from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, CallFromPointer, DefFunction) @@ -29,6 +29,22 @@ @pytest.mark.parametrize('pickle', [pickle0, pickle1]) class TestBasic: + def test_abstractsymbol(self, pickle): + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s', nonnegative=True, integer=False) + + pkl_s0 = pickle.dumps(s0) + pkl_s1 = pickle.dumps(s1) + + new_s0 = pickle.loads(pkl_s0) + new_s1 = pickle.loads(pkl_s1) + + assert s0.assumptions0 == new_s0.assumptions0 + assert s1.assumptions0 == new_s1.assumptions0 + + assert s0 == new_s0 + assert s1 == new_s1 + def test_constant(self, pickle): c = Constant(name='c') assert c.data == 0. diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 9dd0c48584..3eafe219d7 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -17,6 +17,7 @@ from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, Symbol as dSymbol) +from devito.types.basic import AbstractSymbol def test_float_indices(): @@ -70,6 +71,33 @@ def test_floatification_issue_1627(dtype, expected): assert str(exprs[0]) == expected +def test_sympy_assumptions(): + """ + Ensure that AbstractSymbol assumptions are set correctly and + preserved during rebuild. + """ + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True) + + assert s0.is_negative is None + assert s0.is_positive is None + assert s0.is_integer is None + assert s0.is_real is True + assert s1.is_negative is False + assert s1.is_positive is True + assert s1.is_integer is False + assert s1.is_real is True + + s0r = s0._rebuild() + s1r = s1._rebuild() + + assert s0.assumptions0 == s0r.assumptions0 + assert s0 == s0r + + assert s1.assumptions0 == s1r.assumptions0 + assert s1 == s1r + + def test_constant(): c = Constant(name='c') From 881cf6054efd6311136d0fe6a4693d55935db070 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 1 Aug 2024 19:18:00 +0100 Subject: [PATCH 2/3] dsl: Enhance rebuild machinery --- devito/tools/abc.py | 8 ++++++++ devito/types/basic.py | 11 +---------- tests/test_symbolics.py | 6 ++++++ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/devito/tools/abc.py b/devito/tools/abc.py index b943256979..162b3287d3 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -143,6 +143,14 @@ def __init__(self, a, b, c=4): kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs}) + # If this object has SymPy assumptions associated with it, which were not + # in the kwargs, then include them + try: + assumptions = self._assumptions_orig + kwargs.update({k: v for k, v in assumptions.items() if k not in kwargs}) + except AttributeError: + pass + # Should we use a custom reconstructor? try: cls = self._rcls diff --git a/devito/types/basic.py b/devito/types/basic.py index 228af6fcd3..d1dd3dcb93 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -332,29 +332,20 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable): is_imaginary = False is_commutative = True - __rkwargs__ = ('name', 'dtype', 'is_const', 'assumptions0') + __rkwargs__ = ('name', 'dtype', 'is_const') @classmethod def _filter_assumptions(cls, **kwargs): """Extract sympy.Symbol-specific kwargs.""" - assumptions0 = kwargs.get('assumptions0', {}) - assumptions = {} # Pop predefined assumptions for key in ('real', 'imaginary', 'commutative'): kwargs.pop(key, None) - assumptions0.pop(key, None) # Extract sympy.Symbol-specific kwargs for i in list(kwargs): if i in _assume_rules.defined_facts: assumptions[i] = kwargs.pop(i) - - # Extract any remaining unset assumptions - for k, v in assumptions0.items(): - if k not in assumptions: - assumptions[k] = v - return assumptions, kwargs def __new__(cls, *args, **kwargs): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 3eafe219d7..bc7bcd95b5 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -97,6 +97,12 @@ def test_sympy_assumptions(): assert s1.assumptions0 == s1r.assumptions0 assert s1 == s1r + # Check that sympy assumptions can be changed during a rebuild + s2 = s0._rebuild(nonnegative=True, integer=False, real=True) + + assert s2.assumptions0 == s1.assumptions0 + assert s2 == s1 + def test_constant(): c = Constant(name='c') From ac89150ca70881cb3a021a93422eef57b5035263 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 1 Aug 2024 20:00:16 +0100 Subject: [PATCH 3/3] tests: Split assumptions test --- tests/test_symbolics.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index bc7bcd95b5..353fdc934c 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -97,7 +97,14 @@ def test_sympy_assumptions(): assert s1.assumptions0 == s1r.assumptions0 assert s1 == s1r - # Check that sympy assumptions can be changed during a rebuild + +def test_modified_sympy_assumptions(): + """ + Check that sympy assumptions can be changed during a rebuild. + """ + s0 = AbstractSymbol('s') + s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True) + s2 = s0._rebuild(nonnegative=True, integer=False, real=True) assert s2.assumptions0 == s1.assumptions0