Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose AxModelManager methods in Service API generators and add new plotting methods #169

Merged
merged 40 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7a51a45
Expose Ax plots in Service API generators
AngelFP Jan 27, 2024
090e689
Support MOO in plots
AngelFP Jan 27, 2024
b3485b4
Update tests
AngelFP Jan 27, 2024
19dc5d9
Add missing type hints
AngelFP Jan 27, 2024
fa0839b
Implement `plot_feature_importance`
AngelFP Jan 28, 2024
8b5e104
Add `ax_client` property
AngelFP Jan 28, 2024
976a3a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
8bbfb9a
Merge branch 'main' into feature/expose_ax_plots
AngelFP Feb 5, 2024
1af79b6
Uncomment tests
AngelFP Feb 5, 2024
16f7b63
Increase test trials
AngelFP Feb 5, 2024
64968e6
Improve detection of parameter `value_type`.
AngelFP Feb 6, 2024
377979c
Make sure generation strategy moves to next step when adding external…
AngelFP Feb 6, 2024
6461543
Merge branch 'main' into feature/expose_ax_plots
AngelFP Feb 6, 2024
5a4e973
Improve handling of when best model parameters can not be found
AngelFP Feb 6, 2024
5de111c
Ignore unrecognized parameters when loading history
AngelFP Feb 6, 2024
2d58e4d
Merge branch 'main' into feature/expose_ax_plots
AngelFP Feb 15, 2024
ac20ce5
Fix test
AngelFP Feb 15, 2024
84e59fe
Remove unnecessary and duplicated plots from test
AngelFP Feb 15, 2024
a67fafc
Merge branch 'main' into feature/expose_ax_plots
AngelFP Mar 5, 2024
8b81df0
Merge branch 'main' into feature/expose_ax_plots
AngelFP Mar 7, 2024
e982ed1
Move `AxModelManager` to `utils.ax`
AngelFP Mar 7, 2024
8729d36
Update imports
AngelFP Mar 7, 2024
86c8d5d
Add new plots to `AxModelManager`
AngelFP Mar 7, 2024
87842f2
Make `AxServiceGenerator` inherit from `AxModelManager` to expose plo…
AngelFP Mar 7, 2024
3734f14
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
580ea9a
Expose `AxModelManager` on `utils`
AngelFP Mar 7, 2024
8cc41f2
Update docs
AngelFP Mar 7, 2024
f523e03
Fix docs
AngelFP Mar 7, 2024
b951d33
Update test
AngelFP Mar 7, 2024
f77fc89
Return AxClient with fitted model
AngelFP Mar 7, 2024
59e8ca3
Create common utils for converting parameters from optimas to ax
AngelFP Mar 7, 2024
0fe5672
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
21194da
Merge branch 'main' into feature/expose_ax_plots
AngelFP Apr 18, 2024
2b0fb21
Expose `AxModelManager` utilities in `model` property, instead of usi…
AngelFP Apr 18, 2024
2238a53
Update test
AngelFP Apr 18, 2024
fec0bda
Add docstring
AngelFP Apr 18, 2024
e3d5218
Merge branch 'main' into feature/expose_ax_plots
AngelFP Apr 23, 2024
a26380b
Merge branch 'main' into feature/expose_ax_plots
AngelFP May 10, 2024
671bcc4
Merge branch 'main' into feature/expose_ax_plots
AngelFP May 22, 2024
fd6307d
Merge branch 'main' into feature/expose_ax_plots
delaossa May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/source/api/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ Diagnostics
:toctree: _autosummary

ExplorationDiagnostics
AxModelManager
1 change: 1 addition & 0 deletions doc/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ This reference manual details all classes included in optimas.
evaluators
exploration
diagnostics
utils
9 changes: 9 additions & 0 deletions doc/source/api/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Utilities
=========

.. currentmodule:: optimas.utils

.. autosummary::
:toctree: _autosummary

AxModelManager
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"\n",
"The models provide some basic plotting methods for easy visualization, like\n",
":meth:`~optimas.diagnostics.AxModelManager.plot_contour`\n",
"and :meth:`~optimas.diagnostics.AxModelManager.plot_slice`."
":meth:`~optimas.utils.AxModelManager.plot_contour`\n",
"and :meth:`~optimas.utils.AxModelManager.plot_slice`."
]
},
{
Expand Down Expand Up @@ -243,7 +243,7 @@
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"\n",
"In addition to plotting, it is also possible to evaluate the model at any\n",
"point by using the :meth:`~optimas.diagnostics.AxModelManager.evaluate_model`\n",
"point by using the :meth:`~optimas.utils.AxModelManager.evaluate_model`\n",
"method.\n",
"\n",
"In the example below, this method is used to evaluate the model in all the\n",
Expand Down
3 changes: 1 addition & 2 deletions optimas/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .exploration_diagnostics import ExplorationDiagnostics
from .ax_model_manager import AxModelManager

__all__ = ["ExplorationDiagnostics", "AxModelManager"]
__all__ = ["ExplorationDiagnostics"]
2 changes: 1 addition & 1 deletion optimas/diagnostics/exploration_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from optimas.evaluators.base import Evaluator
from optimas.explorations import Exploration
from optimas.utils.other import get_df_with_selection
from optimas.diagnostics.ax_model_manager import AxModelManager
from optimas.utils.ax import AxModelManager


class ExplorationDiagnostics:
Expand Down
4 changes: 3 additions & 1 deletion optimas/explorations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def _load_history(
), "Type {} not valid for `history`".format(type(history))
# Incorporate history into exploration.
if history is not None:
self.attach_evaluations(history)
self.attach_evaluations(
history, ignore_unrecognized_parameters=True
)
# When resuming an exploration, update evaluations counter.
if resume:
self._n_evals = history.size
Expand Down
34 changes: 18 additions & 16 deletions optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
TrialStatus,
)
from optimas.generators.ax.base import AxGenerator
from optimas.utils.ax import AxModelManager
from optimas.utils.ax.other import (
convert_optimas_to_ax_parameters,
convert_optimas_to_ax_objectives,
)


class AxServiceGenerator(AxGenerator):
Expand Down Expand Up @@ -124,6 +129,17 @@ def __init__(
self._parameter_constraints = parameter_constraints
self._outcome_constraints = outcome_constraints
self._ax_client = self._create_ax_client()
self._model = AxModelManager(self._ax_client)

@property
def ax_client(self) -> AxClient:
"""Get the underlying AxClient."""
return self._ax_client

@property
def model(self) -> AxModelManager:
"""Get access to the underlying model using an `AxModelManager`."""
return self._model

def _ask(self, trials: List[Trial]) -> List[Trial]:
"""Fill in the parameter values of the requested trials."""
Expand Down Expand Up @@ -213,19 +229,9 @@ def _create_ax_client(self) -> AxClient:

def _create_ax_parameters(self) -> List:
"""Create list of parameters to pass to an Ax."""
parameters = []
parameters = convert_optimas_to_ax_parameters(self.varying_parameters)
fixed_parameters = {}
for var in self._varying_parameters:
parameters.append(
{
"name": var.name,
"type": "range",
"bounds": [var.lower_bound, var.upper_bound],
"is_fidelity": var.is_fidelity,
"target_value": var.fidelity_target_value,
"value_type": var.dtype.__name__,
}
)
if var.is_fixed:
fixed_parameters[var.name] = var.default_value
# Store fixed parameters as fixed features.
Expand All @@ -234,10 +240,7 @@ def _create_ax_parameters(self) -> List:

def _create_ax_objectives(self) -> Dict[str, ObjectiveProperties]:
"""Create list of objectives to pass to an Ax."""
objectives = {}
for obj in self.objectives:
objectives[obj.name] = ObjectiveProperties(minimize=obj.minimize)
return objectives
return convert_optimas_to_ax_objectives(self.objectives)

def _create_sobol_step(self) -> GenerationStep:
"""Create a Sobol generation step with `n_init` trials."""
Expand Down Expand Up @@ -278,7 +281,6 @@ def _update_parameter(self, parameter):
generation_strategy = self._ax_client.generation_strategy
if generation_strategy._model is not None:
del generation_strategy._curr.model_spec._fitted_model
# Update parameter.
parameters = self._create_ax_parameters()
new_search_space = InstantiationBase.make_search_space(parameters, None)
self._ax_client.experiment.search_space.update_parameter(
Expand Down
3 changes: 3 additions & 0 deletions optimas/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ax import AxModelManager

__all__ = ["AxModelManager"]
3 changes: 3 additions & 0 deletions optimas/utils/ax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ax_model_manager import AxModelManager

__all__ = ["AxModelManager"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the definition of the ExplorationDiagnostics class."""
"""Contains the definition of the AxModelManager class."""

from typing import Optional, Union, List, Tuple, Dict, Any, Literal

Expand All @@ -20,7 +20,10 @@
from ax.modelbridge.registry import Models
from ax.modelbridge.torch import TorchModelBridge
from ax.core.observation import ObservationFeatures
from ax.service.utils.instantiation import ObjectiveProperties
from .other import (
convert_optimas_to_ax_parameters,
convert_optimas_to_ax_objectives,
)

ax_installed = True
except ImportError:
Expand Down Expand Up @@ -77,11 +80,12 @@ def __init__(
"The source must be an `AxClient`, a path to an AxClient json "
"file, or a pandas `DataFrame`."
)
self.ax_client.fit_model()

@property
def _model(self) -> TorchModelBridge:
"""Get the model from the AxClient instance."""
# Make sure model is fitted.
self.ax_client.fit_model()
return self.ax_client.generation_strategy.model

def _build_ax_client_from_dataframe(
Expand All @@ -103,36 +107,10 @@ def _build_ax_client_from_dataframe(
objectives.
"""
# Define parameters for AxClient
axparameters = []
for par in varying_parameters:
# Determine parameter type.
value_dtype = np.dtype(par.dtype)
if value_dtype.kind == "f":
value_type = "float"
elif value_dtype.kind == "i":
value_type = "int"
else:
raise ValueError(
"Ax range parameter can only be of type 'float'ot 'int', "
"not {var.dtype}."
)
# Create parameter dict and append to list.
axparameters.append(
{
"name": par.name,
"type": "range",
"bounds": [par.lower_bound, par.upper_bound],
"is_fidelity": par.is_fidelity,
"target_value": par.fidelity_target_value,
"value_type": value_type,
}
)
axparameters = convert_optimas_to_ax_parameters(varying_parameters)

# Define objectives for AxClient
axobjectives = {
obj.name: ObjectiveProperties(minimize=obj.minimize)
for obj in objectives
}
axobjectives = convert_optimas_to_ax_objectives(objectives)

# Create Ax client.
# We need to explicitly define a generation strategy because otherwise
Expand Down Expand Up @@ -649,3 +627,128 @@ def plot_slice(
ax.legend(frameon=False)

return fig, ax

def plot_cross_validation(
self,
metric_name: Optional[str] = None,
subplot_spec: Optional[SubplotSpec] = None,
gridspec_kw: Optional[Dict[str, Any]] = None,
errorbar_kw: Optional[Dict[str, Any]] = None,
**figure_kw,
) -> Tuple[Figure, Axes]:
"""Make a cross-validation plot for the given metric.

Parameters
----------
metric_name : str, optional.
Name of the metric to plot.
If not specified, it will take the first objective in
``self.ax_client``.
subplot_spec : SubplotSpec, optional
A matplotlib ``SubplotSpec`` in which to draw the axis.
gridspec_kw : dict, optional
Dict with keywords passed to the ``GridSpec``.
errorbar_kw : dict, optional
Dict with keywords passed to ``ax.errorbar_kw``.
**figure_kw
Additional keyword arguments to pass to ``pyplot.figure``. Only
used if no ``subplot_spec`` is given.

Returns
-------
Figure, Axes
"""
# Get metric name.
if metric_name is None:
metric_name = self.ax_client.objective_names[0]

# Evaluate model for each point in the history.
trials = self.ax_client.get_trials_data_frame()
mean, sem = self.evaluate_model(trials)

# Create figure.
gridspec_kw = dict(gridspec_kw or {})
if subplot_spec is None:
fig = plt.figure(**figure_kw)
gs = GridSpec(1, 1, **gridspec_kw)
else:
fig = plt.gcf()
gs = GridSpecFromSubplotSpec(1, 1, subplot_spec, **gridspec_kw)

# Get errorbar kwargs.
errorbar_kw = dict(errorbar_kw or {})
default_errorbar_kw = {"fmt": "o", "ms": 4, "label": "Data"}
errorbar_kw = {**default_errorbar_kw, **errorbar_kw}

# Make plot.
ax = fig.add_subplot(gs[0])
ax.errorbar(trials[metric_name], mean, yerr=sem, **errorbar_kw)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
square_lims = [min(ylim[0], ylim[0]), max(xlim[1], ylim[1])]
ax.plot(
square_lims,
square_lims,
color="k",
ls="--",
label="Ideal correlation",
)
ax.set_xlim(square_lims)
ax.set_ylim(square_lims)
ax.set_xlabel("Observations")
ax.set_ylabel("Model predictions")
ax.legend(frameon=False)
return fig, ax

def plot_feature_importance(
self,
metric_name: Optional[str] = None,
subplot_spec: Optional[SubplotSpec] = None,
gridspec_kw: Optional[Dict[str, Any]] = None,
bar_kw: Optional[Dict[str, Any]] = None,
**figure_kw,
) -> Tuple[Figure, Axes]:
"""Plot the importance of each varying parameter for the given metric.

Parameters
----------
metric_name : str, optional.
Name of the metric for which to determine the importances.
If not specified, it will take the first objective in
``self.ax_client``.
subplot_spec : SubplotSpec, optional
A matplotlib ``SubplotSpec`` in which to draw the axis.
gridspec_kw : dict, optional
Dict with keywords passed to the ``GridSpec``.
bar_kw : dict, optional
Dict with keywords passed to ``ax.bar``.
**figure_kw
Additional keyword arguments to pass to ``pyplot.figure``. Only
used if no ``subplot_spec`` is given.

Returns
-------
Figure, Axes
"""
# Get metric name.
if metric_name is None:
metric_name = self.ax_client.objective_names[0]

# Get feature importances.
importances = self._model.feature_importances(metric_name)

# Create figure.
gridspec_kw = dict(gridspec_kw or {})
if subplot_spec is None:
fig = plt.figure(**figure_kw)
gs = GridSpec(1, 1, **gridspec_kw)
else:
fig = plt.gcf()
gs = GridSpecFromSubplotSpec(1, 1, subplot_spec, **gridspec_kw)
bar_kw = dict(bar_kw or {})

# Make plot.
ax = fig.add_subplot(gs[0])
ax.bar(importances.keys(), importances.values(), **bar_kw)
ax.set_ylabel(f"Importance for metric {metric_name}")
return fig, ax
49 changes: 49 additions & 0 deletions optimas/utils/ax/other.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Contains the definition of various utilities for using Ax."""

from typing import List, Dict

import numpy as np
from ax.service.utils.instantiation import ObjectiveProperties

from optimas.core import VaryingParameter, Objective


def convert_optimas_to_ax_parameters(
varying_parameters: List[VaryingParameter],
) -> List[Dict]:
"""Create list of Ax parameters from optimas varying parameters."""
parameters = []
for var in varying_parameters:
# Determine parameter type.
value_dtype = np.dtype(var.dtype)
if value_dtype.kind == "f":
value_type = "float"
elif value_dtype.kind == "i":
value_type = "int"
else:
raise ValueError(
"Ax range parameter can only be of type 'float'ot 'int', "
f"not {var.dtype}."
)
# Create parameter dict and append to list.
parameters.append(
{
"name": var.name,
"type": "range",
"bounds": [var.lower_bound, var.upper_bound],
"is_fidelity": var.is_fidelity,
"target_value": var.fidelity_target_value,
"value_type": value_type,
}
)
return parameters


def convert_optimas_to_ax_objectives(
objectives: List[Objective],
) -> Dict[str, ObjectiveProperties]:
"""Create list of Ax objectives from optimas objectives."""
ax_objectives = {}
for obj in objectives:
ax_objectives[obj.name] = ObjectiveProperties(minimize=obj.minimize)
return ax_objectives
Loading