Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Integration of Distributed Inference into TorchChat #1376

Open
mreso opened this issue Nov 14, 2024 · 4 comments
Open

[RFC] Integration of Distributed Inference into TorchChat #1376

mreso opened this issue Nov 14, 2024 · 4 comments
Assignees
Labels
Distributed Issues related to all things distributed RFC Request for Comment triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mreso
Copy link
Contributor

mreso commented Nov 14, 2024

🚀 The feature, motivation and pitch

Overview
The goal of this RFC is to discuss the integration of distributed inference into TorchChat. Distributed inference leverages tensor parallelism or pipeline parallelism, or a combination of both to support larger model size which do not fit on a single accelerator. Through parallelization each model shard runs in its own worker process. The processes can either be spawned on the script level (e.g. via torchrun) or from within the main script. For online use cases like chat/server the processes need to coordinate fetching and sharing the user input depending on at which point the processes get spawned. Synchronization points between the processes should be minimized for optimal performance.
The design goals of the integration are:

  • Support all CLI features of TorchChat (generate, chat, server)
  • Minimize code duplication
  • Maintain TorchChat's copy/pastebility

Alternatives

Option 1: Integrate at Model Level
While the usage of a tensor parallel model in PyTorchis is very much transparent, the current pipeline parallel API differs significantly from the usage of a local model. This option hides the distributed inference from the Generator class by introducing the distributed inference inside a torchchat.model.Model derivative. The DistributedModel(torchchat.model.Model) class would implement methods like call() and forward() and handle distribution to the worker processes inside.

  • Pros:
    • Code reuse high
    • Transparent use of distributed model
    • Virtually no changes in main Generator and OpenAiApiGenerator necessary
  • Cons:
    • In this scenario, sampling happens in the main script and thus the return value of the model (logits) need to be transferred between processes (i.e. moved to shared GPU memory)
    • As the Generator is unaware of the parallelism the subprocesses would need to be spawned inside the model itself which is kind of ugly

Option 2: Abstract Base Class for Generator
Introduce a base class Generator which contains the common portions of the implementation generation process like getting and preparing input from the user. LocalGenerator and DistributedGenerator get introduced to handle specifics. The split between base and derivatives can be made at multiple levels, specifically High:Generator.generate, Mid:Generator.decode_n_tokens/prefill, Low: Generator.decode_one_token/prefill

  • Pros:
    • Introduces abstraction in the generation process
    • High code reuse
    • Subprocess creation for parallel workers can be on main script level
    • Added complexity stays mostly separate from local generation
  • Cons:
    • Splitting up the Generator from main generate.py file will hurt copy/pastebility
    • OpenAiApiGenerator (currently inherits from Generator) will require additional changes to work with distributed inference

Option 2b: Integrate at Low Level of Generator without base class
This approach skips the creation of a base class and directly inherits DistributedGenerator(Generator) and adds functionality for distributed inference in the main generate.py file.

  • Pros:
    • Fully reuses the functionality from existing Generator
    • Subprocess creation for parallel workers can be on main script level
    • Maintains copy/pastebility
  • Cons:
    • Some changes necessary in generate.py
    • OpenAiApiGenerator (inherits from Generator) will require additional changes to work with distributed inference

cc @Jack-Khuu @byjlw @lessw2020

Additional context

No response

RFC (Optional)

No response

@mreso mreso self-assigned this Nov 14, 2024
@mreso mreso added the RFC Request for Comment label Nov 14, 2024
@Jack-Khuu Jack-Khuu added the Distributed Issues related to all things distributed label Nov 16, 2024
@Jack-Khuu
Copy link
Contributor

Thanks for spinning this up!! Some initial thoughts (some of which we've chatted about offline, but resharing)

For Option 1, while not requiring changes to the Generator is really tempting, making the Model instance manage distribution/subprocesses themselves is a curious pattern. My gut says that this might come back to bite us either with network costs or with complexity process management being buried too deep (not in script)

For Option 2, my main reservation with regards to a Generator base class(es) with code is that the hierarchy makes it harder to pull code snippets out of the repo. With an abstract Generator base class(es), we keep the "copy/pasteability", but we're not really reusing code at that point

For Option 2b, it requires the least amount of refactoring (though we may refactor in H1, so refactoring isn't inherently bad) and the distributed logic is decently colocated making it easier to read/learn/copy.

@byjlw
Copy link
Contributor

byjlw commented Nov 19, 2024

I put this in the slack channel but also putting it here.

I wanted to show something I've been discussing with Jack. This is the direction we want to go. Specific details and what goes exactly where could change, but wanted to open the discussion. I think this will help you make an informed decision about how to integrate distributed with the current server and CLI.
We'll end up with 3ish modules pip packages that allow for different levels of abstraction based on use case.
The question is where does distributed live? I think it should probably be part of torchchat-core or it could be another module torchchat-distributed that depends on core. The real tradeoff is dependencies and size. if size is small and few additional dependencies then maybe core makes the most sense.

Uploading Screenshot 2024-11-19 at 1.54.31 PM.png…

@byjlw
Copy link
Contributor

byjlw commented Nov 19, 2024

I think option 2b makes the most sense right now given we have refactoring come down the line.
Right now the API/Server isn't particularly clean since things are duplicated between generate and the API. In the short/medium term we want the CLI and API to be making the exact same calls to the generate class with all the logic inside that class.

@mreso
Copy link
Contributor Author

mreso commented Nov 20, 2024

Thanks @byjlw Implementation of 2b lives in this #1382

@Jack-Khuu Jack-Khuu added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Distributed Issues related to all things distributed RFC Request for Comment triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants