Skip to content

Commit

Permalink
adapt to changes in torch_spmd to run on mast (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyu-l authored Jul 30, 2024
1 parent ba003e9 commit 273da65
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def parallelize_llama_torch_spmd(
assert ac_config.mode == "none", "AC not supported by torch_spmd yet"

if parallel_dims.dp_enabled:
from data_parallel import data_parallel, MixedPrecisionPolicy
from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
enable_color_printing = true
enable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"

Expand Down
4 changes: 2 additions & 2 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ profile_freq = 100

[metrics]
log_freq = 10
enable_color_printing = false
enable_tensorboard = false
enable_color_printing = true
save_tb_folder = "tb"

[model]
Expand All @@ -36,7 +36,7 @@ data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
compile = true
dataset = "c4_mini"
dataset = "c4"
mixed_precision_param = "bfloat16"
mixed_precision_reduce = "bfloat16"

Expand Down

0 comments on commit 273da65

Please sign in to comment.