diff --git a/optimas/gen_functions.py b/optimas/gen_functions.py index 7879a7e7..d13c209f 100644 --- a/optimas/gen_functions.py +++ b/optimas/gen_functions.py @@ -96,6 +96,7 @@ def persistent_generator(H, persis_info, gen_specs, libE_info): # Check how many simulations have returned n = len(calc_in["sim_id"]) # Feed the latest simulation results to the generator + trials = [] for i in range(n): trial_index = int(calc_in["trial_index"][i]) trial_status = calc_in["trial_status"][i] @@ -107,8 +108,11 @@ def persistent_generator(H, persis_info, gen_specs, libE_info): y = calc_in[par.name][i] ev = Evaluation(parameter=par, value=y) trial.complete_evaluation(ev) - # Register trial with unknown SEM - generator.tell_trials([trial]) + trials.append(trial) + + # Register trials with unknown SEM + generator.tell_trials(trials) + # Set the number of points to generate to that number: number_of_gen_points = min(n + n_failed_gens, max_evals - n_gens) n_failed_gens = 0 diff --git a/optimas/generators/__init__.py b/optimas/generators/__init__.py index 237fb921..4a74b424 100644 --- a/optimas/generators/__init__.py +++ b/optimas/generators/__init__.py @@ -22,7 +22,7 @@ from .grid_sampling import GridSamplingGenerator from .line_sampling import LineSamplingGenerator from .random_sampling import RandomSamplingGenerator - +from .external import ExternalGenerator __all__ = [ "AxSingleFidelityGenerator", @@ -32,4 +32,5 @@ "GridSamplingGenerator", "LineSamplingGenerator", "RandomSamplingGenerator", + "ExternalGenerator", ] diff --git a/optimas/generators/external.py b/optimas/generators/external.py new file mode 100644 index 00000000..8cb71e1c --- /dev/null +++ b/optimas/generators/external.py @@ -0,0 +1,25 @@ +"""Contains the definition of an external generator.""" + +from .base import Generator + + +class ExternalGenerator(Generator): + """Supports a generator in the CAMPA generator standard.""" + + def __init__( + self, + ext_gen, + **kwargs, + ): + super().__init__( + **kwargs, + ) + self.gen = ext_gen + + def ask(self, n_trials): + """Request the next set of points to evaluate.""" + return self.gen.ask(n_trials) + + def tell(self, trials): + """Send the results of evaluations to the generator.""" + self.gen.tell(trials)