Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
kinman0224 committed Feb 15, 2025
1 parent d545ba7 commit 8e36b3e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 15 deletions.
19 changes: 6 additions & 13 deletions verl/models/qwen2/megatron/modeling_qwen2_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from torch.nn import init
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast
Expand Down Expand Up @@ -324,8 +323,8 @@ def forward(
batch_size, sequence_length = input_ids.shape

# remove padding here
input_ids, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)

# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
Expand Down Expand Up @@ -482,8 +481,6 @@ def forward(self,
"""
if self.pre_process:
# if torch.cuda.current_device() == 0:
# print(f'rank {torch.cuda.current_device()}: input_ids shape before embedding: {input_ids.shape}')
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)

# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
Expand All @@ -493,8 +490,6 @@ def forward(self,
if self.megatron_config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)

# if torch.cuda.current_device() == 0:
# print(f'rank {torch.cuda.current_device()}: input_embeds shape after embedding: {inputs_embeds.shape}')
hidden_states = inputs_embeds
else:
# self.hidden_states should be passed by Megatron
Expand Down Expand Up @@ -586,9 +581,9 @@ def forward(
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# print(f'input_ids.shape = {input_ids.shape}, input_ids_rmpad.shape = {input_ids_rmpad.shape}, indices.shape = {indices.shape}, cu_seqlens[-1] = {cu_seqlens[-1]}')
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)

# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
Expand All @@ -607,15 +602,13 @@ def forward(
hidden_states = outputs
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
logits = self._forward_head(hidden_states)
# print(f'logits.shape = {logits.shape}')
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back. If input is already rmpad, we let the caller pad_input
# print(f'logits.shape = {logits.shape}, indices.shape = {indices.shape}, batch_size = {batch_size}, seq_len = {sequence_length}')
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)

Expand Down Expand Up @@ -660,4 +653,4 @@ def forward(
output.logits = torch.squeeze(output.logits, dim=-1)
return output
else:
return output
return output
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -
'LlamaForCausalLM': llama_megatron_core_te_weight_loader, # use te backend for open-source megatron
'LLaMAForCausalLM': llama_megatron_core_te_weight_loader,
'MistralForCausalLM': mistral_megatron_weight_loader,
'Qwen2ForCausalLM': llama_megatron_weight_loader,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -
'LlamaForCausalLM': llama_megatron_weight_loader, # use te backend for open-source megatron
'LLaMAForCausalLM': llama_megatron_weight_loader,
'MistralForCausalLM': mistral_megatron_weight_loader,
'Qwen2ForCausalLM': llama_megatron_weight_loader,
}


Expand Down

0 comments on commit 8e36b3e

Please sign in to comment.