-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor-parallel like FeedForward to lower memory requirements #10623
base: main
Are you sure you want to change the base?
Conversation
# Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU | ||
for proj_in, proj_out in zip(self.proj_in, self.proj_out): | ||
out = proj_in(hidden_states) | ||
out = self.dropout(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dropout is probably incorrect here. As we split the embed dimension, applying dropout on each split will cause num_split
times more features to be dropped. I think dividing the original dropout rate by num_split
should have equivalent effect as normal feedforward 🤔
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
thanks @a-r-r-o-w |
It might be possible, I can try (but no promises it will be neat). We will probably still require the modeling change to split the embedding dim of the linear layers. I think this is something I can prio later as it's not too important, and since I have to still fake-tensor-parallelize our def _apply_tensor_parallel_ptd(
device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel
) -> None:
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
transformer_plan = {
# ===== Condition embeddings =====
"time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
"time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
"time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
"time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
"caption_projection.linear_1": ColwiseParallel(),
"caption_projection.linear_2": RowwiseParallel(),
"rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
# ===== =====
}
for block in transformer.transformer_blocks:
block_plan = {}
# ===== Attention =====
# 8 all-to-all, 3 all-reduce
block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
block_plan["attn1.norm_q"] = SequenceParallel()
block_plan["attn1.norm_k"] = SequenceParallel()
block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
block_plan["attn2.norm_q"] = SequenceParallel()
block_plan["attn2.norm_k"] = SequenceParallel()
block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
# ===== =====
block_plan["ff.net.0.proj"] = ColwiseParallel()
block_plan["ff.net.2"] = RowwiseParallel()
parallelize_module(block, device_mesh, block_plan)
parallelize_module(transformer, device_mesh, transformer_plan) |
In tensor parallelism, we split the internal layers of modules either column-wise or row-wise across multiple GPUs, perform individual forward passes for each split, and then all_reduce with a sum to gather outputs. We can do the same thing sequentially on a single GPU. Doing so:
Typically, there is a 4x-8x expansion in the intermediate hidden dimension of the FFs. This results in a 4x-8x larger intermediate tensor being created compared to input tensor. Given that we hav large models like HunyuanVideo now, the sequence length can be very large (> 2**14 for decent frames/resolution), so FFs end up allocating much additional memory -- this is much more worse for training without gradient checkpointing (or partial gradient checkpointing #10611), I think, because all intermediate tensors need to be stored. We can, however, get rid of allocating a large intermediate tensor.
There are some gotchas however. Applying this directly on an arbitrary model will most likely not show any memory savings. In order for this to have any effect, we first need to optimize the memory usage of a model to the point where
FeedForward
hidden dim expansion actually starts to affect the peak memory required. There are many ways reach the tipping point where FFs end up causing the peaks (sometimes multiple of below mentioned techniques need to be combined to reach that point):With the latest group offloading support upcoming, we know that VRAM usage can be reduced significantly without any penalty to throughput, given adequate CPU RAM. This reduces the memory peaks from model weights. The main cause of the spiky points on memory trace is now due either:
We can reduce these memory peaks by making use of ideas from tensor and sequence-parellism and applying them sequentially for single-GPU case. This PR implements the tensor-parallel equivalent of FFs. Will do a follow-up for sequence-parallism based optimization in the near future after some more important things are taken care of.
Onto the benchmark numbers!
Benchmark
Benchmarks results:
To make it easier to parse, table:
minimal reproducer with memory trace
Results:
@DN6 @yiyixuxu Would like a first-pass review before making further changes to gather feedback on what should be changed. Can add docs and think about how to expose a single API for applying memory optimizations so that we don't confuse users
cc @bghira (As an admirer and power-user of SimpleTuner, I think some of the latest optimizations will help benefit training as well [group offloading with cuda stream prefetching + this]. Would love to see how low we can go for training the biggest available models with negligible impact to speed)