diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index d581accb9c..5acc0b78be 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -97,7 +97,8 @@ def callback(self, clusters, prefix): # Lifted scalar clusters cannot be guarded # as they would not be in the scope of the guarded clusters - if c.is_scalar: + # unless the guard is for an outer dimension + if c.is_scalar and not (prefix[:-1] and c.guards): guards = {} else: guards = c.guards diff --git a/tests/test_dse.py b/tests/test_dse.py index 98af49c27a..bde9c1b27e 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -1218,7 +1218,7 @@ def test_catch_best_invariant_v2(self): assert len(arrays) == 4 exprs = FindNodes(Expression).visit(op) - sqrt_exprs = exprs[2:4] + sqrt_exprs = exprs[:2] assert all(e.write in arrays for e in sqrt_exprs) assert all(e.expr.rhs.is_Pow for e in sqrt_exprs) assert all(e.write._mem_heap and not e.write._mem_external for e in sqrt_exprs)