Skip to content

Commit

Permalink
support non-pp
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Jan 23, 2025
1 parent 5f7a282 commit 94246f6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
52 changes: 33 additions & 19 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import paddle.nn as nn
from packaging import version
from paddle import framework
from paddle.distributed.fleet.meta_parallel import PipelineLayer

try:
from paddle.base import core
Expand Down Expand Up @@ -161,7 +162,11 @@
from .utils.async_save import AsyncSaver

try:
from .utils.flash_checkpoint import FlashCheckpointManager, get_fused_param_mappings
from .utils.flash_checkpoint import (
FlashCheckpointCallback,
FlashCheckpointManager,
get_fused_param_mappings,
)
except (ImportError, ModuleNotFoundError):
FlashCheckpointManager, get_fused_param_mappings = None, None
from .utils.helper import ( # nested_truncate,
Expand Down Expand Up @@ -350,8 +355,6 @@ def __init__(
)

if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel:
from paddle.distributed.fleet.meta_parallel import PipelineLayer

assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance(
model, PipelineLayer
), "Only support pipeline parallel mode when model is PipelineLayer!!!"
Expand Down Expand Up @@ -700,24 +703,35 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin
"""
assert isinstance(self.model, PretrainedModel), "model should be a PretrainedModel when using flash"
logger.info("Create flash checkpoint manager...")
pipeline_hooks_capacity = (
unwrapped_model.forward_pipeline_parallel_hook_capacity
+ unwrapped_model.backward_pipeline_parallel_hook_capacity
)
self.flash_checkpoint_manager = FlashCheckpointManager(
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
ema_coef=self.args.flash_save_ema_coef,
)
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
unwrapped_model.register_forward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
if isinstance(unwrapped_model, PipelineLayer):
pipeline_hooks_capacity = (
unwrapped_model.forward_pipeline_parallel_hook_capacity
+ unwrapped_model.backward_pipeline_parallel_hook_capacity
)
for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity):
unwrapped_model.register_backward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
self.flash_checkpoint_manager = FlashCheckpointManager(
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
ema_coef=self.args.flash_save_ema_coef,
)
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
unwrapped_model.register_forward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
)
for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity):
unwrapped_model.register_backward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
)
else:
pipeline_hooks_capacity = self.args.gradient_accumulation_steps
self.flash_checkpoint_manager = FlashCheckpointManager(
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
ema_coef=self.args.flash_save_ema_coef,
)
_callback = FlashCheckpointCallback(self.flash_checkpoint_manager)
self.add_callback(_callback)
if resume_from_checkpoint is not None:
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/trainer/utils/flash_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from paddle.optimizer.fusion_utils import FusionStorageHelper

from paddlenlp.trainer.trainer_callback import TrainerCallback
from paddlenlp.transformers.utils import device_guard
from paddlenlp.utils.env import (
CONFIG_NAME,
Expand Down Expand Up @@ -335,6 +336,14 @@ def restore_tensor_from_meta(self, tensor_meta):
return tensor


class FlashCheckpointCallback(TrainerCallback):
def __init__(self, flash_checkpoint_manager):
self.manager = flash_checkpoint_manager

def on_substep_end(self, args, state, control, **kwargs):
self.manager.flash_checkpoint_pipeline_hook(0)


class FlashCheckpointManager:
def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None):
assert worker_num > 0, "worker_num must be greater than 0"
Expand Down

0 comments on commit 94246f6

Please sign in to comment.