Skip to content

Commit

Permalink
Merge pull request #29 from devitocodes/deriv_fix
Browse files Browse the repository at this point in the history
conditions: Improve robustness of derivative parsing
  • Loading branch information
EdCaunt authored Apr 22, 2024
2 parents 3991446 + b862428 commit 6be9d33
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
18 changes: 12 additions & 6 deletions schism/conditions/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,19 @@ def sub_basis(self, basis_map):
except KeyError:
# Should never end up here
raise ValueError("No basis generated for required function")
if type(deriv.deriv_order) != dv.types.utils.DimensionTuple:
d_o = (deriv.deriv_order,)

if type(deriv.deriv_order) is int:
# Derivative taken wrt single dimension.
b_derivs = [(deriv.dims[0], deriv.deriv_order)]
else:
d_o = deriv.deriv_order
# Derivs to take of the basis
b_derivs = tuple([(deriv.dims[i], d_o[i])
for i in range(len(deriv.dims))])
d_order = tuple(o for o in deriv.deriv_order if o != 0)
if len(deriv.dims) != len(d_order):
raise ValueError("Derivatives specified as"
" Derivative(f, x, x) are not compatible"
" with Schism. Use the"
" Derivative(f, (x, 2)) or f.dx2"
" conventions instead")
b_derivs = [(d, o) for d, o in zip(deriv.dims, d_order)]
reps[deriv] = basis.deriv(b_derivs)
return sp.simplify(self._mod_lhs.subs(reps))

Expand Down
12 changes: 11 additions & 1 deletion tests/test_boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,19 @@ def test_derivative_dims(self, bc, ans):
'x**2*d_f_2d(2, 0)/2 + x*y*d_f_2d(1, 1) '
+ '+ x*d_f_2d(1, 0) + y**2*d_f_2d(0, 2)/2 '
+ '+ y*d_f_2d(0, 1) + d_f_2d(0, 0)'),
(dv.Eq(dv.Derivative(f, x, y), 0), {f: basisf2D},
'd_f_2d(1, 1)'),
(dv.Eq(dv.Derivative(f, y, x), 0), {f: basisf2D},
'd_f_2d(1, 1)'),
(dv.Eq(dv.Derivative(f, (x, 1), (y, 1)), 0),
{f: basisf2D},
'd_f_2d(1, 1)'),
(dv.Eq(f.laplace, 0), {f: basisf2D},
'd_f_2d(0, 2) + d_f_2d(2, 0)'),
(dv.Eq(dv.Derivative(f, (x, 2))
+ dv.Derivative(f, (y, 2)), 0),
{f: basisf2D},
'd_f_2d(0, 2) + d_f_2d(2, 0)'),
(dv.Eq(dv.div(v), 0),
{v[0]: basisvx, v[1]: basisvy},
'x*d_vx(2, 0) + x*d_vy(1, 1) + y*d_vx(1, 1) '
Expand Down Expand Up @@ -215,7 +226,6 @@ def test_basis_substitution(self, bc, basis_map, ans):
"""
condition = SingleCondition(bc)
expr = condition.sub_basis(basis_map)
print(expr)
assert str(expr) == ans

@pytest.mark.parametrize('bc, funcs',
Expand Down

0 comments on commit 6be9d33

Please sign in to comment.