Skip to content

Commit

Permalink
Merge pull request #169 from optimas-org/feature/expose_ax_plots
Browse files Browse the repository at this point in the history
Expose `AxModelManager` methods in Service API generators and add new plotting methods
  • Loading branch information
AngelFP authored May 23, 2024
2 parents da8b1f8 + fd6307d commit 742fc71
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 57 deletions.
1 change: 0 additions & 1 deletion doc/source/api/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ Diagnostics
:toctree: _autosummary

ExplorationDiagnostics
AxModelManager
1 change: 1 addition & 0 deletions doc/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ This reference manual details all classes included in optimas.
evaluators
exploration
diagnostics
utils
9 changes: 9 additions & 0 deletions doc/source/api/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Utilities
=========

.. currentmodule:: optimas.utils

.. autosummary::
:toctree: _autosummary

AxModelManager
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"\n",
"The models provide some basic plotting methods for easy visualization, like\n",
":meth:`~optimas.diagnostics.AxModelManager.plot_contour`\n",
"and :meth:`~optimas.diagnostics.AxModelManager.plot_slice`."
":meth:`~optimas.utils.AxModelManager.plot_contour`\n",
"and :meth:`~optimas.utils.AxModelManager.plot_slice`."
]
},
{
Expand Down Expand Up @@ -243,7 +243,7 @@
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"\n",
"In addition to plotting, it is also possible to evaluate the model at any\n",
"point by using the :meth:`~optimas.diagnostics.AxModelManager.evaluate_model`\n",
"point by using the :meth:`~optimas.utils.AxModelManager.evaluate_model`\n",
"method.\n",
"\n",
"In the example below, this method is used to evaluate the model in all the\n",
Expand Down
3 changes: 1 addition & 2 deletions optimas/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .exploration_diagnostics import ExplorationDiagnostics
from .ax_model_manager import AxModelManager

__all__ = ["ExplorationDiagnostics", "AxModelManager"]
__all__ = ["ExplorationDiagnostics"]
2 changes: 1 addition & 1 deletion optimas/diagnostics/exploration_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from optimas.evaluators.base import Evaluator
from optimas.explorations import Exploration
from optimas.utils.other import get_df_with_selection
from optimas.diagnostics.ax_model_manager import AxModelManager
from optimas.utils.ax import AxModelManager


class ExplorationDiagnostics:
Expand Down
4 changes: 3 additions & 1 deletion optimas/explorations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def _load_history(
), "Type {} not valid for `history`".format(type(history))
# Incorporate history into exploration.
if history is not None:
self.attach_evaluations(history)
self.attach_evaluations(
history, ignore_unrecognized_parameters=True
)
# When resuming an exploration, update evaluations counter.
if resume:
self._n_evals = history.size
Expand Down
34 changes: 18 additions & 16 deletions optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
TrialStatus,
)
from optimas.generators.ax.base import AxGenerator
from optimas.utils.ax import AxModelManager
from optimas.utils.ax.other import (
convert_optimas_to_ax_parameters,
convert_optimas_to_ax_objectives,
)


class AxServiceGenerator(AxGenerator):
Expand Down Expand Up @@ -124,6 +129,17 @@ def __init__(
self._parameter_constraints = parameter_constraints
self._outcome_constraints = outcome_constraints
self._ax_client = self._create_ax_client()
self._model = AxModelManager(self._ax_client)

@property
def ax_client(self) -> AxClient:
"""Get the underlying AxClient."""
return self._ax_client

@property
def model(self) -> AxModelManager:
"""Get access to the underlying model using an `AxModelManager`."""
return self._model

def _ask(self, trials: List[Trial]) -> List[Trial]:
"""Fill in the parameter values of the requested trials."""
Expand Down Expand Up @@ -213,19 +229,9 @@ def _create_ax_client(self) -> AxClient:

def _create_ax_parameters(self) -> List:
"""Create list of parameters to pass to an Ax."""
parameters = []
parameters = convert_optimas_to_ax_parameters(self.varying_parameters)
fixed_parameters = {}
for var in self._varying_parameters:
parameters.append(
{
"name": var.name,
"type": "range",
"bounds": [var.lower_bound, var.upper_bound],
"is_fidelity": var.is_fidelity,
"target_value": var.fidelity_target_value,
"value_type": var.dtype.__name__,
}
)
if var.is_fixed:
fixed_parameters[var.name] = var.default_value
# Store fixed parameters as fixed features.
Expand All @@ -234,10 +240,7 @@ def _create_ax_parameters(self) -> List:

def _create_ax_objectives(self) -> Dict[str, ObjectiveProperties]:
"""Create list of objectives to pass to an Ax."""
objectives = {}
for obj in self.objectives:
objectives[obj.name] = ObjectiveProperties(minimize=obj.minimize)
return objectives
return convert_optimas_to_ax_objectives(self.objectives)

def _create_sobol_step(self) -> GenerationStep:
"""Create a Sobol generation step with `n_init` trials."""
Expand Down Expand Up @@ -278,7 +281,6 @@ def _update_parameter(self, parameter):
generation_strategy = self._ax_client.generation_strategy
if generation_strategy._model is not None:
del generation_strategy._curr.model_spec._fitted_model
# Update parameter.
parameters = self._create_ax_parameters()
new_search_space = InstantiationBase.make_search_space(parameters, None)
self._ax_client.experiment.search_space.update_parameter(
Expand Down
3 changes: 3 additions & 0 deletions optimas/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ax import AxModelManager

__all__ = ["AxModelManager"]
3 changes: 3 additions & 0 deletions optimas/utils/ax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ax_model_manager import AxModelManager

__all__ = ["AxModelManager"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the definition of the ExplorationDiagnostics class."""
"""Contains the definition of the AxModelManager class."""

from typing import Optional, Union, List, Tuple, Dict, Any, Literal

Expand All @@ -20,7 +20,10 @@
from ax.modelbridge.registry import Models
from ax.modelbridge.torch import TorchModelBridge
from ax.core.observation import ObservationFeatures
from ax.service.utils.instantiation import ObjectiveProperties
from .other import (
convert_optimas_to_ax_parameters,
convert_optimas_to_ax_objectives,
)

ax_installed = True
except ImportError:
Expand Down Expand Up @@ -77,11 +80,12 @@ def __init__(
"The source must be an `AxClient`, a path to an AxClient json "
"file, or a pandas `DataFrame`."
)
self.ax_client.fit_model()

@property
def _model(self) -> TorchModelBridge:
"""Get the model from the AxClient instance."""
# Make sure model is fitted.
self.ax_client.fit_model()
return self.ax_client.generation_strategy.model

def _build_ax_client_from_dataframe(
Expand All @@ -103,36 +107,10 @@ def _build_ax_client_from_dataframe(
objectives.
"""
# Define parameters for AxClient
axparameters = []
for par in varying_parameters:
# Determine parameter type.
value_dtype = np.dtype(par.dtype)
if value_dtype.kind == "f":
value_type = "float"
elif value_dtype.kind == "i":
value_type = "int"
else:
raise ValueError(
"Ax range parameter can only be of type 'float'ot 'int', "
"not {var.dtype}."
)
# Create parameter dict and append to list.
axparameters.append(
{
"name": par.name,
"type": "range",
"bounds": [par.lower_bound, par.upper_bound],
"is_fidelity": par.is_fidelity,
"target_value": par.fidelity_target_value,
"value_type": value_type,
}
)
axparameters = convert_optimas_to_ax_parameters(varying_parameters)

# Define objectives for AxClient
axobjectives = {
obj.name: ObjectiveProperties(minimize=obj.minimize)
for obj in objectives
}
axobjectives = convert_optimas_to_ax_objectives(objectives)

# Create Ax client.
# We need to explicitly define a generation strategy because otherwise
Expand Down Expand Up @@ -649,3 +627,128 @@ def plot_slice(
ax.legend(frameon=False)

return fig, ax

def plot_cross_validation(
self,
metric_name: Optional[str] = None,
subplot_spec: Optional[SubplotSpec] = None,
gridspec_kw: Optional[Dict[str, Any]] = None,
errorbar_kw: Optional[Dict[str, Any]] = None,
**figure_kw,
) -> Tuple[Figure, Axes]:
"""Make a cross-validation plot for the given metric.
Parameters
----------
metric_name : str, optional.
Name of the metric to plot.
If not specified, it will take the first objective in
``self.ax_client``.
subplot_spec : SubplotSpec, optional
A matplotlib ``SubplotSpec`` in which to draw the axis.
gridspec_kw : dict, optional
Dict with keywords passed to the ``GridSpec``.
errorbar_kw : dict, optional
Dict with keywords passed to ``ax.errorbar_kw``.
**figure_kw
Additional keyword arguments to pass to ``pyplot.figure``. Only
used if no ``subplot_spec`` is given.
Returns
-------
Figure, Axes
"""
# Get metric name.
if metric_name is None:
metric_name = self.ax_client.objective_names[0]

# Evaluate model for each point in the history.
trials = self.ax_client.get_trials_data_frame()
mean, sem = self.evaluate_model(trials)

# Create figure.
gridspec_kw = dict(gridspec_kw or {})
if subplot_spec is None:
fig = plt.figure(**figure_kw)
gs = GridSpec(1, 1, **gridspec_kw)
else:
fig = plt.gcf()
gs = GridSpecFromSubplotSpec(1, 1, subplot_spec, **gridspec_kw)

# Get errorbar kwargs.
errorbar_kw = dict(errorbar_kw or {})
default_errorbar_kw = {"fmt": "o", "ms": 4, "label": "Data"}
errorbar_kw = {**default_errorbar_kw, **errorbar_kw}

# Make plot.
ax = fig.add_subplot(gs[0])
ax.errorbar(trials[metric_name], mean, yerr=sem, **errorbar_kw)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
square_lims = [min(ylim[0], ylim[0]), max(xlim[1], ylim[1])]
ax.plot(
square_lims,
square_lims,
color="k",
ls="--",
label="Ideal correlation",
)
ax.set_xlim(square_lims)
ax.set_ylim(square_lims)
ax.set_xlabel("Observations")
ax.set_ylabel("Model predictions")
ax.legend(frameon=False)
return fig, ax

def plot_feature_importance(
self,
metric_name: Optional[str] = None,
subplot_spec: Optional[SubplotSpec] = None,
gridspec_kw: Optional[Dict[str, Any]] = None,
bar_kw: Optional[Dict[str, Any]] = None,
**figure_kw,
) -> Tuple[Figure, Axes]:
"""Plot the importance of each varying parameter for the given metric.
Parameters
----------
metric_name : str, optional.
Name of the metric for which to determine the importances.
If not specified, it will take the first objective in
``self.ax_client``.
subplot_spec : SubplotSpec, optional
A matplotlib ``SubplotSpec`` in which to draw the axis.
gridspec_kw : dict, optional
Dict with keywords passed to the ``GridSpec``.
bar_kw : dict, optional
Dict with keywords passed to ``ax.bar``.
**figure_kw
Additional keyword arguments to pass to ``pyplot.figure``. Only
used if no ``subplot_spec`` is given.
Returns
-------
Figure, Axes
"""
# Get metric name.
if metric_name is None:
metric_name = self.ax_client.objective_names[0]

# Get feature importances.
importances = self._model.feature_importances(metric_name)

# Create figure.
gridspec_kw = dict(gridspec_kw or {})
if subplot_spec is None:
fig = plt.figure(**figure_kw)
gs = GridSpec(1, 1, **gridspec_kw)
else:
fig = plt.gcf()
gs = GridSpecFromSubplotSpec(1, 1, subplot_spec, **gridspec_kw)
bar_kw = dict(bar_kw or {})

# Make plot.
ax = fig.add_subplot(gs[0])
ax.bar(importances.keys(), importances.values(), **bar_kw)
ax.set_ylabel(f"Importance for metric {metric_name}")
return fig, ax
49 changes: 49 additions & 0 deletions optimas/utils/ax/other.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Contains the definition of various utilities for using Ax."""

from typing import List, Dict

import numpy as np
from ax.service.utils.instantiation import ObjectiveProperties

from optimas.core import VaryingParameter, Objective


def convert_optimas_to_ax_parameters(
varying_parameters: List[VaryingParameter],
) -> List[Dict]:
"""Create list of Ax parameters from optimas varying parameters."""
parameters = []
for var in varying_parameters:
# Determine parameter type.
value_dtype = np.dtype(var.dtype)
if value_dtype.kind == "f":
value_type = "float"
elif value_dtype.kind == "i":
value_type = "int"
else:
raise ValueError(
"Ax range parameter can only be of type 'float'ot 'int', "
f"not {var.dtype}."
)
# Create parameter dict and append to list.
parameters.append(
{
"name": var.name,
"type": "range",
"bounds": [var.lower_bound, var.upper_bound],
"is_fidelity": var.is_fidelity,
"target_value": var.fidelity_target_value,
"value_type": value_type,
}
)
return parameters


def convert_optimas_to_ax_objectives(
objectives: List[Objective],
) -> Dict[str, ObjectiveProperties]:
"""Create list of Ax objectives from optimas objectives."""
ax_objectives = {}
for obj in objectives:
ax_objectives[obj.name] = ObjectiveProperties(minimize=obj.minimize)
return ax_objectives
Loading

0 comments on commit 742fc71

Please sign in to comment.