Skip to content

Commit

Permalink
feat: update to new pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Dec 17, 2024
1 parent 43aef15 commit e0f8b7d
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 169 deletions.
7 changes: 6 additions & 1 deletion pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ def map_common_subexpression(self, expr):
# {{{ FlattenMapper

class FlattenMapper(FlattenMapperBase, IdentityMapper):
pass
def map_int_g(self, expr):
densities, kernel_arguments, changed = rec_int_g_arguments(self, expr)
if not changed:
return expr

return replace(expr, densities=densities, kernel_arguments=kernel_arguments)


def flatten(expr):
Expand Down
136 changes: 74 additions & 62 deletions pytential/symbolic/pde/system_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,21 @@
"""

import logging
import warnings
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from dataclasses import dataclass, replace
from typing import Any, cast

import numpy as np
import pymbolic
import sumpy.symbolic as sym
from pytools import \
generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam
import sumpy.symbolic as sp
from pytools import generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam
from pytools import memoize_on_first_arg
from sumpy.kernel import (AxisSourceDerivative, AxisTargetDerivative,
DirectionalSourceDerivative, ExpressionKernel,
Kernel, KernelWrapper, TargetPointMultiplier)
from pymbolic.typing import ExpressionT, ArithmeticExpressionT
from pymbolic.typing import ArithmeticExpression

import pytential
from pytential.symbolic.mappers import IdentityMapper
from pytential.symbolic.primitives import (DEFAULT_SOURCE, IntG,
NodeCoordinateComponent,
hashable_kernel_args)
from pytential import sym
from pytential.symbolic.mappers import IdentityMapper, flatten
from pytential.utils import chop, solve_from_lu

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,8 +66,8 @@ class RewriteFailedError(RuntimeError):


def rewrite_using_base_kernel(
exprs: Sequence[ExpressionT],
base_kernel: Kernel = _NO_ARG_SENTINEL) -> list[ExpressionT]:
exprs: Sequence[ArithmeticExpression],
base_kernel: Kernel = _NO_ARG_SENTINEL) -> list[ArithmeticExpression]:
"""
Rewrites a list of expressions with :class:`~pytential.symbolic.primitives.IntG`
objects using *base_kernel*.
Expand Down Expand Up @@ -101,7 +95,7 @@ def rewrite_using_base_kernel(
raise NotImplementedError

mapper = RewriteUsingBaseKernelMapper(base_kernel)
return [mapper(expr) for expr in exprs]
return [cast(ArithmeticExpression, mapper(expr)) for expr in exprs]


class RewriteUsingBaseKernelMapper(IdentityMapper):
Expand Down Expand Up @@ -131,8 +125,9 @@ def map_int_g(self, expr):
self.base_kernel) for new_int_g in new_int_gs)


def _get_sympy_kernel_expression(expr: ArithmeticExpressionT,
kernel_arguments: Mapping[str, Any]) -> sym.Basic:
def _get_sympy_kernel_expression(
expr: ArithmeticExpression,
kernel_arguments: Mapping[str, Any]) -> sp.Basic:
"""Convert a :mod:`pymbolic` expression to :mod:`sympy` expression
after substituting kernel arguments.
Expand All @@ -149,22 +144,22 @@ def _get_sympy_kernel_expression(expr: ArithmeticExpressionT,


def _monom_to_expr(monom: Sequence[int],
variables: Sequence[sym.Basic | ArithmeticExpressionT]
) -> sym.Basic | ArithmeticExpressionT:
variables: Sequence[sp.Basic | ArithmeticExpression]
) -> sp.Basic | ArithmeticExpression:
"""Convert a monomial to an expression using given variables.
For example, ``[3, 2, 1]`` with variables ``[x, y, z]`` is converted to
``x^3 y^2 z``.
"""
prod: ArithmeticExpressionT = 1
prod: ArithmeticExpression = 1
for i, nrepeats in enumerate(monom):
for _ in range(nrepeats):
prod *= variables[i]

return prod


def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
def convert_target_transformation_to_source(int_g: sym.IntG) -> list[sym.IntG]:
r"""Convert an ``IntG`` with :class:`~sumpy.kernel.AxisTargetDerivative`
or :class:`~sumpy.kernel.TargetPointMultiplier` to a list
of ``IntG``\ s without them and only source dependent transformations.
Expand All @@ -182,8 +177,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:

knl = int_g.target_kernel
if not knl.is_translation_invariant:
warnings.warn(f"Translation variant kernel ({knl}) found.",
stacklevel=2)
from warnings import warn

warn(f"Translation variant kernel ({knl}) found.", stacklevel=2)
return [int_g]

# we use a symbol for d = (x - y)
Expand All @@ -204,7 +200,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
expr = expr.diff(ds[knl.axis])
found = True
else:
warnings.warn(
from warnings import warn

warn(
f"Unknown target kernel ({knl}) found. "
"Returning IntG expression unchanged.", stacklevel=2)
return [int_g]
Expand All @@ -213,9 +211,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
if not found:
return [int_g]

int_g = int_g.copy(target_kernel=knl)
int_g = replace(int_g, target_kernel=knl)

sources_pymbolic = [NodeCoordinateComponent(i) for i in range(knl.dim)]
sources_pymbolic = sym.nodes(knl.dim).as_vector()
expr = expr.expand()
# Now the expr is an Add and looks like
# u_{d[0], d[1]}(d, y)*d[0]*y[1] + u(d, y) * d[1]
Expand Down Expand Up @@ -255,7 +253,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
for _ in range(nrepeats):
knl = AxisSourceDerivative(axis, knl)
new_source_kernels.append(knl)
new_int_g = int_g.copy(source_kernels=new_source_kernels)
new_int_g = replace(int_g, source_kernels=tuple(new_source_kernels))

(monom, coeff,) = remaining_factors.terms()[0]
# Now from d[0]*y[1], we separate the two terms
Expand All @@ -266,89 +264,98 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
* conv(coeff)
# since d/d(d) = - d/d(y), we multiply by -1 to get source derivatives
density_multiplier *= (-1)**int(sum(nrepeats for _, nrepeats in derivatives))
new_int_gs = _multiply_int_g(new_int_g, sym.sympify(expr_multiplier),
new_int_gs = _multiply_int_g(new_int_g, sp.sympify(expr_multiplier),
density_multiplier)
result.extend(new_int_gs)
return result


def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
density_multiplier: ArithmeticExpressionT) -> list[IntG]:
def _multiply_int_g(
int_g: sym.IntG,
expr_multiplier: sp.Basic,
density_multiplier: ArithmeticExpression) -> list[sym.IntG]:
"""Multiply the expression in ``IntG`` with the *expr_multiplier*
which is a symbolic (:mod:`sympy` or :mod:`symengine`) expression and
multiply the densities with *density_multiplier* which is a :mod:`pymbolic`
expression.
"""
from pymbolic import substitute

result = []

base_kernel = int_g.target_kernel.get_base_kernel()
sym_d = sym.make_sym_vector("d", base_kernel.dim)
sym_d = sp.make_sym_vector("d", base_kernel.dim)
base_kernel_expr = _get_sympy_kernel_expression(base_kernel.expression,
int_g.kernel_arguments)
subst = {pymbolic.var(f"d{i}"): pymbolic.var("d")[i] for i in
subst = {sym.var(f"d{i}"): sym.var("d")[i] for i in
range(base_kernel.dim)}
conv = sym.SympyToPymbolicMapper()
conv = sp.SympyToPymbolicMapper()

if expr_multiplier == 1:
# if there's no expr_multiplier, only multiply the densities
return [int_g.copy(densities=tuple(density*density_multiplier
for density in int_g.densities))]
return [replace(
int_g,
densities=tuple(density*density_multiplier for density in int_g.densities))
]

for knl, density in zip(int_g.source_kernels, int_g.densities, strict=True):
if expr_multiplier == 1:
new_knl = knl.get_base_kernel()
else:
new_expr = conv(knl.postprocess_at_source(base_kernel_expr, sym_d)
* expr_multiplier)
new_expr = pymbolic.substitute(new_expr, subst)
new_knl = ExpressionKernel(knl.dim, new_expr,
new_expr = substitute(new_expr, subst)
new_knl = ExpressionKernel(knl.dim, flatten(new_expr),
knl.get_base_kernel().global_scaling_const,
knl.is_complex_valued)
result.append(int_g.copy(target_kernel=new_knl,
result.append(replace(
int_g,
target_kernel=new_knl,
densities=(density*density_multiplier,),
source_kernels=(new_knl,)))
source_kernels=(new_knl,)
))
return result


def rewrite_int_g_using_base_kernel(
int_g: IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
int_g: sym.IntG, base_kernel: ExpressionKernel) -> ArithmeticExpression:
r"""Rewrite an ``IntG`` to an expression with ``IntG``\ s having the
base kernel *base_kernel*.
"""
result: ArithmeticExpressionT = 0
result: ArithmeticExpression = 0
for knl, density in zip(int_g.source_kernels, int_g.densities, strict=True):
result += _rewrite_int_g_using_base_kernel(
int_g.copy(source_kernels=(knl,), densities=(density,)),
replace(int_g, source_kernels=(knl,), densities=(density,)),
base_kernel)

return result


def _rewrite_int_g_using_base_kernel(
int_g: IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
int_g: sym.IntG, base_kernel: ExpressionKernel) -> ArithmeticExpression:
r"""Rewrites an ``IntG`` with only one source kernel to an expression with
``IntG``\ s having the base kernel *base_kernel*.
"""
target_kernel = int_g.target_kernel.replace_base_kernel(base_kernel)
dim = target_kernel.dim

result: ArithmeticExpressionT = 0
result: ArithmeticExpression = 0

density, = int_g.densities
source_kernel, = int_g.source_kernels
deriv_relation = get_deriv_relation_kernel(source_kernel.get_base_kernel(),
base_kernel, hashable_kernel_arguments=(
hashable_kernel_args(int_g.kernel_arguments)))
sym.hashable_kernel_args(int_g.kernel_arguments)))

const = deriv_relation.const
# NOTE: we set a dofdesc here to force the evaluation of this integral
# on the source instead of the target when using automatic tagging
# see :meth:`pytential.symbolic.mappers.LocationTagger._default_dofdesc`
if int_g.source.geometry is None:
dd = int_g.source.copy(geometry=DEFAULT_SOURCE)
dd = int_g.source.copy(geometry=sym.DEFAULT_SOURCE)
else:
dd = int_g.source
const *= pytential.sym.integral(dim, dim-1, density, dofdesc=dd)
const *= sym.integral(dim, dim-1, density, dofdesc=dd)

if const != 0 and target_kernel != target_kernel.get_base_kernel():
# There might be some TargetPointMultipliers hanging around.
Expand Down Expand Up @@ -377,8 +384,13 @@ def _rewrite_int_g_using_base_kernel(
for _ in range(val):
knl = AxisSourceDerivative(d, knl)
c *= -1
result += int_g.copy(source_kernels=(knl,), target_kernel=target_kernel,
densities=(density * c,), kernel_arguments=new_kernel_args)
result += replace(
int_g,
source_kernels=(knl,),
target_kernel=target_kernel,
densities=(density * c,),
kernel_arguments=new_kernel_args)

return result


Expand All @@ -394,9 +406,9 @@ class DerivRelation:
.. autoattribute:: linear_combination
"""

const: ArithmeticExpressionT
const: ArithmeticExpression
"""A constant to add to the combination."""
linear_combination: Sequence[tuple[tuple[int, ...], ArithmeticExpressionT]]
linear_combination: Sequence[tuple[tuple[int, ...], ArithmeticExpression]]
"""A list of pairs ``(mi, coeffs)``."""


Expand Down Expand Up @@ -432,7 +444,7 @@ def get_deriv_relation(
res = []
for knl in kernels:
res.append(get_deriv_relation_kernel(knl, base_kernel,
hashable_kernel_arguments=hashable_kernel_args(kernel_arguments),
hashable_kernel_arguments=sym.hashable_kernel_args(kernel_arguments),
tol=tol, order=order))
return res

Expand Down Expand Up @@ -460,14 +472,14 @@ def get_deriv_relation_kernel(
order=order,
hashable_kernel_arguments=hashable_kernel_arguments)
dim = base_kernel.dim
sym_vec = sym.make_sym_vector("d", dim)
sympy_conv = sym.SympyToPymbolicMapper()
sym_vec = sp.make_sym_vector("d", dim)
sympy_conv = sp.SympyToPymbolicMapper()

expr = _get_sympy_kernel_expression(kernel.expression, kernel_arguments)
vec = []
for i in range(len(mis)):
vec.append(evalf(expr.xreplace(dict(zip(sym_vec, rand[:, i], strict=True)))))
vec = sym.Matrix(vec)
vec = sp.Matrix(vec)
result = []
const = 0
logger.debug("%s = ", kernel)
Expand All @@ -493,8 +505,8 @@ def get_deriv_relation_kernel(

@dataclass
class LUFactorization:
L: sym.Matrix
U: sym.Matrix
L: sp.Matrix
U: sp.Matrix
perm: Sequence[tuple[int, int]]


Expand Down Expand Up @@ -539,8 +551,8 @@ def _get_base_kernel_matrix_lu_factorization(
rand: np.ndarray = rng.integers(1, 10**15, size=(dim, len(mis))).astype(object)
for i in range(rand.shape[0]):
for j in range(rand.shape[1]):
rand[i, j] = sym.sympify(rand[i, j])/10**15
sym_vec = sym.make_sym_vector("d", dim)
rand[i, j] = sp.sympify(rand[i, j])/10**15
sym_vec = sp.make_sym_vector("d", dim)

base_expr = _get_sympy_kernel_expression(base_kernel.expression,
dict(hashable_kernel_arguments))
Expand All @@ -560,7 +572,7 @@ def _get_base_kernel_matrix_lu_factorization(
row.append(1)
mat.append(row)

sym_mat = sym.Matrix(mat)
sym_mat = sp.Matrix(mat)
failed = False
try:
L, U, perm = sym_mat.LUdecomposition()
Expand All @@ -569,7 +581,7 @@ def _get_base_kernel_matrix_lu_factorization(
# and sympy returns U with last row zero
failed = True

if not sym.USE_SYMENGINE and all(expr == 0 for expr in U[-1, :]):
if not sp.USE_SYMENGINE and all(expr == 0 for expr in U[-1, :]):
failed = True

if failed:
Expand Down
2 changes: 1 addition & 1 deletion pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ class IntG(Expression):
derivatives attached. k-th elements represents the k-th source derivative
operator above.
"""
densities: tuple[Expression, ...]
densities: tuple[ArithmeticExpression, ...]
"""A tuple of density expressions. Length of this tuple must match the length
of the *source_kernels* arguments.
"""
Expand Down
Loading

0 comments on commit e0f8b7d

Please sign in to comment.