From 1e33d1f7b92955cc96e56ac32160e2732d81f51e Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Mon, 24 Feb 2025 15:02:16 -0800 Subject: [PATCH] Add category to AnalysisCards (#3414) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3414 We've had a few instances come up in meetings last week that I think will make having a category useful: 1. Allow us to group similar analysis together in the UI so everything isn't shown in one tab 2. Allow for easy grouping of analysis in bento notebooks Additionally, as I was implementing this, there is a TODO on base_client to find a good hueristic for which analysis to show, I think implementing that hueristic will be much easier with category. I think this is useful enough to warrant the addition because it is distinct from level- which indicates importance - and category should still be ranked by level. Reviewed By: mpolson64, Cesar-Cardoso Differential Revision: D69726341 --- ax/analysis/analysis.py | 17 +++++++++++++++++ .../healthcheck/can_generate_candidates.py | 3 ++- .../healthcheck/constraints_feasibility.py | 6 +++++- ax/analysis/healthcheck/regression_analysis.py | 3 ++- .../healthcheck/search_space_analysis.py | 3 ++- .../healthcheck/should_generate_candidates.py | 3 ++- .../tests/test_can_generate_candidates.py | 3 ++- .../tests/test_constraints_feasibility.py | 3 ++- .../tests/test_regression_analysis.py | 3 ++- .../tests/test_search_space_analysis.py | 3 ++- .../tests/test_should_generate_candidates.py | 4 +++- ax/analysis/markdown/markdown_analysis.py | 11 ++++++++++- ax/analysis/metric_summary.py | 8 +++++++- .../plotly/arm_effects/insample_effects.py | 3 ++- .../plotly/arm_effects/predicted_effects.py | 3 ++- ax/analysis/plotly/cross_validation.py | 3 ++- ax/analysis/plotly/interaction.py | 3 ++- ax/analysis/plotly/parallel_coordinates.py | 3 ++- ax/analysis/plotly/plotly_analysis.py | 2 ++ ax/analysis/plotly/progression.py | 3 ++- ax/analysis/plotly/scatter.py | 3 ++- ax/analysis/plotly/surface/contour.py | 3 ++- ax/analysis/plotly/surface/slice.py | 3 ++- .../plotly/surface/tests/test_contour.py | 3 ++- ax/analysis/plotly/surface/tests/test_slice.py | 3 ++- .../plotly/tests/test_cross_validation.py | 3 ++- .../plotly/tests/test_insample_effects.py | 3 ++- ax/analysis/plotly/tests/test_interaction.py | 3 ++- .../plotly/tests/test_parallel_coordinates.py | 3 ++- .../plotly/tests/test_predicted_effects.py | 3 ++- ax/analysis/plotly/tests/test_progression.py | 3 ++- ax/analysis/plotly/tests/test_scatter.py | 3 ++- ax/analysis/search_space_summary.py | 8 +++++++- ax/analysis/summary.py | 8 +++++++- ax/analysis/tests/test_metric_summary.py | 3 ++- ax/analysis/tests/test_search_space_summary.py | 3 ++- ax/analysis/tests/test_summary.py | 3 ++- ax/service/utils/analysis_base.py | 9 ++++++++- ax/storage/sqa_store/decoder.py | 1 + ax/storage/sqa_store/encoder.py | 1 + ax/storage/sqa_store/sqa_classes.py | 1 + ax/storage/sqa_store/tests/test_sqa_store.py | 5 ++++- 42 files changed, 131 insertions(+), 37 deletions(-) diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index 10fcff8e043..7ac698985ae 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -32,6 +32,14 @@ class AnalysisCardLevel(IntEnum): CRITICAL = 40 +class AnalysisCardCategory(IntEnum): + ERROR = 0 + ACTIONABLE = 1 + INSIGHT = 2 + DIAGNOSTIC = 3 # Equivalent to "health check" in online setting + INFO = 4 + + class AnalysisCard(Base): # Name of the analysis computed, usually the class name of the Analysis which # produced the card. Useful for grouping by when querying a large collection of @@ -44,6 +52,8 @@ class AnalysisCard(Base): title: str subtitle: str + # Level of the card with respect to its importance. Higher levels are more + # important, and will be displayed first. level: int df: pd.DataFrame # Raw data produced by the Analysis @@ -53,6 +63,9 @@ class AnalysisCard(Base): # the blob and presenting it to the user (ex. PlotlyAnalysisCard.get_figure() # decodes the blob into a go.Figure object). blob: str + # Type of the card (ex: "insight", "diagnostic"), useful for + # grouping the cards to display only one category in notebook environments. + category: int # How to interpret the blob (ex. "dataframe", "plotly", "markdown") blob_annotation = "dataframe" @@ -64,6 +77,7 @@ def __init__( level: int, df: pd.DataFrame, blob: str, + category: int, attributes: dict[str, Any] | None = None, ) -> None: self.name = name @@ -73,6 +87,7 @@ def __init__( self.df = df self.blob = blob self.attributes = {} if attributes is None else attributes + self.category = category def _ipython_display_(self) -> None: """ @@ -162,6 +177,7 @@ def _create_analysis_card( subtitle: str, level: int, df: pd.DataFrame, + category: int, ) -> AnalysisCard: """ Make an AnalysisCard from this Analysis using provided fields and @@ -175,6 +191,7 @@ def _create_analysis_card( level=level, df=df, blob=df.to_json(), + category=category, ) @property diff --git a/ax/analysis/healthcheck/can_generate_candidates.py b/ax/analysis/healthcheck/can_generate_candidates.py index 624f5fabf57..22479f516ab 100644 --- a/ax/analysis/healthcheck/can_generate_candidates.py +++ b/ax/analysis/healthcheck/can_generate_candidates.py @@ -9,7 +9,7 @@ from datetime import datetime import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.healthcheck_analysis import ( HealthcheckAnalysis, @@ -94,4 +94,5 @@ def compute( } ), level=level, + category=AnalysisCardCategory.DIAGNOSTIC, ) diff --git a/ax/analysis/healthcheck/constraints_feasibility.py b/ax/analysis/healthcheck/constraints_feasibility.py index 1d21fe823a3..c4ace263836 100644 --- a/ax/analysis/healthcheck/constraints_feasibility.py +++ b/ax/analysis/healthcheck/constraints_feasibility.py @@ -9,7 +9,7 @@ import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.healthcheck_analysis import ( HealthcheckAnalysis, @@ -68,6 +68,7 @@ def compute( title_status = "Success" level = AnalysisCardLevel.LOW df = pd.DataFrame({"status": [status]}) + category = AnalysisCardCategory.DIAGNOSTIC if experiment is None: raise UserInputError( @@ -83,6 +84,7 @@ def compute( subtitle=subtitle, df=df, level=level, + category=category, ) if ( @@ -97,6 +99,7 @@ def compute( subtitle=subtitle, df=df, level=level, + category=category, ) if generation_strategy is None: @@ -148,6 +151,7 @@ def compute( subtitle=subtitle, df=df, level=level, + category=category, ) diff --git a/ax/analysis/healthcheck/regression_analysis.py b/ax/analysis/healthcheck/regression_analysis.py index 2c58c79e211..1f439bcf81f 100644 --- a/ax/analysis/healthcheck/regression_analysis.py +++ b/ax/analysis/healthcheck/regression_analysis.py @@ -8,7 +8,7 @@ import json import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.healthcheck_analysis import ( HealthcheckAnalysis, HealthcheckAnalysisCard, @@ -105,6 +105,7 @@ def compute( subtitle=subtitle, df=df, level=AnalysisCardLevel.LOW, + category=AnalysisCardCategory.DIAGNOSTIC, ) diff --git a/ax/analysis/healthcheck/search_space_analysis.py b/ax/analysis/healthcheck/search_space_analysis.py index 1209ea80953..5a578fb0ec2 100644 --- a/ax/analysis/healthcheck/search_space_analysis.py +++ b/ax/analysis/healthcheck/search_space_analysis.py @@ -11,7 +11,7 @@ import numpy as np import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.healthcheck_analysis import ( HealthcheckAnalysis, HealthcheckAnalysisCard, @@ -104,6 +104,7 @@ def compute( df=df, level=level, attributes={"trial_index": self.trial_index}, + category=AnalysisCardCategory.DIAGNOSTIC, ) diff --git a/ax/analysis/healthcheck/should_generate_candidates.py b/ax/analysis/healthcheck/should_generate_candidates.py index 3f359f64408..f4bd033a5b3 100644 --- a/ax/analysis/healthcheck/should_generate_candidates.py +++ b/ax/analysis/healthcheck/should_generate_candidates.py @@ -8,7 +8,7 @@ import json import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.healthcheck_analysis import ( HealthcheckAnalysis, @@ -57,4 +57,5 @@ def compute( ), level=AnalysisCardLevel.CRITICAL, attributes=self.attributes, + category=AnalysisCardCategory.DIAGNOSTIC, ) diff --git a/ax/analysis/healthcheck/tests/test_can_generate_candidates.py b/ax/analysis/healthcheck/tests/test_can_generate_candidates.py index 3ecd71faf23..ea6bafb112f 100644 --- a/ax/analysis/healthcheck/tests/test_can_generate_candidates.py +++ b/ax/analysis/healthcheck/tests/test_can_generate_candidates.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.can_generate_candidates import ( CanGenerateCandidatesAnalysis, ) @@ -34,6 +34,7 @@ def test_passes_if_can_generate(self) -> None: self.assertEqual(card.title, "Ax Candidate Generation Success") self.assertEqual(card.subtitle, "No problems found.") self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC) pdt.assert_frame_equal( card.df, pd.DataFrame( diff --git a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py index 92067cca5ac..3332ecd892b 100644 --- a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py +++ b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py @@ -10,7 +10,7 @@ import numpy as np import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.constraints_feasibility import ( constraints_feasibility, ConstraintsFeasibilityAnalysis, @@ -169,6 +169,7 @@ def test_compute(self) -> None: self.assertEqual(card.name, "ConstraintsFeasibility") self.assertEqual(card.title, "Ax Constraints Feasibility Success") self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC) self.assertEqual(card.subtitle, "All constraints are feasible.") df_metric_d = pd.DataFrame( diff --git a/ax/analysis/healthcheck/tests/test_regression_analysis.py b/ax/analysis/healthcheck/tests/test_regression_analysis.py index aa2987e892c..8ba21249f95 100644 --- a/ax/analysis/healthcheck/tests/test_regression_analysis.py +++ b/ax/analysis/healthcheck/tests/test_regression_analysis.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis from ax.core.data import Data from ax.utils.common.testutils import TestCase @@ -42,6 +42,7 @@ def test_regression_analysis(self) -> None: and "Trial 0" in card.subtitle ) self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC) df = pd.DataFrame( { diff --git a/ax/analysis/healthcheck/tests/test_search_space_analysis.py b/ax/analysis/healthcheck/tests/test_search_space_analysis.py index 0c8b2f5d7aa..5a0003d9acc 100644 --- a/ax/analysis/healthcheck/tests/test_search_space_analysis.py +++ b/ax/analysis/healthcheck/tests/test_search_space_analysis.py @@ -7,7 +7,7 @@ import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.search_space_analysis import ( boundary_proportions_message, search_space_boundary_proportions, @@ -35,6 +35,7 @@ def test_search_space_analysis(self) -> None: card = ssa.compute(experiment=experiment) self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC) self.assertEqual(card.name, "SearchSpaceAnalysis") self.assertEqual(card.title, "Ax Search Space Analysis Warning") print(card.subtitle) diff --git a/ax/analysis/healthcheck/tests/test_should_generate_candidates.py b/ax/analysis/healthcheck/tests/test_should_generate_candidates.py index ed5298ffcf9..0272bbf8b2e 100644 --- a/ax/analysis/healthcheck/tests/test_should_generate_candidates.py +++ b/ax/analysis/healthcheck/tests/test_should_generate_candidates.py @@ -7,7 +7,7 @@ from random import randint -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates from ax.utils.common.testutils import TestCase @@ -23,6 +23,7 @@ def test_should(self) -> None: ).compute() self.assertEqual(card.get_status(), HealthcheckStatus.PASS) self.assertEqual(card.level, AnalysisCardLevel.CRITICAL) + self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC) self.assertEqual(card.subtitle, "Something reassuring") self.assertEqual(card.attributes["trial_index"], trial_index) @@ -35,5 +36,6 @@ def test_should_not(self) -> None: ).compute() self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) self.assertEqual(card.level, AnalysisCardLevel.CRITICAL) + self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC) self.assertEqual(card.subtitle, "Something concerning") self.assertEqual(card.attributes["trial_index"], trial_index) diff --git a/ax/analysis/markdown/markdown_analysis.py b/ax/analysis/markdown/markdown_analysis.py index f8a7dd57b4a..e70efd304c7 100644 --- a/ax/analysis/markdown/markdown_analysis.py +++ b/ax/analysis/markdown/markdown_analysis.py @@ -9,7 +9,13 @@ import traceback import pandas as pd -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE +from ax.analysis.analysis import ( + Analysis, + AnalysisCard, + AnalysisCardCategory, + AnalysisCardLevel, + AnalysisE, +) from ax.core.experiment import Experiment from ax.generation_strategy.generation_strategy import GenerationStrategy from IPython.display import display, Markdown @@ -47,6 +53,7 @@ def _create_markdown_analysis_card( level: int, df: pd.DataFrame, message: str, + category: int, ) -> MarkdownAnalysisCard: """ Make a MarkdownAnalysisCard from this Analysis using provided fields and @@ -60,6 +67,7 @@ def _create_markdown_analysis_card( level=level, df=df, blob=message, + category=category, ) @@ -80,4 +88,5 @@ def markdown_analysis_card_from_analysis_e( ), df=pd.DataFrame(), level=AnalysisCardLevel.DEBUG, + category=AnalysisCardCategory.ERROR, ) diff --git a/ax/analysis/metric_summary.py b/ax/analysis/metric_summary.py index e79ec9ad5f2..38d829640ae 100644 --- a/ax/analysis/metric_summary.py +++ b/ax/analysis/metric_summary.py @@ -5,7 +5,12 @@ # pyre-strict -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel +from ax.analysis.analysis import ( + Analysis, + AnalysisCard, + AnalysisCardCategory, + AnalysisCardLevel, +) from ax.core.experiment import Experiment from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy @@ -40,4 +45,5 @@ def compute( subtitle="High-level summary of the `Metric`-s in this `Experiment`", level=AnalysisCardLevel.MID, df=experiment.metric_config_summary_df, + category=AnalysisCardCategory.INFO, ) diff --git a/ax/analysis/plotly/arm_effects/insample_effects.py b/ax/analysis/plotly/arm_effects/insample_effects.py index bf2234c936a..005a1985779 100644 --- a/ax/analysis/plotly/arm_effects/insample_effects.py +++ b/ax/analysis/plotly/arm_effects/insample_effects.py @@ -9,7 +9,7 @@ from logging import Logger import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.arm_effects.utils import ( get_predictions_by_arm, prepare_arm_effects_plot, @@ -144,6 +144,7 @@ def compute( level=level + nudge, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) return card diff --git a/ax/analysis/plotly/arm_effects/predicted_effects.py b/ax/analysis/plotly/arm_effects/predicted_effects.py index 77e04d431ca..e5fbc3c5a28 100644 --- a/ax/analysis/plotly/arm_effects/predicted_effects.py +++ b/ax/analysis/plotly/arm_effects/predicted_effects.py @@ -9,7 +9,7 @@ from typing import Any import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.arm_effects.utils import ( get_predictions_by_arm, @@ -144,6 +144,7 @@ def compute( level=level + nudge, df=df, fig=fig, + category=AnalysisCardCategory.ACTIONABLE, ) diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index c9dda53ede7..25fd5483455 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -7,7 +7,7 @@ import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.utils import select_metric @@ -124,6 +124,7 @@ def compute( level=AnalysisCardLevel.LOW.value + nudge, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index cd29d38ae71..ff5ea4bcf47 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -10,7 +10,7 @@ import pandas as pd import torch -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard @@ -259,6 +259,7 @@ def compute( level=AnalysisCardLevel.MID, df=sensitivity_df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) def _get_oak_model(self, experiment: Experiment, metric_name: str) -> TorchAdapter: diff --git a/ax/analysis/plotly/parallel_coordinates.py b/ax/analysis/plotly/parallel_coordinates.py index f55ba2e8cf8..e00189db364 100644 --- a/ax/analysis/plotly/parallel_coordinates.py +++ b/ax/analysis/plotly/parallel_coordinates.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.utils import select_metric @@ -61,6 +61,7 @@ def compute( level=AnalysisCardLevel.HIGH, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) diff --git a/ax/analysis/plotly/plotly_analysis.py b/ax/analysis/plotly/plotly_analysis.py index e957065b19f..31a84e5c510 100644 --- a/ax/analysis/plotly/plotly_analysis.py +++ b/ax/analysis/plotly/plotly_analysis.py @@ -47,6 +47,7 @@ def _create_plotly_analysis_card( level: int, df: pd.DataFrame, fig: go.Figure, + category: int, ) -> PlotlyAnalysisCard: """ Make a PlotlyAnalysisCard from this Analysis using provided fields and @@ -60,4 +61,5 @@ def _create_plotly_analysis_card( level=level, df=df, blob=pio.to_json(fig), + category=category, ) diff --git a/ax/analysis/plotly/progression.py b/ax/analysis/plotly/progression.py index a5558dbb2c7..aea54514c13 100644 --- a/ax/analysis/plotly/progression.py +++ b/ax/analysis/plotly/progression.py @@ -8,7 +8,7 @@ import numpy as np import plotly.express as px -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.utils import select_metric @@ -133,6 +133,7 @@ def compute( level=AnalysisCardLevel.MID, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) diff --git a/ax/analysis/plotly/scatter.py b/ax/analysis/plotly/scatter.py index 2154bc6fa14..e88827f2da7 100644 --- a/ax/analysis/plotly/scatter.py +++ b/ax/analysis/plotly/scatter.py @@ -7,7 +7,7 @@ import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.core.experiment import Experiment @@ -75,6 +75,7 @@ def compute( level=AnalysisCardLevel.HIGH, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index 60bb32a7ede..bbd8aa19211 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -8,7 +8,7 @@ import math import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.surface.utils import ( @@ -106,6 +106,7 @@ def compute( level=AnalysisCardLevel.LOW, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index ae6958a58a6..220cb95acd4 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -8,7 +8,7 @@ import math import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.surface.utils import ( @@ -93,6 +93,7 @@ def compute( level=AnalysisCardLevel.LOW, df=df, fig=fig, + category=AnalysisCardCategory.INSIGHT, ) diff --git a/ax/analysis/plotly/surface/tests/test_contour.py b/ax/analysis/plotly/surface/tests/test_contour.py index 6deec31ae4a..8e8b2951ed7 100644 --- a/ax/analysis/plotly/surface/tests/test_contour.py +++ b/ax/analysis/plotly/surface/tests/test_contour.py @@ -5,7 +5,7 @@ # pyre-strict -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.surface.contour import ContourPlot from ax.exceptions.core import UserInputError from ax.service.ax_client import AxClient, ObjectiveProperties @@ -71,6 +71,7 @@ def test_compute(self) -> None: "2D contour of the surrogate model's predicted outcomes for bar", ) self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual( {*card.df.columns}, { diff --git a/ax/analysis/plotly/surface/tests/test_slice.py b/ax/analysis/plotly/surface/tests/test_slice.py index 557a7665c37..63349c115a4 100644 --- a/ax/analysis/plotly/surface/tests/test_slice.py +++ b/ax/analysis/plotly/surface/tests/test_slice.py @@ -5,7 +5,7 @@ # pyre-strict -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.surface.slice import SlicePlot from ax.exceptions.core import UserInputError from ax.service.ax_client import AxClient, ObjectiveProperties @@ -61,6 +61,7 @@ def test_compute(self) -> None: "1D slice of the surrogate model's predicted outcomes for bar", ) self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual( {*card.df.columns}, { diff --git a/ax/analysis/plotly/tests/test_cross_validation.py b/ax/analysis/plotly/tests/test_cross_validation.py index 8cb588466b7..94bf11ca749 100644 --- a/ax/analysis/plotly/tests/test_cross_validation.py +++ b/ax/analysis/plotly/tests/test_cross_validation.py @@ -5,7 +5,7 @@ # pyre-strict -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.cross_validation import CrossValidationPlot from ax.core.trial import Trial from ax.exceptions.core import UserInputError @@ -57,6 +57,7 @@ def test_compute(self) -> None: "Out-of-sample predictions using leave-one-out CV", ) self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual( {*card.df.columns}, {"arm_name", "observed", "observed_sem", "predicted", "predicted_sem"}, diff --git a/ax/analysis/plotly/tests/test_insample_effects.py b/ax/analysis/plotly/tests/test_insample_effects.py index 1bb648d008d..1703fb36608 100644 --- a/ax/analysis/plotly/tests/test_insample_effects.py +++ b/ax/analysis/plotly/tests/test_insample_effects.py @@ -9,7 +9,7 @@ import torch -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.arm_effects.insample_effects import InSampleEffectsPlot from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm from ax.exceptions.core import DataRequiredError, UserInputError @@ -134,6 +134,7 @@ def test_compute_modeled_can_use_ebts_for_gs_with_non_predictive_model( ) # +2 because it's on objective, +1 because it's modeled self.assertEqual(card.level, AnalysisCardLevel.MID + 3) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) def test_compute_modeled_can_use_ebts_for_no_gs(self) -> None: # GIVEN an experiment with a trial with data diff --git a/ax/analysis/plotly/tests/test_interaction.py b/ax/analysis/plotly/tests/test_interaction.py index 8314ae362b6..5903bffab2e 100644 --- a/ax/analysis/plotly/tests/test_interaction.py +++ b/ax/analysis/plotly/tests/test_interaction.py @@ -5,7 +5,7 @@ # pyre-strict -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.interaction import InteractionPlot from ax.exceptions.core import UserInputError from ax.service.ax_client import AxClient, ObjectiveProperties @@ -72,6 +72,7 @@ def test_compute(self) -> None: "slice or contour plots", ) self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual( {*card.df.columns}, {"feature", "sensitivity"}, diff --git a/ax/analysis/plotly/tests/test_parallel_coordinates.py b/ax/analysis/plotly/tests/test_parallel_coordinates.py index 1ecb178cb62..cae2ada329a 100644 --- a/ax/analysis/plotly/tests/test_parallel_coordinates.py +++ b/ax/analysis/plotly/tests/test_parallel_coordinates.py @@ -6,7 +6,7 @@ # pyre-strict import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.parallel_coordinates import ( _get_parameter_dimension, ParallelCoordinatesPlot, @@ -37,6 +37,7 @@ def test_compute(self) -> None: "View arm parameterizations with their respective metric values", ) self.assertEqual(card.level, AnalysisCardLevel.HIGH) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual({*card.df.columns}, {"arm_name", "branin", "x1", "x2"}) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "plotly") diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index da402240e67..83cb20ffbe8 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -9,7 +9,7 @@ import torch -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.arm_effects.predicted_effects import PredictedEffectsPlot from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm from ax.core.observation import ObservationFeatures @@ -126,6 +126,7 @@ def test_compute(self) -> None: ) ), ) + self.assertEqual(card.category, AnalysisCardCategory.ACTIONABLE) # AND THEN it has the right rows and columns in the dataframe self.assertEqual( {*card.df.columns}, diff --git a/ax/analysis/plotly/tests/test_progression.py b/ax/analysis/plotly/tests/test_progression.py index 9d110326019..3349d94b201 100644 --- a/ax/analysis/plotly/tests/test_progression.py +++ b/ax/analysis/plotly/tests/test_progression.py @@ -6,7 +6,7 @@ # pyre-strict import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.progression import ( _calculate_wallclock_timeseries, ProgressionPlot, @@ -36,6 +36,7 @@ def test_compute(self) -> None: "Observe how the metric changes as each trial progresses", ) self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual( {*card.df.columns}, {"trial_index", "arm_name", "branin_map", "progression", "wallclock_time"}, diff --git a/ax/analysis/plotly/tests/test_scatter.py b/ax/analysis/plotly/tests/test_scatter.py index e1f86659aa1..bdf2482779b 100644 --- a/ax/analysis/plotly/tests/test_scatter.py +++ b/ax/analysis/plotly/tests/test_scatter.py @@ -5,7 +5,7 @@ # pyre-strict -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.plotly.scatter import _prepare_data, ScatterPlot from ax.exceptions.core import DataRequiredError, UserInputError from ax.modelbridge.registry import Generators @@ -38,6 +38,7 @@ def test_compute(self) -> None: "Compare arms by their observed metric values", ) self.assertEqual(card.level, AnalysisCardLevel.HIGH) + self.assertEqual(card.category, AnalysisCardCategory.INSIGHT) self.assertEqual( {*card.df.columns}, {"arm_name", "trial_index", "branin_a", "branin_b", "is_optimal"}, diff --git a/ax/analysis/search_space_summary.py b/ax/analysis/search_space_summary.py index 66df00fdaaa..1a3a5917e1f 100644 --- a/ax/analysis/search_space_summary.py +++ b/ax/analysis/search_space_summary.py @@ -5,7 +5,12 @@ # pyre-strict -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel +from ax.analysis.analysis import ( + Analysis, + AnalysisCard, + AnalysisCardCategory, + AnalysisCardLevel, +) from ax.core.experiment import Experiment from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy @@ -41,4 +46,5 @@ def compute( subtitle="High-level summary of the `Parameter`-s in this `Experiment`", level=AnalysisCardLevel.MID, df=experiment.search_space.summary_df, + category=AnalysisCardCategory.INFO, ) diff --git a/ax/analysis/summary.py b/ax/analysis/summary.py index da0958e47b0..d1641e2b0e2 100644 --- a/ax/analysis/summary.py +++ b/ax/analysis/summary.py @@ -5,7 +5,12 @@ # pyre-strict -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel +from ax.analysis.analysis import ( + Analysis, + AnalysisCard, + AnalysisCardCategory, + AnalysisCardLevel, +) from ax.core.experiment import Experiment from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy @@ -45,4 +50,5 @@ def compute( subtitle="High-level summary of the `Trial`-s in this `Experiment`", level=AnalysisCardLevel.MID, df=experiment.to_df(omit_empty_columns=self.omit_empty_columns), + category=AnalysisCardCategory.INFO, ) diff --git a/ax/analysis/tests/test_metric_summary.py b/ax/analysis/tests/test_metric_summary.py index ef461e53b33..9412a8b7dec 100644 --- a/ax/analysis/tests/test_metric_summary.py +++ b/ax/analysis/tests/test_metric_summary.py @@ -6,7 +6,7 @@ # pyre-strict import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.metric_summary import MetricSummary from ax.core.metric import Metric from ax.exceptions.core import UserInputError @@ -48,6 +48,7 @@ def test_compute(self) -> None: "High-level summary of the `Metric`-s in this `Experiment`", ) self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertEqual(card.category, AnalysisCardCategory.INFO) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "dataframe") diff --git a/ax/analysis/tests/test_search_space_summary.py b/ax/analysis/tests/test_search_space_summary.py index 10163dbcb65..415f94951b2 100644 --- a/ax/analysis/tests/test_search_space_summary.py +++ b/ax/analysis/tests/test_search_space_summary.py @@ -6,7 +6,7 @@ # pyre-strict import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.search_space_summary import SearchSpaceSummary from ax.exceptions.core import UserInputError from ax.preview.api.client import Client @@ -58,6 +58,7 @@ def test_compute(self) -> None: "High-level summary of the `Parameter`-s in this `Experiment`", ) self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertEqual(card.category, AnalysisCardCategory.INFO) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "dataframe") diff --git a/ax/analysis/tests/test_summary.py b/ax/analysis/tests/test_summary.py index 119b67858df..26ef44c16f6 100644 --- a/ax/analysis/tests/test_summary.py +++ b/ax/analysis/tests/test_summary.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.summary import Summary from ax.core.trial import Trial from ax.exceptions.core import UserInputError @@ -60,6 +60,7 @@ def test_compute(self) -> None: "High-level summary of the `Trial`-s in this `Experiment`", ) self.assertEqual(card.level, AnalysisCardLevel.MID) + self.assertEqual(card.category, AnalysisCardCategory.INFO) self.assertIsNotNone(card.blob) self.assertEqual(card.blob_annotation, "dataframe") diff --git a/ax/service/utils/analysis_base.py b/ax/service/utils/analysis_base.py index fe7204a1d19..aa7b32201d2 100644 --- a/ax/service/utils/analysis_base.py +++ b/ax/service/utils/analysis_base.py @@ -9,7 +9,13 @@ import pandas as pd -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE +from ax.analysis.analysis import ( + Analysis, + AnalysisCard, + AnalysisCardCategory, + AnalysisCardLevel, + AnalysisE, +) from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.experiment import Experiment @@ -89,6 +95,7 @@ def compute_analyses( blob=traceback_str, df=pd.DataFrame(), level=AnalysisCardLevel.DEBUG, + category=AnalysisCardCategory.ERROR, ) ) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index bad3318389c..162217dae4a 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -1061,6 +1061,7 @@ def analysis_card_from_sqa( if analysis_card_sqa.attributes == "" else json.loads(analysis_card_sqa.attributes) ), + category=analysis_card_sqa.category, ) card.db_id = analysis_card_sqa.id return card diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index ca39a73dea0..7bdaccdccb0 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -1103,4 +1103,5 @@ def analysis_card_to_sqa( time_created=timestamp, experiment_id=experiment_id, attributes=json.dumps(analysis_card.attributes), + category=analysis_card.category, ) diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index da5ec6a2ef4..90ec61cc332 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -351,6 +351,7 @@ class SQAAnalysisCard(Base): Integer, ForeignKey("experiment_v2.id"), nullable=False ) attributes: Column[str] = Column(Text(LONGTEXT_BYTES), nullable=False) + category: Column[int] = Column(Integer, nullable=False) class SQAExperiment(Base): diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 39765bd8719..0f5b8350500 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -17,7 +17,7 @@ from unittest.mock import MagicMock, Mock, patch import pandas as pd -from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel +from ax.analysis.analysis import AnalysisCard, AnalysisCardCategory, AnalysisCardLevel from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.core.arm import Arm @@ -2209,6 +2209,7 @@ def test_AnalysisCard(self) -> None: df=test_df, blob="test blob", attributes={"foo": "bar"}, + category=AnalysisCardCategory.DIAGNOSTIC, ) markdown_analysis_card = MarkdownAnalysisCard( name="test_markdown_analysis_card", @@ -2218,6 +2219,7 @@ def test_AnalysisCard(self) -> None: df=test_df, blob="This is some **really cool** markdown", attributes={"foo": "baz"}, + category=AnalysisCardCategory.DIAGNOSTIC, ) plotly_analysis_card = PlotlyAnalysisCard( name="test_plotly_analysis_card", @@ -2227,6 +2229,7 @@ def test_AnalysisCard(self) -> None: df=test_df, blob=pio.to_json(go.Figure()), attributes={"foo": "bad"}, + category=AnalysisCardCategory.DIAGNOSTIC, ) with self.subTest("test_save_analysis_cards"): save_experiment(self.experiment)