diff --git a/test/test_apply_function_pullbacks.py b/test/test_apply_function_pullbacks.py index a3c009a77..7eec82325 100755 --- a/test/test_apply_function_pullbacks.py +++ b/test/test_apply_function_pullbacks.py @@ -151,8 +151,8 @@ def test_apply_single_function_pullbacks_triangle3d(): vc: as_vector(Jinv[j, i] * rvc[j], i), t: rt, s: as_tensor([[rs[0], rs[1], rs[2]], [rs[1], rs[3], rs[4]], [rs[2], rs[4], rs[5]]]), - cov2t: as_tensor(Jinv[k, i] * rcov2t[k, l] * Jinv[l, j], (i, j)), - contra2t: as_tensor((1.0 / detJ) ** 2 * J[i, k] * rcontra2t[k, l] * J[j, l], (i, j)), + cov2t: as_tensor(Jinv[k, i] * (rcov2t[k, l] * Jinv[l, j]), (i, j)), + contra2t: (1.0 / detJ) ** 2 * as_tensor(J[i, k] * (rcontra2t[k, l] * J[j, l]), (i, j)), # Mixed elements become a bit more complicated uml2: as_vector([ruml2[0] / detJ, ruml2[1] / detJ]), um: rum, diff --git a/ufl/indexsum.py b/ufl/indexsum.py index 357cd8dea..189c7715b 100644 --- a/ufl/indexsum.py +++ b/ufl/indexsum.py @@ -6,6 +6,7 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later +from ufl.algebra import Product from ufl.constantvalue import Zero from ufl.core.expr import Expr, ufl_err_str from ufl.core.multiindex import MultiIndex @@ -21,7 +22,7 @@ class IndexSum(Operator): """Index sum.""" - __slots__ = ("_dimension", "ufl_free_indices", "ufl_index_dimensions") + __slots__ = ("_dimension", "_initialised", "ufl_free_indices", "ufl_index_dimensions") def __new__(cls, summand, index): """Create a new IndexSum.""" @@ -33,10 +34,10 @@ def __new__(cls, summand, index): if len(index) != 1: raise ValueError(f"Expecting a single Index but got {len(index)}.") + (j,) = index # Simplification to zero if isinstance(summand, Zero): sh = summand.ufl_shape - (j,) = index fi = summand.ufl_free_indices fid = summand.ufl_index_dimensions pos = fi.index(j.count()) @@ -44,10 +45,22 @@ def __new__(cls, summand, index): fid = fid[:pos] + fid[pos + 1 :] return Zero(sh, fi, fid) - return Operator.__new__(cls) + # Factor out common factors + if isinstance(summand, Product): + a, b = summand.ufl_operands + if j.count() not in a.ufl_free_indices: + return Product(a, IndexSum(b, index)) + elif j.count() not in b.ufl_free_indices: + return Product(b, IndexSum(a, index)) + + self = Operator.__new__(cls) + self._initialised = False + return self def __init__(self, summand, index): """Initialise.""" + if self._initialised: + return (j,) = index fi = summand.ufl_free_indices fid = summand.ufl_index_dimensions @@ -56,6 +69,7 @@ def __init__(self, summand, index): self.ufl_free_indices = fi[:pos] + fi[pos + 1 :] self.ufl_index_dimensions = fid[:pos] + fid[pos + 1 :] Operator.__init__(self, (summand, index)) + self._initialised = True def index(self): """Get index.""" diff --git a/ufl/pullback.py b/ufl/pullback.py index 0dc780e52..b71c010dc 100644 --- a/ufl/pullback.py +++ b/ufl/pullback.py @@ -270,7 +270,7 @@ def apply(self, expr): # Apply transform "row-wise" to TensorElement(PiolaMapped, ...) *k, i, j, m, n = indices(len(expr.ufl_shape) + 2) kmn = (*k, m, n) - return as_tensor((1.0 / detJ) ** 2 * J[i, m] * expr[kmn] * J[j, n], (*k, i, j)) + return (1.0 / detJ) ** 2 * as_tensor(J[i, m] * (expr[kmn] * J[j, n]), (*k, i, j)) def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]: """Get the physical value shape when this pull back is applied to an element on a domain. @@ -313,7 +313,7 @@ def apply(self, expr): # Apply transform "row-wise" to TensorElement(PiolaMapped, ...) *k, i, j, m, n = indices(len(expr.ufl_shape) + 2) kmn = (*k, m, n) - return as_tensor(K[m, i] * expr[kmn] * K[n, j], (*k, i, j)) + return as_tensor(K[m, i] * (expr[kmn] * K[n, j]), (*k, i, j)) def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]: """Get the physical value shape when this pull back is applied to an element on a domain. @@ -358,7 +358,7 @@ def apply(self, expr): # Apply transform "row-wise" to TensorElement(PiolaMapped, ...) *k, i, j, m, n = indices(len(expr.ufl_shape) + 2) kmn = (*k, m, n) - return as_tensor((1.0 / detJ) * K[m, i] * expr[kmn] * J[j, n], (*k, i, j)) + return (1.0 / detJ) * as_tensor(K[m, i] * (expr[kmn] * J[j, n]), (*k, i, j)) def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]: """Get the physical value shape when this pull back is applied to an element.