Skip to content

Commit

Permalink
Add category to AnalysisCards
Browse files Browse the repository at this point in the history
Summary:
We've had a few instances come up in meetings last week that I think will make having a category useful:
1. Allow us to group similar analysis together in the UI so everything isn't shown in one tab
2. Allow for easy grouping of analysis in bento notebooks

Additionally, as I was implementing this, there is a TODO on base_client to find a good hueristic for which analysis to show, I think implementing that hueristic will be much easier with category.

I think this is useful enough to warrant the addition because it is distinct from level- which indicates importance - and category should still be ranked by level.

Reviewed By: mpolson64, Cesar-Cardoso

Differential Revision: D69726341
  • Loading branch information
mgarrard authored and facebook-github-bot committed Feb 24, 2025
1 parent e05db76 commit 4a5e0bb
Show file tree
Hide file tree
Showing 42 changed files with 131 additions and 37 deletions.
17 changes: 17 additions & 0 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class AnalysisCardLevel(IntEnum):
CRITICAL = 40


class AnalysisCardCategory(IntEnum):
ERROR = 0
ACTIONABLE = 1
INSIGHT = 2
DIAGNOSTIC = 3 # Equivalent to "health check" in online setting
INFO = 4


class AnalysisCard(Base):
# Name of the analysis computed, usually the class name of the Analysis which
# produced the card. Useful for grouping by when querying a large collection of
Expand All @@ -44,6 +52,8 @@ class AnalysisCard(Base):
title: str
subtitle: str

# Level of the card with respect to its importance. Higher levels are more
# important, and will be displayed first.
level: int

df: pd.DataFrame # Raw data produced by the Analysis
Expand All @@ -53,6 +63,9 @@ class AnalysisCard(Base):
# the blob and presenting it to the user (ex. PlotlyAnalysisCard.get_figure()
# decodes the blob into a go.Figure object).
blob: str
# Type of the card (ex: "insight", "diagnostic"), useful for
# grouping the cards to display only one category in notebook environments.
category: int
# How to interpret the blob (ex. "dataframe", "plotly", "markdown")
blob_annotation = "dataframe"

Expand All @@ -64,6 +77,7 @@ def __init__(
level: int,
df: pd.DataFrame,
blob: str,
category: int,
attributes: dict[str, Any] | None = None,
) -> None:
self.name = name
Expand All @@ -73,6 +87,7 @@ def __init__(
self.df = df
self.blob = blob
self.attributes = {} if attributes is None else attributes
self.category = category

def _ipython_display_(self) -> None:
"""
Expand Down Expand Up @@ -162,6 +177,7 @@ def _create_analysis_card(
subtitle: str,
level: int,
df: pd.DataFrame,
category: int,
) -> AnalysisCard:
"""
Make an AnalysisCard from this Analysis using provided fields and
Expand All @@ -175,6 +191,7 @@ def _create_analysis_card(
level=level,
df=df,
blob=df.to_json(),
category=category,
)

@property
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/can_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
Expand Down Expand Up @@ -94,4 +94,5 @@ def compute(
}
),
level=level,
category=AnalysisCardCategory.DIAGNOSTIC,
)
6 changes: 5 additions & 1 deletion ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
Expand Down Expand Up @@ -68,6 +68,7 @@ def compute(
title_status = "Success"
level = AnalysisCardLevel.LOW
df = pd.DataFrame({"status": [status]})
category = AnalysisCardCategory.DIAGNOSTIC

if experiment is None:
raise UserInputError(
Expand All @@ -83,6 +84,7 @@ def compute(
subtitle=subtitle,
df=df,
level=level,
category=category,
)

if (
Expand All @@ -97,6 +99,7 @@ def compute(
subtitle=subtitle,
df=df,
level=level,
category=category,
)

if generation_strategy is None:
Expand Down Expand Up @@ -148,6 +151,7 @@ def compute(
subtitle=subtitle,
df=df,
level=level,
category=category,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/regression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
Expand Down Expand Up @@ -105,6 +105,7 @@ def compute(
subtitle=subtitle,
df=df,
level=AnalysisCardLevel.LOW,
category=AnalysisCardCategory.DIAGNOSTIC,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
Expand Down Expand Up @@ -104,6 +104,7 @@ def compute(
df=df,
level=level,
attributes={"trial_index": self.trial_index},
category=AnalysisCardCategory.DIAGNOSTIC,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/should_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
Expand Down Expand Up @@ -57,4 +57,5 @@ def compute(
),
level=AnalysisCardLevel.CRITICAL,
attributes=self.attributes,
category=AnalysisCardCategory.DIAGNOSTIC,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timedelta

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.can_generate_candidates import (
CanGenerateCandidatesAnalysis,
)
Expand All @@ -34,6 +34,7 @@ def test_passes_if_can_generate(self) -> None:
self.assertEqual(card.title, "Ax Candidate Generation Success")
self.assertEqual(card.subtitle, "No problems found.")
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
pdt.assert_frame_equal(
card.df,
pd.DataFrame(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.constraints_feasibility import (
constraints_feasibility,
ConstraintsFeasibilityAnalysis,
Expand Down Expand Up @@ -169,6 +169,7 @@ def test_compute(self) -> None:
self.assertEqual(card.name, "ConstraintsFeasibility")
self.assertEqual(card.title, "Ax Constraints Feasibility Success")
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.subtitle, "All constraints are feasible.")

df_metric_d = pd.DataFrame(
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/tests/test_regression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis
from ax.core.data import Data
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -42,6 +42,7 @@ def test_regression_analysis(self) -> None:
and "Trial 0" in card.subtitle
)
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)

df = pd.DataFrame(
{
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/tests/test_search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.search_space_analysis import (
boundary_proportions_message,
search_space_boundary_proportions,
Expand Down Expand Up @@ -35,6 +35,7 @@ def test_search_space_analysis(self) -> None:
card = ssa.compute(experiment=experiment)

self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.name, "SearchSpaceAnalysis")
self.assertEqual(card.title, "Ax Search Space Analysis Warning")
print(card.subtitle)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from random import randint

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
from ax.utils.common.testutils import TestCase
Expand All @@ -23,6 +23,7 @@ def test_should(self) -> None:
).compute()
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
self.assertEqual(card.level, AnalysisCardLevel.CRITICAL)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.subtitle, "Something reassuring")
self.assertEqual(card.attributes["trial_index"], trial_index)

Expand All @@ -35,5 +36,6 @@ def test_should_not(self) -> None:
).compute()
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
self.assertEqual(card.level, AnalysisCardLevel.CRITICAL)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.subtitle, "Something concerning")
self.assertEqual(card.attributes["trial_index"], trial_index)
11 changes: 10 additions & 1 deletion ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import traceback

import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.analysis.analysis import (
Analysis,
AnalysisCard,
AnalysisCardCategory,
AnalysisCardLevel,
AnalysisE,
)
from ax.core.experiment import Experiment
from ax.generation_strategy.generation_strategy import GenerationStrategy
from IPython.display import display, Markdown
Expand Down Expand Up @@ -47,6 +53,7 @@ def _create_markdown_analysis_card(
level: int,
df: pd.DataFrame,
message: str,
category: int,
) -> MarkdownAnalysisCard:
"""
Make a MarkdownAnalysisCard from this Analysis using provided fields and
Expand All @@ -60,6 +67,7 @@ def _create_markdown_analysis_card(
level=level,
df=df,
blob=message,
category=category,
)


Expand All @@ -80,4 +88,5 @@ def markdown_analysis_card_from_analysis_e(
),
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
category=AnalysisCardCategory.ERROR,
)
8 changes: 7 additions & 1 deletion ax/analysis/metric_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

# pyre-strict

from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel
from ax.analysis.analysis import (
Analysis,
AnalysisCard,
AnalysisCardCategory,
AnalysisCardLevel,
)
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -40,4 +45,5 @@ def compute(
subtitle="High-level summary of the `Metric`-s in this `Experiment`",
level=AnalysisCardLevel.MID,
df=experiment.metric_config_summary_df,
category=AnalysisCardCategory.INFO,
)
3 changes: 2 additions & 1 deletion ax/analysis/plotly/arm_effects/insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import Logger

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.plotly.arm_effects.utils import (
get_predictions_by_arm,
prepare_arm_effects_plot,
Expand Down Expand Up @@ -144,6 +144,7 @@ def compute(
level=level + nudge,
df=df,
fig=fig,
category=AnalysisCardCategory.INSIGHT,
)
return card

Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/arm_effects/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.plotly.arm_effects.utils import (
get_predictions_by_arm,
Expand Down Expand Up @@ -144,6 +144,7 @@ def compute(
level=level + nudge,
df=df,
fig=fig,
category=AnalysisCardCategory.ACTIONABLE,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
Expand Down Expand Up @@ -124,6 +124,7 @@ def compute(
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
fig=fig,
category=AnalysisCardCategory.INSIGHT,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pandas as pd
import torch
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard

Expand Down Expand Up @@ -259,6 +259,7 @@ def compute(
level=AnalysisCardLevel.MID,
df=sensitivity_df,
fig=fig,
category=AnalysisCardCategory.INSIGHT,
)

def _get_oak_model(self, experiment: Experiment, metric_name: str) -> TorchAdapter:
Expand Down
Loading

0 comments on commit 4a5e0bb

Please sign in to comment.