Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer offloading through weight-only offload #867

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

hanzhi713
Copy link
Member

This PR requires jax >= 0.4.34, != 0.4.35, >=0.4.36: it works on jax 0.4.34, but is broken on jax 0.4.35 due to libtpu bug. It worked on nightly jax 0.4.36 as of 10/30.

This PR represents effort to enable optimizer offloading. The approach we use in this PR is weight-only offloading, which is based on similar building blocks as activation offloading (aka remat offload). When offloading is enabled, optimizer states are stored on CPU pinned memory. Before apply optimizer to calculate updates, optimizer states are moved from CPU memory to HBM via jax.device_put. The new optimizer states are moved back from HBM to CPU.

An alternative approach to this PR is host computation. Host computation means that optimizer transformations are computed on CPU. Before the start of the computation, gradients and weights are transferred to CPU, and after the computation, their new values are transferred back to HBM. This method has lower HBM footprint, but it's much 2x ~ 3x slower due to slow CPU computation. Also, it's very buggy.

TLDR: to be merged after upgrading jax to 0.4.36.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

axlearn/common/optimizer_base.py Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
Comment on lines 2048 to 2050
Only wrap the optimizer that you actually want to offload with this function to avoid
unneseccary overhead. This is usually the optimizer that occupies the most HBM. For example,
when you have chained optimizers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the overhead come from? Is it from the states of clip_by_global_norm being offloaded? If so, could we use regular expressions to specify which states to offload?

axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
@hanzhi713 hanzhi713 requested a review from ruomingp December 4, 2024 21:11
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question about device_put...

axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
@hanzhi713
Copy link
Member Author

A question about device_put...

Before the optimizer can be invoked, the offloaded optimizer states need to be transferred to device memory space. If we remove these device_put calls, we will get errors like xxx is not supported on pined_host memory space, where xxx is some XLA primitive operations such as add (forgot the exact error message but is something like this)

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification on device_put calls. Could you add a comment on why it's necessary? Also two suggestions...

axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Show resolved Hide resolved
@hanzhi713 hanzhi713 requested a review from ruomingp December 5, 2024 19:16
axlearn/common/optimizers.py Show resolved Hide resolved
axlearn/common/optimizers.py Show resolved Hide resolved
@hanzhi713 hanzhi713 force-pushed the weight-only-offload-cleanup branch from 3adee99 to 5ea7bb4 Compare January 28, 2025 00:13
@hanzhi713 hanzhi713 requested a review from a team as a code owner January 28, 2025 00:13
@hanzhi713 hanzhi713 enabled auto-merge January 28, 2025 00:16
@hanzhi713 hanzhi713 requested a review from ruomingp January 28, 2025 00:19
@hanzhi713
Copy link
Member Author

@markblee Can you take a look at the pytype errors in CI? Should I change Nested type to include tuple/namedtuple or should I just ignore the errors?

@markblee
Copy link
Contributor

@markblee Can you take a look at the pytype errors in CI? Should I change Nested type to include tuple/namedtuple or should I just ignore the errors?

Maybe we can relax the partition fn return type to include named tuple? Changing Nested may have undesirable impact elsewhere.

@hanzhi713
Copy link
Member Author

Maybe we can relax the partition fn return type to include named tuple? Changing Nested may have undesirable impact elsewhere.

Thanks. I included the return type in Nested[...]. Just NamedTuple wouldn't work.

@hanzhi713 hanzhi713 requested a review from markblee January 28, 2025 05:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants