Skip to content

Commit

Permalink
Merge branch 'more-efficient-initial' into better-rk
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed May 22, 2023
2 parents 2d34bac + 3829d4d commit 3eef349
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions diffrax/step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,42 @@ def _select_initial_step(
t0: Scalar,
y0: PyTree,
args: PyTree,
func: Callable[[Scalar, PyTree, PyTree], PyTree],
func: Callable[[PyTree[AbstractTerm], Scalar, PyTree, PyTree], PyTree],
error_order: Scalar,
rtol: Scalar,
atol: Scalar,
norm: Callable[[PyTree], Scalar],
) -> Scalar:
f0 = func(terms, t0, y0, args)
scale = (atol + ω(y0).call(jnp.abs) * rtol).ω
d0 = norm((y0**ω / scale**ω).ω)
d1 = norm((f0**ω / scale**ω).ω)

_cond = (d0 < 1e-5) | (d1 < 1e-5)
_d1 = jnp.where(_cond, 1, d1)
h0 = jnp.where(_cond, 1e-6, 0.01 * (d0 / _d1))
def fn(carry):
t, y, _h0, _d1, _f, _ = carry
f = func(terms, t, y, args)
return t, y, _h0, _d1, _f, f

def intermediate(carry):
_, _, _, _, _, f0 = carry
d0 = norm((y0**ω / scale**ω).ω)
d1 = norm((f0**ω / scale**ω).ω)
_cond = (d0 < 1e-5) | (d1 < 1e-5)
_d1 = jnp.where(_cond, 1, d1)
h0 = jnp.where(_cond, 1e-6, 0.01 * (d0 / _d1))
t1 = t0 + h0
y1 = (y0**ω + h0 * f0**ω).ω
return t1, y1, h0, d1, f0, f0

t1 = t0 + h0
y1 = (y0**ω + h0 * f0**ω).ω
f1 = func(terms, t1, y1, args)
scale = (atol + ω(y0).call(jnp.abs) * rtol).ω
dummy_h = t0
dummy_d = eqxi.eval_empty(norm, y0)
dummy_f = eqxi.eval_empty(lambda: func(terms, t0, y0, args))
_, _, h0, d1, f0, f1 = eqxi.scan_trick(
fn, [intermediate], (t0, y0, dummy_h, dummy_d, dummy_f, dummy_f)
)
d2 = norm(((f1**ω - f0**ω) / scale**ω).ω) / h0

max_d = jnp.maximum(d1, d2)
h1 = jnp.where(
max_d <= 1e-15,
jnp.maximum(1e-6, h0 * 1e-3),
(0.01 / max_d) ** (1 / error_order),
)

return jnp.minimum(100 * h0, h1)


Expand Down

0 comments on commit 3eef349

Please sign in to comment.