Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 20, 2024
1 parent 3f8f1a3 commit 0573fce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 15 additions & 4 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
"bfloat16": torch.bfloat16,
}

TORCH_DTYPE_ARGS = [
"checkpoint.export_dtype",
"training.mixed_precision_param",
"training.mixed_precision_reduce",
]


def torch_dtype(dtype_str: str) -> torch.dtype:
return DTYPE_MAP[dtype_str]
Expand Down Expand Up @@ -284,7 +290,7 @@ def __init__(self):
self.parser.add_argument(
"--training.mixed_precision_param",
type=torch_dtype,
default="bfloat16",
default=torch_dtype("bfloat16"),
choices=["bfloat16", "float32"],
help="""
torch dtype to use for parameters when applying mixed precision via FSDP.
Expand All @@ -294,7 +300,7 @@ def __init__(self):
self.parser.add_argument(
"--training.mixed_precision_reduce",
type=torch_dtype,
default="float32",
default=torch_dtype("float32"),
choices=["float32"],
help="""
torch dtype to use for reductions when applying mixed precision via FSDP.
Expand Down Expand Up @@ -368,7 +374,7 @@ def __init__(self):
self.parser.add_argument(
"--checkpoint.export_dtype",
type=torch_dtype,
default="float32",
default=torch_dtype("float32"),
choices=["float16", "bfloat16", "float32"],
help="""
Converts to the specified precision when training completes and model_weights_only=true.
Expand Down Expand Up @@ -456,6 +462,9 @@ def parse_args(self, args_list: list = sys.argv[1:]):
try:
with open(config_file, "rb") as f:
for k, v in tomllib.load(f).items():
for k_, v_ in v.items():
if ".".join([k, k_]) in TORCH_DTYPE_ARGS:
v[k_] = torch_dtype(v_)
# to prevent overwrite of non-specified keys
args_dict[k] |= v
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
Expand Down Expand Up @@ -499,7 +508,9 @@ def parse_args_from_command_line(
# aux parser to parse the command line only args, with no defaults from main parser
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
for arg, val in vars(args).items():
if isinstance(val, bool):
if arg in TORCH_DTYPE_ARGS:
aux_parser.add_argument("--" + arg, type=torch_dtype)
elif isinstance(val, bool):
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
mp_policy = MixedPrecisionPolicy(
param_dtype=job_config.training.mixed_precision_param,
reduce_dtype=job_config.training.mixed_precision_param,
reduce_dtype=job_config.training.mixed_precision_reduce,
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
Expand Down

0 comments on commit 0573fce

Please sign in to comment.