Skip to content

Commit

Permalink
feat: primer on ode modelling using diffsol (#106)
Browse files Browse the repository at this point in the history
* bug: take_state now fully releases the problem
* bug: reit tstop and root finder if state mutated
  • Loading branch information
martinjrobins authored Nov 12, 2024
1 parent b2035bd commit 8d3b3df
Show file tree
Hide file tree
Showing 48 changed files with 1,447 additions and 391 deletions.
33 changes: 14 additions & 19 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,38 @@ repository = "https://github.com/martinjrobins/diffsol"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["nalgebra", "faer"]
default = ["nalgebra", "faer", "diffsl"]
faer = []
nalgebra = []
sundials = ["suitesparse_sys", "bindgen", "cc"]
suitesparse = ["suitesparse_sys"]
diffsl-cranelift = ["diffsl-no-llvm", "diffsl"]
diffsl = [ ]
diffsl = ["dep:diffsl"]
diffsl-llvm = []
diffsl-llvm13 = ["diffsl13-0", "diffsl-llvm", "diffsl"]
diffsl-llvm14 = ["diffsl14-0", "diffsl-llvm", "diffsl"]
diffsl-llvm15 = ["diffsl15-0", "diffsl-llvm", "diffsl"]
diffsl-llvm16 = ["diffsl16-0", "diffsl-llvm", "diffsl"]
diffsl-llvm17 = ["diffsl17-0", "diffsl-llvm", "diffsl"]
diffsl-llvm13 = ["diffsl/llvm13-0", "diffsl", "diffsl-llvm"]
diffsl-llvm14 = ["diffsl/llvm14-0", "diffsl", "diffsl-llvm"]
diffsl-llvm15 = ["diffsl/llvm15-0", "diffsl", "diffsl-llvm"]
diffsl-llvm16 = ["diffsl/llvm16-0", "diffsl", "diffsl-llvm"]
diffsl-llvm17 = ["diffsl/llvm17-0", "diffsl", "diffsl-llvm"]

[dependencies]
nalgebra = "0.33"
nalgebra = "0.33.2"
nalgebra-sparse = { version = "0.10", features = ["io"] }
num-traits = "0.2.17"
serde = { version = "1.0.196", features = ["derive"] }
diffsl-no-llvm = { package = "diffsl", version = "=0.2.0", optional = true }
diffsl13-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm13-0"], optional = true }
diffsl14-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm14-0"], optional = true }
diffsl15-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm15-0"], optional = true }
diffsl16-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm16-0"], optional = true }
diffsl17-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm17-0"], optional = true }
serde = { version = "1.0.215", features = ["derive"] }
diffsl = { package = "diffsl", version = "0.2.2", optional = true }
petgraph = "0.6.4"
faer = "0.19.4"
suitesparse_sys = { version = "0.1.3", optional = true }
thiserror = "1.0.63"
thiserror = "2.0.3"

[dev-dependencies]
insta = { version = "1.34.0", features = ["yaml"] }
insta = { version = "1.41.1", features = ["yaml"] }
criterion = { version = "0.5.1" }
skeptic = "0.13.7"

[build-dependencies]
bindgen = { version = "0.70.1", optional = true }
cc = { version = "1.0.99", optional = true }
cc = { version = "1.2.0", optional = true }

[[bench]]
name = "ode_solvers"
Expand Down
10 changes: 10 additions & 0 deletions book/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "book"
version = "0.1.0"
edition = "2021"

[dependencies]
diffsol = { path = "..", features = ["diffsl"] }
nalgebra = "0.33.2"
faer = "0.19.4"
plotly = "0.10.0"
6 changes: 3 additions & 3 deletions book/book.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ mathjax-support = true

[preprocessor.keeper]
command = "mdbook-keeper"
manifest_dir = ".."
externs = ["diffsol", "nalgebra", "faer"]
manifest_dir = "."
externs = ["diffsol", "nalgebra", "faer", "plotly"]
# is_workspace = true
build_features = ["diffsl-cranelift"]
# build_features = ["plotly"]

35 changes: 22 additions & 13 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
# Summary

- [Specifying the problem](./specifying_the_problem.md)
- [ODE equations](./ode_equations.md)
- [Mass matrix](./mass_matrix.md)
- [Root finding](./root_finding.md)
- [Forward Sensitivity](./forward_sensitivity.md)
- [Custom Problem Structs](./custom_problem_structs.md)
- [Non-linear functions](./non_linear_functions.md)
- [Constant functions](./constant_functions.md)
- [Linear functions](./linear_functions.md)
- [Putting it all together](./putting_it_all_together.md)
- [DiffSL](./diffsl.md)
- [Sparse problems](./sparse_problems.md)
- [Modelling with ODEs](./primer/modelling_with_odes.md)
- [Explicit First Order ODEs](./primer/first_order_odes.md)
- [Example: Population Dynamics](./primer/population_dynamics.md)
- [Higher Order ODEs](./primer/higher_order_odes.md)
- [Example: Spring-mass systems](./primer/spring_mass_systems.md)
- [Discrete Events](./primer/discrete_events.md)
- [Example: Compartmental models of Drug Delivery](./primer/compartmental_models_of_drug_delivery.md)
- [Example: Bouncing Ball](./primer/bouncing_ball.md)
- [DAEs via the Mass Matrix](./primer/the_mass_matrix.md)
- [Example: Electrical Circuits](./primer/electrical_circuits.md)
- [DiffSol APIs for specifying problems](./specify/specifying_the_problem.md)
- [ODE equations](./specify/ode_equations.md)
- [Mass matrix](./specify/mass_matrix.md)
- [Root finding](./specify/root_finding.md)
- [Forward Sensitivity](./specify/forward_sensitivity.md)
- [Custom Problem Structs](./specify/custom/custom_problem_structs.md)
- [Non-linear functions](./specify/custom/non_linear_functions.md)
- [Constant functions](./specify/custom/constant_functions.md)
- [Linear functions](./specify/custom/linear_functions.md)
- [Putting it all together](./specify/custom/putting_it_all_together.md)
- [DiffSL](./specify/diffsl.md)
- [Sparse problems](./specify/sparse_problems.md)
- [Choosing a solver](./choosing_a_solver.md)
- [Initialisation](./initialisation.md)
- [Solving the problem](./solving_the_problem.md)
Expand Down
58 changes: 0 additions & 58 deletions book/src/diffsl.md
Original file line number Diff line number Diff line change
@@ -1,59 +1 @@
# DiffSL

Thus far we have used Rust code to specify the problem we want to solve. This is fine if you are using DiffSol from Rust, but what if you want to use DiffSol from a higher-level language like Python or R?
For this usecase we have designed a Domain Specific Language (DSL) called DiffSL that can be used to specify the problem. DiffSL is not a general purpose language but is tightly constrained to
the specification of a system of ordinary differential equations. It features a relativly simple syntax that consists of writing a series of tensors (dense or sparse) that represent the equations of the system.
For more detail on the syntax of DiffSL see the [DiffSL book](https://martinjrobins.github.io/diffsl/). This section will focus on how to use DiffSL to specify a problem in DiffSol.


## DiffSL Context

The main struct that is used to specify a problem in DiffSL is the [`DiffSlContext`](https://docs.rs/diffsol/latest/diffsol/ode_solver/diffsl/struct.DiffSlContext.html) struct. Creating this struct
Just-In-Time (JIT) compiles your DiffSL code into a form that can be executed efficiently by DiffSol.

```rust
# fn main() {
use diffsol::{DiffSl, CraneliftModule};
type M = nalgebra::DMatrix<f64>;
type CG = CraneliftModule;

let eqn = DiffSl::<M, CG>::compile("
in = [r, k]
r { 1.0 }
k { 1.0 }
u { 0.1 }
F { r * u * (1.0 - u / k) }
out { u }
").unwrap();
# }
```

Once you have created the `DiffSlContext` struct you can use it to create a problem using the `build_diffsl` method on the [`OdeBuilder`](https://docs.rs/diffsol/latest/diffsol/ode_solver/builder/struct.OdeBuilder.html) struct.


```rust
# fn main() {
# use diffsol::{DiffSl, CraneliftModule};
use diffsol::{OdeBuilder, Bdf, OdeSolverMethod, OdeSolverState};
# type M = nalgebra::DMatrix<f64>;
# type CG = CraneliftModule;


# let eqn = DiffSl::<M, CG>::compile("
# in = [r, k]
# r { 1.0 }
# k { 1.0 }
# u { 0.1 }
# F { r * u * (1.0 - u / k) }
# out { u }
# ").unwrap();
let problem = OdeBuilder::new()
.rtol(1e-6)
.p([1.0, 10.0])
.build_from_eqn(eqn).unwrap();
let mut solver = Bdf::default();
let t = 0.4;
let state = OdeSolverState::new(&problem, &solver).unwrap();
let _soln = solver.solve(&problem, state, t).unwrap();
# }
```
Empty file added book/src/lib.rs
Empty file.
145 changes: 145 additions & 0 deletions book/src/primer/bouncing_ball.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Example: Bouncing Ball

Modelling a bouncing ball is a simple and intuitive example of a system with discrete events. The ball is dropped from a height \\(h\\) and bounces off the ground with a coefficient of restitution \\(e\\). When the ball hits the ground, its velocity is reversed and scaled by the coefficient of restitution, and the ball rises and then continues to fall until it hits the ground again. This process repeats until halted.

The second order ODE that describes the motion of the ball is given by:

\\[
\frac{d^2x}{dt^2} = -g
\\]

where \\(x\\) is the position of the ball, \\(t\\) is time, and \\(g\\) is the acceleration due to gravity. We can rewrite this as a system of two first order ODEs by introducing a new variable for the velocity of the ball:

\\[
\begin{align*}
\frac{dx}{dt} &= v \\\\
\frac{dv}{dt} &= -g
\end{align*}
\\]

where \\(v = \frac{dx}{dt}\\) is the velocity of the ball. This is a system of two first order ODEs, which can be written in vector form as:

\\[
\frac{d\mathbf{y}}{dt} = \mathbf{f}(\mathbf{y}, t)
\\]

where

\\[
\mathbf{y} = \begin{bmatrix} x \\\\ v \end{bmatrix}
\\]

and

\\[
\mathbf{f}(\mathbf{y}, t) = \begin{bmatrix} v \\\\ -g \end{bmatrix}
\\]

The initial conditions for the ball, including the height from which it is dropped and its initial velocity, are given by:

\\[
\mathbf{y}(0) = \begin{bmatrix} h \\\\ 0 \end{bmatrix}
\\]

When the ball hits the ground, we need to update the velocity of the ball according to the coefficient of restitution, which is the ratio of the velocity after the bounce to the velocity before the bounce. The velocity after the bounce \\(v'\\) is given by:

\\[
v' = -e v
\\]

where \\(e\\) is the coefficient of restitution. However, to implement this in our ODE solver, we need to detect when the ball hits the ground. We can do this by using DiffSol's event handling feature, which allows us to specify a function that is equal to zero when the event occurs, i.e. when the ball hits the ground. This function \\(g(\mathbf{y}, t)\\) is called an event or root function, and for our bouncing ball problem, it is given by:

\\[
g(\mathbf{y}, t) = x
\\]

where \\(x\\) is the position of the ball. When the ball hits the ground, the event function will be zero and DiffSol will stop the integration, and we can update the velocity of the ball accordingly.

In code, the bouncing ball problem can be solved using DiffSol as follows:

```rust
# fn main() {
# use std::fs;
use diffsol::{
DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod,
OdeSolverStopReason,
};
use plotly::{
Plot, Scatter, common::Mode, layout::Layout, layout::Axis
};
type M = nalgebra::DMatrix<f64>;
type CG = CraneliftModule;

let eqn = DiffSl::<M, CG>::compile("
g { 9.81 } h { 10.0 }
u_i {
x = h,
v = 0,
}
F_i {
v,
-g,
}
stop {
x,
}
").unwrap();

let e = 0.8;
let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap();
let mut solver = Bdf::default();
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem).unwrap();

let mut x = Vec::new();
let mut v = Vec::new();
let mut t = Vec::new();
let final_time = 10.0;

// save the initial state
x.push(solver.state().unwrap().y[0]);
v.push(solver.state().unwrap().y[1]);
t.push(0.0);

// solve until the final time is reached
solver.set_stop_time(final_time).unwrap();
loop {
match solver.step() {
Ok(OdeSolverStopReason::InternalTimestep) => (),
Ok(OdeSolverStopReason::RootFound(t)) => {
// get the state when the event occurred
let mut y = solver.interpolate(t).unwrap();

// update the velocity of the ball
y[1] *= -e;

// make sure the ball is above the ground
y[0] = y[0].max(f64::EPSILON);

// set the state to the updated state
solver.state_mut().unwrap().y.copy_from(&y);
solver.state_mut().unwrap().dy[0] = y[1];
*solver.state_mut().unwrap().t = t;
},
Ok(OdeSolverStopReason::TstopReached) => break,
Err(_) => panic!("unexpected solver error"),
}
x.push(solver.state().unwrap().y[0]);
v.push(solver.state().unwrap().y[1]);
t.push(solver.state().unwrap().t);
}
let mut plot = Plot::new();
let x = Scatter::new(t.clone(), x).mode(Mode::Lines).name("x");
let v = Scatter::new(t, v).mode(Mode::Lines).name("v");
plot.add_trace(x);
plot.add_trace(v);

let layout = Layout::new()
.x_axis(Axis::new().title("t"))
.y_axis(Axis::new());
plot.set_layout(layout);
let plot_html = plot.to_inline_html(Some("bouncing-ball"));
# fs::write("../src/primer/images/bouncing-ball.html", plot_html).expect("Unable to write file");
# }
```
{{#include images/bouncing-ball.html}}
Loading

0 comments on commit 8d3b3df

Please sign in to comment.