-
Notifications
You must be signed in to change notification settings - Fork 870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed inference example #890
base: main
Are you sure you want to change the base?
Conversation
Thanks @angeloskath! This is very timely, as I was looking for such an example for a couple days. |
Amazing, time to buy a second m2 ultra:p |
@angeloskath, please correct me if I am wrong. By looking at the implementation, it seems like we are sharding vertically. For o_proj, we have to wait for all nodes to complete the forward pass before moving on to the next layer. This would create a bottleneck as the slowest node would slow down the entire process. Would it be better to shard by layers instead? Edit: |
I think you might be correct here @mzbac! Tho I would like to also benchmark @angeloskath approach. I have been researching this topic for weeks to support it on FastMLX. And according to the paper I read and Accelerate docs, layer group sharding is the best approach for distributed inference and training. But requires every single node / machine to have quick access to model weights/shard on device. |
They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency. This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though. The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate |
@angeloskath, thank you for the detailed explanation. I may try to get another M2 Ultra and test it via the Thunderbolt 4 connection :) |
IMO this is exactly what we need in the long run. In the short term, the hype is around the 400B llama - but that will fade eventually. Latency optimization is what I think fits with the overall MLX ethos. |
I tried clustering one M2 Ultra 192GB with another M2 Ultra 128GB, splitting the weights to 160GB and 67GB (not tensor parallelism) for llama3 405b. I got around 0.3 t/s, but I expected it to be closer to 1 or 2 t/s. I'm not sure if this is related to mlx or some system-level issue. ps: |
Was this over WiFi or thunderbolt 4 @mzbac ? |
TB4, I did run some tests and I feel there may be a memory issue when the memory consumption reaches a certain limit by mlx causes the token per second to slow down to 0.x. I am not exactly sure what the issue is, but sharding across deepseek coder v2 4bit was working fine (60+ vram and up to 1xx ram cache). |
Which OS are you on? A couple things that might help:
The |
Maybe putting more on the 128GB machine will help also. Like 140 and 87 or something. |
@awni Thanks for the pointers. I will try to upgrade macOS, currently, it's on version 14.5. |
Just to share the update, upgrading to macOs 15.0 helped solve the memory issue, and now I am able to run 405B 4-bit around 3.4 t/s - not bad at all. |
Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only faster from here 💪 |
I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration: |
06162b8
to
fbbf173
Compare
any update to speed since? got my hands on two 192gbs and getting ready to run some tests over the weekend |
nothing in the mlx-sharding part. I am still waiting for MLX to support pipeline parallelism in MPI. Once that is supported, there may be some performance improvements compared to using gRPC. |
48d5bf4
to
e648f9a
Compare
0f40077
to
9d7e80b
Compare
LFG 🚀🔥 |
1c2825a
to
a14db45
Compare
a14db45
to
8e3d9f3
Compare
Simply distributed inference on top of ml-explore/mlx#1270 . Again a draft PR so we can iterate on the design. This communication will be very latency bound (probably impractical) so no need to be particularly excited yet.