Skip to content

Commit

Permalink
Merge pull request #4 from rustodeans/main
Browse files Browse the repository at this point in the history
Use a reference of problem for set_problem()
  • Loading branch information
martinjrobins authored Mar 23, 2024
2 parents 8d45e9c + 490b815 commit 0fc656d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ mod tests {
let y = solver.solve(&problem, t).unwrap();

let mut state = OdeSolverState::new(&problem);
solver.set_problem(&mut state, problem);
solver.set_problem(&mut state, &problem);
while state.t <= t {
solver.step(&mut state).unwrap();
}
Expand Down
7 changes: 3 additions & 4 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,9 @@ where
fn problem(&self) -> Option<&OdeSolverProblem<Eqn>> {
self.ode_problem.as_ref()
}

fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: OdeSolverProblem<Eqn>) {
self.ode_problem = Some(problem);
let problem = self.ode_problem.as_ref().unwrap();

fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>) {
self.ode_problem = Some(problem.clone());
let nstates = problem.eqn.nstates();
self.order = 1usize;
self.n_equal_steps = 0;
Expand Down
10 changes: 4 additions & 6 deletions src/ode_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ pub mod diffsl;

pub trait OdeSolverMethod<Eqn: OdeEquations> {
fn problem(&self) -> Option<&OdeSolverProblem<Eqn>>;
fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: OdeSolverProblem<Eqn>);
fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>);
fn step(&mut self, state: &mut OdeSolverState<Eqn::M>) -> Result<()>;
fn interpolate(&self, state: &OdeSolverState<Eqn::M>, t: Eqn::T) -> Eqn::V;
fn solve(&mut self, problem: &OdeSolverProblem<Eqn>, t: Eqn::T) -> Result<Eqn::V> {
let problem = problem.clone();
let mut state = OdeSolverState::new(&problem);
self.set_problem(&mut state, problem);
self.set_problem(&mut state, &problem);
while state.t <= t {
self.step(&mut state)?;
}
Expand All @@ -44,9 +43,8 @@ pub trait OdeSolverMethod<Eqn: OdeEquations> {
t: Eqn::T,
root_solver: &mut RS,
) -> Result<Eqn::V> {
let problem = problem.clone();
let mut state = OdeSolverState::new_consistent(&problem, root_solver)?;
self.set_problem(&mut state, problem);
self.set_problem(&mut state, &problem);
while state.t <= t {
self.step(&mut state)?;
}
Expand Down Expand Up @@ -254,7 +252,7 @@ mod tests {
Eqn: OdeEquations<M = M, T = M::T, V = M::V>,
{
let mut state = OdeSolverState::new_consistent(&problem, &mut root_solver).unwrap();
method.set_problem(&mut state, problem);
method.set_problem(&mut state, &problem);
for point in solution.solution_points.iter() {
while state.t < point.t {
method.step(&mut state).unwrap();
Expand Down

0 comments on commit 0fc656d

Please sign in to comment.