-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Break down parallelize_llama for inference cases #402
Conversation
[ghstack-poisoned]
ghstack-source-id: d8a32ad293ce8f1fafa141e3bbfa06654db75910 Pull Request resolved: #402
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change -- making them more modular is great!!
I had some comments -- basically I think we should make each sub-function call more modular and only pass in related arguments and configs, leaving experimental configs and interacting flags in parallelize_llama
.
Breaking up `parallelize_llama` into: - `apply_tp` - `apply_ac` - `apply_compile` - `apply_dp` This is for functionality reuse in inference cases, because one would not need activation checkpointing or DP there. Can also improve code modularity and readability. [ghstack-poisoned]
ghstack-source-id: 9aeee4c063c63eed380cac219c9c8e1eb4169f9d Pull Request resolved: #402
Breaking up `parallelize_llama` into: - `apply_tp` - `apply_ac` - `apply_compile` - `apply_dp` This is for functionality reuse in inference cases, because one would not need activation checkpointing or DP there. Can also improve code modularity and readability. [ghstack-poisoned]
ghstack-source-id: 72e37e2e506af6115f9cb18179543dd6df602961 Pull Request resolved: #402
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for helping make it modulized!
if job_config.model.norm_type == "fused_rmsnorm": | ||
raise NotImplementedError( | ||
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be removed thanks to #404
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, noticed in CI. Removed.
Breaking up `parallelize_llama` into: - `apply_tp` - `apply_ac` - `apply_compile` - `apply_dp` This is for functionality reuse in inference cases, because one would not need activation checkpointing or DP there. Can also improve code modularity and readability. [ghstack-poisoned]
ghstack-source-id: fc8e221b5047337f59dea31f2c51d6168fe4fe88 Pull Request resolved: #402
ghstack-source-id: fc8e221b5047337f59dea31f2c51d6168fe4fe88 Pull Request resolved: #402
ghstack-source-id: fc8e221b5047337f59dea31f2c51d6168fe4fe88 Pull Request resolved: pytorch#402
ghstack-source-id: fc8e221b5047337f59dea31f2c51d6168fe4fe88 Pull Request resolved: pytorch#402
Stack from ghstack (oldest at bottom):
Breaking up
parallelize_llama
into:apply_tp
apply_ac
apply_compile
apply_dp
This is for functionality reuse in inference cases, because one would not need activation checkpointing or DP there.
Can also improve code modularity and readability.