Skip to content

Commit

Permalink
Finalize repr method for GenerationNodes (#2092)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2092

This diff does the following:
Add transition criterion, gen_unlimited_trials to the repr string for GenerationNode. This will be the final repr method for GenerationNode, plz lmk if there are other fields that y'all want represented here.

In coming diffs:
(3) final pass on all the doc strings and variables -- lots to clean up here
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed
(6) rename transiton criterion to action criterion
(7) remove conditionals for legacy usecase
( clean up any lingering todos

Reviewed By: saitcakmak

Differential Revision: D52267866

fbshipit-source-id: 476f86c3ad0a686ba8ebee28f495981aa9cb9b96
  • Loading branch information
mgarrard authored and facebook-github-bot committed Dec 19, 2023
1 parent 3f05d8e commit aea3669
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
31 changes: 17 additions & 14 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,6 @@ def fitted_model(self) -> ModelBridge:
"""fitted_model from self.model_spec_to_gen_from for convenience"""
return self.model_spec_to_gen_from.fitted_model

@property
def _fitted_model(self) -> Optional[ModelBridge]:
"""Private property to return optional fitted_model from
self.model_spec_to_gen_from for convenience. If no model is fit,
will return None. If using the non-private `fitted_model` property,
and no model is fit, a UserInput error will be raised.
"""
return self.model_spec_to_gen_from._fitted_model

@property
def fixed_features(self) -> Optional[ObservationFeatures]:
"""fixed_features from self.model_spec_to_gen_from for convenience"""
Expand Down Expand Up @@ -242,6 +233,15 @@ def _unique_id(self) -> str:
"""Returns a unique id for this GenerationNode"""
return self.node_name

@property
def _fitted_model(self) -> Optional[ModelBridge]:
"""Private property to return optional fitted_model from
self.model_spec_to_gen_from for convenience. If no model is fit,
will return None. If using the non-private `fitted_model` property,
and no model is fit, a UserInput error will be raised.
"""
return self.model_spec_to_gen_from._fitted_model

def fit(
self,
experiment: Experiment,
Expand Down Expand Up @@ -546,13 +546,16 @@ def generator_run_limit(self, supress_generation_errors: bool = True) -> int:
def __repr__(self) -> str:
"String representation of this GenerationNode"
# add model specs
repr = f"{self.__class__.__name__}(model_specs="
str_rep = f"{self.__class__.__name__}(model_specs="
model_spec_str = str(self.model_specs).replace("\n", " ").replace("\t", "")
repr += model_spec_str
str_rep += model_spec_str

# add node name, gen_unlimited_trials, and transition_criteria
str_rep += f", node_name={self.node_name}"
str_rep += f", gen_unlimited_trials={str(self.gen_unlimited_trials)}"
str_rep += f", transition_criteria={str(self.transition_criteria)}"

# add node name
repr += f", node_name={self.node_name}"
return f"{repr})"
return f"{str_rep})"


@dataclass
Expand Down
14 changes: 13 additions & 1 deletion ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from logging import Logger
from unittest.mock import patch, PropertyMock

from ax.core.base_trial import TrialStatus

from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.modelbridge.cross_validation import (
Expand All @@ -19,6 +21,7 @@
from ax.modelbridge.generation_node import GenerationNode, GenerationStep
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
Expand Down Expand Up @@ -135,14 +138,23 @@ def test_node_string_representation(self) -> None:
model_gen_kwargs={},
),
],
gen_unlimited_trials=False,
transition_criteria=[
MaxTrials(threshold=5, only_in_statuses=[TrialStatus.RUNNING])
],
)
string_rep = str(node)
self.assertEqual(
string_rep,
(
"GenerationNode(model_specs=[ModelSpec(model_enum=GPEI,"
" model_kwargs={}, model_gen_kwargs={}, model_cv_kwargs={},"
" )], node_name=test)"
" )], node_name=test, gen_unlimited_trials=False, "
"transition_criteria=[MaxTrials({'threshold': 5, "
"'only_in_statuses': [<TrialStatus.RUNNING: 4>], "
"'not_in_statuses': None, 'transition_to': None, "
"'block_transition_if_unmet': True, 'block_gen_if_met': False})]"
")"
),
)

Expand Down

0 comments on commit aea3669

Please sign in to comment.