From 88fd383f3eaa74d3f17e0d0322fe87d47a22142f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 11:17:01 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- torchtitan/checkpoint.py | 9 +---- torchtitan/config_manager.py | 35 +++++++++++++++++++- torchtitan/parallelisms/parallelize_llama.py | 4 +-- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 33fe8c05..2e1fdf67 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -26,13 +26,6 @@ from torchtitan.logging_utils import init_logger, logger -DTYPE_MAP = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - - class IntervalType(enum.Enum): SECONDS = enum.auto() STEPS = enum.auto() @@ -141,7 +134,7 @@ def __init__( self.pg = dist.new_group(backend="gloo") self.model_weights_only = ckpt_config.model_weights_only - self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype] + self.export_dtype = ckpt_config.export_dtype self.mp = None async_mode = ckpt_config.async_mode.lower() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c..99cc0746 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -9,6 +9,8 @@ from collections import defaultdict from typing import Tuple, Union +import torch + try: import tomllib except ModuleNotFoundError: @@ -16,6 +18,16 @@ from torchtitan.logging_utils import logger +DTYPE_MAP = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def torch_dtype(dtype_str: str) -> torch.dtype: + return DTYPE_MAP[dtype_str] + class JobConfig: """ @@ -207,6 +219,26 @@ def __init__(self): default=1, help="Pipeline Parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--training.mixed_precision_param", + type=torch_dtype, + default="bfloat16", + choices=["bfloat16", "float32"], + help=""" + torch dtype to use for parameters when applying mixed precision via FSDP. + This feature only takes effect when data_parallel_degree > 1 + """, + ) + self.parser.add_argument( + "--training.mixed_precision_reduce", + type=torch_dtype, + default="float32", + choices=["float32"], + help=""" + torch dtype to use for reductions when applying mixed precision via FSDP. + This feature only takes effect when data_parallel_degree > 1 + """, + ) self.parser.add_argument( "--training.compile", action="store_true", @@ -273,8 +305,9 @@ def __init__(self): ) self.parser.add_argument( "--checkpoint.export_dtype", - type=str, + type=torch_dtype, default="float32", + choices=["float16", "bfloat16", "float32"], help=""" Converts to the specified precision when training completes and model_weights_only=true. Currently supports float32, float16, and bfloat16. diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9c8d0a29..4f152587 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -209,9 +209,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - # TODO: Expose `reduce_dtype` as a config option. mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=job_config.training.mixed_precision_param, + reduce_dtype=job_config.training.mixed_precision_param, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} From 59447929a8e3fe46e5b10368f418c666fbf08688 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 12:46:57 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- torchtitan/config_manager.py | 19 +++++++++++++++---- torchtitan/parallelisms/parallelize_llama.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 99cc0746..6e6bf00a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -24,6 +24,12 @@ "bfloat16": torch.bfloat16, } +TORCH_DTYPE_ARGS = [ + "checkpoint.export_dtype", + "training.mixed_precision_param", + "training.mixed_precision_reduce", +] + def torch_dtype(dtype_str: str) -> torch.dtype: return DTYPE_MAP[dtype_str] @@ -222,7 +228,7 @@ def __init__(self): self.parser.add_argument( "--training.mixed_precision_param", type=torch_dtype, - default="bfloat16", + default=torch_dtype("bfloat16"), choices=["bfloat16", "float32"], help=""" torch dtype to use for parameters when applying mixed precision via FSDP. @@ -232,7 +238,7 @@ def __init__(self): self.parser.add_argument( "--training.mixed_precision_reduce", type=torch_dtype, - default="float32", + default=torch_dtype("float32"), choices=["float32"], help=""" torch dtype to use for reductions when applying mixed precision via FSDP. @@ -306,7 +312,7 @@ def __init__(self): self.parser.add_argument( "--checkpoint.export_dtype", type=torch_dtype, - default="float32", + default=torch_dtype("float32"), choices=["float16", "bfloat16", "float32"], help=""" Converts to the specified precision when training completes and model_weights_only=true. @@ -394,6 +400,9 @@ def parse_args(self, args_list: list = sys.argv[1:]): try: with open(config_file, "rb") as f: for k, v in tomllib.load(f).items(): + for k_, v_ in v.items(): + if ".".join([k, k_]) in TORCH_DTYPE_ARGS: + v[k_] = torch_dtype(v_) # to prevent overwrite of non-specified keys args_dict[k] |= v except (FileNotFoundError, tomllib.TOMLDecodeError) as e: @@ -437,7 +446,9 @@ def parse_args_from_command_line( # aux parser to parse the command line only args, with no defaults from main parser aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) for arg, val in vars(args).items(): - if isinstance(val, bool): + if arg in TORCH_DTYPE_ARGS: + aux_parser.add_argument("--" + arg, type=torch_dtype) + elif isinstance(val, bool): aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4f152587..cce9e9d6 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -211,7 +211,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names mp_policy = MixedPrecisionPolicy( param_dtype=job_config.training.mixed_precision_param, - reduce_dtype=job_config.training.mixed_precision_param, + reduce_dtype=job_config.training.mixed_precision_reduce, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} From a4f1d9d94dc8e7840979dd7aa1f4eddb46984902 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 17:18:53 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchtitan/checkpoint.py | 4 +-- torchtitan/config_manager.py | 31 +++++--------------- torchtitan/parallelisms/parallelize_llama.py | 6 ++-- 3 files changed, 13 insertions(+), 28 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 2e1fdf67..81bdf592 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -22,7 +22,7 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger @@ -134,7 +134,7 @@ def __init__( self.pg = dist.new_group(backend="gloo") self.model_weights_only = ckpt_config.model_weights_only - self.export_dtype = ckpt_config.export_dtype + self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.mp = None async_mode = ckpt_config.async_mode.lower() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 6e6bf00a..1a3e36d4 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -18,22 +18,12 @@ from torchtitan.logging_utils import logger -DTYPE_MAP = { +TORCH_DTYPE_MAP = { "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } -TORCH_DTYPE_ARGS = [ - "checkpoint.export_dtype", - "training.mixed_precision_param", - "training.mixed_precision_reduce", -] - - -def torch_dtype(dtype_str: str) -> torch.dtype: - return DTYPE_MAP[dtype_str] - class JobConfig: """ @@ -227,8 +217,8 @@ def __init__(self): ) self.parser.add_argument( "--training.mixed_precision_param", - type=torch_dtype, - default=torch_dtype("bfloat16"), + type=str, + default="bfloat16", choices=["bfloat16", "float32"], help=""" torch dtype to use for parameters when applying mixed precision via FSDP. @@ -237,8 +227,8 @@ def __init__(self): ) self.parser.add_argument( "--training.mixed_precision_reduce", - type=torch_dtype, - default=torch_dtype("float32"), + type=str, + default="float32", choices=["float32"], help=""" torch dtype to use for reductions when applying mixed precision via FSDP. @@ -311,8 +301,8 @@ def __init__(self): ) self.parser.add_argument( "--checkpoint.export_dtype", - type=torch_dtype, - default=torch_dtype("float32"), + type=str, + default="float32", choices=["float16", "bfloat16", "float32"], help=""" Converts to the specified precision when training completes and model_weights_only=true. @@ -400,9 +390,6 @@ def parse_args(self, args_list: list = sys.argv[1:]): try: with open(config_file, "rb") as f: for k, v in tomllib.load(f).items(): - for k_, v_ in v.items(): - if ".".join([k, k_]) in TORCH_DTYPE_ARGS: - v[k_] = torch_dtype(v_) # to prevent overwrite of non-specified keys args_dict[k] |= v except (FileNotFoundError, tomllib.TOMLDecodeError) as e: @@ -446,9 +433,7 @@ def parse_args_from_command_line( # aux parser to parse the command line only args, with no defaults from main parser aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) for arg, val in vars(args).items(): - if arg in TORCH_DTYPE_ARGS: - aux_parser.add_argument("--" + arg, type=torch_dtype) - elif isinstance(val, bool): + if isinstance(val, bool): aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index cce9e9d6..0bd0a966 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -28,7 +28,7 @@ from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger @@ -210,8 +210,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names mp_policy = MixedPrecisionPolicy( - param_dtype=job_config.training.mixed_precision_param, - reduce_dtype=job_config.training.mixed_precision_reduce, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}