Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add APIs to offload states of model, optimizer, and engine #6011

Merged
merged 28 commits into from
Sep 27, 2024

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Aug 16, 2024

This PR adds the following APIs to offload model, optimizer, and engine states.

def offload_states(self,
                   include: Container[OffloadStateTypeEnum] = None,
                   device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
                   pin_memory: bool = True,
                   non_blocking: bool = False) -> None:
    """Move the ZeRO optimizer buffers to the specified device.

    Arguments:
        include: Optional. The set of states to offload. If not provided, all states are offloaded.
        device: Optional. The device to move the ZeRO optimizer buffers to.
        pin_memory: Optional. Whether to pin the memory of the offloaded states.
        non_blocking: Optional. Whether to offload the states asynchronously.
...
def offload_states_back(self, non_blocking: bool = False) -> None:

Here is the typical usage.

# Offload after forward, backward, and step
model.offload_states()
# Do something requiring a lot of device memory
...
# Load states back to device memory
model.offload_states_back()

You can selectively offload states to balance the offloading overhead and memory saving.

model.offload_states(include=set([OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.opt_states], device=OffloadDeviceEnum.cpu)

Performance (4.3B parameters / 4x A100)

  • Environment (4x A100, benchmark script)
    • Average Device to Host transfer time: 2.45 GB/s, aggregated: 9.79 GB/s
    • Average Host to Device transfer: 11.05 GB/s, aggregated: 44.19 GB/s
  • Mem (allocated by PyTorch)
    • Before offload 18.2GB
    • After offloading 17.7MB
  • Time (benchmark script, offloading time/loading time)

python output_table.py

pin_memory=0 non_blocking=0 pin_memory=0 non_blocking=1 pin_memory=1 non_blocking=0 pin_memory=1 non_blocking=1
1 4.34 / 3.42 4.99 / 2.37 6.5 / 2.42 6.0 / 2.39
2 9.9 / 3.28 5.1 / 2.34 6.21 / 2.42 6.25 / 2.45
3 9.92 / 3.19 6.71 / 2.35 6.33 / 2.38 5.93 / 2.42
4 9.55 / 2.82 7.11 / 2.39 6.9 / 2.38 6.5 / 2.43
5 4.4 / 3.35 6.04 / 2.41 6.26 / 2.41 6.32 / 2.47
6 4.4 / 3.57 6.58 / 2.42 6.88 / 2.4 6.35 / 2.43
7 9.51 / 3.12 6.9 / 2.39 6.9 / 2.39 6.46 / 2.4
8 4.77 / 3.64 6.69 / 2.39 7.39 / 2.42 6.56 / 2.46
9 9.5 / 3.07 7.18 / 2.42 6.67 / 2.39 7.38 / 2.46

TODO:

  • Enable offloading to a NVMe storage -> NVMe support is non-trivial. I suggest adding the support in another PR
  • [DONE] Discard buffer (and recreate it) instead of offloading. We don't need to restore the contiguous buffer for reduce.
  • [DONE] Check pin_memory improves performance or not

deepspeed/runtime/utils.py Outdated Show resolved Hide resolved
@tohtana
Copy link
Contributor Author

tohtana commented Sep 4, 2024

@tjruwase Added the document.

@kfertakis
Copy link

kfertakis commented Sep 12, 2024

Hi @tohtana ,

Thank you for your work. I've been trying the new APIs to test model offloading in a multi-model deployment (e.g., deepspeed-chat) as part of #5620 . Although the API works in offloading a model and reducing GPU memory initially, after bringing the model back and completing the first training iteration (i.e., optimiser states have been updated), I get a RuntimeError: param {} still in flight exception when trying to offload the model again. I thus wanted to ask whether you think this has something to do with a misuse of the API from my end or if you could provide some further context. I'm providing the relevant stack trace below: Thank you again.

[rank0]: Traceback (most recent call last):
[rank0]:   File "training_script.py", line 173, in gen_function
[rank0]:     self.model_engine.offload_states()
[rank0]:   File "/home/user/DeepSpeed/deepspeed/runtime/engine.py", line 3710, in offload_states
[rank0]:     self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking)
[rank0]:   File "/home/user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2794, in offload_states
[rank0]:     self.empty_partition_cache()
[rank0]:   File "/home/user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2785, in empty_partition_cache
[rank0]:     self.parameter_offload.empty_partition_cache()
[rank0]:   File "/home/user/DeepSpeed/deepspeed/runtime/zero/parameter_offload.py", line 181, in empty_partition_cache
[rank0]:     self.partition_all_parameters()
[rank0]:   File "/home/user/DeepSpeed/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/user/DeepSpeed/deepspeed/runtime/zero/parameter_offload.py", line 159, in partition_all_parameters
[rank0]:     self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
[rank0]:   File "/home/user/DeepSpeed/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/user/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/user/DeepSpeed/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 412, in release_and_reset_all
[rank0]:     raise RuntimeError(f"param {param.ds_summary()} still in flight")
[rank0]: RuntimeError: param {'id': 1, 'status': 'INFLIGHT', 'numel': 4198400, 'ds_numel': 4198400, 'shape': (2050, 2048), 'ds_shape': (2050, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape': torch.Size([4198400])} still in flight

@tohtana
Copy link
Contributor Author

tohtana commented Sep 12, 2024

Thank you for reporting, @kfertakis!

I have an example script showing the usage of the APIs. Can you try this?
I suspect that ZeRO3 fails to clean the partitioning status for some models. I would like to clarify that your issue is model specific or not.

@kfertakis
Copy link

So I tested the issue again with various models and it seems the problem is model-size related as it does not seem to occur for smaller models (i.e., <= 1B params, e.g., gpt2, gpt2-medium) and it does for bigger ones(i.e., OPT-1.3B, mistral-7B). Is there anything I could do to investigate it further and debug it? By the way, I should mention that I'm testing this in a single node, single GPU configuration (i.e., single worker) thus ZeRO3 partitioning should not have to partition data across other workers. I will also test the benchmark you referenced with an artificially larger model size setting.

Thanks again.

@tohtana
Copy link
Contributor Author

tohtana commented Sep 17, 2024

Hi @kfertakis, I tried this example with a 4B model but it worked. Can you try this on your environment?
It would be also great if you could offer us a simple repo.

@tjruwase
Copy link
Contributor

in flight exception when trying to offload the model again. I thus wanted to ask whether you think this has something to do with a misuse of the API from my end or if you could provide some further context. I'm providing the relevant stack trace below: Thank you again.

@tohtana, I wonder if it is useful to expose validate_device() functionality as a deepspeed utility, so that clients can check/confirm the offload status at arbitrary points in their code?

def validate_device(model, device: torch.device, include) -> None:

Similar to how see_memory_usage enables inspection of HBM/DRAM usage at any point, we could provide mechanisms for offload status. Perhaps we need something like see_offload_status that displays the mapping of params, grads, and optimizer to {HBM, DRAM, NVMe}.

@kfertakis, would love to get your thoughts as well on whether any of the above would be useful? Thanks!

@kfertakis
Copy link

Hey, thanks for the comments.

@tohtana, I've tried the example you provided and it does seem to work so I'm sharing a fork of the DeepSpeed-Examples repo to showcase the problem. I've modified the DeepSpeed-Chat code to use offload_states. After you prepare an environment with the right deepspeed version for the new API and also install DeepSpeed-Chat, you can then run the following:

deepspeed --num_gpus=1 ./applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py --actor_model_name_or_path facebook/opt-1.3b --critic_model_name_or_path facebook/opt-1.3b --actor_zero_stage 3 --critic_zero_stage 3 --num_padding_at_beginning 1 --data_path Dahoas/rm-static --per_device_generation_batch_size 2 --per_device_training_batch_size 2 --generation_batches 1 --ppo_epochs 1 --max_answer_seq_len 512 --max_prompt_seq_len 512 --gradient_accumulation_steps 1 --actor_dropout 0.0 --deepspeed --dtype bf16 --enable_hybrid_engine --offload_test

this should lead to the RuntimeError: param {} still in flight that I mentioned. Any thoughts on this would be much appreciated.

@tjruwase thanks for the reference. Current problem aside, I can see how the helper functions can be useful in the future for ensuring consistency. thanks.

@tohtana
Copy link
Contributor Author

tohtana commented Sep 21, 2024

Hi @kfertakis, thank you for sharing the repro. It seems that the actual issue is related to ZeRO3's prefetching.

I opened #6557 as a workaround to address this issue. Can you try the branch tohtana/clean_up_prefetch_param? It also includes the offloading APIs. You can just switch to it.

@kfertakis
Copy link

Hi @tohtana,

thank you for your work. I tried your branch and the issue seems to be fixed. I will continue testing and raise any new issues but for now, the offload_states API seems to be working as expected. Many thanks.

@kfertakis
Copy link

I also wanted to ask if the offloading functionality could be extended to support DeepSpeedCPUAdam optimiser, besides FusedAdam, in the future for offloading a model with an optimizer which is already offloaded to the CPU? Thank you

@tohtana
Copy link
Contributor Author

tohtana commented Sep 27, 2024

I wonder if it is useful to expose validate_device() functionality as a deepspeed utility, so that clients can check/confirm the offload status at arbitrary points in their code?

@tjruwase Let me address this by another PR after this one is merged.

@tohtana tohtana added this pull request to the merge queue Sep 27, 2024
@tohtana
Copy link
Contributor Author

tohtana commented Sep 27, 2024

Thank you @kfertakis for validating the fix.

I also wanted to ask if the offloading functionality could be extended to support DeepSpeedCPUAdam optimiser, besides FusedAdam, in the future for offloading a model with an optimizer which is already offloaded to the CPU? Thank you

Let me consider how to do this. Please feel free to open a new issue to track it as I am going to merge this PR first.

Merged via the queue into master with commit 047bcf6 Sep 27, 2024
14 checks passed
@kfertakis
Copy link

Thank you @tohtana for completing and merging the feature. I've opened two additional requests #6595 , #6596 to track the relevant extensions we discussed above. thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants