From 829d125e2915978aab0d2fdeb147a673d9f1754b Mon Sep 17 00:00:00 2001 From: Takeru Ohta Date: Thu, 25 Jun 2020 23:16:15 +0900 Subject: [PATCH] Redesign optuna solver interface. --- kurobako_solvers/scripts/optuna_solver.py | 84 +++++----- kurobako_solvers/src/optuna.rs | 182 ++++------------------ 2 files changed, 69 insertions(+), 197 deletions(-) diff --git a/kurobako_solvers/scripts/optuna_solver.py b/kurobako_solvers/scripts/optuna_solver.py index a6c42ad..b6e6f6a 100755 --- a/kurobako_solvers/scripts/optuna_solver.py +++ b/kurobako_solvers/scripts/optuna_solver.py @@ -1,26 +1,22 @@ #! /usr/bin/env python3 import argparse +import json + from kurobako import solver from kurobako.solver.optuna import OptunaSolverFactory import optuna +import optuna.integration +import optuna.pruners +import optuna.samplers ## ## (1) Parse command-line arguments ## parser = argparse.ArgumentParser() -parser.add_argument("--sampler", choices=["tpe", "random", "skopt", "cma-es"], default="tpe") -parser.add_argument("--tpe-startup-trials", type=int, default=10) -parser.add_argument("--tpe-ei-candidates", type=int, default=24) -parser.add_argument("--tpe-prior-weight", type=float, default=1.0) -parser.add_argument("--skopt-base-estimator", choices=["GP", "RF", "ET", "GBRT"], default="GP") -parser.add_argument("--pruner", choices=["median", "asha", "nop", "hyperband"], default="median") -parser.add_argument("--median-startup-trials", type=int, default=5) -parser.add_argument("--median-warmup-steps", type=int, default=0) -parser.add_argument("--asha-min-resource", type=int, default=1) -parser.add_argument("--asha-reduction-factor", type=int, default=4) -parser.add_argument("--hyperband-min-resource", type=int, default=1) -parser.add_argument("--hyperband-reduction-factor", type=int, default=3) -parser.add_argument("--hyperband-n-brackets", type=int, default=4) +parser.add_argument("--sampler", type=str, default="TPESampler") +parser.add_argument("--sampler-kwargs", type=str, default="{}") +parser.add_argument("--pruner", type=str, default="MedianPruner") +parser.add_argument("--pruner-kwargs", type=str, default="{}") parser.add_argument("--loglevel", choices=["debug", "info", "warning", "error"]) parser.add_argument("--direction", choices=["minimize", "maximize"], default="minimize") parser.add_argument("--use-discrete-uniform", action="store_true") @@ -41,41 +37,35 @@ def create_study(seed): elif args.loglevel == "error": optuna.logging.set_verbosity(optuna.logging.ERROR) - if args.sampler == "random": - sampler = optuna.samplers.RandomSampler(seed=seed) - elif args.sampler == "tpe": - sampler = optuna.samplers.TPESampler( - n_startup_trials=args.tpe_startup_trials, - n_ei_candidates=args.tpe_ei_candidates, - prior_weight=args.tpe_prior_weight, - seed=seed, - ) - elif args.sampler == "skopt": - skopt_kwargs = {"base_estimator": args.skopt_base_estimator} - sampler = optuna.integration.SkoptSampler(skopt_kwargs=skopt_kwargs) - elif args.sampler == "cma-es": - sampler = optuna.samplers.CmaEsSampler(seed=seed) - else: - raise ValueError("Unknown sampler: {}".format(args.sampler)) + # Sampler. + sampler_cls = getattr( + optuna.samplers, args.sampler, getattr(optuna.integration, args.sampler, None) + ) + if sampler_cls is None: + raise ValueError("Unknown sampler: {}.".format(args.sampler)) + + sampler_kwargs = json.loads(args.sampler_kwargs) + try: + sampler_kwargs["seed"] = seed + sampler = sampler_cls(**sampler_kwargs) + except: + del sampler_kwargs["seed"] + sampler = sampler_cls(**sampler_kwargs) + + # Pruner. + pruner_cls = getattr( + optuna.pruners, args.pruner, getattr(optuna.integration, args.pruner, None) + ) + if pruner_cls is None: + raise ValueError("Unknown pruner: {}.".format(args.pruner)) - if args.pruner == "median": - pruner = optuna.pruners.MedianPruner( - n_startup_trials=args.median_startup_trials, n_warmup_steps=args.median_warmup_steps - ) - elif args.pruner == "asha": - pruner = optuna.pruners.SuccessiveHalvingPruner( - min_resource=args.asha_min_resource, reduction_factor=args.asha_reduction_factor - ) - elif args.pruner == "hyperband": - pruner = optuna.pruners.HyperbandPruner( - min_resource=args.hyperband_min_resource, - reduction_factor=args.hyperband_reduction_factor, - n_brackets=args.hyperband_n_brackets, - ) - elif args.pruner == "nop": - pruner = optuna.pruners.NopPruner() - else: - raise ValueError("Unknown pruner: {}".format(args.pruner)) + pruner_kwargs = json.loads(args.pruner_kwargs) + try: + pruner_kwargs["seed"] = seed + pruner = pruner_cls(**pruner_kwargs) + except: + del pruner_kwargs["seed"] + pruner = pruner_cls(**pruner_kwargs) return optuna.create_study(sampler=sampler, pruner=pruner, direction=args.direction) diff --git a/kurobako_solvers/src/optuna.rs b/kurobako_solvers/src/optuna.rs index b808a3e..e963e4b 100644 --- a/kurobako_solvers/src/optuna.rs +++ b/kurobako_solvers/src/optuna.rs @@ -36,29 +36,6 @@ mod defaults { } define!(loglevel, is_loglevel, String, "warning".to_owned()); - define!(sampler, is_sampler, String, "tpe".to_owned()); - define!(tpe_startup_trials, is_tpe_startup_trials, usize, 10); - define!(tpe_ei_candidates, is_tpe_ei_candidates, usize, 24); - define!(tpe_prior_weight, is_tpe_prior_weight, f64, 1.0); - define!( - skopt_base_estimator, - is_skopt_base_estimator, - String, - "GP".to_owned() - ); - define!(pruner, is_pruner, String, "median".to_owned()); - define!(median_startup_trials, is_median_startup_trials, usize, 5); - define!(median_warmup_steps, is_median_warmup_steps, usize, 0); - define!(asha_min_resource, is_asha_min_resource, usize, 1); - define!(asha_reduction_factor, is_asha_reduction_factor, usize, 4); - define!(hyperband_min_resource, is_hyperband_min_resource, usize, 1); - define!( - hyperband_reduction_factor, - is_hyperband_reduction_factor, - usize, - 3 - ); - define!(hyperband_n_brackets, is_hyperband_n_brackets, usize, 4); } /// Recipe of `OptunaSolver`. @@ -76,84 +53,29 @@ pub struct OptunaSolverRecipe { #[serde(default = "defaults::loglevel")] pub loglevel: String, - /// Sampler type. - #[structopt( - long, - default_value = "tpe", - possible_values = &["tpe", "random", "skopt", "cma-es"] - )] - #[serde(skip_serializing_if = "defaults::is_sampler")] - #[serde(default = "defaults::sampler")] - pub sampler: String, - - #[structopt(long, default_value = "10")] - #[serde(skip_serializing_if = "defaults::is_tpe_startup_trials")] - #[serde(default = "defaults::tpe_startup_trials")] - pub tpe_startup_trials: usize, - - #[structopt(long, default_value = "24")] - #[serde(skip_serializing_if = "defaults::is_tpe_ei_candidates")] - #[serde(default = "defaults::tpe_ei_candidates")] - pub tpe_ei_candidates: usize, - - #[structopt(long, default_value = "1.0")] - #[serde(skip_serializing_if = "defaults::is_tpe_prior_weight")] - #[serde(default = "defaults::tpe_prior_weight")] - pub tpe_prior_weight: f64, - - #[structopt( - long, - default_value = "GP", - possible_values = &["GP", "RF", "ET", "GBRT"] - )] - #[serde(skip_serializing_if = "defaults::is_skopt_base_estimator")] - #[serde(default = "defaults::skopt_base_estimator")] - pub skopt_base_estimator: String, - - /// Pruner type. - #[structopt( - long, - default_value = "median", - possible_values = &["median", "asha", "nop", "hyperband"] - )] - #[serde(skip_serializing_if = "defaults::is_pruner")] - #[serde(default = "defaults::pruner")] - pub pruner: String, - - #[structopt(long, default_value = "5")] - #[serde(skip_serializing_if = "defaults::is_median_startup_trials")] - #[serde(default = "defaults::median_startup_trials")] - pub median_startup_trials: usize, - - #[structopt(long, default_value = "0")] - #[serde(skip_serializing_if = "defaults::is_median_warmup_steps")] - #[serde(default = "defaults::median_warmup_steps")] - pub median_warmup_steps: usize, - - #[structopt(long, default_value = "1")] - #[serde(skip_serializing_if = "defaults::is_asha_min_resource")] - #[serde(default = "defaults::asha_min_resource")] - pub asha_min_resource: usize, - - #[structopt(long, default_value = "4")] - #[serde(skip_serializing_if = "defaults::is_asha_reduction_factor")] - #[serde(default = "defaults::asha_reduction_factor")] - pub asha_reduction_factor: usize, + /// Sampler class name (e.g., "TPESampler"). + #[structopt(long)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub sampler: Option, - #[structopt(long, default_value = "1")] - #[serde(skip_serializing_if = "defaults::is_hyperband_min_resource")] - #[serde(default = "defaults::hyperband_min_resource")] - pub hyperband_min_resource: usize, + /// Sampler arguments (e.g., "{\"seed\": 10}"). + #[structopt(long)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub sampler_kwargs: Option, - #[structopt(long, default_value = "3")] - #[serde(skip_serializing_if = "defaults::is_hyperband_reduction_factor")] - #[serde(default = "defaults::hyperband_reduction_factor")] - pub hyperband_reduction_factor: usize, + /// Pruner class name (e.g., "MedianPruner"). + #[structopt(long)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub pruner: Option, - #[structopt(long, default_value = "4")] - #[serde(skip_serializing_if = "defaults::is_hyperband_n_brackets")] - #[serde(default = "defaults::hyperband_n_brackets")] - pub hyperband_n_brackets: usize, + /// Pruner arguments (e.g., "{\"n_warmup_steps\": 10}"). + #[structopt(long)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub pruner_kwargs: Option, /// Sets optimization direction to "maximize". /// @@ -171,58 +93,18 @@ impl OptunaSolverRecipe { fn build_args(&self) -> Vec { let mut args = Vec::new(); add_arg(&mut args, "--loglevel", &self.loglevel); - add_arg(&mut args, "--sampler", &self.sampler); - add_arg( - &mut args, - "--tpe-startup-trials", - &self.tpe_startup_trials.to_string(), - ); - add_arg( - &mut args, - "--tpe-ei-candidates", - &self.tpe_ei_candidates.to_string(), - ); - add_arg( - &mut args, - "--tpe-prior-weight", - &self.tpe_prior_weight.to_string(), - ); - add_arg(&mut args, "--pruner", &self.pruner); - add_arg( - &mut args, - "--median-startup-trials", - &self.median_startup_trials.to_string(), - ); - add_arg( - &mut args, - "--median-warmup-steps", - &self.median_warmup_steps.to_string(), - ); - add_arg( - &mut args, - "--asha-min-resource", - &self.asha_min_resource.to_string(), - ); - add_arg( - &mut args, - "--asha-reduction-factor", - &self.asha_reduction_factor.to_string(), - ); - add_arg( - &mut args, - "--hyperband-min-resource", - &self.hyperband_min_resource.to_string(), - ); - add_arg( - &mut args, - "--hyperband-reduction-factor", - &self.hyperband_reduction_factor.to_string(), - ); - add_arg( - &mut args, - "--hyperband-n-brackets", - &self.hyperband_n_brackets.to_string(), - ); + if let Some(v) = &self.sampler { + add_arg(&mut args, "--sampler", v); + } + if let Some(v) = &self.sampler_kwargs { + add_arg(&mut args, "--sampler-kwargs", v); + } + if let Some(v) = &self.pruner { + add_arg(&mut args, "--pruner", v); + } + if let Some(v) = &self.pruner_kwargs { + add_arg(&mut args, "--pruner-kwargs", v); + } if self.maximize { args.push("--direction".to_owned()); args.push("maximize".to_owned());