diff --git a/ax/analysis/__init__.py b/ax/analysis/__init__.py index bea822c15c2..bfc9d7b28ac 100644 --- a/ax/analysis/__init__.py +++ b/ax/analysis/__init__.py @@ -11,8 +11,16 @@ AnalysisCardLevel, display_cards, ) +from ax.analysis.metric_summary import MetricSummary from ax.analysis.summary import Summary from ax.analysis.markdown import * # noqa from ax.analysis.plotly import * # noqa -__all__ = ["Analysis", "AnalysisCard", "AnalysisCardLevel", "display_cards", "Summary"] +__all__ = [ + "Analysis", + "AnalysisCard", + "AnalysisCardLevel", + "display_cards", + "MetricSummary", + "Summary", +] diff --git a/ax/analysis/metric_summary.py b/ax/analysis/metric_summary.py new file mode 100644 index 00000000000..e79ec9ad5f2 --- /dev/null +++ b/ax/analysis/metric_summary.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel +from ax.core.experiment import Experiment +from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy + + +class MetricSummary(Analysis): + """ + Creates a dataframe with information about each metric in the + experiment. The resulting dataframe has one row per metric, and the + following columns: + - Name: the name of the metric. + - Type: the metric subclass (e.g., Metric, BraninMetric). + - Goal: the goal for this for this metric, based on the optimization + config (minimize, maximize, constraint or track). + - Bound: the bound of this metric (e.g., "<=10.0") if it is being used + as part of an ObjectiveThreshold or OutcomeConstraint. + - Lower is Better: whether the user prefers this metric to be lower, + if provided. + """ + + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, + ) -> AnalysisCard: + if experiment is None: + raise UserInputError( + "`MetricSummary` analysis requires an `Experiment` input" + ) + return self._create_analysis_card( + title=f"MetricSummary for `{experiment.name}`", + subtitle="High-level summary of the `Metric`-s in this `Experiment`", + level=AnalysisCardLevel.MID, + df=experiment.metric_config_summary_df, + ) diff --git a/ax/analysis/tests/test_metric_summary.py b/ax/analysis/tests/test_metric_summary.py new file mode 100644 index 00000000000..39f31aa97ab --- /dev/null +++ b/ax/analysis/tests/test_metric_summary.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.metric_summary import MetricSummary +from ax.core.metric import Metric +from ax.exceptions.core import UserInputError +from ax.preview.api.client import Client +from ax.preview.api.configs import ExperimentConfig +from ax.utils.common.testutils import TestCase + + +class TestMetricSummary(TestCase): + def test_compute(self) -> None: + client = Client() + client.configure_experiment( + experiment_config=ExperimentConfig( + name="test_experiment", + parameters=[], + ) + ) + client.configure_optimization( + objective="foo, bar", outcome_constraints=["baz <= 0.0", "foo >= 1.0"] + ) + # TODO: Debug error raised by + # client.configure_metrics(metrics=[IMetric(name="qux")]) + + client._experiment._tracking_metrics = {"qux": Metric(name="qux")} + + analysis = MetricSummary() + + with self.assertRaisesRegex(UserInputError, "requires an `Experiment`"): + analysis.compute() + + experiment = client._experiment + card = analysis.compute(experiment=experiment) + + # Test metadata + self.assertEqual(card.name, "MetricSummary") + self.assertEqual(card.title, "MetricSummary for `test_experiment`") + self.assertEqual( + card.subtitle, + "High-level summary of the `Metric`-s in this `Experiment`", + ) + self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "dataframe") + + # Test dataframe for accuracy + self.assertEqual( + {*card.df.columns}, + { + "Goal", + "Name", + "Type", + "Lower is Better", + "Bound", + }, + ) + expected = pd.DataFrame( + { + "Name": ["bar", "foo", "baz", "qux"], + "Type": ["MapMetric", "MapMetric", "MapMetric", "Metric"], + "Goal": pd.Series( + ["maximize", "maximize", "constrain", "track"], + dtype=pd.CategoricalDtype( + categories=[ + "minimize", + "maximize", + "constrain", + "track", + "None", + ], + ordered=True, + ), + ), + "Bound": ["None", ">= 1.0", "<= 0.0", "None"], + "Lower is Better": ["None", "None", "None", "None"], + } + ) + pd.testing.assert_frame_equal(card.df, expected) diff --git a/ax/analysis/tests/test_summary.py b/ax/analysis/tests/test_summary.py index d788ab06bf5..119b67858df 100644 --- a/ax/analysis/tests/test_summary.py +++ b/ax/analysis/tests/test_summary.py @@ -104,7 +104,7 @@ def test_compute(self) -> None: }, } ) - self.assertTrue(card.df.equals(expected)) + pd.testing.assert_frame_equal(card.df, expected) # Test without omitting empty columns analysis_no_omit = Summary(omit_empty_columns=False) diff --git a/sphinx/source/analysis.rst b/sphinx/source/analysis.rst index ac998f96cdf..02db96177ec 100644 --- a/sphinx/source/analysis.rst +++ b/sphinx/source/analysis.rst @@ -151,6 +151,14 @@ Summary :undoc-members: :show-inheritance: +MetricSummary +~~~~~~~ + +.. automodule:: ax.analysis.metric_summary + :members: + :undoc-members: + :show-inheritance: + Interaction ~~~~~~~~~~~