-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why use RowwiseParallel for nn.Embedding instead of ColwiseParallel? #785
Comments
Thanks for the question. I think both Rowwise and Colwise would work correctly, and it seems to be their performance on memory and computation time should be similar. It is worth noting that vocab dimension is often much larger than the hidden dimension, so technically it has more potential to do Rowwise, although in practice the TP degree is likely not that much. The difference should be in communication:
At the time the code was written, all-to-all for Colwise was not supported yet, and we ended up using Rowwise. To be honest I haven't tested the perf difference, since all-to-all is supported now. Are you interested in testing and sharing with us, on your hardware and preferred config? |
@xffxff Thanks! This looks interesting.
You are basically right.
That could be the reason, but to verify it from the profile trace, we'll need to look at the GPU part (what you presented is the CPU part, but the throughput should be GPU-bound). Note that the communication all-to-all / reduce-scatter should be exposed, so their difference could also contribute to the throughput difference (again let's compare them in the GPU profile).
The redistribution with Colwise is just a |
@tianyu-l Thanks for getting back to me! Added some screenshots from the GPU part, some observations:
I'm not sure how to do more analysis, but hope this helps. I’d be happy to hear any insights you might have! (I tried uploading the trace files, but they are too big -- Github limits file sizes to 25MB) |
Thanks for all of these wonderful responses!
This is why I ended up raising the issue. I was trying to explain why rowwise worked in a simple way and just kept coming back to colwise being so much easier to explain. |
Colwise makes the logic a bit more clear. Rowwise splits on the token dimension, leading to confusion on how the different shards handle tokens that are not present within their shard. From a bit of debugging it seems like there is a special case for this somewhere deep in pytorch source code, but I could not find it.
With colwise, the embedding weight matrix is split on the model dim dimension, so all shards have all the tokens, just different parts of the model dim.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L133
Can someone provide some insight?
The text was updated successfully, but these errors were encountered: