diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index 6b0e26bc2f..9c5ecde279 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -262,13 +262,11 @@ def callback(self, clusters, prefix, xtracted=None): if made: idx = processed.index(g[0]) - for n, c in enumerate(g, -len(g)): processed[processed.index(c)] = made.pop(n) + processed[idx:idx] = made xtracted.extend(made) - while made: - processed.insert(idx, made.pop(-1)) return processed diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index a6c4f36644..d581accb9c 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -97,7 +97,10 @@ def callback(self, clusters, prefix): # Lifted scalar clusters cannot be guarded # as they would not be in the scope of the guarded clusters - guards = {} if c.guards and c.is_scalar else c.guards + if c.is_scalar: + guards = {} + else: + guards = c.guards lifted.append(c.rebuild(ispace=ispace, properties=properties, guards=guards)) diff --git a/examples/performance/00_overview.ipynb b/examples/performance/00_overview.ipynb index 08fd3aa49c..f6374b577d 100644 --- a/examples/performance/00_overview.ipynb +++ b/examples/performance/00_overview.ipynb @@ -491,8 +491,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "float r1 = 1.0F/h_y;\n", - "\n", "START(section0)\n", "#pragma omp parallel num_threads(nthreads)\n", "{\n", @@ -510,6 +508,8 @@ "}\n", "STOP(section0,timers)\n", "\n", + "float r1 = 1.0F/h_y;\n", + "\n", "for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n", "{\n", " START(section1)\n", @@ -1207,8 +1207,6 @@ " _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n", " _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n", "\n", - " float r1 = 1.0F/h_y;\n", - "\n", " START(section0)\n", " #pragma omp parallel num_threads(nthreads)\n", " {\n", @@ -1227,6 +1225,8 @@ " }\n", " STOP(section0,timers)\n", "\n", + " float r1 = 1.0F/h_y;\n", + "\n", " for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n", " {\n", " START(section1)\n", @@ -1319,8 +1319,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "float r1 = 1.0F/h_y;\n", - "\n", "START(section0)\n", "#pragma omp parallel num_threads(nthreads)\n", "{\n", @@ -1339,6 +1337,8 @@ "}\n", "STOP(section0,timers)\n", "\n", + "float r1 = 1.0F/h_y;\n", + "\n", "for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n", "{\n", " START(section1)\n", @@ -1495,8 +1495,6 @@ " _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n", " _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n", "\n", - " float r1 = 1.0F/h_y;\n", - "\n", " START(section0)\n", " #pragma omp parallel num_threads(nthreads)\n", " {\n", @@ -1515,6 +1513,8 @@ " }\n", " STOP(section0,timers)\n", "\n", + " float r1 = 1.0F/h_y;\n", + "\n", " for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n", " {\n", " START(section1)\n", @@ -1633,9 +1633,6 @@ " _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n", " _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n", "\n", - " float r1 = 1.0F/h_x;\n", - " float r2 = 1.0F/h_y;\n", - "\n", " START(section0)\n", " #pragma omp parallel num_threads(nthreads)\n", " {\n", @@ -1654,6 +1651,9 @@ " }\n", " STOP(section0,timers)\n", "\n", + " float r1 = 1.0F/h_x;\n", + " float r2 = 1.0F/h_y;\n", + "\n", " for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n", " {\n", " START(section1)\n", @@ -1730,8 +1730,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" } }, "nbformat": 4, diff --git a/tests/test_dse.py b/tests/test_dse.py index 989e22929a..5ba871ec4d 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -11,7 +11,7 @@ SparseTimeFunction, Dimension, SubDimension, ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, - first_derivative, solve, transpose, Abs, cos, + first_derivative, solve, transpose, Abs, cos, exp, sin, sqrt, floor, Ge, Lt, Derivative) from devito.exceptions import InvalidArgument, InvalidOperator from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, @@ -327,6 +327,23 @@ def test_implicit_only(self): assert_structure(op, ['t,x,y', 't'], 'txy') assert trees[1].dimensions == [time] + def test_scalar_cond(self): + grid = Grid(shape=(5, 5)) + time = grid.time_dim + u = TimeFunction(name="u", grid=grid, time_order=1) + bt = ConditionalDimension(name="bt", parent=time, condition=Ge(time, 2)) + + W = (1 - exp(-(time - 5)/5)) + eqns = [Eq(u.forward, 1), + Eq(u.forward, u.forward * (1 - W) + W * u, implicit_dims=bt)] + op = Operator(eqns) + + trees = retrieve_iteration_tree(op) + + assert len(trees) == 2 + assert_structure(op, ['t', 't,x,y', 't,x,y'], 'txyxy') + assert trees[0].dimensions == [time] + class TestAliases: @@ -2108,8 +2125,8 @@ def test_sum_of_nested_derivatives(self, expr, exp_arrays, exp_ops): # Also check against expected operation count to make sure # all redundancies have been detected correctly - for i, exp in enumerate(as_tuple(exp_ops[n])): - assert summary[('section%d' % i, None)].ops == exp + for i, expected in enumerate(as_tuple(exp_ops[n])): + assert summary[('section%d' % i, None)].ops == expected def test_derivatives_from_different_levels(self): """