Skip to content

Commit

Permalink
IndexSum: Factor out common factors if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 26, 2025
1 parent 36e0e5d commit 4713c06
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/test_apply_function_pullbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def test_apply_single_function_pullbacks_triangle3d():
vc: as_vector(Jinv[j, i] * rvc[j], i),
t: rt,
s: as_tensor([[rs[0], rs[1], rs[2]], [rs[1], rs[3], rs[4]], [rs[2], rs[4], rs[5]]]),
cov2t: as_tensor(Jinv[k, i] * rcov2t[k, l] * Jinv[l, j], (i, j)),
contra2t: as_tensor((1.0 / detJ) ** 2 * J[i, k] * rcontra2t[k, l] * J[j, l], (i, j)),
cov2t: as_tensor(Jinv[k, i] * (rcov2t[k, l] * Jinv[l, j]), (i, j)),
contra2t: (1.0 / detJ) ** 2 * as_tensor(J[i, k] * (rcontra2t[k, l] * J[j, l]), (i, j)),
# Mixed elements become a bit more complicated
uml2: as_vector([ruml2[0] / detJ, ruml2[1] / detJ]),
um: rum,
Expand Down
20 changes: 17 additions & 3 deletions ufl/indexsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algebra import Product
from ufl.constantvalue import Zero
from ufl.core.expr import Expr, ufl_err_str
from ufl.core.multiindex import MultiIndex
Expand All @@ -21,7 +22,7 @@
class IndexSum(Operator):
"""Index sum."""

__slots__ = ("_dimension", "ufl_free_indices", "ufl_index_dimensions")
__slots__ = ("_dimension", "_initialised", "ufl_free_indices", "ufl_index_dimensions")

def __new__(cls, summand, index):
"""Create a new IndexSum."""
Expand All @@ -33,21 +34,33 @@ def __new__(cls, summand, index):
if len(index) != 1:
raise ValueError(f"Expecting a single Index but got {len(index)}.")

(j,) = index
# Simplification to zero
if isinstance(summand, Zero):
sh = summand.ufl_shape
(j,) = index
fi = summand.ufl_free_indices
fid = summand.ufl_index_dimensions
pos = fi.index(j.count())
fi = fi[:pos] + fi[pos + 1 :]
fid = fid[:pos] + fid[pos + 1 :]
return Zero(sh, fi, fid)

return Operator.__new__(cls)
# Factor out common factors
if isinstance(summand, Product):
a, b = summand.ufl_operands
if j.count() not in a.ufl_free_indices:
return Product(a, IndexSum(b, index))
elif j.count() not in b.ufl_free_indices:
return Product(b, IndexSum(a, index))

self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, summand, index):
"""Initialise."""
if self._initialised:
return
(j,) = index
fi = summand.ufl_free_indices
fid = summand.ufl_index_dimensions
Expand All @@ -56,6 +69,7 @@ def __init__(self, summand, index):
self.ufl_free_indices = fi[:pos] + fi[pos + 1 :]
self.ufl_index_dimensions = fid[:pos] + fid[pos + 1 :]
Operator.__init__(self, (summand, index))
self._initialised = True

def index(self):
"""Get index."""
Expand Down
6 changes: 3 additions & 3 deletions ufl/pullback.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def apply(self, expr):
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
kmn = (*k, m, n)
return as_tensor((1.0 / detJ) ** 2 * J[i, m] * expr[kmn] * J[j, n], (*k, i, j))
return (1.0 / detJ) ** 2 * as_tensor(J[i, m] * (expr[kmn] * J[j, n]), (*k, i, j))

def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
"""Get the physical value shape when this pull back is applied to an element on a domain.
Expand Down Expand Up @@ -313,7 +313,7 @@ def apply(self, expr):
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
kmn = (*k, m, n)
return as_tensor(K[m, i] * expr[kmn] * K[n, j], (*k, i, j))
return as_tensor(K[m, i] * (expr[kmn] * K[n, j]), (*k, i, j))

def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
"""Get the physical value shape when this pull back is applied to an element on a domain.
Expand Down Expand Up @@ -358,7 +358,7 @@ def apply(self, expr):
# Apply transform "row-wise" to TensorElement(PiolaMapped, ...)
*k, i, j, m, n = indices(len(expr.ufl_shape) + 2)
kmn = (*k, m, n)
return as_tensor((1.0 / detJ) * K[m, i] * expr[kmn] * J[j, n], (*k, i, j))
return (1.0 / detJ) * as_tensor(K[m, i] * (expr[kmn] * J[j, n]), (*k, i, j))

def physical_value_shape(self, element, domain) -> typing.Tuple[int, ...]:
"""Get the physical value shape when this pull back is applied to an element.
Expand Down

0 comments on commit 4713c06

Please sign in to comment.