Skip to content

Commit

Permalink
feat: update IntG to expr_dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 6, 2024
1 parent 37c253c commit 5a7d513
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 187 deletions.
10 changes: 6 additions & 4 deletions pytential/linalg/direct_solver_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def map_int_g(self, expr):
if name not in source_args
}

return expr.copy(target_kernel=target_kernel,
source_kernels=source_kernels,
densities=self.rec(expr.densities),
kernel_arguments=kernel_arguments)
from dataclasses import replace
return replace(expr,
target_kernel=target_kernel,
source_kernels=source_kernels,
densities=self.rec(expr.densities),
kernel_arguments=kernel_arguments)

# }}}

Expand Down
51 changes: 25 additions & 26 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
THE SOFTWARE.
"""

from dataclasses import replace
from functools import reduce

from pymbolic.mapper.stringifier import (
Expand Down Expand Up @@ -129,9 +130,7 @@ def map_int_g(self, expr):
if not changed:
return expr

return expr.copy(
densities=densities,
kernel_arguments=kernel_arguments)
return replace(expr, densities=densities, kernel_arguments=kernel_arguments)

def map_interpolation(self, expr):
operand = self.rec(expr.operand)
Expand Down Expand Up @@ -261,10 +260,7 @@ def map_int_g(self, expr):
if not changed:
return expr

return expr.copy(
densities=densities,
kernel_arguments=kernel_arguments,
)
return replace(expr, densities=densities, kernel_arguments=kernel_arguments)

def map_common_subexpression(self, expr):
child = self.rec(expr.child)
Expand Down Expand Up @@ -522,9 +518,9 @@ def map_product(self, expr):

def map_int_g(self, expr):
from sumpy.kernel import AxisTargetDerivative
return expr.copy(
target_kernel=AxisTargetDerivative(
self.ambient_axis, expr.target_kernel))

target_kernel = AxisTargetDerivative(self.ambient_axis, expr.target_kernel)
return replace(expr, target_kernel=target_kernel)


class DerivativeSourceAndNablaComponentCollector(
Expand Down Expand Up @@ -570,15 +566,15 @@ def map_int_g(self, expr):
raise ValueError(
"Unregularized evaluation does not support one-sided limits")

expr = expr.copy(
qbx_forced_limit=None,
densities=self.rec(expr.densities),
kernel_arguments={
name: self.rec(arg_expr)
for name, arg_expr in expr.kernel_arguments.items()
})

return expr
return replace(
expr,
qbx_forced_limit=None,
densities=self.rec(expr.densities),
kernel_arguments={
name: self.rec(arg_expr)
for name, arg_expr in expr.kernel_arguments.items()
}
)

# }}}

Expand Down Expand Up @@ -626,7 +622,7 @@ def map_num_reference_derivative(self, expr):

def map_int_g(self, expr):
if expr.target.discr_stage is None:
expr = expr.copy(target=expr.target.to_stage1())
expr = replace(expr, target=expr.target.to_stage1())

if expr.source.discr_stage is not None:
return expr
Expand All @@ -638,16 +634,18 @@ def map_int_g(self, expr):

from_dd = expr.source.to_stage1()
to_dd = from_dd.to_quad_stage2()
densities = [prim.interp(from_dd, to_dd, self.rec(density)) for
density in expr.densities]
densities = tuple(
prim.interp(from_dd, to_dd, self.rec(density)) for
density in expr.densities)

from_dd = from_dd.copy(discr_stage=self.from_discr_stage)
kernel_arguments = {
name: prim.interp(from_dd, to_dd,
self.rec(self.tagger(arg_expr)))
for name, arg_expr in expr.kernel_arguments.items()}

return expr.copy(
return replace(
expr,
densities=densities,
kernel_arguments=kernel_arguments,
source=to_dd)
Expand Down Expand Up @@ -678,7 +676,8 @@ def map_int_g(self, expr):

is_self = source_discr is target_discr

expr = expr.copy(
expr = replace(
expr,
densities=self.rec(expr.densities),
kernel_arguments={
name: self.rec(arg_expr)
Expand Down Expand Up @@ -707,8 +706,8 @@ def map_int_g(self, expr):

if expr.qbx_forced_limit == "avg":
return 0.5*(
expr.copy(qbx_forced_limit=+1)
+ expr.copy(qbx_forced_limit=-1))
replace(expr, qbx_forced_limit=+1)
+ replace(expr, qbx_forced_limit=-1))
else:
return expr

Expand Down
Loading

0 comments on commit 5a7d513

Please sign in to comment.