Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add category to AnalysisCards #3414

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class AnalysisCardLevel(IntEnum):
CRITICAL = 40


class AnalysisCardCategory(IntEnum):
ERROR = 0
ACTIONABLE = 1
INSIGHT = 2
DIAGNOSTIC = 3 # Equivalent to "health check" in online setting
INFO = 4


class AnalysisCard(Base):
# Name of the analysis computed, usually the class name of the Analysis which
# produced the card. Useful for grouping by when querying a large collection of
Expand All @@ -44,6 +52,8 @@ class AnalysisCard(Base):
title: str
subtitle: str

# Level of the card with respect to its importance. Higher levels are more
# important, and will be displayed first.
level: int

df: pd.DataFrame # Raw data produced by the Analysis
Expand All @@ -53,6 +63,9 @@ class AnalysisCard(Base):
# the blob and presenting it to the user (ex. PlotlyAnalysisCard.get_figure()
# decodes the blob into a go.Figure object).
blob: str
# Type of the card (ex: "insight", "diagnostic"), useful for
# grouping the cards to display only one category in notebook environments.
category: int
# How to interpret the blob (ex. "dataframe", "plotly", "markdown")
blob_annotation = "dataframe"

Expand All @@ -64,6 +77,7 @@ def __init__(
level: int,
df: pd.DataFrame,
blob: str,
category: int,
attributes: dict[str, Any] | None = None,
) -> None:
self.name = name
Expand All @@ -73,6 +87,7 @@ def __init__(
self.df = df
self.blob = blob
self.attributes = {} if attributes is None else attributes
self.category = category

def _ipython_display_(self) -> None:
"""
Expand Down Expand Up @@ -162,6 +177,7 @@ def _create_analysis_card(
subtitle: str,
level: int,
df: pd.DataFrame,
category: int,
) -> AnalysisCard:
"""
Make an AnalysisCard from this Analysis using provided fields and
Expand All @@ -175,6 +191,7 @@ def _create_analysis_card(
level=level,
df=df,
blob=df.to_json(),
category=category,
)

@property
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/can_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
Expand Down Expand Up @@ -94,4 +94,5 @@ def compute(
}
),
level=level,
category=AnalysisCardCategory.DIAGNOSTIC,
)
6 changes: 5 additions & 1 deletion ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
Expand Down Expand Up @@ -68,6 +68,7 @@ def compute(
title_status = "Success"
level = AnalysisCardLevel.LOW
df = pd.DataFrame({"status": [status]})
category = AnalysisCardCategory.DIAGNOSTIC

if experiment is None:
raise UserInputError(
Expand All @@ -83,6 +84,7 @@ def compute(
subtitle=subtitle,
df=df,
level=level,
category=category,
)

if (
Expand All @@ -97,6 +99,7 @@ def compute(
subtitle=subtitle,
df=df,
level=level,
category=category,
)

if generation_strategy is None:
Expand Down Expand Up @@ -148,6 +151,7 @@ def compute(
subtitle=subtitle,
df=df,
level=level,
category=category,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/regression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
Expand Down Expand Up @@ -105,6 +105,7 @@ def compute(
subtitle=subtitle,
df=df,
level=AnalysisCardLevel.LOW,
category=AnalysisCardCategory.DIAGNOSTIC,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
Expand Down Expand Up @@ -104,6 +104,7 @@ def compute(
df=df,
level=level,
attributes={"trial_index": self.trial_index},
category=AnalysisCardCategory.DIAGNOSTIC,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/should_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
Expand Down Expand Up @@ -57,4 +57,5 @@ def compute(
),
level=AnalysisCardLevel.CRITICAL,
attributes=self.attributes,
category=AnalysisCardCategory.DIAGNOSTIC,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timedelta

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.can_generate_candidates import (
CanGenerateCandidatesAnalysis,
)
Expand All @@ -34,6 +34,7 @@ def test_passes_if_can_generate(self) -> None:
self.assertEqual(card.title, "Ax Candidate Generation Success")
self.assertEqual(card.subtitle, "No problems found.")
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
pdt.assert_frame_equal(
card.df,
pd.DataFrame(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.constraints_feasibility import (
constraints_feasibility,
ConstraintsFeasibilityAnalysis,
Expand Down Expand Up @@ -169,6 +169,7 @@ def test_compute(self) -> None:
self.assertEqual(card.name, "ConstraintsFeasibility")
self.assertEqual(card.title, "Ax Constraints Feasibility Success")
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.subtitle, "All constraints are feasible.")

df_metric_d = pd.DataFrame(
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/tests/test_regression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis
from ax.core.data import Data
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -42,6 +42,7 @@ def test_regression_analysis(self) -> None:
and "Trial 0" in card.subtitle
)
self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)

df = pd.DataFrame(
{
Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/healthcheck/tests/test_search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.search_space_analysis import (
boundary_proportions_message,
search_space_boundary_proportions,
Expand Down Expand Up @@ -35,6 +35,7 @@ def test_search_space_analysis(self) -> None:
card = ssa.compute(experiment=experiment)

self.assertEqual(card.level, AnalysisCardLevel.LOW)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.name, "SearchSpaceAnalysis")
self.assertEqual(card.title, "Ax Search Space Analysis Warning")
print(card.subtitle)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from random import randint

from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
from ax.utils.common.testutils import TestCase
Expand All @@ -23,6 +23,7 @@ def test_should(self) -> None:
).compute()
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
self.assertEqual(card.level, AnalysisCardLevel.CRITICAL)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.subtitle, "Something reassuring")
self.assertEqual(card.attributes["trial_index"], trial_index)

Expand All @@ -35,5 +36,6 @@ def test_should_not(self) -> None:
).compute()
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
self.assertEqual(card.level, AnalysisCardLevel.CRITICAL)
self.assertEqual(card.category, AnalysisCardCategory.DIAGNOSTIC)
self.assertEqual(card.subtitle, "Something concerning")
self.assertEqual(card.attributes["trial_index"], trial_index)
11 changes: 10 additions & 1 deletion ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import traceback

import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.analysis.analysis import (
Analysis,
AnalysisCard,
AnalysisCardCategory,
AnalysisCardLevel,
AnalysisE,
)
from ax.core.experiment import Experiment
from ax.generation_strategy.generation_strategy import GenerationStrategy
from IPython.display import display, Markdown
Expand Down Expand Up @@ -47,6 +53,7 @@ def _create_markdown_analysis_card(
level: int,
df: pd.DataFrame,
message: str,
category: int,
) -> MarkdownAnalysisCard:
"""
Make a MarkdownAnalysisCard from this Analysis using provided fields and
Expand All @@ -60,6 +67,7 @@ def _create_markdown_analysis_card(
level=level,
df=df,
blob=message,
category=category,
)


Expand All @@ -80,4 +88,5 @@ def markdown_analysis_card_from_analysis_e(
),
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
category=AnalysisCardCategory.ERROR,
)
8 changes: 7 additions & 1 deletion ax/analysis/metric_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

# pyre-strict

from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel
from ax.analysis.analysis import (
Analysis,
AnalysisCard,
AnalysisCardCategory,
AnalysisCardLevel,
)
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -40,4 +45,5 @@ def compute(
subtitle="High-level summary of the `Metric`-s in this `Experiment`",
level=AnalysisCardLevel.MID,
df=experiment.metric_config_summary_df,
category=AnalysisCardCategory.INFO,
)
3 changes: 2 additions & 1 deletion ax/analysis/plotly/arm_effects/insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import Logger

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.plotly.arm_effects.utils import (
get_predictions_by_arm,
prepare_arm_effects_plot,
Expand Down Expand Up @@ -144,6 +144,7 @@ def compute(
level=level + nudge,
df=df,
fig=fig,
category=AnalysisCardCategory.INSIGHT,
)
return card

Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/arm_effects/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.plotly.arm_effects.utils import (
get_predictions_by_arm,
Expand Down Expand Up @@ -144,6 +144,7 @@ def compute(
level=level + nudge,
df=df,
fig=fig,
category=AnalysisCardCategory.ACTIONABLE,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
Expand Down Expand Up @@ -124,6 +124,7 @@ def compute(
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
fig=fig,
category=AnalysisCardCategory.INSIGHT,
)


Expand Down
3 changes: 2 additions & 1 deletion ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pandas as pd
import torch
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard

Expand Down Expand Up @@ -259,6 +259,7 @@ def compute(
level=AnalysisCardLevel.MID,
df=sensitivity_df,
fig=fig,
category=AnalysisCardCategory.INSIGHT,
)

def _get_oak_model(self, experiment: Experiment, metric_name: str) -> TorchAdapter:
Expand Down
Loading