Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PPO] feat: Add LoRA support for PPO #205

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,5 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt

# vim
*.swp

.venv/
6 changes: 6 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ actor_rollout_ref:
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: False
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: all-linear # Target modules for LoRA adaptation
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
Expand Down Expand Up @@ -110,6 +113,9 @@ critic:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: all-linear # Target modules for LoRA adaptation
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
Expand Down
13 changes: 10 additions & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def __init__(self,
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls

# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get('lora_rank', 0) > 0


# define KL control
if self.use_reference_policy:
if config.algorithm.kl_ctrl.type == 'fixed':
Expand Down Expand Up @@ -474,7 +478,7 @@ def init_workers(self):
raise NotImplementedError

# create reference policy if needed
if self.use_reference_policy:
if self.use_reference_policy and not self.ref_in_actor:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
Expand Down Expand Up @@ -506,7 +510,7 @@ def init_workers(self):
self.critic_wg = all_wg['critic']
self.critic_wg.init_model()

if self.use_reference_policy:
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg['ref']
self.ref_policy_wg.init_model()

Expand Down Expand Up @@ -614,7 +618,10 @@ def fit(self):
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
Expand Down
104 changes: 78 additions & 26 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,22 @@
from verl.utils.model import compute_position_id_with_mask
from verl.utils.flops_counter import FlopsCounter
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from peft import LoraConfig, TaskType, get_peft_model
from codetiming import Timer

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))

def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj

def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
Expand Down Expand Up @@ -98,6 +108,8 @@ def __init__(self, config: DictConfig, role: str):

self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)

self._is_lora = self.config.model.get('lora_rank', 0) > 0

self.role = role
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']

Expand Down Expand Up @@ -137,6 +149,7 @@ def __init__(self, config: DictConfig, role: str):
self.ulysses_sequence_parallel_size)
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size


def _build_model_optimizer(self,
model_path,
fsdp_config,
Expand Down Expand Up @@ -209,6 +222,18 @@ def _build_model_optimizer(self,

if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
if self._is_lora:
print("Applying LoRA to actor module")
actor_module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'bias': "none"
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
torch.distributed.barrier()

if self.rank == 0:
Expand All @@ -229,7 +254,7 @@ def _build_model_optimizer(self,

mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)

auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None))
auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.model.get('lora_rank', 0) > 0)

if self._is_rollout and self.config.rollout.name == 'hf':
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
Expand Down Expand Up @@ -290,7 +315,6 @@ def _build_rollout(self):
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}'
rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp'])

if self.config.rollout.name == 'hf':
from verl.workers.rollout import HFRollout
from verl.workers.sharding_manager import BaseShardingManager
Expand Down Expand Up @@ -468,36 +492,48 @@ def generate_sequences(self, prompts: DataProto):
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
def compute_log_prob(self, data: DataProto, no_lora=False):
# when no_lora is True, we use the actor without lora applied to calculate the log_prob
# which is mostly used for ref log_prob calculation
assert self._is_actor
data = data.to('cuda')
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info['temperature'] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={'old_log_probs': output},
meta_info={'temperature': self.config.rollout.temperature})
output = self.ulysses_sharding_manager.postprocess_data(output)
from contextlib import nullcontext
adapter_ctx = self.actor.actor_module.disable_adapter() if no_lora else nullcontext()
with adapter_ctx:
data = data.to('cuda')
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info['temperature'] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={'old_log_probs': output},
meta_info={'temperature': self.config.rollout.temperature})
output = self.ulysses_sharding_manager.postprocess_data(output)

output = output.to('cpu')
output = output.to('cpu')

# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)

torch.cuda.empty_cache()
return output
torch.cuda.empty_cache()
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
if self._is_lora:
# if _is_lora, actor without lora applied is the ref
data = self.compute_log_prob(data, no_lora=True)
# this old_log_probs is in fact ref_log_prob
data = DataProto.from_dict(tensors={'ref_log_prob': data.batch['old_log_probs']})
return data
assert self._is_ref

# else:
# otherwise, the class have a standalone ref model
data = data.to('cuda')

micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
Expand Down Expand Up @@ -592,6 +628,8 @@ def __init__(self, config):
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0

self._is_lora = self.config.model.get('lora_rank', 0) > 0

def _build_critic_model_optimizer(self, config):
# the following line is necessary
Expand Down Expand Up @@ -653,6 +691,20 @@ def _build_critic_model_optimizer(self, config):

if config.model.get('enable_gradient_checkpointing', False):
critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})

if self._is_lora:
print("Applying LoRA to critic module")
critic_module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'bias': "none"
}
critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))

if self.rank == 0:
print_model_size(critic_module)

Expand All @@ -671,7 +723,7 @@ def _build_critic_model_optimizer(self, config):

mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)

auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy)
auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self.config.model.get('lora_rank', 0) > 0)

log_gpu_memory_usage('Before critic FSDP', logger=None)

Expand Down
17 changes: 16 additions & 1 deletion verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os
import logging
import torch
from peft import PeftModel
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from torch.distributed.device_mesh import DeviceMesh
from collections import OrderedDict

from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps
Expand Down Expand Up @@ -68,13 +70,26 @@ def __init__(self,

def __enter__(self):
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()

if isinstance(self.module._fsdp_wrapped_module, PeftModel):
# the model to sync weights to is a vLLM model (not a peft model), so we need to merge the adapters
with FSDP.summon_full_params(self.module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summon full params may cause OOM. @PeterSH6 Is there a better approach, that can merge lora weights in sharded form, or at least, one parameter after another to support large models?

self.module.merge_adapter()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merge_adapter is not releasing the same original model structure as before. Is there any other way to merge and get the original base model structure?

params = self.module._fsdp_wrapped_module.base_model.model.state_dict()
# FIXME: use more rigorous way to filter out the adapter weights
params = OrderedDict((k.replace(".base_layer.", "."), v) for k, v in params.items() if not ".lora_" in k)
else:
params = self.module.state_dict()

log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory
load_format = 'hf' if self.full_params else 'dtensor'
self.inference_engine.sync_model_weights(params, load_format=load_format)
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)

if isinstance(self.module._fsdp_wrapped_module, PeftModel):
with FSDP.summon_full_params(self.module):
self.module.unmerge_adapter()
del params
torch.cuda.empty_cache()
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
Expand Down