Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support vllm-0.6.6 #214

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
48 changes: 47 additions & 1 deletion chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from itertools import cycle
import math
import os
import torch

import ray
import ray.util.collective as col
Expand Down Expand Up @@ -751,12 +752,26 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0):

def alltoall_routed_expert_parameter(self, pipe_stage=0):
assert self._synchronizer is not None
import torch

comm_group = self.tensor_and_expert_parallel_group()
rank = torch.distributed.get_rank(group=comm_group)
world_size = torch.distributed.get_world_size(group=comm_group)
# with open(f"/workspace/code/cmd/moelite_scripts/{self.name}_{rank}_{world_size}.txt", "a+") as file:
# file.write(f"debug alltoall rank: {rank} in comm group {id(comm_group)}, num_params: {len(self._parameters_to_sync[pipe_stage])}" + "\n")
# breakpoint()
for name, param in self._parameters_to_sync[pipe_stage]:
param, state = self._synchronizer.alltoall_routed_experts(
name,
param,
self.tensor_and_expert_parallel_group()
self.tensor_and_expert_parallel_group(),
self.name,
rank,
world_size
)

# self._logger.info(f"debug {name} {param.shape} state: {state}")
# state = True
if state:
self._expert_sync_buffer.pop(name, "Not Found.")
self._expert_sync_buffer[name] = param
Expand Down Expand Up @@ -829,6 +844,7 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr
parameters_to_sync = self._parameters_to_recv[to_rank]
else:
parameters_to_sync = self._parameters_to_send
# self._logger.info(f"stage2 need to sync params: {len(parameters_to_sync[0])}")
else:
del self._sync_buffer
self._sync_buffer = defaultdict(list)
Expand Down Expand Up @@ -888,9 +904,12 @@ def tensor_generator():
)
dense_bucket_num = 0
sparse_bucket_num = 0
count = 0
for bucket_or_tensor, is_dense in bucket_generator:
if is_dense:
index = 0 if stage2 else (to_rank % self.tp_num_mapping)
# if stage2:
# self._logger.info(f"stage2 bucket: {len(bucket_or_tensor)} count: {count}")
all_buffers = coalesced_comm_dense_two_stage(
bucket_or_tensor, col.broadcast, rank,
extra_args=(src_rank, group_name), tensor_changed=tensor_changed,
Expand All @@ -903,11 +922,38 @@ def tensor_generator():
del value
self._sync_buffer[key] += cpu_value
del all_buffers
count += len(bucket_or_tensor)
# if stage2:
# self._logger.info(f"finished stage2 bucket_or_tensor: {len(bucket_or_tensor)} count: {count}")
dense_bucket_num += 1
else:
col.broadcast(bucket_or_tensor, src_rank, group_name)
sparse_bucket_num += 1

if stage2:
self._logger.info(f"debug finished stage2 comm")
else:
self._logger.info(f"debug finished stage1 comm")

check_rank = self.tensor_parallel_rank()
if False:#self.tensor_parallel_rank() == check_rank and stage2:# and check_rank not in [0, 1, 2, 3]:
if not isinstance(self.model, list):
model = [self.model]
else:
model = self.model
for item in model[0].named_parameters():
name, param = item
if "layers.0" in name:
print(f"debug output param {name} {param.shape}")
offset = 4
num_prints = param.shape[0] // offset
with open(f"/workspace/code/cmd/moelite_scripts/new/tp2ep4pp1_{check_rank}_{name}.txt", "a+") as file:
for i in range(num_prints):
start = offset * i
end = start + offset
tensor_to_print = param[start:end]
file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n")

debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger)

self.empty_cache()
Expand Down
38 changes: 15 additions & 23 deletions chatlearn/models/vllm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,18 @@
from .. import is_vllm_v2


if is_vllm_v2():
if importlib.util.find_spec("vllm"):
from . import ray_gpu_executor
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks import input_preprocess
from chatlearn.models.vllm.hooks import async_llm_engine
from chatlearn.models.vllm.hooks import llm
from chatlearn.models.vllm.hooks import loader
from chatlearn.models.vllm.hooks import worker_base
else:
if importlib.util.find_spec("vllm"):
import vllm
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion # pylint: disable=ungrouped-imports
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
from chatlearn.models.vllm.hooks import sampler
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
from chatlearn.models.vllm.hooks import llm_engine, logits_processor
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
from chatlearn.models.vllm.hooks import worker
else:
from chatlearn.models.vllm.hooks import input_preprocess
from chatlearn.models.vllm.hooks import format_device_name
if importlib.util.find_spec("vllm"):

from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion

if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
from chatlearn.models.vllm.hooks.vllm_0_3_0 import *
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
from chatlearn.models.vllm.hooks.vllm_0_5_1 import *
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks.vllm_0_6_3 import *
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_6:
from .vllm_0_6_6 import *
else:
raise RuntimeError(
f"vLLM version expected in {list(member.value for member in VLLMVersion)}, while {CURRENT_VLLM_VERSION}.")
62 changes: 0 additions & 62 deletions chatlearn/models/vllm/hooks/input_preprocess.py

This file was deleted.

21 changes: 21 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.3.0."""

from ... import is_vllm_v2

assert not is_vllm_v2(), "vLLM-0.3.0 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`."

from . import sampler
23 changes: 23 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.5.1."""

from ... import is_vllm_v2

assert not is_vllm_v2(), "vLLM-0.5.1 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`."

from . import llm_engine
from . import logits_processor
from . import worker
29 changes: 29 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Additional hooks of vllm-0.6.3."""

from ... import is_vllm_v2
from . import format_device_name
from . import input_preprocess

if is_vllm_v2():
from . import async_llm_engine
from . import llm
from . import loader
from . import ray_gpu_executor
from . import worker_base
else:
from . import llm_engine
from . import logits_processor
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine."""
"""del init_ray_cluster in AsyncLLMEngine."""

from typing import Dict, Optional

Expand Down
55 changes: 55 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 input preprocess to pass prompt text."""

# pylint: disable=unused-import,unused-argument
from vllm.inputs import preprocess
from vllm.inputs.parse import parse_singleton_prompt

def extract_prompt_components(
self,
prompt,
request_id,
lora_request=None):
'''
Extract the components of any single encoder or decoder input prompt.

Arguments:

* request_id
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts

Returns:

* prompt
* prompt_token_ids
* multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
'''
parsed = parse_singleton_prompt(prompt)

assert parsed["type"] == "tokens", \
f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}"

prompt_text = parsed["content"]["prompt"]
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")

return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)

preprocess.InputPreprocessor._extract_prompt_components = extract_prompt_components
30 changes: 30 additions & 0 deletions chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.5.1 llm_engine remove __reduce__ function."""

import inspect

# pylint: disable=unused-import,wildcard-import,unused-argument
from vllm.engine import llm_engine


source = inspect.getsource(llm_engine.LLMEngine.__reduce__)
if 'RuntimeError' in source:
def __reduce__(self):
# This is to ensure that the LLMEngine can be referenced in
# the closure used to initialize Ray worker actors
pass

del llm_engine.LLMEngine.__reduce__
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def init(self, load_config):

loader.DummyModelLoader.__init__ = init


# add ckpt loading of megatron format
def load_model(self, *, model_config,
device_config,
Expand Down
Loading