Skip to content

Commit

Permalink
Add generation node string to base trial (facebook#3355)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3355

This will help expose the methods used to generate trials to the user, e.g., via experiment.to_df, matching legacy behavior in exp_to_df.

Differential Revision: D68909600

Reviewed By: saitcakmak
  • Loading branch information
bernardbeckerman committed Feb 13, 2025
1 parent f7227a2 commit af616f0
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
44 changes: 43 additions & 1 deletion ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.formatting_utils import data_and_evaluations_from_raw_data
from ax.core.generator_run import GeneratorRun
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchResult
Expand All @@ -26,13 +26,18 @@
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import SortableBase
from ax.utils.common.constants import Keys
from pyre_extensions import none_throws


if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401

MANUAL_GENERATION_METHOD_STR = "Manual"
UNKNOWN_GENERATION_METHOD_STR = "Unknown"
STATUS_QUO_GENERATION_METHOD_STR = "Status Quo"


def immutable_once_run(func: Callable) -> Callable:
"""Decorator for methods that should throw Error when
Expand Down Expand Up @@ -659,6 +664,43 @@ def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTr
"Use `trial.mark_abandoned` if applicable."
)

@property
def generation_method_str(self) -> str:
"""Returns the generation method(s) used to generate this trial's arms,
as a human-readable string (e.g. 'Sobol', 'BoTorch', 'Manual', etc.).
Returns a comma-delimited string if multiple generation methods were used.
"""
# Use model key provided during warm-starting if present, since the
# generator run may not be present on warm-started trials.
if (
warm_start_model_key := self._properties.get(Keys.WARMSTART_TRIAL_MODEL_KEY)
) is not None:
return warm_start_model_key

generation_methods = {
none_throws(generator_run._model_key)
for generator_run in self.generator_runs
if generator_run._model_key is not None
}

# Add generator-run-type strings for non-ModelBridge generator runs.
gr_type_name_to_str = {
GeneratorRunType.MANUAL.name: MANUAL_GENERATION_METHOD_STR,
GeneratorRunType.STATUS_QUO.name: STATUS_QUO_GENERATION_METHOD_STR,
}
generation_methods |= {
gr_type_name_to_str[generator_run.generator_run_type]
for generator_run in self.generator_runs
if generator_run.generator_run_type in gr_type_name_to_str
}

return (
# Sort for deterministic output
", ".join(sorted(generation_methods))
if generation_methods
else UNKNOWN_GENERATION_METHOD_STR
)

def _mark_failed_if_past_TTL(self) -> None:
"""If trial has TTL set and is running, check if the TTL has elapsed
and mark the trial failed if so.
Expand Down
20 changes: 17 additions & 3 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@

import numpy as np
from ax.core.arm import Arm
from ax.core.base_trial import (
MANUAL_GENERATION_METHOD_STR,
TrialStatus,
UNKNOWN_GENERATION_METHOD_STR,
)
from ax.core.batch_trial import BatchTrial, GeneratorRunStruct
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.parameter import FixedParameter, ParameterType
from ax.core.search_space import SearchSpace
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.runners.synthetic import SyntheticRunner
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -76,9 +80,12 @@ def test_BasicProperties(self) -> None:
self.batch.generator_run_structs[0].generator_run.generator_run_type,
GeneratorRunType.MANUAL.name,
)
self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR)

# Test empty arms
self.assertEqual(len(self.experiment.new_batch_trial().abandoned_arms), 0)
# Test empty trial
t = self.experiment.new_batch_trial()
self.assertEqual(len(t.abandoned_arms), 0)
self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR)

def test_UndefinedSetters(self) -> None:
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -480,6 +487,9 @@ def test_clone_to(self, _) -> None:
# test cloning with clear_trial_type=True
new_batch_trial = batch.clone_to(clear_trial_type=True)
self.assertIsNone(new_batch_trial.trial_type)
self.assertEqual(
new_batch_trial.generation_method_str, MANUAL_GENERATION_METHOD_STR
)

def test_Runner(self) -> None:
# Verify BatchTrial without runner will fail
Expand Down Expand Up @@ -667,9 +677,12 @@ def test_TTL(self) -> None:
self.assertIn(2, self.experiment.trial_indices_by_status[TrialStatus.FAILED])

def test_get_candidate_metadata_from_all_generator_runs(self) -> None:
self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR)
gr_1 = get_generator_run()
gr_2 = get_generator_run2()
self.batch.add_generator_run(gr_1)
self.assertEqual(self.batch.generation_method_str, "Manual, Sobol")

# Arms are named when adding GR to trial, so reassign to have a GR that has
# names arms.
gr_1 = self.batch._generator_run_structs[-1].generator_run
Expand Down Expand Up @@ -719,6 +732,7 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None:
cand_metadata_expected[arm.name],
self.batch._get_candidate_metadata(arm.name),
)
self.assertEqual(self.batch.generation_method_str, "Manual, Sobol")

def test_Sortable(self) -> None:
new_batch_trial = self.experiment.new_batch_trial()
Expand Down
12 changes: 10 additions & 2 deletions ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from unittest.mock import Mock, patch

import pandas as pd
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.base_trial import (
BaseTrial,
MANUAL_GENERATION_METHOD_STR,
TrialStatus,
UNKNOWN_GENERATION_METHOD_STR,
)
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.runner import Runner
Expand Down Expand Up @@ -87,10 +92,13 @@ def test_basic_properties(self) -> None:
self.assertEqual(
self.trial.generator_run.generator_run_type, GeneratorRunType.MANUAL.name
)
self.assertEqual(self.trial.generation_method_str, MANUAL_GENERATION_METHOD_STR)

# Test empty arms
t = self.experiment.new_trial()
with self.assertRaises(AttributeError):
self.experiment.new_trial().arm_weights
t.arm_weights
self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR)

self.trial.mark_running(no_runner_required=True)
self.assertTrue(self.trial.status.is_running)
Expand Down
1 change: 1 addition & 0 deletions ax/utils/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,5 @@ class Keys(str, Enum):
TASK_FEATURES = "task_features"
TRIAL_COMPLETION_TIMESTAMP = "trial_completion_timestamp"
WARM_START_REFITTING = "warm_start_refitting"
WARMSTART_TRIAL_MODEL_KEY = "generation_model_key"
X_BASELINE = "X_baseline"

0 comments on commit af616f0

Please sign in to comment.