From f5d7255db18579b445df239bfb5e199ba492ce7a Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 10 Jul 2024 14:32:10 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 86bcffd8..127c4011 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -514,6 +514,15 @@ def parse_args(self, args_list: list = sys.argv[1:]): logger.exception(f"Error details: {str(e)}") raise e + if ( + "experimental" in args_dict + and "pipeline_parallel_split_points" in args_dict["experimental"] + ): + exp = args_dict["experimental"] + exp["pipeline_parallel_split_points"] = string_list( + exp["pipeline_parallel_split_points"] + ) + # override args dict with cmd_args cmd_args_dict = self._args_to_two_level_dict(cmd_args) for section, section_args in cmd_args_dict.items():