Skip to content

Commit

Permalink
Merge pull request #26 from martinjrobins/i23-state-ownership
Browse files Browse the repository at this point in the history
I23-state-ownership
  • Loading branch information
martinjrobins authored Apr 5, 2024
2 parents cc7fe8a + 90bd4dc commit cacc3b6
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 72 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ out-of-the-box with vectors and matrices from the
[nalgebra](https://nalgebra.org) crate, or you can implement your own types by
implementing the various vector and matrix traits in diffsol.

**Note**: This library is still in the early stages of development and is not
ready for production use. The API is likely to change in the future.

## Features

Currently only one solver is implemented, the Backward Differentiation Formula
(BDF) method. This is a variable step-size implicit method that is suitable for
DiffSol has two implementations of the Backward Differentiation Formula
(BDF) method, one in pure rust, the other wrapping the [Sundials](https://github.com/LLNL/sundials) IDA solver.
This method is a variable step-size implicit method that is suitable for
stiff ODEs and semi-explicit DAEs and is similar to the BDF method in MATLAB's
`ode15s` solver or the `bdf` solver in SciPy's `solve_ivp` function.

Expand Down Expand Up @@ -115,7 +113,7 @@ solver.set_problem(&mut state, &problem);
while state.t <= t {
solver.step(&mut state).unwrap();
}
let _y = solver.interpolate(&state, t);
let _y = solver.interpolate(&state, t).unwrap();
```

Note that `step` will advance the state to the next time step as chosen by the
Expand Down
8 changes: 4 additions & 4 deletions benches/solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ mod robertson_ode {
fn bdf() {
let mut s = Bdf::default();
let (problem, _soln) = robertson_ode::<nalgebra::DMatrix<f64>>(false);
let _y = s.solve(&problem, 1.0);
let _y = s.solve(&problem, 4.0000e+10);
}

#[cfg(feature = "sundials")]
#[divan::bench]
fn sundials() {
let mut s = diffsol::SundialsIda::default();
let (problem, _soln) = robertson_ode::<diffsol::SundialsMatrix>(false);
let _y = s.solve(&problem, 1.0);
let _y = s.solve(&problem, 4.0000e+10);
}
}

Expand All @@ -52,7 +52,7 @@ mod robertson {
let mut s = Bdf::default();
let (problem, _soln) = robertson::<nalgebra::DMatrix<f64>>(false);
let mut root = NewtonNonlinearSolver::new(LU::default());
let _y = s.make_consistent_and_solve(&problem, 1.0, &mut root);
let _y = s.make_consistent_and_solve(&problem, 4.0000e+10, &mut root);
}

#[cfg(feature = "sundials")]
Expand All @@ -63,6 +63,6 @@ mod robertson {
let mut s = diffsol::SundialsIda::default();
let (problem, _soln) = robertson::<diffsol::SundialsMatrix>(false);
let mut root = NewtonNonlinearSolver::new(SundialsLinearSolver::new_dense());
let _y = s.make_consistent_and_solve(&problem, 1.0, &mut root);
let _y = s.make_consistent_and_solve(&problem, 4.0000e+10, &mut root);
}
}
20 changes: 10 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@
//!
//! let mut solver = Bdf::default();
//! let t = 0.4;
//! let mut state = OdeSolverState::new(&problem);
//! solver.set_problem(&mut state, &problem);
//! while state.t <= t {
//! solver.step(&mut state).unwrap();
//! let state = OdeSolverState::new(&problem);
//! solver.set_problem(state, &problem);
//! while solver.state().unwrap().t <= t {
//! solver.step().unwrap();
//! }
//! let y = solver.interpolate(&state, t);
//! let y = solver.interpolate(t);
//! ```
//!
//! ## DiffSL
Expand Down Expand Up @@ -203,12 +203,12 @@ mod tests {
let t = 0.4;
let y = solver.solve(&problem, t).unwrap();

let mut state = OdeSolverState::new(&problem);
solver.set_problem(&mut state, &problem);
while state.t <= t {
solver.step(&mut state).unwrap();
let state = OdeSolverState::new(&problem);
solver.set_problem(state, &problem);
while solver.state().unwrap().t <= t {
solver.step().unwrap();
}
let y2 = solver.interpolate(&state, t);
let y2 = solver.interpolate(t).unwrap();

y2.assert_eq(&y, 1e-6);
}
Expand Down
24 changes: 24 additions & 0 deletions src/nonlinear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,54 @@ impl<V> NonLinearSolveSolution<V> {
}
}

/// A solver for the nonlinear problem `F(x) = 0`.
pub trait NonLinearSolver<C: Op> {
/// Set the problem to be solved, any previous problem is discarded.
fn set_problem(&mut self, problem: SolverProblem<C>);

/// Get a reference to the current problem, if any.
fn problem(&self) -> Option<&SolverProblem<C>>;

/// Get a mutable reference to the current problem, if any.
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>>;

/// Take the current problem, if any, and return it.
/// Any internal state of the solver is reset, and `set_problem`
/// must be called before solving the problem again.
fn take_problem(&mut self) -> Option<SolverProblem<C>>;

/// Reset the solver to its initial state.
fn reset(&mut self) {
if let Some(problem) = self.take_problem() {
self.set_problem(problem);
}
}

/// Set the time for the problem, if the op `C` has a time parameter.
fn set_time(&mut self, t: C::T) -> Result<()> {
self.problem_mut()
.ok_or_else(|| anyhow!("No problem set"))?
.t = t;
Ok(())
}

// Solve the problem `F(x) = 0` and return the solution `x`.
fn solve(&mut self, state: &C::V) -> Result<C::V> {
let mut state = state.clone();
self.solve_in_place(&mut state)?;
Ok(state)
}

// Solve the problem `F(x) = 0` in place.
fn solve_in_place(&mut self, state: &mut C::V) -> Result<()>;

// Set the maximum number of iterations for the solver.
fn set_max_iter(&mut self, max_iter: usize);

// Get the maximum number of iterations for the solver.
fn max_iter(&self) -> usize;

// Get the number of iterations taken by the solver on the last call to `solve`.
fn niter(&self) -> usize;
}

Expand Down
100 changes: 79 additions & 21 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ops::AddAssign;
use std::rc::Rc;

use anyhow::Result;
use anyhow::{anyhow, Result};
use nalgebra::{DMatrix, DVector};
use num_traits::{One, Pow, Zero};
use serde::Serialize;
Expand Down Expand Up @@ -66,6 +66,7 @@ pub struct Bdf<M: DenseMatrix<T = Eqn::T, V = Eqn::V>, Eqn: OdeEquations> {
gamma: Vec<Eqn::T>,
error_const: Vec<Eqn::T>,
statistics: BdfStatistics<Eqn::T>,
state: Option<OdeSolverState<Eqn::M>>,
}

impl<T: Scalar, Eqn: OdeEquations<T = T, V = DVector<T>, M = DMatrix<T>> + 'static> Default
Expand All @@ -90,6 +91,7 @@ impl<T: Scalar, Eqn: OdeEquations<T = T, V = DVector<T>, M = DMatrix<T>> + 'stat
error_const: vec![T::from(1.0); Self::MAX_ORDER + 1],
u: DMatrix::<T>::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: BdfStatistics::default(),
state: None,
}
}
}
Expand Down Expand Up @@ -119,6 +121,7 @@ where
error_const: self.error_const.clone(),
u: DMatrix::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: self.statistics.clone(),
state: self.state.clone(),
}
}
}
Expand Down Expand Up @@ -168,14 +171,14 @@ where
r
}

fn _update_step_size(&mut self, factor: Eqn::T, state: &mut OdeSolverState<Eqn::M>) {
fn _update_step_size(&mut self, factor: Eqn::T) {
//If step size h is changed then also need to update the terms in
//the first equation of page 9 of [1]:
//
//- constant c = h / (1-kappa) gamma_k term
//- lu factorisation of (M - c * J) used in newton iteration (same equation)

state.h *= factor;
self.state.as_mut().unwrap().h *= factor;
self.n_equal_steps = 0;

// update D using equations in section 3.2 of [1]
Expand All @@ -191,9 +194,11 @@ where
}
std::mem::swap(&mut self.diff, &mut self.diff_tmp);

self.nonlinear_problem_op()
.unwrap()
.set_c(state.h, &self.alpha, self.order);
self.nonlinear_problem_op().unwrap().set_c(
self.state.as_ref().unwrap().h,
&self.alpha,
self.order,
);

// reset nonlinear's linear solver problem as lu factorisation has changed
self.nonlinear_solver.as_mut().reset();
Expand Down Expand Up @@ -222,7 +227,7 @@ where
}
}

fn _predict_forward(&mut self, state: &OdeSolverState<Eqn::M>) -> Eqn::V {
fn _predict_forward(&mut self) -> Eqn::V {
let nstates = self.diff.nrows();
// predict forward to new step (eq 2 in [1])
let y_predict = {
Expand All @@ -240,7 +245,10 @@ where
}

// update time
let t_new = state.t + state.h;
let t_new = {
let state = self.state.as_ref().unwrap();
state.t + state.h
};
self.nonlinear_solver.as_mut().set_time(t_new).unwrap();
y_predict
}
Expand All @@ -251,25 +259,43 @@ where
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
fn interpolate(&self, state: &OdeSolverState<Eqn::M>, t: Eqn::T) -> Eqn::V {
fn interpolate(&self, t: Eqn::T) -> Result<Eqn::V> {
//interpolate solution at time values t* where t-h < t* < t
//
//definition of the interpolating polynomial can be found on page 7 of [1]

// state must be set
let state = self.state.as_ref().ok_or(anyhow!("State not set"))?;

// check that t is before the current time
if t > state.t {
return Err(anyhow!("Interpolation time is after current time"));
}

let mut time_factor = Eqn::T::from(1.0);
let mut order_summation = self.diff.column(0).into_owned();
for i in 0..self.order {
let i_t = Eqn::T::from(i as f64);
time_factor *= (t - (state.t - state.h * i_t)) / (state.h * (Eqn::T::one() + i_t));
order_summation += self.diff.column(i + 1) * scale(time_factor);
}
order_summation
Ok(order_summation)
}

fn problem(&self) -> Option<&OdeSolverProblem<Eqn>> {
self.ode_problem.as_ref()
}

fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>) {
fn state(&self) -> Option<&OdeSolverState<Eqn::M>> {
self.state.as_ref()
}

fn take_state(&mut self) -> Option<OdeSolverState<<Eqn>::M>> {
self.state.take()
}

fn set_problem(&mut self, state: OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>) {
let mut state = state;
self.ode_problem = Some(problem.clone());
let nstates = problem.eqn.nstates();
self.order = 1usize;
Expand Down Expand Up @@ -342,9 +368,12 @@ where

// update statistics
self.statistics.initial_step_size = state.h;

// store state
self.state = Some(state);
}

fn step(&mut self, state: &mut OdeSolverState<Eqn::M>) -> Result<()> {
fn step(&mut self) -> Result<()> {
// we will try and use the old jacobian unless convergence of newton iteration
// fails
// tells callable to update rhs jacobian if the jacobian is requested (by nonlinear solver)
Expand All @@ -355,7 +384,10 @@ where
let mut error_norm: Eqn::T;
let mut scale_y: Eqn::V;
let mut updated_jacobian = false;
let mut y_predict = self._predict_forward(state);
if self.state.is_none() {
return Err(anyhow!("State not set"));
}
let mut y_predict = self._predict_forward();

// loop until step is accepted
let y_new = loop {
Expand Down Expand Up @@ -400,15 +432,16 @@ where
factor = Eqn::T::from(Self::MIN_FACTOR);
}
// todo, do we need to update the linear solver problem here since we converged?
self._update_step_size(factor, state);
self._update_step_size(factor);

// if step size too small, then fail
let state = self.state.as_ref().unwrap();
if state.h < Eqn::T::from(Self::MIN_TIMESTEP) {
return Err(anyhow::anyhow!("Step size too small at t = {}", state.t));
}

// new prediction
y_predict = self._predict_forward(state);
y_predict = self._predict_forward();

// update statistics
self.statistics.number_of_error_test_failures += 1;
Expand All @@ -419,10 +452,10 @@ where
if updated_jacobian {
// newton iteration did not converge, but jacobian has already been
// evaluated so reduce step size by 0.3 (as per [1]) and try again
self._update_step_size(Eqn::T::from(0.3), state);
self._update_step_size(Eqn::T::from(0.3));

// new prediction
y_predict = self._predict_forward(state);
y_predict = self._predict_forward();

// update statistics
} else {
Expand All @@ -440,14 +473,17 @@ where
};

// take the accepted step
state.t += state.h;
state.y = y_new;
{
let state = self.state.as_mut().unwrap();
state.y = y_new;
state.t += state.h;
}

// update statistics
self.statistics.number_of_linear_solver_setups =
self.nonlinear_problem_op().unwrap().number_of_jac_evals();
self.statistics.number_of_steps += 1;
self.statistics.final_step_size = state.h;
self.statistics.final_step_size = self.state.as_ref().unwrap().h;

self._update_differences(&d);

Expand Down Expand Up @@ -502,8 +538,30 @@ where
if factor > Eqn::T::from(Self::MAX_FACTOR) {
factor = Eqn::T::from(Self::MAX_FACTOR);
}
self._update_step_size(factor, state);
self._update_step_size(factor);
}
Ok(())
}
}

#[cfg(test)]
mod test {
use crate::{
ode_solver::tests::{test_interpolate, test_no_set_problem, test_take_state},
Bdf,
};

type M = nalgebra::DMatrix<f64>;
#[test]
fn bdf_no_set_problem() {
test_no_set_problem::<M, _>(Bdf::default())
}
#[test]
fn bdf_take_state() {
test_take_state::<M, _>(Bdf::default())
}
#[test]
fn bdf_test_interpolate() {
test_interpolate::<M, _>(Bdf::default())
}
}
Loading

0 comments on commit cacc3b6

Please sign in to comment.