From 0573fce1a4c09cdae0c7e3e224ee35adf6bec5d5 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 12:46:58 -0700 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchtitan/config_manager.py | 19 +++++++++++++++---- torchtitan/parallelisms/parallelize_llama.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 23c4bab2..7484eb93 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -24,6 +24,12 @@ "bfloat16": torch.bfloat16, } +TORCH_DTYPE_ARGS = [ + "checkpoint.export_dtype", + "training.mixed_precision_param", + "training.mixed_precision_reduce", +] + def torch_dtype(dtype_str: str) -> torch.dtype: return DTYPE_MAP[dtype_str] @@ -284,7 +290,7 @@ def __init__(self): self.parser.add_argument( "--training.mixed_precision_param", type=torch_dtype, - default="bfloat16", + default=torch_dtype("bfloat16"), choices=["bfloat16", "float32"], help=""" torch dtype to use for parameters when applying mixed precision via FSDP. @@ -294,7 +300,7 @@ def __init__(self): self.parser.add_argument( "--training.mixed_precision_reduce", type=torch_dtype, - default="float32", + default=torch_dtype("float32"), choices=["float32"], help=""" torch dtype to use for reductions when applying mixed precision via FSDP. @@ -368,7 +374,7 @@ def __init__(self): self.parser.add_argument( "--checkpoint.export_dtype", type=torch_dtype, - default="float32", + default=torch_dtype("float32"), choices=["float16", "bfloat16", "float32"], help=""" Converts to the specified precision when training completes and model_weights_only=true. @@ -456,6 +462,9 @@ def parse_args(self, args_list: list = sys.argv[1:]): try: with open(config_file, "rb") as f: for k, v in tomllib.load(f).items(): + for k_, v_ in v.items(): + if ".".join([k, k_]) in TORCH_DTYPE_ARGS: + v[k_] = torch_dtype(v_) # to prevent overwrite of non-specified keys args_dict[k] |= v except (FileNotFoundError, tomllib.TOMLDecodeError) as e: @@ -499,7 +508,9 @@ def parse_args_from_command_line( # aux parser to parse the command line only args, with no defaults from main parser aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) for arg, val in vars(args).items(): - if isinstance(val, bool): + if arg in TORCH_DTYPE_ARGS: + aux_parser.add_argument("--" + arg, type=torch_dtype) + elif isinstance(val, bool): aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index e06507fa..8cecdb4a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -393,7 +393,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names mp_policy = MixedPrecisionPolicy( param_dtype=job_config.training.mixed_precision_param, - reduce_dtype=job_config.training.mixed_precision_param, + reduce_dtype=job_config.training.mixed_precision_reduce, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}