Skip to content

Commit

Permalink
Merge pull request #20 from sile/redesign-optuna-solver-interface
Browse files Browse the repository at this point in the history
Redesign optuna solver interface.
  • Loading branch information
sile authored Jun 26, 2020
2 parents debca7b + 829d125 commit c73877a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 197 deletions.
84 changes: 37 additions & 47 deletions kurobako_solvers/scripts/optuna_solver.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
#! /usr/bin/env python3
import argparse
import json

from kurobako import solver
from kurobako.solver.optuna import OptunaSolverFactory
import optuna
import optuna.integration
import optuna.pruners
import optuna.samplers

##
## (1) Parse command-line arguments
##
parser = argparse.ArgumentParser()
parser.add_argument("--sampler", choices=["tpe", "random", "skopt", "cma-es"], default="tpe")
parser.add_argument("--tpe-startup-trials", type=int, default=10)
parser.add_argument("--tpe-ei-candidates", type=int, default=24)
parser.add_argument("--tpe-prior-weight", type=float, default=1.0)
parser.add_argument("--skopt-base-estimator", choices=["GP", "RF", "ET", "GBRT"], default="GP")
parser.add_argument("--pruner", choices=["median", "asha", "nop", "hyperband"], default="median")
parser.add_argument("--median-startup-trials", type=int, default=5)
parser.add_argument("--median-warmup-steps", type=int, default=0)
parser.add_argument("--asha-min-resource", type=int, default=1)
parser.add_argument("--asha-reduction-factor", type=int, default=4)
parser.add_argument("--hyperband-min-resource", type=int, default=1)
parser.add_argument("--hyperband-reduction-factor", type=int, default=3)
parser.add_argument("--hyperband-n-brackets", type=int, default=4)
parser.add_argument("--sampler", type=str, default="TPESampler")
parser.add_argument("--sampler-kwargs", type=str, default="{}")
parser.add_argument("--pruner", type=str, default="MedianPruner")
parser.add_argument("--pruner-kwargs", type=str, default="{}")
parser.add_argument("--loglevel", choices=["debug", "info", "warning", "error"])
parser.add_argument("--direction", choices=["minimize", "maximize"], default="minimize")
parser.add_argument("--use-discrete-uniform", action="store_true")
Expand All @@ -41,41 +37,35 @@ def create_study(seed):
elif args.loglevel == "error":
optuna.logging.set_verbosity(optuna.logging.ERROR)

if args.sampler == "random":
sampler = optuna.samplers.RandomSampler(seed=seed)
elif args.sampler == "tpe":
sampler = optuna.samplers.TPESampler(
n_startup_trials=args.tpe_startup_trials,
n_ei_candidates=args.tpe_ei_candidates,
prior_weight=args.tpe_prior_weight,
seed=seed,
)
elif args.sampler == "skopt":
skopt_kwargs = {"base_estimator": args.skopt_base_estimator}
sampler = optuna.integration.SkoptSampler(skopt_kwargs=skopt_kwargs)
elif args.sampler == "cma-es":
sampler = optuna.samplers.CmaEsSampler(seed=seed)
else:
raise ValueError("Unknown sampler: {}".format(args.sampler))
# Sampler.
sampler_cls = getattr(
optuna.samplers, args.sampler, getattr(optuna.integration, args.sampler, None)
)
if sampler_cls is None:
raise ValueError("Unknown sampler: {}.".format(args.sampler))

sampler_kwargs = json.loads(args.sampler_kwargs)
try:
sampler_kwargs["seed"] = seed
sampler = sampler_cls(**sampler_kwargs)
except:
del sampler_kwargs["seed"]
sampler = sampler_cls(**sampler_kwargs)

# Pruner.
pruner_cls = getattr(
optuna.pruners, args.pruner, getattr(optuna.integration, args.pruner, None)
)
if pruner_cls is None:
raise ValueError("Unknown pruner: {}.".format(args.pruner))

if args.pruner == "median":
pruner = optuna.pruners.MedianPruner(
n_startup_trials=args.median_startup_trials, n_warmup_steps=args.median_warmup_steps
)
elif args.pruner == "asha":
pruner = optuna.pruners.SuccessiveHalvingPruner(
min_resource=args.asha_min_resource, reduction_factor=args.asha_reduction_factor
)
elif args.pruner == "hyperband":
pruner = optuna.pruners.HyperbandPruner(
min_resource=args.hyperband_min_resource,
reduction_factor=args.hyperband_reduction_factor,
n_brackets=args.hyperband_n_brackets,
)
elif args.pruner == "nop":
pruner = optuna.pruners.NopPruner()
else:
raise ValueError("Unknown pruner: {}".format(args.pruner))
pruner_kwargs = json.loads(args.pruner_kwargs)
try:
pruner_kwargs["seed"] = seed
pruner = pruner_cls(**pruner_kwargs)
except:
del pruner_kwargs["seed"]
pruner = pruner_cls(**pruner_kwargs)

return optuna.create_study(sampler=sampler, pruner=pruner, direction=args.direction)

Expand Down
182 changes: 32 additions & 150 deletions kurobako_solvers/src/optuna.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,6 @@ mod defaults {
}

define!(loglevel, is_loglevel, String, "warning".to_owned());
define!(sampler, is_sampler, String, "tpe".to_owned());
define!(tpe_startup_trials, is_tpe_startup_trials, usize, 10);
define!(tpe_ei_candidates, is_tpe_ei_candidates, usize, 24);
define!(tpe_prior_weight, is_tpe_prior_weight, f64, 1.0);
define!(
skopt_base_estimator,
is_skopt_base_estimator,
String,
"GP".to_owned()
);
define!(pruner, is_pruner, String, "median".to_owned());
define!(median_startup_trials, is_median_startup_trials, usize, 5);
define!(median_warmup_steps, is_median_warmup_steps, usize, 0);
define!(asha_min_resource, is_asha_min_resource, usize, 1);
define!(asha_reduction_factor, is_asha_reduction_factor, usize, 4);
define!(hyperband_min_resource, is_hyperband_min_resource, usize, 1);
define!(
hyperband_reduction_factor,
is_hyperband_reduction_factor,
usize,
3
);
define!(hyperband_n_brackets, is_hyperband_n_brackets, usize, 4);
}

/// Recipe of `OptunaSolver`.
Expand All @@ -76,84 +53,29 @@ pub struct OptunaSolverRecipe {
#[serde(default = "defaults::loglevel")]
pub loglevel: String,

/// Sampler type.
#[structopt(
long,
default_value = "tpe",
possible_values = &["tpe", "random", "skopt", "cma-es"]
)]
#[serde(skip_serializing_if = "defaults::is_sampler")]
#[serde(default = "defaults::sampler")]
pub sampler: String,

#[structopt(long, default_value = "10")]
#[serde(skip_serializing_if = "defaults::is_tpe_startup_trials")]
#[serde(default = "defaults::tpe_startup_trials")]
pub tpe_startup_trials: usize,

#[structopt(long, default_value = "24")]
#[serde(skip_serializing_if = "defaults::is_tpe_ei_candidates")]
#[serde(default = "defaults::tpe_ei_candidates")]
pub tpe_ei_candidates: usize,

#[structopt(long, default_value = "1.0")]
#[serde(skip_serializing_if = "defaults::is_tpe_prior_weight")]
#[serde(default = "defaults::tpe_prior_weight")]
pub tpe_prior_weight: f64,

#[structopt(
long,
default_value = "GP",
possible_values = &["GP", "RF", "ET", "GBRT"]
)]
#[serde(skip_serializing_if = "defaults::is_skopt_base_estimator")]
#[serde(default = "defaults::skopt_base_estimator")]
pub skopt_base_estimator: String,

/// Pruner type.
#[structopt(
long,
default_value = "median",
possible_values = &["median", "asha", "nop", "hyperband"]
)]
#[serde(skip_serializing_if = "defaults::is_pruner")]
#[serde(default = "defaults::pruner")]
pub pruner: String,

#[structopt(long, default_value = "5")]
#[serde(skip_serializing_if = "defaults::is_median_startup_trials")]
#[serde(default = "defaults::median_startup_trials")]
pub median_startup_trials: usize,

#[structopt(long, default_value = "0")]
#[serde(skip_serializing_if = "defaults::is_median_warmup_steps")]
#[serde(default = "defaults::median_warmup_steps")]
pub median_warmup_steps: usize,

#[structopt(long, default_value = "1")]
#[serde(skip_serializing_if = "defaults::is_asha_min_resource")]
#[serde(default = "defaults::asha_min_resource")]
pub asha_min_resource: usize,

#[structopt(long, default_value = "4")]
#[serde(skip_serializing_if = "defaults::is_asha_reduction_factor")]
#[serde(default = "defaults::asha_reduction_factor")]
pub asha_reduction_factor: usize,
/// Sampler class name (e.g., "TPESampler").
#[structopt(long)]
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub sampler: Option<String>,

#[structopt(long, default_value = "1")]
#[serde(skip_serializing_if = "defaults::is_hyperband_min_resource")]
#[serde(default = "defaults::hyperband_min_resource")]
pub hyperband_min_resource: usize,
/// Sampler arguments (e.g., "{\"seed\": 10}").
#[structopt(long)]
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub sampler_kwargs: Option<String>,

#[structopt(long, default_value = "3")]
#[serde(skip_serializing_if = "defaults::is_hyperband_reduction_factor")]
#[serde(default = "defaults::hyperband_reduction_factor")]
pub hyperband_reduction_factor: usize,
/// Pruner class name (e.g., "MedianPruner").
#[structopt(long)]
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub pruner: Option<String>,

#[structopt(long, default_value = "4")]
#[serde(skip_serializing_if = "defaults::is_hyperband_n_brackets")]
#[serde(default = "defaults::hyperband_n_brackets")]
pub hyperband_n_brackets: usize,
/// Pruner arguments (e.g., "{\"n_warmup_steps\": 10}").
#[structopt(long)]
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub pruner_kwargs: Option<String>,

/// Sets optimization direction to "maximize".
///
Expand All @@ -171,58 +93,18 @@ impl OptunaSolverRecipe {
fn build_args(&self) -> Vec<String> {
let mut args = Vec::new();
add_arg(&mut args, "--loglevel", &self.loglevel);
add_arg(&mut args, "--sampler", &self.sampler);
add_arg(
&mut args,
"--tpe-startup-trials",
&self.tpe_startup_trials.to_string(),
);
add_arg(
&mut args,
"--tpe-ei-candidates",
&self.tpe_ei_candidates.to_string(),
);
add_arg(
&mut args,
"--tpe-prior-weight",
&self.tpe_prior_weight.to_string(),
);
add_arg(&mut args, "--pruner", &self.pruner);
add_arg(
&mut args,
"--median-startup-trials",
&self.median_startup_trials.to_string(),
);
add_arg(
&mut args,
"--median-warmup-steps",
&self.median_warmup_steps.to_string(),
);
add_arg(
&mut args,
"--asha-min-resource",
&self.asha_min_resource.to_string(),
);
add_arg(
&mut args,
"--asha-reduction-factor",
&self.asha_reduction_factor.to_string(),
);
add_arg(
&mut args,
"--hyperband-min-resource",
&self.hyperband_min_resource.to_string(),
);
add_arg(
&mut args,
"--hyperband-reduction-factor",
&self.hyperband_reduction_factor.to_string(),
);
add_arg(
&mut args,
"--hyperband-n-brackets",
&self.hyperband_n_brackets.to_string(),
);
if let Some(v) = &self.sampler {
add_arg(&mut args, "--sampler", v);
}
if let Some(v) = &self.sampler_kwargs {
add_arg(&mut args, "--sampler-kwargs", v);
}
if let Some(v) = &self.pruner {
add_arg(&mut args, "--pruner", v);
}
if let Some(v) = &self.pruner_kwargs {
add_arg(&mut args, "--pruner-kwargs", v);
}
if self.maximize {
args.push("--direction".to_owned());
args.push("maximize".to_owned());
Expand Down

0 comments on commit c73877a

Please sign in to comment.