From ac31baa8958f4e3011efe37f38b3d70aa577171c Mon Sep 17 00:00:00 2001 From: Stephen Xie <49465499+StephenXie@users.noreply.github.com> Date: Fri, 31 Jan 2025 11:50:17 -0800 Subject: [PATCH 1/8] Add initial LoRA to ppo fsdp actor --- .gitignore | 2 ++ verl/trainer/config/ppo_trainer.yaml | 6 ++++++ verl/workers/fsdp_workers.py | 29 +++++++++++++++++++++++++--- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index bce77bf0..e49dbd2e 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,5 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt # vim *.swp + +.venv/ \ No newline at end of file diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index b294a7cb..55d5e3aa 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -48,6 +48,9 @@ actor_rollout_ref: grad_offload: False optimizer_offload: False fsdp_size: -1 + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation ref: fsdp_config: param_offload: False @@ -110,6 +113,9 @@ critic: # transformer_layer_cls_to_wrap: None min_num_params: 0 fsdp_size: -1 + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: null diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1d3ce76e..8c2d9836 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -38,12 +38,22 @@ from verl.utils.model import compute_position_id_with_mask from verl.utils.flops_counter import FlopsCounter from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - +from peft import LoraConfig, get_peft_model from codetiming import Timer logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +def convert_to_regular_types(obj): + """Convert Hydra configs and other special types to regular Python types.""" + from omegaconf import ListConfig, DictConfig + if isinstance(obj, (ListConfig, DictConfig)): + return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) + elif isinstance(obj, (list, tuple)): + return [convert_to_regular_types(x) for x in obj] + elif isinstance(obj, dict): + return {k: convert_to_regular_types(v) for k, v in obj.items()} + return obj def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: @@ -209,6 +219,18 @@ def _build_model_optimizer(self, if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + if self.config.actor.get('lora_rank', 0) > 0: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + 'task_type': TaskType.CAUSAL_LM, + 'r': self.config.actor.lora_rank, + 'lora_alpha': self.config.actor.lora_alpha, + 'target_modules': convert_to_regular_types(self.config.actor.target_modules), + 'bias': "none" + } + actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) torch.distributed.barrier() if self.rank == 0: @@ -229,7 +251,7 @@ def _build_model_optimizer(self, mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None)) + auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.actor.get('lora_rank', 0) > 0) if self._is_rollout and self.config.rollout.name == 'hf': # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma @@ -290,7 +312,6 @@ def _build_rollout(self): dp = self.world_size // infer_tp assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) - if self.config.rollout.name == 'hf': from verl.workers.rollout import HFRollout from verl.workers.sharding_manager import BaseShardingManager @@ -367,7 +388,9 @@ def init_model(self): actor_optimizer=self.actor_optimizer) if self._is_rollout: + #TODO merge adapter self.rollout, self.rollout_sharding_manager = self._build_rollout() + #TODO unmerge adapter if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, From 76b1875b804c022406f9cab724e04f9a5c03f22d Mon Sep 17 00:00:00 2001 From: Tony Lian Date: Mon, 3 Feb 2025 09:49:19 -0800 Subject: [PATCH 2/8] Add LoRA in critic and adjust the config format --- verl/trainer/config/ppo_trainer.yaml | 6 ++--- verl/workers/fsdp_workers.py | 40 ++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 55d5e3aa..a473e88a 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -18,6 +18,9 @@ actor_rollout_ref: override_config: { } enable_gradient_checkpointing: True use_remove_padding: False + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 256 @@ -48,9 +51,6 @@ actor_rollout_ref: grad_offload: False optimizer_offload: False fsdp_size: -1 - lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) - lora_alpha: 16 # LoRA scaling factor - target_modules: all-linear # Target modules for LoRA adaptation ref: fsdp_config: param_offload: False diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 8c2d9836..1c2c76b5 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -38,7 +38,7 @@ from verl.utils.model import compute_position_id_with_mask from verl.utils.flops_counter import FlopsCounter from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from peft import LoraConfig, get_peft_model +from peft import LoraConfig, TaskType, get_peft_model from codetiming import Timer logger = logging.getLogger(__file__) @@ -147,6 +147,8 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size) self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + self._is_lora = self.config.model.get('lora_rank', 0) > 0 + def _build_model_optimizer(self, model_path, fsdp_config, @@ -219,15 +221,15 @@ def _build_model_optimizer(self, if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) - if self.config.actor.get('lora_rank', 0) > 0: + if self._is_lora: print("Applying LoRA to actor module") actor_module.enable_input_require_grads() # Convert config to regular Python types before creating PEFT model lora_config = { 'task_type': TaskType.CAUSAL_LM, - 'r': self.config.actor.lora_rank, - 'lora_alpha': self.config.actor.lora_alpha, - 'target_modules': convert_to_regular_types(self.config.actor.target_modules), + 'r': self.config.model.lora_rank, + 'lora_alpha': self.config.model.lora_alpha, + 'target_modules': convert_to_regular_types(self.config.model.target_modules), 'bias': "none" } actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) @@ -251,7 +253,7 @@ def _build_model_optimizer(self, mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.actor.get('lora_rank', 0) > 0) + auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.model.get('lora_rank', 0) > 0) if self._is_rollout and self.config.rollout.name == 'hf': # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma @@ -388,9 +390,15 @@ def init_model(self): actor_optimizer=self.actor_optimizer) if self._is_rollout: - #TODO merge adapter + if self._is_lora: + print("Merge adapter") + print(self.actor_module_fsdp._fsdp_wrapped_module) + import ipdb; ipdb.set_trace() + self.actor_module_fsdp.merge_adapter() self.rollout, self.rollout_sharding_manager = self._build_rollout() - #TODO unmerge adapter + if self._is_lora: + print("Unmerge adapter") + self.actor_module_fsdp.unmerge_adapter() if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, @@ -615,6 +623,8 @@ def __init__(self, config): self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 + + self._is_lora = self.config.model.get('lora_rank', 0) > 0 def _build_critic_model_optimizer(self, config): # the following line is necessary @@ -676,6 +686,20 @@ def _build_critic_model_optimizer(self, config): if config.model.get('enable_gradient_checkpointing', False): critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + + if self._is_lora: + print("Applying LoRA to critic module") + critic_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + 'task_type': TaskType.CAUSAL_LM, + 'r': self.config.model.lora_rank, + 'lora_alpha': self.config.model.lora_alpha, + 'target_modules': convert_to_regular_types(self.config.model.target_modules), + 'bias': "none" + } + critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + if self.rank == 0: print_model_size(critic_module) From b2a28245b503440e151f22e8481fed256b134e07 Mon Sep 17 00:00:00 2001 From: Tony Lian Date: Tue, 4 Feb 2025 21:11:14 -0800 Subject: [PATCH 3/8] Update peft implementation --- verl/workers/fsdp_workers.py | 8 -------- verl/workers/sharding_manager/fsdp_vllm.py | 17 ++++++++++++++++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1c2c76b5..001341a7 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -390,15 +390,7 @@ def init_model(self): actor_optimizer=self.actor_optimizer) if self._is_rollout: - if self._is_lora: - print("Merge adapter") - print(self.actor_module_fsdp._fsdp_wrapped_module) - import ipdb; ipdb.set_trace() - self.actor_module_fsdp.merge_adapter() self.rollout, self.rollout_sharding_manager = self._build_rollout() - if self._is_lora: - print("Unmerge adapter") - self.actor_module_fsdp.unmerge_adapter() if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 19490f4e..46de0daf 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -15,9 +15,11 @@ import os import logging import torch +from peft import PeftModel from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig from torch.distributed.device_mesh import DeviceMesh +from collections import OrderedDict from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps @@ -68,13 +70,26 @@ def __init__(self, def __enter__(self): log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger) - params = self.module.state_dict() + + if isinstance(self.module._fsdp_wrapped_module, PeftModel): + # the model to sync weights to is a vLLM model (not a peft model), so we need to merge the adapters + with FSDP.summon_full_params(self.module): + self.module._fsdp_wrapped_module.merge_adapter() + params = self.module._fsdp_wrapped_module.base_model.model.state_dict() + # FIXME: use more rigorous way to filter out the adapter weights + params = OrderedDict((k.replace(".base_layer.", "."), v) for k, v in params.items() if not ".lora_" in k) + else: + params = self.module.state_dict() + log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) # Copy, not share memory load_format = 'hf' if self.full_params else 'dtensor' self.inference_engine.sync_model_weights(params, load_format=load_format) log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) + if isinstance(self.module._fsdp_wrapped_module, PeftModel): + with FSDP.summon_full_params(self.module): + self.module._fsdp_wrapped_module.unmerge_adapter() del params torch.cuda.empty_cache() log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) From 50be77d00b2a99cb2d564b5b55187261aeb266cd Mon Sep 17 00:00:00 2001 From: Tony Lian Date: Tue, 4 Feb 2025 21:35:36 -0800 Subject: [PATCH 4/8] Update get_fsdp_wrap_policy for the critic --- verl/workers/fsdp_workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 001341a7..1edb4784 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -710,7 +710,7 @@ def _build_critic_model_optimizer(self, config): mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy) + auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self.config.model.get('lora_rank', 0) > 0) log_gpu_memory_usage('Before critic FSDP', logger=None) From 25fd586ded433568348347269eb1bdaffdf369dc Mon Sep 17 00:00:00 2001 From: Jiayi Pan Date: Wed, 5 Feb 2025 06:49:08 +0000 Subject: [PATCH 5/8] actor ref lora 2in1 wip --- verl/workers/fsdp_workers.py | 58 ++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1edb4784..d128ed32 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -108,6 +108,8 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self._is_lora = self.config.model.get('lora_rank', 0) > 0 + self.role = role assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref'] @@ -147,7 +149,6 @@ def __init__(self, config: DictConfig, role: str): self.ulysses_sequence_parallel_size) self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size - self._is_lora = self.config.model.get('lora_rank', 0) > 0 def _build_model_optimizer(self, model_path, @@ -491,36 +492,47 @@ def generate_sequences(self, prompts: DataProto): return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_log_prob(self, data: DataProto): + def compute_log_prob(self, data: DataProto, no_lora=False): + # when no_lora is True, we use the actor without lora applied to calculate the log_prob + # which is mostly used for ref log_prob calculation assert self._is_actor - data = data.to('cuda') - # we should always recompute old_log_probs when it is HybridEngine - data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info['temperature'] = self.config.rollout.temperature - # perform recompute log_prob - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.actor.compute_log_prob(data=data) - output = DataProto.from_dict(tensors={'old_log_probs': output}, - meta_info={'temperature': self.config.rollout.temperature}) - output = self.ulysses_sharding_manager.postprocess_data(output) + from contextlib import nullcontext + adapter_ctx = self.actor.actor_module.disable_adapter() if no_lora else nullcontext() + with adapter_ctx: + data = data.to('cuda') + # we should always recompute old_log_probs when it is HybridEngine + data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info['temperature'] = self.config.rollout.temperature + # perform recompute log_prob + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data) + output = self.actor.compute_log_prob(data=data) + output = DataProto.from_dict(tensors={'old_log_probs': output}, + meta_info={'temperature': self.config.rollout.temperature}) + output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') + output = output.to('cpu') - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - if self.world_size > 1: - self.actor.actor_module._handle.reshard(True) + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + self.actor.actor_module._handle.reshard(True) - torch.cuda.empty_cache() - return output + torch.cuda.empty_cache() + return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref - + if self._is_lora: + # TODO + pass + # if _is_lora, actor without lora applied is the ref + # return self.compute_log_prob(data, no_lora=True) + # else: + # otherwise, the class have a standalone ref model data = data.to('cuda') micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu From e01412a967f8022ad542d80ac1468a08c6eff84a Mon Sep 17 00:00:00 2001 From: Jiayi Pan Date: Wed, 5 Feb 2025 07:02:27 +0000 Subject: [PATCH 6/8] minor fix --- verl/workers/sharding_manager/fsdp_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 46de0daf..b640e1a0 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -74,7 +74,7 @@ def __enter__(self): if isinstance(self.module._fsdp_wrapped_module, PeftModel): # the model to sync weights to is a vLLM model (not a peft model), so we need to merge the adapters with FSDP.summon_full_params(self.module): - self.module._fsdp_wrapped_module.merge_adapter() + self.module.merge_adapter() params = self.module._fsdp_wrapped_module.base_model.model.state_dict() # FIXME: use more rigorous way to filter out the adapter weights params = OrderedDict((k.replace(".base_layer.", "."), v) for k, v in params.items() if not ".lora_" in k) From a0be6d967554dd281c6b3bcc9890f82cbd8f82cc Mon Sep 17 00:00:00 2001 From: Jiayi Pan Date: Wed, 5 Feb 2025 07:18:49 +0000 Subject: [PATCH 7/8] clean up fsdp lora logic --- verl/workers/sharding_manager/fsdp_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index b640e1a0..5f162eda 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -89,7 +89,7 @@ def __enter__(self): if isinstance(self.module._fsdp_wrapped_module, PeftModel): with FSDP.summon_full_params(self.module): - self.module._fsdp_wrapped_module.unmerge_adapter() + self.module.unmerge_adapter() del params torch.cuda.empty_cache() log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger) From cbc9e5d02d5d909f36f32829037f8381b869f8bf Mon Sep 17 00:00:00 2001 From: Jiayi Pan Date: Wed, 5 Feb 2025 08:05:51 +0000 Subject: [PATCH 8/8] actor ref lora 2in1 --- verl/trainer/ppo/ray_trainer.py | 13 ++++++++++--- verl/workers/fsdp_workers.py | 9 +++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index a2251ec9..3df6ac9a 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -323,6 +323,10 @@ def __init__(self, self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get('lora_rank', 0) > 0 + + # define KL control if self.use_reference_policy: if config.algorithm.kl_ctrl.type == 'fixed': @@ -474,7 +478,7 @@ def init_workers(self): raise NotImplementedError # create reference policy if needed - if self.use_reference_policy: + if self.use_reference_policy and not self.ref_in_actor: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, @@ -506,7 +510,7 @@ def init_workers(self): self.critic_wg = all_wg['critic'] self.critic_wg.init_model() - if self.use_reference_policy: + if self.use_reference_policy and not self.ref_in_actor: self.ref_policy_wg = all_wg['ref'] self.ref_policy_wg.init_model() @@ -614,7 +618,10 @@ def fit(self): if self.use_reference_policy: # compute reference log_prob with _timer('ref', timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d128ed32..033cded7 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -525,12 +525,13 @@ def compute_log_prob(self, data: DataProto, no_lora=False): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): - assert self._is_ref if self._is_lora: - # TODO - pass # if _is_lora, actor without lora applied is the ref - # return self.compute_log_prob(data, no_lora=True) + data = self.compute_log_prob(data, no_lora=True) + # this old_log_probs is in fact ref_log_prob + data = DataProto.from_dict(tensors={'ref_log_prob': data.batch['old_log_probs']}) + return data + assert self._is_ref # else: # otherwise, the class have a standalone ref model data = data.to('cuda')