Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why use RowwiseParallel for nn.Embedding instead of ColwiseParallel? #785

Open
corey-lambda opened this issue Jan 10, 2025 · 5 comments
Open
Assignees
Labels
question Further information is requested

Comments

@corey-lambda
Copy link

corey-lambda commented Jan 10, 2025

Colwise makes the logic a bit more clear. Rowwise splits on the token dimension, leading to confusion on how the different shards handle tokens that are not present within their shard. From a bit of debugging it seems like there is a special case for this somewhere deep in pytorch source code, but I could not find it.

With colwise, the embedding weight matrix is split on the model dim dimension, so all shards have all the tokens, just different parts of the model dim.

https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L133

    parallelize_module(
        model,
        tp_mesh,
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
            ),

Can someone provide some insight?

@tianyu-l tianyu-l added the question Further information is requested label Jan 10, 2025
@tianyu-l tianyu-l self-assigned this Jan 10, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 11, 2025

Thanks for the question.

I think both Rowwise and Colwise would work correctly, and it seems to be their performance on memory and computation time should be similar. It is worth noting that vocab dimension is often much larger than the hidden dimension, so technically it has more potential to do Rowwise, although in practice the TP degree is likely not that much.

The difference should be in communication:

  • With Rowwise, after the layer the embeddings would be stored in a _MaskPartial placement and requires a reduce-scatter to be sharded on the sequence dimension in Sequence Parallel (or all-reduce if vanilla Tensor Parallel)
  • With Colwise, the embeddings are sharded on the hidden dim, which requires an all-to-all to be sharded on the sequence dim (or all-gather if vanilla Tensor Parallel).

At the time the code was written, all-to-all for Colwise was not supported yet, and we ended up using Rowwise.

To be honest I haven't tested the perf difference, since all-to-all is supported now. Are you interested in testing and sharing with us, on your hardware and preferred config?

@xffxff
Copy link

xffxff commented Jan 12, 2025

I tested on 8 H100 gpus using llama3_8b.toml, setting the tp size to 8 and the batch size to 4.

Here are the results for memory usage and throughput:
image

Both Colwise and Rowise had the same max active memory usage, but Colwise had higher max reserved memory usage and it also achieved slightly better throughput.

Colwise Profile
image

Rowwise Profile
image

It looks like the logic of aten::embedding is more complicated with Rowwise than with Colwise. It might because the embeddings are shard on the vocab dim, so the embeddings in local only have a partial embeddings of the whole vocab, that means we can't just do index_select, I noticed that aten:lt and aten::ge were involved, which might be used to check if we can look up the token embedding in local embeddings. Could this be why Colwise is slightly faster?

Also, it looks like the logic of redistribution with Colwise is more complicated than with Rowwise.

@tianyu-l
Copy link
Contributor

@xffxff Thanks! This looks interesting.

It looks like the logic of aten::embedding is more complicated with Rowwise than with Colwise. It might because the embeddings are shared on the vocab dim, so the embeddings in local only have a partial embeddings of the whole vocab, that means we can't just do index_select, I noticed that aten:lt and aten::ge were involved, which might be used to check if we can look up the token embedding in local embeddings.

You are basically right.

Could this be why Colwise is slightly faster?

That could be the reason, but to verify it from the profile trace, we'll need to look at the GPU part (what you presented is the CPU part, but the throughput should be GPU-bound). Note that the communication all-to-all / reduce-scatter should be exposed, so their difference could also contribute to the throughput difference (again let's compare them in the GPU profile).

Also, it looks like the logic of redistribution with Colwise is more complicated than with Rowwise.

The redistribution with Colwise is just a Shard(2) to Shard(1) (hidden dim to sequence dim) redistribution of the underlying DTensor. But it seems incurring multiple contiguous calls. cc: @wanchaol

@xffxff
Copy link

xffxff commented Jan 13, 2025

let's compare them in the GPU profile).

@tianyu-l Thanks for getting back to me!

Added some screenshots from the GPU part, some observations:

  • The index_select kernel took 0.104 ms in the Colwise profile, but it took 0.788 ms in the Rowwise profile
  • The all-to-all took 7.142 ms in the Colwise profile, but the reduce scatter in the Rowwise profile only took 1.456 ms.

I'm not sure how to do more analysis, but hope this helps. I’d be happy to hear any insights you might have!

(I tried uploading the trace files, but they are too big -- Github limits file sizes to 25MB)

Colwise profile with the GPU part:
image
image

Rowwise profile with the GPU part:
image
image

@corey-lambda
Copy link
Author

Thanks for all of these wonderful responses!

It looks like the logic of aten::embedding is more complicated with Rowwise than with Colwise. It might because the embeddings are shared on the vocab dim, so the embeddings in local only have a partial embeddings of the whole vocab, that means we can't just do index_select, I noticed that aten:lt and aten::ge were involved, which might be used to check if we can look up the token embedding in local embeddings.

This is why I ended up raising the issue. I was trying to explain why rowwise worked in a simple way and just kept coming back to colwise being so much easier to explain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants