-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimizer offloading through weight-only offload #867
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
axlearn/common/optimizers.py
Outdated
Only wrap the optimizer that you actually want to offload with this function to avoid | ||
unneseccary overhead. This is usually the optimizer that occupies the most HBM. For example, | ||
when you have chained optimizers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does the overhead come from? Is it from the states of clip_by_global_norm
being offloaded? If so, could we use regular expressions to specify which states to offload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A question about device_put...
Before the optimizer can be invoked, the offloaded optimizer states need to be transferred to device memory space. If we remove these device_put calls, we will get errors like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification on device_put
calls. Could you add a comment on why it's necessary? Also two suggestions...
3adee99
to
5ea7bb4
Compare
@markblee Can you take a look at the pytype errors in CI? Should I change Nested type to include tuple/namedtuple or should I just ignore the errors? |
Maybe we can relax the partition fn return type to include named tuple? Changing Nested may have undesirable impact elsewhere. |
Thanks. I included the return type in Nested[...]. Just NamedTuple wouldn't work. |
This PR requires jax >= 0.4.34, != 0.4.35, >=0.4.36: it works on jax 0.4.34, but is broken on jax 0.4.35 due to libtpu bug. It worked on nightly jax 0.4.36 as of 10/30.
This PR represents effort to enable optimizer offloading. The approach we use in this PR is weight-only offloading, which is based on similar building blocks as activation offloading (aka remat offload). When offloading is enabled, optimizer states are stored on CPU pinned memory. Before apply optimizer to calculate updates, optimizer states are moved from CPU memory to HBM via
jax.device_put
. The new optimizer states are moved back from HBM to CPU.An alternative approach to this PR is host computation. Host computation means that optimizer transformations are computed on CPU. Before the start of the computation, gradients and weights are transferred to CPU, and after the computation, their new values are transferred back to HBM. This method has lower HBM footprint, but it's much 2x ~ 3x slower due to slow CPU computation. Also, it's very buggy.
TLDR: to be merged after upgrading jax to 0.4.36.