Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clip_grad_norm does not work with Tensor Parallelism #2420

Open
nathan-az opened this issue Feb 22, 2025 · 12 comments
Open

clip_grad_norm does not work with Tensor Parallelism #2420

nathan-az opened this issue Feb 22, 2025 · 12 comments
Labels
bug Something isn't working distributed Anything related to distributed env (multi-GPU, multi-node)

Comments

@nathan-az
Copy link
Contributor

[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/recipes/full_finetune_distributed.py", line 948, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/recipes/full_finetune_distributed.py", line 836, in train
[rank2]:     grad_norm = torch.nn.utils.clip_grad_norm_(
[rank2]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 38, in _no_grad_wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 219, in clip_grad_norm_
[rank2]:     total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
[rank2]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 38, in _no_grad_wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 91, in _get_total_norm
[rank2]:     norms.extend(torch._foreach_norm(device_tensors, norm_type))
[rank2]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 764, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 348, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 168, in dispatch
[rank2]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank2]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 373, in unwrap_to_op_info
[rank2]:     self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 474, in _try_replicate_spec_for_scalar_tensor
[rank2]:     raise RuntimeError(
[rank2]: RuntimeError: aten._foreach_norm.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Currently it appears tensor parallelism is not compatible with grad norm clipping. I did some testing (printing non-DTensor parameter names in the training loop) and I believe it's due to the normalisation layers.

The above can be reproduced by taking the standard LLaMA 3.1 8B_full.yaml configs, and setting:

tensor_parallel_dim: 4  # set to device count
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
clip_grad_norm: 1.0

This would be great if it could be fixed due to instability in the initial training steps.

I'm unable to contribute a fix here, as I'm not familiar enough with tensor parallelism, however I suppose the fix will probably involve tweaking the TP plan to ensure that the norm layers are also DTensors.

@ebsmothers
Copy link
Contributor

Thanks @nathan-az for creating this. If your hypothesis is correct, it may be because we are not currently parallelizing our RMSNorm layers via SequenceParallel. @acisseJZhong was actually just looking into this -- seems our current RMSNorm definition is not compatible with SequenceParallel (a bit weird because we are really just calling into the nn.functional version..)? She may be able to share more here

@ebsmothers ebsmothers added bug Something isn't working distributed Anything related to distributed env (multi-GPU, multi-node) labels Feb 23, 2025
@nathan-az
Copy link
Contributor Author

nathan-az commented Feb 23, 2025

it may be because we are not currently parallelizing our RMSNorm layers via SequenceParallel

I did manage to at least change the error by replacing the RMSNorm with the Python implementation as noted in the SequenceParallel docs, and adding the following to the TP plan:

    "norm": SequenceParallel(),
    "layers.*.sa_norm": SequenceParallel(),
    "layers.*.mlp_norm": SequenceParallel(),

But this brought a new error message in the forward pass of the RoPE application. I tried a few things from here, but had no luck 😓

[rank3]:   File "/app/torchtune/recipes/full_finetune_distributed.py", line 959, in recipe_main
[rank3]:     recipe.train()
[rank3]:   File "/app/torchtune/recipes/full_finetune_distributed.py", line 792, in train
[rank3]:     logits = self._model(**batch)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/app/torchtune/torchtune/modules/transformer.py", line 636, in forward
[rank3]:     h = layer(
[rank3]:         ^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank3]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank3]:     return disable_fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 764, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 495, in checkpoint
[rank3]:     ret = function(*args, **kwargs)
[rank3]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/app/torchtune/torchtune/modules/transformer.py", line 122, in forward
[rank3]:     attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank3]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/app/torchtune/torchtune/modules/attention.py", line 245, in forward
[rank3]:     q = self.pos_embeddings(q, input_pos=input_pos)
[rank3]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/app/torchtune/torchtune/models/llama3_1/_position_embeddings.py", line 177, in forward
[rank3]:     rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
[rank3]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: shape '[-1, 1024, 1, 64, 2]' is invalid for input of size 65536

@acisseJZhong
Copy link
Contributor

acisseJZhong commented Feb 24, 2025

Hi @nathan-az the default llama3 tp plan does not apply Sequence Parallel to the norm layer, only Tensor Parallel to the linear layer. It should not parallelize norm layer, what error are you seeing when turning on clip_grad_norm?

If you wanna try out enabling SequenceParallel, you could try out the following parallelism plan

BASE_LLAMA_TP_SP_PLAN = {
      "tok_embeddings": RowwiseParallel(
          input_layouts=Replicate(),
          output_layouts=Shard(1),
      ),
    "norm": SequenceParallel(),
    "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
    "layers.*.sa_norm": SequenceParallel(),
    "layers.*.attn": prepare_module_input(
                input_layouts=(Shard(1), None),
                desired_input_layouts=(Replicate(), None),
            ),
    "layers.*.attn.q_proj": ColwiseParallel(),
    "layers.*.attn.k_proj": ColwiseParallel(),
    "layers.*.attn.v_proj": ColwiseParallel(),
    "layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
    "layers.*.mlp_norm": SequenceParallel(),
    "layers.*.mlp": prepare_module_input(
                    input_layouts=(Shard(1),),
                    desired_input_layouts=(Replicate(),),
                ),
    "layers.*.mlp.w1": ColwiseParallel(),
    "layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
    "layers.*.mlp.w3": ColwiseParallel(),
}

Let me know how it goes! From my past experience, torchtune's RMSNorm implementation might not work with Sequence Parallel, if you also encountered this, feel free to try out Torchtitan's RMSNorm. Please let us know if that works!

@nathan-az
Copy link
Contributor Author

Thanks @acisseJZhong - the full stacktrace from enabling clip_grad_norm is in the "details" expand at the top of this issue. The error due to a mix of Tensor and DTensor. I believe the norm layers are the cause, because I added some logging of a dictionary mapping p.dtype -> name and saw only norm layers weren't DTensors. So my main motivation isn't to further parallelise the norm layers (these, and any associated collectives should be very cheap I believe) - just to find a way to enable gradient norm clipping, and I thought including the norms in the TP plan would do this passively 😄!

Good news is that your tensor parallel plan above seemed to work (at least it didn't throw any errors during training - I can't confirm numerical correctness), even with the torchtune RMSNorm implementation.

I'm willing to try this out on a fork, but I don't have capacity to add tests / ensure correctness. Do you have a branch/PR I can follow?

@nathan-az
Copy link
Contributor Author

nathan-az commented Feb 24, 2025

@acisseJZhong just an FYI the above TP plan doesn't appear to work with FP8 training - at least the way I've implemented in this PR, largely taken from torchtitan.

Error below:

[rank3]:   File "/app/torchtune_custom/recipes/full_finetune_distributed.py", line 995, in recipe_main
[rank3]:     recipe.train()
[rank3]:   File "/app/torchtune_custom/recipes/full_finetune_distributed.py", line 862, in train
[rank3]:     current_loss.backward()
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
[rank3]:     torch.autograd.backward(
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 353, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
[rank3]:     return user_fn(self, *args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torchao/float8/float8_linear.py", line 86, in backward
[rank3]:     grad_weight = torch.mm(
[rank3]:                   ^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank3]:     return disable_fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 764, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 348, in __torch_dispatch__
[rank3]:     return DTensor._op_dispatcher.dispatch(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 183, in dispatch
[rank3]:     self.redistribute_local_args(
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 319, in redistribute_local_args
[rank3]:     resharded_local_tensor = redistribute_local_tensor(
[rank3]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py", line 239, in redistribute_local_tensor
[rank3]:     new_local_tensor = shard_spec._to_new_shard_dim(
[rank3]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py", line 330, in _to_new_shard_dim
[rank3]:     local_tensor = local_tensor.contiguous()
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/conda/lib/python3.11/site-packages/torchao/float8/float8_tensor.py", line 375, in __torch_dispatch__
[rank3]:     raise NotImplementedError(f"attempting to run {func}, this is not supported")
[rank3]: NotImplementedError: attempting to run aten.contiguous.default, this is not supported

TP plan with FP8 TP utils (should be identical to yours, but substituting the torch.ao ColwiseLinear and RowwiseLinear utils in all but the embedding and output layers:

    if enable_float8:
        rowwise_parallel, colwise_parallel, prepare_module_input = (
            Float8RowwiseParallel,
            Float8ColwiseParallel,
            PrepareFloat8ModuleInput,
        )
    else:
        rowwise_parallel, colwise_parallel, prepare_module_input = (
            RowwiseParallel,
            ColwiseParallel,
            PrepareModuleInput,
        )


    base_tp_plan = {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
        "layers.*.sa_norm": SequenceParallel(),
        "layers.*.attn": prepare_module_input(
            input_layouts=(Shard(1), None),
            desired_input_layouts=(Replicate(), None),
        ),
        "layers.*.attn.q_proj": colwise_parallel(),
        "layers.*.attn.k_proj": colwise_parallel(),
        "layers.*.attn.v_proj": colwise_parallel(),
        "layers.*.attn.output_proj": rowwise_parallel(output_layouts=Shard(1)),
        "layers.*.mlp_norm": SequenceParallel(),
        "layers.*.mlp": prepare_module_input(
            input_layouts=(Shard(1),),
            desired_input_layouts=(Replicate(),),
        ),
        "layers.*.mlp.w1": colwise_parallel(),
        "layers.*.mlp.w2": rowwise_parallel(output_layouts=Shard(1)),
        "layers.*.mlp.w3": colwise_parallel(),
    }

The config I used to reproduce:

output_dir: /tmp/torchtune/llama3_1_8B/full # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: outputs/base_model/Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 256

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  packed: True  # True increases speed
seed: null
shuffle: True

tensor_parallel_dim: 4
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan

# Model Arguments
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: outputs/base_model/Llama-3.1-8B-Instruct
  checkpoint_files: [
    model-00001-of-00004.safetensors,
    model-00002-of-00004.safetensors,
    model-00003-of-00004.safetensors,
    model-00004-of-00004.safetensors
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1

optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 2e-5
  # fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 2
clip_grad_norm: null
compile: False  # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
gradient_accumulation_steps: 4  # Use to increase effective batch size

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True  # True reduces memory
enable_activation_offloading: False  # True reduces memory
custom_sharded_layers: [] # ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.

# Reduced precision
dtype: bf16

enable_fp8_training: true

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

I believe it is the result of the changes to the plan, because the fp8 training works with the old plan, and no grad clipping. If I simply change out the plan (but keep clipping off and all other settings unchange) it yields the error.

I understand FP8 training isn't merged so probably isn't a priority, but if you have any inkling as to why the above occurs, that would be greatly appreciated - I have no clue how to even begin debugging an error like that, and if it's addressable we can merge it before (or inside) that PR.

@nathan-az
Copy link
Contributor Author

@vkuzo in case you have any thoughts as the resident torch.ao expert :)

@awgu
Copy link

awgu commented Feb 24, 2025

The solution may be to use the implicit_replication context around clip_grad_norm_ (code pointer).

from torch.distributed.tensor.experimental import implicit_replication
with implicit_replication():
    clip_grad_norm_(...)

cc: @tianyu-l

@vkuzo
Copy link

vkuzo commented Feb 24, 2025

For float8 training, we've tested our float8 TP code where activations and gradients are in row-major memory format. Is there anything ^ which would invalidate that assumption?

@nathan-az
Copy link
Contributor Author

nathan-az commented Feb 25, 2025

implicit_replication seems to work! With the original TP plan I see no error, and despite enormous (gargantuan) gradient norms, loss is stable and not devolving to NaN.

Some insights from my testing. I'm comparing TP with 8n accumulation steps with FSDP with n gradient accumulation steps with the same seed, learning rate, precision. My understand is that these should be equivalent in terms of examples trained on per-optimizer step (8 devices).

  • FP8 in TP seems slower than BF16 - note that neither is using compile since I am using activation offloading and these don't seem compatible in TP
  • grad norms are considerably higher in TP - in the hundreds - which is confusing given the setups should be equivalent?
  • FSDP loss drops much faster
  • I think tokens per second logging is incorrect when using TP (it still multiplies by world_size but needs to divide by tp_size if I understand correctly?)

The above are insights from an experimental branch I'm working off with a combination of features not on main so I don't expect any fixes or anybody to run away and debug.

I'll get a PR in for using the implicit_replication context manager because that's simple and looks like it's working. Then I'll slowly work through my other open PRs (FP8 support and HSDP support) which appear stable.

After that I'm happy to create issues and look more deeply into TP, since I understand TP is still a relatively new feature.

@awgu
Copy link

awgu commented Feb 25, 2025

One thing that might be relevant is what dtype you are accumulating gradients in. If using microbatching/gradient accumulation, accumulating gradients in fp32 is generally helpful/required, but it requires more memory. (Maybe stochastic rounding could work in place of fp32, but I am not sure what support looks like in PyTorch.)

In FSDP2, if you set reduce_dtype=torch.float32 while also setting set_requires_gradient_sync(False), then those accumulation backwards will include logic to upcast the gradients to fp32 and accumulate in fp32 (naively without fusion).

@nathan-az
Copy link
Contributor Author

Another aspect may be this gradient scaling, which is scaling with world size rather than DP size. If the motivation is to undo some FSDP2 normalization, this should probably scale with the DP size (or divide by TP size afterwards). Unfortunately I lack the capacity to look into this any further so I'm sticking with FSDP (/ HSDP) for now, but I think there are decent leads in this thread!

@nathan-az
Copy link
Contributor Author

I did have time to at least run some experiments to help future debugging on this branch. Functional changes:

  • scale grads by DP size, not world size (I believe this is strictly correct and there is currently a bug?)
  • perform all grad norm clipping in the with_replication context managerlarger.

Config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: outputs/base_model/Llama-3.3-70B-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00030'
  model_type: LLAMA3
  output_dir: ${output_dir}
  recipe_checkpoint: null
clip_grad_norm: null
compile: true
custom_sharded_layers: []
dataset:
  _component_: torchtune.datasets.slimorca_dataset
  packed: true
  train_on_input: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 50
metric_logger:
  _component_: torchtune.training.metric_logging.MLFlowLogger
model:
  _component_: torchtune.models.llama3_3.llama3_3_70b
optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 2.0e-06
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tensor_parallel_dim: 1
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: 4096
  path: outputs/base_model/Llama-3.3-70B-Instruct/original/tokenizer.model

All tests use 8 devices. FSDP tests are run with gradient_accumulation_steps: 1, and compile: true. TP tests are run with gradient_accumulation_steps: 8, and compile: false. Other than that, clipping is alternated between null and 1.0. Runs are seeded, so I expected near identical results.

Overall

We see seemingly close grad norms and loss, but TP loss is consistently higher, and we note some spikes in TP norms.

Image

TP: with/without grad norm clipping

Clipped TP norms appear to spike higher than unclipped.

Image

TP vs FSDP: no clipping

Perhaps most concerning to me is that without clipping, there is still pretty substantial difference in the grad norm size between the FSDP and TP case.

Image

This could be nothing, but my branch has no functional differences in the no clipping case so I'm a bit surprised to see such variance in the grad norms, and not see exactly (or nearer to exactly) equal losses between FSDP and TP. I'm not sure if there is some functional difference in the training loop, an inaccuracy in the gradient norm calculation, or a precision issue with gradients as @awgu suggested. Either way, I just hope these experiments serve as a useful starting point if this needs debugging 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Anything related to distributed env (multi-GPU, multi-node)
Projects
None yet
Development

No branches or pull requests

5 participants