Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603537825
  • Loading branch information
georgedahl authored and copybara-github committed Feb 2, 2024
1 parent 325af9f commit d32e5a9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
18 changes: 11 additions & 7 deletions init2winit/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build_hparams(model_name,
hparam_file,
hparam_overrides,
input_pipeline_hps=None,
allow_unrecognized_hparams=False):
allowed_unrecognized_hparams=None):
"""Build experiment hyperparameters.
Args:
Expand All @@ -115,11 +115,12 @@ def build_hparams(model_name,
string encoding of this hyperparameter override dict. Note that this is
applied after the hyperparameter file overrides.
input_pipeline_hps: a dict of hyperparameters for performance tuning.
allow_unrecognized_hparams: if set, hparam overrides are allowed to
introduce new hparam keys that will most likely be ignored. Downgrading
unrecognized hparams from an error to a warning can be useful when trying
to tune using a shared search space over multiple workloads that don't all
support the same set of hyperparameters.
allowed_unrecognized_hparams: An optional list of hparam keys that hparam
overrides are allowed to introduce. There is no guaranteed these new
hparam keys will be ignored. Downgrading an explicit list of unrecognized
hparams from an error to a warning can be useful when trying to tune using
a shared search space over multiple workloads that don't all support the
same set of hyperparameters.
Returns:
A ConfigDict of experiment hyperparameters.
Expand Down Expand Up @@ -186,10 +187,13 @@ def build_hparams(model_name,
merged['optimizer'] != hparam_overrides['optimizer']):
merged['opt_hparams'] = {}
hparam_overrides = expand_dot_keys(hparam_overrides)
if allow_unrecognized_hparams:
if allowed_unrecognized_hparams:
new_keys = [k for k in hparam_overrides if k not in merged]
if new_keys:
logging.warning('Unrecognized top-level hparams: %s', new_keys)
if any(k not in allowed_unrecognized_hparams for k in new_keys):
raise ValueError(
f'Unrecognized top-level hparams not in allowlist: {new_keys}')
with merged.unlocked():
merged.update(hparam_overrides)
else:
Expand Down
13 changes: 7 additions & 6 deletions init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@
'hparam_file', None, 'Optional path to hparam json file for overriding '
'hyperparameters. Hyperparameters are loaded before '
'applying --hparam_overrides.')
flags.DEFINE_bool(
'allow_unrecognized_hparams', False,
'Downgrades unrecognized hparam override keys from an error to a warning.')
flags.DEFINE_list(
'allowed_unrecognized_hparams', [],
'Downgrades unrecognized hparam override keys from an error to a warning '
'for the supplied list of keys.')
flags.DEFINE_string(
'training_metrics_config', '',
'JSON representation of the training metrics config.')
Expand Down Expand Up @@ -189,7 +190,7 @@ def _run(
early_stopping_min_steps,
eval_steps,
hparam_file,
allow_unrecognized_hparams,
allowed_unrecognized_hparams,
hparam_overrides,
initializer_name,
model_name,
Expand Down Expand Up @@ -221,7 +222,7 @@ def _run(
hparam_file=hparam_file,
hparam_overrides=hparam_overrides,
input_pipeline_hps=input_pipeline_hps,
allow_unrecognized_hparams=allow_unrecognized_hparams)
allowed_unrecognized_hparams=allowed_unrecognized_hparams)

# Note that one should never tune an RNG seed!!! The seed is only included in
# the hparams for convenience of running hparam trials with multiple seeds per
Expand Down Expand Up @@ -349,7 +350,7 @@ def main(unused_argv):
early_stopping_min_steps=FLAGS.early_stopping_min_steps,
eval_steps=eval_steps,
hparam_file=FLAGS.hparam_file,
allow_unrecognized_hparams=FLAGS.allow_unrecognized_hparams,
allowed_unrecognized_hparams=FLAGS.allowed_unrecognized_hparams,
hparam_overrides=FLAGS.hparam_overrides,
initializer_name=FLAGS.initializer,
model_name=FLAGS.model,
Expand Down
4 changes: 2 additions & 2 deletions init2winit/test_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def test_unrecognized_override(self):
dataset_name='lm1b_v2',
hparam_file=None,
hparam_overrides=hps_overrides,
allow_unrecognized_hparams=False,
allowed_unrecognized_hparams=[],
)
merged_hps = hyperparameters.build_hparams(
model_name='transformer',
initializer_name='noop',
dataset_name='lm1b_v2',
hparam_file=None,
hparam_overrides=hps_overrides,
allow_unrecognized_hparams=True,
allowed_unrecognized_hparams=['lr_hparamsTYPO'],
)
expected_added_field = {'base_lr': 77.0}
self.assertEqual(merged_hps.lr_hparamsTYPO.to_dict(), expected_added_field)
Expand Down

0 comments on commit d32e5a9

Please sign in to comment.