-
Notifications
You must be signed in to change notification settings - Fork 544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clip_grad_norm
does not work with Tensor Parallelism
#2420
Comments
Thanks @nathan-az for creating this. If your hypothesis is correct, it may be because we are not currently parallelizing our RMSNorm layers via SequenceParallel. @acisseJZhong was actually just looking into this -- seems our current RMSNorm definition is not compatible with SequenceParallel (a bit weird because we are really just calling into the nn.functional version..)? She may be able to share more here |
I did manage to at least change the error by replacing the RMSNorm with the Python implementation as noted in the SequenceParallel docs, and adding the following to the TP plan: "norm": SequenceParallel(),
"layers.*.sa_norm": SequenceParallel(),
"layers.*.mlp_norm": SequenceParallel(), But this brought a new error message in the forward pass of the RoPE application. I tried a few things from here, but had no luck 😓
|
Hi @nathan-az the default llama3 tp plan does not apply Sequence Parallel to the norm layer, only Tensor Parallel to the linear layer. It should not parallelize norm layer, what error are you seeing when turning on If you wanna try out enabling SequenceParallel, you could try out the following parallelism plan
Let me know how it goes! From my past experience, torchtune's RMSNorm implementation might not work with Sequence Parallel, if you also encountered this, feel free to try out Torchtitan's RMSNorm. Please let us know if that works! |
Thanks @acisseJZhong - the full stacktrace from enabling Good news is that your tensor parallel plan above seemed to work (at least it didn't throw any errors during training - I can't confirm numerical correctness), even with the I'm willing to try this out on a fork, but I don't have capacity to add tests / ensure correctness. Do you have a branch/PR I can follow? |
@acisseJZhong just an FYI the above TP plan doesn't appear to work with FP8 training - at least the way I've implemented in this PR, largely taken from Error below:
TP plan with FP8 TP utils (should be identical to yours, but substituting the if enable_float8:
rowwise_parallel, colwise_parallel, prepare_module_input = (
Float8RowwiseParallel,
Float8ColwiseParallel,
PrepareFloat8ModuleInput,
)
else:
rowwise_parallel, colwise_parallel, prepare_module_input = (
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
)
base_tp_plan = {
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"layers.*.sa_norm": SequenceParallel(),
"layers.*.attn": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"layers.*.attn.q_proj": colwise_parallel(),
"layers.*.attn.k_proj": colwise_parallel(),
"layers.*.attn.v_proj": colwise_parallel(),
"layers.*.attn.output_proj": rowwise_parallel(output_layouts=Shard(1)),
"layers.*.mlp_norm": SequenceParallel(),
"layers.*.mlp": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"layers.*.mlp.w1": colwise_parallel(),
"layers.*.mlp.w2": rowwise_parallel(output_layouts=Shard(1)),
"layers.*.mlp.w3": colwise_parallel(),
} The config I used to reproduce: output_dir: /tmp/torchtune/llama3_1_8B/full # /tmp may be deleted by your system. Change it to your preference.
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: outputs/base_model/Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: 256
# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: True # True increases speed
seed: null
shuffle: True
tensor_parallel_dim: 4
tensor_parallel_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan
# Model Arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: outputs/base_model/Llama-3.1-8B-Instruct
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 1
epochs: 1
optimizer:
_component_: torchao.prototype.low_bit_optim.AdamW8bit
lr: 2e-5
# fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 2
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
gradient_accumulation_steps: 4 # Use to increase effective batch size
# Training env
device: cuda
# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: [] # ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
# Reduced precision
dtype: bf16
enable_fp8_training: true
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True I believe it is the result of the changes to the plan, because the fp8 training works with the old plan, and no grad clipping. If I simply change out the plan (but keep clipping off and all other settings unchange) it yields the error. I understand FP8 training isn't merged so probably isn't a priority, but if you have any inkling as to why the above occurs, that would be greatly appreciated - I have no clue how to even begin debugging an error like that, and if it's addressable we can merge it before (or inside) that PR. |
@vkuzo in case you have any thoughts as the resident |
The solution may be to use the
cc: @tianyu-l |
For float8 training, we've tested our float8 TP code where activations and gradients are in row-major memory format. Is there anything ^ which would invalidate that assumption? |
Some insights from my testing. I'm comparing TP with
The above are insights from an experimental branch I'm working off with a combination of features not on I'll get a PR in for using the After that I'm happy to create issues and look more deeply into TP, since I understand TP is still a relatively new feature. |
One thing that might be relevant is what dtype you are accumulating gradients in. If using microbatching/gradient accumulation, accumulating gradients in fp32 is generally helpful/required, but it requires more memory. (Maybe stochastic rounding could work in place of fp32, but I am not sure what support looks like in PyTorch.) In FSDP2, if you set |
Another aspect may be this gradient scaling, which is scaling with world size rather than DP size. If the motivation is to undo some FSDP2 normalization, this should probably scale with the DP size (or divide by TP size afterwards). Unfortunately I lack the capacity to look into this any further so I'm sticking with FSDP (/ HSDP) for now, but I think there are decent leads in this thread! |
I did have time to at least run some experiments to help future debugging on this branch. Functional changes:
Config: batch_size: 1
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: outputs/base_model/Llama-3.3-70B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: '00030'
model_type: LLAMA3
output_dir: ${output_dir}
recipe_checkpoint: null
clip_grad_norm: null
compile: true
custom_sharded_layers: []
dataset:
_component_: torchtune.datasets.slimorca_dataset
packed: true
train_on_input: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 50
metric_logger:
_component_: torchtune.training.metric_logging.MLFlowLogger
model:
_component_: torchtune.models.llama3_3.llama3_3_70b
optimizer:
_component_: torchao.prototype.low_bit_optim.AdamW8bit
lr: 2.0e-06
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tensor_parallel_dim: 1
tensor_parallel_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: 4096
path: outputs/base_model/Llama-3.3-70B-Instruct/original/tokenizer.model All tests use 8 devices. FSDP tests are run with OverallWe see seemingly close grad norms and loss, but TP loss is consistently higher, and we note some spikes in TP norms. ![]() TP: with/without grad norm clippingClipped TP norms appear to spike higher than unclipped. ![]() TP vs FSDP: no clippingPerhaps most concerning to me is that without clipping, there is still pretty substantial difference in the grad norm size between the FSDP and TP case. ![]() This could be nothing, but my branch has no functional differences in the no clipping case so I'm a bit surprised to see such variance in the grad norms, and not see exactly (or nearer to exactly) equal losses between FSDP and TP. I'm not sure if there is some functional difference in the training loop, an inaccuracy in the gradient norm calculation, or a precision issue with gradients as @awgu suggested. Either way, I just hope these experiments serve as a useful starting point if this needs debugging 😄 |
Currently it appears tensor parallelism is not compatible with grad norm clipping. I did some testing (printing non-DTensor parameter names in the training loop) and I believe it's due to the normalisation layers.
The above can be reproduced by taking the standard LLaMA 3.1
8B_full.yaml
configs, and setting:This would be great if it could be fixed due to instability in the initial training steps.
I'm unable to contribute a fix here, as I'm not familiar enough with tensor parallelism, however I suppose the fix will probably involve tweaking the TP plan to ensure that the norm layers are also DTensors.
The text was updated successfully, but these errors were encountered: