Skip to content

Commit

Permalink
Remove force CPU offload.
Browse files Browse the repository at this point in the history
Signed-off-by: Dennis Liu <[email protected]>
  • Loading branch information
Victarry committed Feb 13, 2025
1 parent 3ffd732 commit 6a2d88a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ def bulk_offload_group(self, group_to_offload):
if self.tensor_need_offloading_checker(tensor_on_device):
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state
tensor_on_device.data = torch.Tensor() # Force to release memory

def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
Expand Down Expand Up @@ -517,9 +516,14 @@ def tensor_need_offloading_checker_all(tensor):
"mentioned what to offload (weights/activations)"
)

cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
# cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
# num_offload_group=num_layers,
# num_model_group=model_layers,
# tensor_need_offloading_checker=tensor_need_offloading_checker,
# )

cpu_offload_handler = SynchronizedGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker,
)

Expand Down

0 comments on commit 6a2d88a

Please sign in to comment.