Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generation node string to base trial #3355

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.formatting_utils import data_and_evaluations_from_raw_data
from ax.core.generator_run import GeneratorRun
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchResult
Expand All @@ -26,13 +26,18 @@
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import SortableBase
from ax.utils.common.constants import Keys
from pyre_extensions import none_throws


if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401

MANUAL_GENERATION_METHOD_STR = "Manual"
UNKNOWN_GENERATION_METHOD_STR = "Unknown"
STATUS_QUO_GENERATION_METHOD_STR = "Status Quo"


def immutable_once_run(func: Callable) -> Callable:
"""Decorator for methods that should throw Error when
Expand Down Expand Up @@ -659,6 +664,43 @@ def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTr
"Use `trial.mark_abandoned` if applicable."
)

@property
def generation_method_str(self) -> str:
"""Returns the generation method(s) used to generate this trial's arms,
as a human-readable string (e.g. 'Sobol', 'BoTorch', 'Manual', etc.).
Returns a comma-delimited string if multiple generation methods were used.
"""
# Use model key provided during warm-starting if present, since the
# generator run may not be present on warm-started trials.
if (
warm_start_model_key := self._properties.get(Keys.WARMSTART_TRIAL_MODEL_KEY)
) is not None:
return warm_start_model_key

generation_methods = {
none_throws(generator_run._model_key)
for generator_run in self.generator_runs
if generator_run._model_key is not None
}

# Add generator-run-type strings for non-ModelBridge generator runs.
gr_type_name_to_str = {
GeneratorRunType.MANUAL.name: MANUAL_GENERATION_METHOD_STR,
GeneratorRunType.STATUS_QUO.name: STATUS_QUO_GENERATION_METHOD_STR,
}
generation_methods |= {
gr_type_name_to_str[generator_run.generator_run_type]
for generator_run in self.generator_runs
if generator_run.generator_run_type in gr_type_name_to_str
}

return (
# Sort for deterministic output
", ".join(sorted(generation_methods))
if generation_methods
else UNKNOWN_GENERATION_METHOD_STR
)

def _mark_failed_if_past_TTL(self) -> None:
"""If trial has TTL set and is running, check if the TTL has elapsed
and mark the trial failed if so.
Expand Down
20 changes: 17 additions & 3 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@

import numpy as np
from ax.core.arm import Arm
from ax.core.base_trial import (
MANUAL_GENERATION_METHOD_STR,
TrialStatus,
UNKNOWN_GENERATION_METHOD_STR,
)
from ax.core.batch_trial import BatchTrial, GeneratorRunStruct
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.parameter import FixedParameter, ParameterType
from ax.core.search_space import SearchSpace
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.runners.synthetic import SyntheticRunner
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -76,9 +80,12 @@ def test_BasicProperties(self) -> None:
self.batch.generator_run_structs[0].generator_run.generator_run_type,
GeneratorRunType.MANUAL.name,
)
self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR)

# Test empty arms
self.assertEqual(len(self.experiment.new_batch_trial().abandoned_arms), 0)
# Test empty trial
t = self.experiment.new_batch_trial()
self.assertEqual(len(t.abandoned_arms), 0)
self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR)

def test_UndefinedSetters(self) -> None:
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -480,6 +487,9 @@ def test_clone_to(self, _) -> None:
# test cloning with clear_trial_type=True
new_batch_trial = batch.clone_to(clear_trial_type=True)
self.assertIsNone(new_batch_trial.trial_type)
self.assertEqual(
new_batch_trial.generation_method_str, MANUAL_GENERATION_METHOD_STR
)

def test_Runner(self) -> None:
# Verify BatchTrial without runner will fail
Expand Down Expand Up @@ -667,9 +677,12 @@ def test_TTL(self) -> None:
self.assertIn(2, self.experiment.trial_indices_by_status[TrialStatus.FAILED])

def test_get_candidate_metadata_from_all_generator_runs(self) -> None:
self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR)
gr_1 = get_generator_run()
gr_2 = get_generator_run2()
self.batch.add_generator_run(gr_1)
self.assertEqual(self.batch.generation_method_str, "Manual, Sobol")

# Arms are named when adding GR to trial, so reassign to have a GR that has
# names arms.
gr_1 = self.batch._generator_run_structs[-1].generator_run
Expand Down Expand Up @@ -719,6 +732,7 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None:
cand_metadata_expected[arm.name],
self.batch._get_candidate_metadata(arm.name),
)
self.assertEqual(self.batch.generation_method_str, "Manual, Sobol")

def test_Sortable(self) -> None:
new_batch_trial = self.experiment.new_batch_trial()
Expand Down
12 changes: 10 additions & 2 deletions ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from unittest.mock import Mock, patch

import pandas as pd
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.base_trial import (
BaseTrial,
MANUAL_GENERATION_METHOD_STR,
TrialStatus,
UNKNOWN_GENERATION_METHOD_STR,
)
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.runner import Runner
Expand Down Expand Up @@ -87,10 +92,13 @@ def test_basic_properties(self) -> None:
self.assertEqual(
self.trial.generator_run.generator_run_type, GeneratorRunType.MANUAL.name
)
self.assertEqual(self.trial.generation_method_str, MANUAL_GENERATION_METHOD_STR)

# Test empty arms
t = self.experiment.new_trial()
with self.assertRaises(AttributeError):
self.experiment.new_trial().arm_weights
t.arm_weights
self.assertEqual(t.generation_method_str, UNKNOWN_GENERATION_METHOD_STR)

self.trial.mark_running(no_runner_required=True)
self.assertTrue(self.trial.status.is_running)
Expand Down
1 change: 1 addition & 0 deletions ax/utils/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,5 @@ class Keys(str, Enum):
TASK_FEATURES = "task_features"
TRIAL_COMPLETION_TIMESTAMP = "trial_completion_timestamp"
WARM_START_REFITTING = "warm_start_refitting"
WARMSTART_TRIAL_MODEL_KEY = "generation_model_key"
X_BASELINE = "X_baseline"