Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve builder implementation for solver settings #332

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions pywr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ use crate::tracing::setup_tracing;
use ::tracing::info;
use anyhow::{bail, Context, Result};
use clap::{Parser, Subcommand, ValueEnum};
#[cfg(feature = "cbc")]
use pywr_core::solvers::{CbcSolver, CbcSolverSettings, CbcSolverSettingsBuilder};
#[cfg(feature = "ipm-ocl")]
use pywr_core::solvers::{ClIpmF32Solver, ClIpmF64Solver, ClIpmSolverSettings};
use pywr_core::solvers::{ClpSolver, ClpSolverSettings};
use pywr_core::solvers::{ClpSolver, ClpSolverSettings, ClpSolverSettingsBuilder};
#[cfg(feature = "highs")]
use pywr_core::solvers::{HighsSolver, HighsSolverSettings};
use pywr_core::solvers::{HighsSolver, HighsSolverSettings, HighsSolverSettingsBuilder};
#[cfg(feature = "ipm-simd")]
use pywr_core::solvers::{SimdIpmF64Solver, SimdIpmSolverSettings};
use pywr_core::test_utils::make_random_model;
Expand All @@ -25,6 +27,8 @@ enum Solver {
Clp,
#[cfg(feature = "highs")]
Highs,
#[cfg(feature = "cbc")]
Cbc,
#[cfg(feature = "ipm-ocl")]
CLIPMF32,
#[cfg(feature = "ipm-ocl")]
Expand All @@ -39,6 +43,8 @@ impl Display for Solver {
Solver::Clp => write!(f, "clp"),
#[cfg(feature = "highs")]
Solver::Highs => write!(f, "highs"),
#[cfg(feature = "cbc")]
Solver::Cbc => write!(f, "cbc"),
#[cfg(feature = "ipm-ocl")]
Solver::CLIPMF32 => write!(f, "clipmf32"),
#[cfg(feature = "ipm-ocl")]
Expand Down Expand Up @@ -85,9 +91,6 @@ enum Commands {
data_path: Option<PathBuf>,
#[arg(short, long)]
output_path: Option<PathBuf>,
/// Use multiple threads for simulation.
#[arg(short, long, default_value_t = false)]
parallel: bool,
/// The number of threads to use in parallel simulation.
#[arg(short, long, default_value_t = 1)]
threads: usize,
Expand All @@ -102,9 +105,6 @@ enum Commands {
data_path: Option<PathBuf>,
#[arg(short, long)]
output_path: Option<PathBuf>,
/// Use multiple threads for simulation.
#[arg(short, long, default_value_t = false)]
parallel: bool,
/// The number of threads to use in parallel simulation.
#[arg(short, long, default_value_t = 1)]
threads: usize,
Expand Down Expand Up @@ -139,15 +139,13 @@ fn main() -> Result<()> {
solver,
data_path,
output_path,
parallel: _,
threads: _,
} => run(model, solver, data_path.as_deref(), output_path.as_deref()),
threads,
} => run(model, solver, data_path.as_deref(), output_path.as_deref(), *threads),
Commands::RunMulti {
model,
solver,
data_path,
output_path,
parallel: _,
threads: _,
} => run_multi(model, solver, data_path.as_deref(), output_path.as_deref()),
Commands::RunRandom {
Expand Down Expand Up @@ -252,17 +250,43 @@ fn handle_conversion_errors(errors: &[ComponentConversionError], stop_on_error:
Ok(())
}

fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>) {
fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>, threads: usize) {
let data = std::fs::read_to_string(path).unwrap();
let data_path = data_path.or_else(|| path.parent());
let schema_v2: PywrModel = serde_json::from_str(data.as_str()).unwrap();

let model = schema_v2.build_model(data_path, output_path).unwrap();

match *solver {
Solver::Clp => model.run::<ClpSolver>(&ClpSolverSettings::default()),
Solver::Clp => {
let mut settings_builder = ClpSolverSettingsBuilder::default();
if threads > 1 {
settings_builder = settings_builder.parallel();
settings_builder = settings_builder.threads(threads);
}
let settings = settings_builder.build();
model.run::<ClpSolver>(&settings)
}
#[cfg(feature = "cbc")]
Solver::Cbc => {
let mut settings_builder = CbcSolverSettingsBuilder::default();
if threads > 1 {
settings_builder = settings_builder.parallel();
settings_builder = settings_builder.threads(threads);
}
let settings = settings_builder.build();
model.run::<CbcSolver>(&settings)
}
#[cfg(feature = "highs")]
Solver::Highs => model.run::<HighsSolver>(&HighsSolverSettings::default()),
Solver::Highs => {
let mut settings_builder = HighsSolverSettingsBuilder::default();
if threads > 1 {
settings_builder = settings_builder.parallel();
settings_builder = settings_builder.threads(threads);
}
let settings = settings_builder.build();
model.run::<HighsSolver>(&settings)
}
#[cfg(feature = "ipm-ocl")]
Solver::CLIPMF32 => model.run_multi_scenario::<ClIpmF32Solver>(&ClIpmSolverSettings::default()),
#[cfg(feature = "ipm-ocl")]
Expand All @@ -285,6 +309,8 @@ fn run_multi(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path
Solver::Clp => model.run::<ClpSolver>(&ClpSolverSettings::default()),
#[cfg(feature = "highs")]
Solver::Highs => model.run::<HighsSolver>(&HighsSolverSettings::default()),
#[cfg(feature = "cbc")]
Solver::Cbc => model.run::<CbcSolver>(&CbcSolverSettings::default()),
#[cfg(feature = "ipm-ocl")]
Solver::CLIPMF32 => model.run_multi_scenario::<ClIpmF32Solver>(&ClIpmSolverSettings::default()),
#[cfg(feature = "ipm-ocl")]
Expand All @@ -303,6 +329,8 @@ fn run_random(num_systems: usize, density: usize, num_scenarios: usize, solver:
Solver::Clp => model.run::<ClpSolver>(&ClpSolverSettings::default()),
#[cfg(feature = "highs")]
Solver::Highs => model.run::<HighsSolver>(&HighsSolverSettings::default()),
#[cfg(feature = "cbc")]
Solver::Cbc => model.run::<CbcSolver>(&CbcSolverSettings::default()),
#[cfg(feature = "ipm-ocl")]
Solver::CLIPMF32 => model.run_multi_scenario::<ClIpmF32Solver>(&ClIpmSolverSettings::default()),
#[cfg(feature = "ipm-ocl")]
Expand Down
8 changes: 4 additions & 4 deletions pywr-core/src/solvers/cbc/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl CbcSolverSettings {
///
/// let mut builder = CbcSolverSettingsBuilder::default();
///
/// builder.parallel();
/// builder = builder.parallel();
/// let settings = builder.build();
///
/// ```
Expand All @@ -56,18 +56,18 @@ pub struct CbcSolverSettingsBuilder {
}

impl CbcSolverSettingsBuilder {
pub fn parallel(&mut self) -> &mut Self {
pub fn parallel(mut self) -> Self {
self.parallel = true;
self
}

pub fn threads(&mut self, threads: usize) -> &mut Self {
pub fn threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}

/// Construct a [`CbcSolverSettings`] from the builder.
pub fn build(&self) -> CbcSolverSettings {
pub fn build(self) -> CbcSolverSettings {
CbcSolverSettings {
parallel: self.parallel,
threads: self.threads,
Expand Down
8 changes: 4 additions & 4 deletions pywr-core/src/solvers/clp/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl ClpSolverSettings {
///
/// let mut builder = ClpSolverSettingsBuilder::default();
///
/// builder.parallel();
/// builder = builder.parallel();
/// let settings = builder.build();
///
/// ```
Expand All @@ -56,18 +56,18 @@ pub struct ClpSolverSettingsBuilder {
}

impl ClpSolverSettingsBuilder {
pub fn parallel(&mut self) -> &mut Self {
pub fn parallel(mut self) -> Self {
self.parallel = true;
self
}

pub fn threads(&mut self, threads: usize) -> &mut Self {
pub fn threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}

/// Construct a [`ClpSolverSettings`] from the builder.
pub fn build(&self) -> ClpSolverSettings {
pub fn build(self) -> ClpSolverSettings {
ClpSolverSettings {
parallel: self.parallel,
threads: self.threads,
Expand Down
8 changes: 4 additions & 4 deletions pywr-core/src/solvers/highs/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl HighsSolverSettings {
/// let settings = HighsSolverSettingsBuilder::default().parallel().threads(4).build();
///
/// let mut builder = HighsSolverSettingsBuilder::default();
/// builder.parallel();
/// builder = builder.parallel();
/// let settings = builder.build();
///
/// ```
Expand All @@ -55,18 +55,18 @@ pub struct HighsSolverSettingsBuilder {
}

impl HighsSolverSettingsBuilder {
pub fn parallel(&mut self) -> &mut Self {
pub fn parallel(mut self) -> Self {
self.parallel = true;
self
}

pub fn threads(&mut self, threads: usize) -> &mut Self {
pub fn threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}

/// Construct a [`HighsSolverSettings`] from the builder.
pub fn build(&self) -> HighsSolverSettings {
pub fn build(self) -> HighsSolverSettings {
HighsSolverSettings {
parallel: self.parallel,
threads: self.threads,
Expand Down
16 changes: 8 additions & 8 deletions pywr-core/src/solvers/ipm_ocl/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,43 +90,43 @@ impl Default for ClIpmSolverSettingsBuilder {
}

impl ClIpmSolverSettingsBuilder {
pub fn num_chunks(&mut self, num_chunks: NonZeroUsize) -> &mut Self {
pub fn num_chunks(mut self, num_chunks: NonZeroUsize) -> Self {
self.num_chunks = num_chunks;
self
}

pub fn parallel(&mut self) -> &mut Self {
pub fn parallel(mut self) -> Self {
self.parallel = true;
self
}

pub fn threads(&mut self, threads: usize) -> &mut Self {
pub fn threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}

pub fn primal_feasibility(&mut self, tolerance: f64) -> &mut Self {
pub fn primal_feasibility(mut self, tolerance: f64) -> Self {
self.tolerances.primal_feasibility = tolerance;
self
}

pub fn dual_feasibility(&mut self, tolerance: f64) -> &mut Self {
pub fn dual_feasibility(mut self, tolerance: f64) -> Self {
self.tolerances.dual_feasibility = tolerance;
self
}

pub fn optimality(&mut self, tolerance: f64) -> &mut Self {
pub fn optimality(mut self, tolerance: f64) -> Self {
self.tolerances.optimality = tolerance;
self
}

pub fn max_iterations(&mut self, max_iterations: NonZeroUsize) -> &mut Self {
pub fn max_iterations(mut self, max_iterations: NonZeroUsize) -> Self {
self.max_iterations = max_iterations;
self
}

/// Construct a [`ClIpmSolverSettings`] from the builder.
pub fn build(&self) -> ClIpmSolverSettings {
pub fn build(self) -> ClIpmSolverSettings {
ClIpmSolverSettings {
parallel: self.parallel,
threads: self.threads,
Expand Down
20 changes: 11 additions & 9 deletions pywr-core/src/solvers/ipm_simd/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ where
/// let settings: SimdIpmSolverSettings<f64, 4> = SimdIpmSolverSettingsBuilder::default().parallel().threads(4).build();
///
/// let mut builder = SimdIpmSolverSettingsBuilder::default();
/// builder.max_iterations(NonZero::new(50).unwrap());
/// builder = builder.max_iterations(NonZero::new(50).unwrap());
/// let settings: SimdIpmSolverSettings<f64, 4> = builder.build();
///
/// builder.parallel();
/// let mut builder = SimdIpmSolverSettingsBuilder::default();
/// builder = builder.max_iterations(NonZero::new(50).unwrap());
/// builder = builder.parallel();
/// let settings: SimdIpmSolverSettings<f64, 4> = builder.build();
///
/// ```
Expand Down Expand Up @@ -112,37 +114,37 @@ where
LaneCount<N>: SupportedLaneCount,
T: SimdElement + From<f64>,
{
pub fn parallel(&mut self) -> &mut Self {
pub fn parallel(mut self) -> Self {
self.parallel = true;
self
}

pub fn threads(&mut self, threads: usize) -> &mut Self {
pub fn threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}

pub fn primal_feasibility(&mut self, tolerance: f64) -> &mut Self {
pub fn primal_feasibility(mut self, tolerance: f64) -> Self {
self.tolerances.primal_feasibility = Simd::<T, N>::splat(tolerance.into());
self
}

pub fn dual_feasibility(&mut self, tolerance: f64) -> &mut Self {
pub fn dual_feasibility(mut self, tolerance: f64) -> Self {
self.tolerances.dual_feasibility = Simd::<T, N>::splat(tolerance.into());
self
}

pub fn optimality(&mut self, tolerance: f64) -> &mut Self {
pub fn optimality(mut self, tolerance: f64) -> Self {
self.tolerances.optimality = Simd::<T, N>::splat(tolerance.into());
self
}

pub fn max_iterations(&mut self, max_iterations: NonZeroUsize) -> &mut Self {
pub fn max_iterations(mut self, max_iterations: NonZeroUsize) -> Self {
self.max_iterations = max_iterations;
self
}
/// Construct a [`SimdIpmSolverSettings`] from the builder.
pub fn build(&self) -> SimdIpmSolverSettings<T, N> {
pub fn build(self) -> SimdIpmSolverSettings<T, N> {
SimdIpmSolverSettings {
parallel: self.parallel,
threads: self.threads,
Expand Down
8 changes: 4 additions & 4 deletions pywr-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ fn build_clp_settings(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<ClpSolverS
if let Some(kwargs) = kwargs {
if let Ok(value) = kwargs.get_item("threads") {
if let Some(threads) = value {
builder.threads(threads.extract::<usize>()?);
builder = builder.threads(threads.extract::<usize>()?);
}
kwargs.del_item("threads")?;
}

if let Ok(value) = kwargs.get_item("parallel") {
if let Some(parallel) = value {
if parallel.extract::<bool>()? {
builder.parallel();
builder = builder.parallel();
}
}
kwargs.del_item("parallel")?;
Expand All @@ -219,15 +219,15 @@ fn build_highs_settings(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<HighsSol
if let Some(kwargs) = kwargs {
if let Ok(value) = kwargs.get_item("threads") {
if let Some(threads) = value {
builder.threads(threads.extract::<usize>()?);
builder = builder.threads(threads.extract::<usize>()?);
}
kwargs.del_item("threads")?;
}

if let Ok(value) = kwargs.get_item("parallel") {
if let Some(parallel) = value {
if parallel.extract::<bool>()? {
builder.parallel();
builder = builder.parallel();
}
}
kwargs.del_item("parallel")?;
Expand Down
Loading