Skip to content

Commit

Permalink
Fix hashing for interface kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Feb 6, 2025
1 parent 6592cdf commit 2507977
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
17 changes: 1 addition & 16 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import ufl
import finat.ufl
from ufl.algorithms import extract_arguments, extract_coefficients, replace
from ufl.algorithms.signature import compute_expression_signature
from ufl.domain import as_domain, extract_unique_domain

from pyop2 import op2
from pyop2.caching import memory_and_disk_cache

from finat.element_factory import create_element, as_fiat_cell
from tsfc import compile_expression_dual_evaluation
from tsfc.ufl_utils import extract_firedrake_constants
from tsfc.ufl_utils import extract_firedrake_constants, hash_expr

import gem
import finat
Expand Down Expand Up @@ -1452,20 +1451,6 @@ def __init__(self, glob):
self.ufl_domain = lambda: None


def hash_expr(expr):
"""Return a numbering-invariant hash of a UFL expression.
:arg expr: A UFL expression.
:returns: A numbering-invariant hash for the expression.
"""
domain_numbering = {d: i for i, d in enumerate(ufl.domain.extract_domains(expr))}
coefficient_numbering = {c: i for i, c in enumerate(extract_coefficients(expr))}
constant_numbering = {c: i for i, c in enumerate(extract_firedrake_constants(expr))}
return compute_expression_signature(
expr, {**domain_numbering, **coefficient_numbering, **constant_numbering}
)


class VomOntoVomWrapper(object):
"""Utility class for interpolating from one ``VertexOnlyMesh`` to it's
intput ordering ``VertexOnlyMesh``, or vice versa.
Expand Down
19 changes: 16 additions & 3 deletions firedrake/tsfc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from os import path, environ, getuid, makedirs
import tempfile
import collections
import functools

from tsfc.kernel_interface.firedrake_loopy import KernelBuilder
import ufl
import finat.ufl
from ufl import conj, Form, ZeroBaseForm
from .ufl_expr import TestFunction

from tsfc import compile_form as original_tsfc_compile_form
from tsfc.parameters import PARAMETERS as tsfc_default_parameters
from tsfc.ufl_utils import extract_firedrake_constants
from tsfc.ufl_utils import extract_firedrake_constants, hash_expr

from pyop2 import op2
from pyop2.caching import memory_and_disk_cache, default_parallel_hashkey
Expand Down Expand Up @@ -57,7 +59,7 @@ def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal):
form.signature(),
prefix,
utils.tuplify(parameters),
type(interface).__name__,
_make_interface_key(interface),
diagonal,
)

Expand Down Expand Up @@ -142,7 +144,7 @@ def _compile_form_hashkey(form, name, parameters=None, split=True, interface=Non
name,
utils.tuplify(parameters),
split,
type(interface).__name__,
_make_interface_key(interface),
diagonal,
)

Expand Down Expand Up @@ -311,3 +313,14 @@ def extract_numbered_coefficients(expr, numbers):
else:
coefficients.append(coeff)
return coefficients


def _make_interface_key(interface):
if interface:
# Passing interface here is a small hack done in patch.py. What
# really matters for caching is what is used in the 'dont_split' kwarg.
assert isinstance(interface, functools.partial)
assert interface.func is KernelBuilder
return tuple(map(hash_expr, interface.keywords["dont_split"]))
else:
return None
24 changes: 23 additions & 1 deletion tsfc/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from ufl import as_tensor, indices, replace
from ufl.algorithms import compute_form_data as ufl_compute_form_data
from ufl.algorithms import estimate_total_polynomial_degree
from ufl.algorithms.analysis import extract_arguments, extract_type
from ufl.algorithms.analysis import extract_arguments, extract_coefficients, extract_type
from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
from ufl.algorithms.apply_derivatives import apply_derivatives
from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering
from ufl.algorithms.apply_restrictions import apply_restrictions
from ufl.algorithms.comparison_checker import do_comparison_check
from ufl.algorithms.remove_complex_nodes import remove_complex_nodes
from ufl.algorithms.signature import compute_expression_signature
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction
from ufl.geometry import QuadratureWeight
Expand Down Expand Up @@ -495,3 +496,24 @@ class TSFCConstantMixin:

def __init__(self):
pass


def hash_expr(expr: ufl.core.expr.Expr) -> str:
"""Return a numbering-invariant hash of a UFL expression.
Parameters
----------
expr :
A UFL expression.
Returns
-------
str :
A numbering-invariant hash for the expression.
"""
domain_numbering = {d: i for i, d in enumerate(ufl.domain.extract_domains(expr))}
coefficient_numbering = {c: i for i, c in enumerate(extract_coefficients(expr))}
constant_numbering = {c: i for i, c in enumerate(extract_firedrake_constants(expr))}
return compute_expression_signature(
expr, {**domain_numbering, **coefficient_numbering, **constant_numbering}
)

0 comments on commit 2507977

Please sign in to comment.