From d32e5a9d491315de7590cb79d80681227a4c22ec Mon Sep 17 00:00:00 2001 From: "George E. Dahl" Date: Thu, 1 Feb 2024 18:10:03 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 603537825 --- init2winit/hyperparameters.py | 18 +++++++++++------- init2winit/main.py | 13 +++++++------ init2winit/test_hyperparameters.py | 4 ++-- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index 66acd4ce..081bb2e9 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -102,7 +102,7 @@ def build_hparams(model_name, hparam_file, hparam_overrides, input_pipeline_hps=None, - allow_unrecognized_hparams=False): + allowed_unrecognized_hparams=None): """Build experiment hyperparameters. Args: @@ -115,11 +115,12 @@ def build_hparams(model_name, string encoding of this hyperparameter override dict. Note that this is applied after the hyperparameter file overrides. input_pipeline_hps: a dict of hyperparameters for performance tuning. - allow_unrecognized_hparams: if set, hparam overrides are allowed to - introduce new hparam keys that will most likely be ignored. Downgrading - unrecognized hparams from an error to a warning can be useful when trying - to tune using a shared search space over multiple workloads that don't all - support the same set of hyperparameters. + allowed_unrecognized_hparams: An optional list of hparam keys that hparam + overrides are allowed to introduce. There is no guaranteed these new + hparam keys will be ignored. Downgrading an explicit list of unrecognized + hparams from an error to a warning can be useful when trying to tune using + a shared search space over multiple workloads that don't all support the + same set of hyperparameters. Returns: A ConfigDict of experiment hyperparameters. @@ -186,10 +187,13 @@ def build_hparams(model_name, merged['optimizer'] != hparam_overrides['optimizer']): merged['opt_hparams'] = {} hparam_overrides = expand_dot_keys(hparam_overrides) - if allow_unrecognized_hparams: + if allowed_unrecognized_hparams: new_keys = [k for k in hparam_overrides if k not in merged] if new_keys: logging.warning('Unrecognized top-level hparams: %s', new_keys) + if any(k not in allowed_unrecognized_hparams for k in new_keys): + raise ValueError( + f'Unrecognized top-level hparams not in allowlist: {new_keys}') with merged.unlocked(): merged.update(hparam_overrides) else: diff --git a/init2winit/main.py b/init2winit/main.py index b923fdf1..1347aee1 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -136,9 +136,10 @@ 'hparam_file', None, 'Optional path to hparam json file for overriding ' 'hyperparameters. Hyperparameters are loaded before ' 'applying --hparam_overrides.') -flags.DEFINE_bool( - 'allow_unrecognized_hparams', False, - 'Downgrades unrecognized hparam override keys from an error to a warning.') +flags.DEFINE_list( + 'allowed_unrecognized_hparams', [], + 'Downgrades unrecognized hparam override keys from an error to a warning ' + 'for the supplied list of keys.') flags.DEFINE_string( 'training_metrics_config', '', 'JSON representation of the training metrics config.') @@ -189,7 +190,7 @@ def _run( early_stopping_min_steps, eval_steps, hparam_file, - allow_unrecognized_hparams, + allowed_unrecognized_hparams, hparam_overrides, initializer_name, model_name, @@ -221,7 +222,7 @@ def _run( hparam_file=hparam_file, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps, - allow_unrecognized_hparams=allow_unrecognized_hparams) + allowed_unrecognized_hparams=allowed_unrecognized_hparams) # Note that one should never tune an RNG seed!!! The seed is only included in # the hparams for convenience of running hparam trials with multiple seeds per @@ -349,7 +350,7 @@ def main(unused_argv): early_stopping_min_steps=FLAGS.early_stopping_min_steps, eval_steps=eval_steps, hparam_file=FLAGS.hparam_file, - allow_unrecognized_hparams=FLAGS.allow_unrecognized_hparams, + allowed_unrecognized_hparams=FLAGS.allowed_unrecognized_hparams, hparam_overrides=FLAGS.hparam_overrides, initializer_name=FLAGS.initializer, model_name=FLAGS.model, diff --git a/init2winit/test_hyperparameters.py b/init2winit/test_hyperparameters.py index e614e1ff..b7abf251 100644 --- a/init2winit/test_hyperparameters.py +++ b/init2winit/test_hyperparameters.py @@ -64,7 +64,7 @@ def test_unrecognized_override(self): dataset_name='lm1b_v2', hparam_file=None, hparam_overrides=hps_overrides, - allow_unrecognized_hparams=False, + allowed_unrecognized_hparams=[], ) merged_hps = hyperparameters.build_hparams( model_name='transformer', @@ -72,7 +72,7 @@ def test_unrecognized_override(self): dataset_name='lm1b_v2', hparam_file=None, hparam_overrides=hps_overrides, - allow_unrecognized_hparams=True, + allowed_unrecognized_hparams=['lr_hparamsTYPO'], ) expected_added_field = {'base_lr': 77.0} self.assertEqual(merged_hps.lr_hparamsTYPO.to_dict(), expected_added_field)