From 55a5dd9dcddd22c95d5bca124b7d6c61729a3247 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Mon, 27 May 2024 07:08:08 +0000 Subject: [PATCH 1/3] dist runtime opt source --- colossalai/inference/batch_bucket.py | 64 ++++++++ colossalai/inference/config.py | 16 +- colossalai/inference/core/request_handler.py | 6 +- colossalai/inference/core/rpc_engine.py | 54 +++++-- colossalai/inference/executor/rpc_worker.py | 152 ++++++++++++++++--- colossalai/inference/utils.py | 67 ++++++++ 6 files changed, 314 insertions(+), 45 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index f8571c0ca030..a344a9579cde 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -521,3 +521,67 @@ def fd_inter_tensor(self) -> None: def __repr__(self) -> str: return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})" + + +class RPCBatchBucket(BatchBucket): + def __init__(self, *args, **argv): + self.is_rpc = True + super().__init__(*args, **argv) + + # For compatibility + def get_1D_inputs(self) -> List[int]: + assert len(self._sequences_dict) > 0, "No sequence in the batch" + first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence + if first_seq.output_len == 0: + # Assume prefill stage + assert all( + seq.output_len == 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.input_token_id) + return out_li + else: + # Assume decoding stage + if self.use_spec_dec: + # For Speculative Decoding + # the number of tokens to be verified in parallel plus the correct token in the last step + return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1) + assert all( + seq.output_len > 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + assert self.is_compact, "BatchBucket is not compact" + out = [0] * self.current_batch_size + for seq_id, index_in_b in self._sequences_indexes.items(): + seq: Sequence = self._sequences_dict[seq_id] + out[index_in_b] = seq.output_token_id[-1] + return out + + # For compatibility + def get_sequence_lengths(self) -> List[int]: + assert self.is_compact # Debug usage + sequence_lengths = self.seq_lengths[: self.current_batch_size] + return sequence_lengths + + def get_1D_inputs_spec_dec(self, n: int) -> List[int]: + # Used for main model verification in **Decoding Stage** + # `n` is the number of tokens to be verified, + # and so that prepare the last `n` tokens of each sequence as the inputs + assert len(self._sequences_dict) > 0, "No sequence in the batch" + assert all( + seq.output_len >= n for seq in self._sequences_dict.values() + ), "Sequence output tokens must be greater than or equal to the number of tokens to be verified." + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.output_token_id[-n:]) + return out_li + + # For compatibility + def get_block_table_tensor(self) -> torch.Tensor: + assert self.is_compact # Debug usage + block_table = self.block_tables[: self.current_batch_size] + return block_table diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 61bc7c8abc9c..c5f7b61ce786 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -87,8 +87,12 @@ class InputMetaData(RPC_PARAM): def to_rpc_param(self) -> Dict[str, any]: return { - "block_tables": self.block_tables.tolist(), - "sequence_lengths": self.sequence_lengths.tolist(), + "block_tables": self.block_tables.tolist() + if isinstance(self.block_tables, torch.Tensor) + else self.block_tables, + "sequence_lengths": self.sequence_lengths.tolist() + if isinstance(self.block_tables, torch.Tensor) + else self.sequence_lengths, "batch_size": self.batch_size, "is_prompts": self.is_prompts, "use_cuda_kernel": self.use_cuda_kernel, @@ -113,10 +117,14 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": return InputMetaData( block_tables=torch.tensor( rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() - ), + ) + if isinstance(rpc_dict["block_tables"], list) + else rpc_dict["block_tables"], sequence_lengths=torch.tensor( rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() - ), + ) + if isinstance(rpc_dict["sequence_lengths"], list) + else rpc_dict["sequence_lengths"], batch_size=rpc_dict["batch_size"], is_prompts=rpc_dict["is_prompts"], use_cuda_kernel=rpc_dict["use_cuda_kernel"], diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 5085c55558b4..4041fc88a528 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -4,7 +4,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig -from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.batch_bucket import BatchBucket, RPCBatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager @@ -376,7 +376,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, # which may cause bugs and this issue should be fixed later. - self.running_bb = BatchBucket( + self.running_bb = RPCBatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, head_dim=head_dim, max_batch_size=self.max_batch_size, @@ -386,7 +386,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo fd_interm_tensor=None, dtype=self.dtype, ) - self.prefill_bb = BatchBucket( + self.prefill_bb = RPCBatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, head_dim=head_dim, max_batch_size=self.max_batch_size, diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 439c4b0b5fff..0b4cfe37ab9f 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -11,7 +11,7 @@ from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.configuration_utils import PretrainedConfig -from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.executor.rpc_worker import rpcWorkerService from colossalai.inference.utils import find_available_ports @@ -161,8 +161,16 @@ def init_workers(self): raise Exception("conn error!") self.logger.info(f"Build RPC Connection Success! Begin to load model...") asyncio.run(self.init_worker_env()) + self._init_worker_forward() self.logger.info(f"init dist env over") + def _init_worker_forward(self): + """ + Async wrappers for forward, because it will be invoked many times. + """ + assert len(self.workers) == self.tp_size, "init workers first" + self.worker_forwards = [rpyc.async_(worker.execute_model_forward) for worker in self.workers] + async def async_parallel_wrapper(self, f, *args, **kwargs): async_res = rpyc.async_(f)(*args, **kwargs) await asyncio.to_thread(async_res.wait) @@ -209,7 +217,8 @@ async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): asyncio.run(self._init_device_cache(alloc_shape)) - def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: + def prepare_input(self, batch: RPCBatchBucket) -> Tuple[List[int], InputMetaData]: + assert batch.is_rpc, "the batch must be RPCBatchBucket" input_ids = batch.get_1D_inputs() sequence_lengths = batch.get_sequence_lengths() @@ -219,7 +228,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: n_tokens = batch.current_batch_size if batch.use_spec_dec: n_tokens = batch.num_tokens_to_verify + 1 - assert n_tokens == input_ids.size(0) + assert n_tokens == len(input_ids) n_tokens = n_tokens * batch.current_batch_size batch_token_ids = None @@ -251,20 +260,38 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: batch_token_ids=batch_token_ids, ) - return input_ids.tolist(), input_meta_data + return input_ids, input_meta_data + + async def async_parallel_forward(self, async_f, *args, **kwargs): + async_res = async_f(*args, **kwargs) + await asyncio.to_thread(async_res.wait) + assert async_res.ready + return async_res.value async def step_(self, input_token_ids, input_meta_data: InputMetaData): assert len(self.workers) == self.tp_size, "init workers first" - init_tasks = [ - self.async_parallel_wrapper( - worker.execute_model_forward, - input_token_ids, - input_meta_data.to_rpc_param(), - self.generation_config_dict, - ) - for worker in self.workers - ] + init_tasks = [] + for rank, async_forward in enumerate(self.worker_forwards): + if rank == 0: + init_tasks.append( + self.async_parallel_forward( + async_forward, + input_token_ids, + input_meta_data.to_rpc_param(), + self.generation_config_dict, + ) + ) + else: + init_tasks.append( + self.async_parallel_forward( + async_forward, + None, + None, + None, + ) + ) + ret = await asyncio.gather(*init_tasks) return ret[0] @@ -277,7 +304,6 @@ def step(self) -> List[str]: next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) # update the request_handler - next_tokens = torch.tensor(next_tokens, dtype=torch.int) self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 913b8667dcf9..4ac6026da434 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import rpyc import torch @@ -18,7 +18,7 @@ model_policy_map, ) from colossalai.inference.sampler import search_tokens -from colossalai.inference.utils import get_model_size +from colossalai.inference.utils import Timer, get_model_size from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager @@ -51,6 +51,12 @@ class rpcWorkerService(rpyc.Service): def exposed_init_dist_env(self, rank, world_size, master_address, master_port): logger.info(f"init process group for rank {rank}") colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) + self.rank = rank + + # profiling only, remove later + self.t_prepare = Timer("[Timer] prepare the data") + self.t_exe = Timer("[Timer] execute the model forward") + self.t_sampler = Timer("[Timer] sampler time") logger.info(f"init process group done for rank {rank}") def exposed_init_model( @@ -98,38 +104,50 @@ def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...] logger.info("physical cache init over") def exposed_execute_model_forward( - self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict + self, + input_token_ids_param: Optional[List[int]] = None, + input_meta_data_param: Optional[dict] = None, + generation_config_param: Optional[dict] = None, ): # prepare the data for model forward - input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) - input_meta_data.fd_inter_tensor = self.fd_inter_tensor + with self.t_prepare: + input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( + input_token_ids_param=input_token_ids_param, + input_meta_data_param=input_meta_data_param, + generation_config_param=generation_config_param, + ) + if input_meta_data.is_prompts: n_tokens = input_meta_data.sequence_lengths.sum().item() else: n_tokens = input_meta_data.batch_size - input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) # execute the model - logits = self.model( - input_token_ids, - self.output_tensor[:n_tokens], - input_meta_data, - self.k_cache, - self.v_cache, - ) + with self.t_exe: + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) - # sampler - if self.inference_config.pad_input: - logits = logits[:, -1, :] - next_tokens = search_tokens( - generation_config_param, - logits, - input_meta_data.is_prompts, - input_meta_data.batch_token_ids, - ) + if self.rank == 0: + with self.t_sampler: + # sampler + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + generation_config, + logits, + input_meta_data.is_prompts, + input_meta_data.batch_token_ids, + ) - # return the tokens generated to scheduler - return next_tokens.tolist() + # return the tokens generated to scheduler + # only rank 0 need to pass the data back + # to reduce the overhead of rpc param passing + return next_tokens.cpu() def _init_output_tensor(self): alloc_shape = ( @@ -166,6 +184,84 @@ def _init_fd_tensor(self): self.fd_inter_tensor = fd_inter_tensor + def _broadcast_param_to_all_workers( + self, + input_token_ids_param: Optional[List[int]] = None, + input_meta_data_param: Optional[dict] = None, + generation_config_param: Optional[dict] = None, + ): + if self.rank == 0: + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) + input_meta_data.fd_inter_tensor = self.fd_inter_tensor + input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) + generation_config = generation_config_param + + if dist.get_world_size() > 1: + broadcast_list = {} + for k, v in input_meta_data_param.items(): + if not isinstance(v, List): + broadcast_list[k] = v + + # Pass the tensor shape and type in advance for + # other workers to prepare the empty tensor and async transport tensors + broadcast_list["block_tables"] = ( + input_meta_data.block_tables.size(), + input_meta_data.block_tables.dtype, + ) + broadcast_list["sequence_lengths"] = ( + input_meta_data.sequence_lengths.size(), + input_meta_data.sequence_lengths.dtype, + ) + broadcast_list["input_token_ids"] = (input_token_ids.size(), input_token_ids.dtype) + + # Generation Config Param + broadcast_list["generation_config"] = generation_config_param + + # send some meta data and some tensor shape + torch.distributed.broadcast_object_list([broadcast_list], src=self.rank) + + # send the real tensor + torch.distributed.broadcast(input_meta_data.block_tables, src=self.rank) + torch.distributed.broadcast(input_meta_data.sequence_lengths, src=self.rank) + torch.distributed.broadcast(input_token_ids, src=self.rank) + + else: + assert input_meta_data_param is None, "Input Must Be None" + + # recv the meta data + recv_list = [None] + torch.distributed.broadcast_object_list(recv_list, src=0) + input_meta_data_param = recv_list[0] + + generation_config = input_meta_data_param["generation_config"] + + blocktable_shape, blocktable_type = input_meta_data_param["block_tables"] + blocktables = torch.empty(blocktable_shape, dtype=blocktable_type, device=self.device) + sequence_lengths_shape, sequence_lengths_type = input_meta_data_param["sequence_lengths"] + sequence_lengths = torch.empty(sequence_lengths_shape, dtype=sequence_lengths_type, device=self.device) + input_token_ids_shape, input_token_ids_type = input_meta_data_param["input_token_ids"] + input_token_ids = torch.empty(input_token_ids_shape, dtype=input_token_ids_type, device=self.device) + + # recv the real tensor + async1 = torch.distributed.broadcast(blocktables, src=0, async_op=True) + async2 = torch.distributed.broadcast(sequence_lengths, src=0, async_op=True) + async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True) + + input_meta_data_param["sequence_lengths"] = sequence_lengths + input_meta_data_param["blocktables"] = blocktables + + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) + input_meta_data.fd_inter_tensor = self.fd_inter_tensor + + async1.wait() + async2.wait() + async3.wait() + + input_meta_data.block_tables = blocktables + input_meta_data.sequence_lengths = sequence_lengths + + return input_token_ids, input_meta_data, generation_config + def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): """ Shard model or/and Load weight @@ -304,3 +400,11 @@ def exposed_compute_only_for_test(self): logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") return data.item() + + def __del__(self): + """ + profiling only, remove later + """ + del self.t_prepare + del self.t_exe + del self.t_sampler diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 072bedec3587..aaf80181dad7 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -113,3 +113,70 @@ def find_available_ports(num: int): print(f"An OS error occurred: {e}") raise RuntimeError("Error finding available ports") return free_ports + + +""" +below just for profiling temporarily, will removed before merge +""" +import time +from contextlib import asynccontextmanager, contextmanager + + +@contextmanager +def timer(name=""): + # (@lry89757) will remove later + start_time = time.time() + try: + yield + finally: + end_time = time.time() + elapsed_time = end_time - start_time + print(f"{name} took {elapsed_time:.6f} seconds") + + +class Timer: + # (@lry89757) will remove later + def __init__(self, name=""): + print(f"init timer, {name}") + self.name = name + self.times = [] + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + end_time = time.time() + elapsed_time = end_time - self.start_time + self.times.append(elapsed_time) + print(f"{self.name} took {elapsed_time:.6f} seconds") + self.print_info() + + def print_info(self): + average_prefill_time = self.times[0] + print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds") + if len(self.times) > 1: + average_decoding_time = sum(self.times[1:]) / len(self.times[1:]) + print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds") + + def __del__(self): + if self.times: + average_prefill_time = self.times[0] + print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds") + if len(self.times) > 1: + average_decoding_time = sum(self.times[1:]) / len(self.times[1:]) + print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds") + else: + print(f"{self.name} no timings recorded") + + +@asynccontextmanager +async def async_timer(name=""): + # (@lry89757) will remove later + start_time = time.time() + try: + yield + finally: + end_time = time.time() + elapsed_time = end_time - start_time + print(f"{name} took {elapsed_time:.6f} seconds") From 509f3a62ab9db5b0b31e2fa1b42081f198740f75 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Wed, 5 Jun 2024 04:02:22 +0000 Subject: [PATCH 2/3] tmp save for profiling --- colossalai/inference/batch_bucket.py | 1 + colossalai/inference/config.py | 27 +++--- colossalai/inference/core/engine.py | 85 +++++++++++++++---- colossalai/inference/core/rpc_engine.py | 67 +++++++++++++-- colossalai/inference/executor/rpc_worker.py | 93 ++++++++++++++------- colossalai/inference/utils.py | 4 +- 6 files changed, 206 insertions(+), 71 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index a344a9579cde..f9589091af64 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -526,6 +526,7 @@ def __repr__(self) -> str: class RPCBatchBucket(BatchBucket): def __init__(self, *args, **argv): self.is_rpc = True + self.device = "cpu" super().__init__(*args, **argv) # For compatibility diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index c5f7b61ce786..5d819a22efb5 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -87,12 +87,14 @@ class InputMetaData(RPC_PARAM): def to_rpc_param(self) -> Dict[str, any]: return { - "block_tables": self.block_tables.tolist() - if isinstance(self.block_tables, torch.Tensor) - else self.block_tables, - "sequence_lengths": self.sequence_lengths.tolist() - if isinstance(self.block_tables, torch.Tensor) - else self.sequence_lengths, + "block_tables": self.block_tables, + # "block_tables": self.block_tables.tolist() + # if isinstance(self.block_tables, torch.Tensor) + # else self.block_tables, + "sequence_lengths": self.sequence_lengths, + # "sequence_lengths": self.sequence_lengths.tolist() + # if isinstance(self.block_tables, torch.Tensor) + # else self.sequence_lengths, "batch_size": self.batch_size, "is_prompts": self.is_prompts, "use_cuda_kernel": self.use_cuda_kernel, @@ -114,17 +116,14 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": from colossalai.accelerator import get_accelerator dtype = getattr(torch, rpc_dict["dtype"]) + device = get_accelerator().get_current_device() return InputMetaData( - block_tables=torch.tensor( - rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() - ) + block_tables=torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device) if isinstance(rpc_dict["block_tables"], list) - else rpc_dict["block_tables"], - sequence_lengths=torch.tensor( - rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() - ) + else rpc_dict["block_tables"].to(device), + sequence_lengths=torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device) if isinstance(rpc_dict["sequence_lengths"], list) - else rpc_dict["sequence_lengths"], + else rpc_dict["sequence_lengths"].to(device), batch_size=rpc_dict["batch_size"], is_prompts=rpc_dict["is_prompts"], use_cuda_kernel=rpc_dict["use_cuda_kernel"], diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 96c2b15ee16e..6f5f5ba6470d 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,4 +1,5 @@ import time +from contextlib import nullcontext from itertools import count from typing import Dict, List, Optional, Tuple, Type, Union @@ -24,7 +25,7 @@ from colossalai.inference.sampler import search_tokens from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size +from colossalai.inference.utils import Timer, get_model_size from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager @@ -103,6 +104,30 @@ def __init__( self.use_glide = False self.n_spec_tokens = self.inference_config.max_n_spec_tokens + # profiling only, remove later + self.timing = False + + self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext() + self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext() + self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext() + + self.profiling = False + self.profiler = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + # schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode), + ) + if self.profiling + else nullcontext() + ) + self._verify_args() def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): @@ -517,6 +542,7 @@ def generate( prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, + step_list: Optional[List[int]] = None, ) -> List[str]: """ Executing the inference step. @@ -559,7 +585,11 @@ def generate( output_seqs_list += self.steps_spec_dec() else: while self.request_handler.check_unfinished_seqs(): + a = time.perf_counter() output_seqs_list += self.step() + b = time.perf_counter() + if isinstance(step_list, list): + step_list.append(b - a) output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) @@ -574,6 +604,19 @@ def generate( else: return output_str + def __del__(self): + if self.timing: + del self.t_prepare + del self.t_exe + del self.t_sampler + self.record() + + def record(self): + if self.profiling: + file = "/home/lurunyu/projects/ColossalAI/test_trace_non_rpc.json" + self.profiler.export_chrome_trace(file) + self.logger.info(f"trace has been saved into {file}") + @property def has_prompt_template(self) -> bool: """ """ @@ -741,23 +784,31 @@ def step(self) -> List[str]: List[str]: Decoded finished sequences generated by one step. """ - batch = self.request_handler.schedule() + with self.profiler: + with self.t_prepare: + batch = self.request_handler.schedule() - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] - else: - model_executable = self.model - - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - if self.inference_config.pad_input: - logits = logits[:, -1, :] - next_tokens = search_tokens( - self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids - ) - self.request_handler.append_next_tokens(next_tokens) - finished_sequences = self.request_handler.update() + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + with self.t_exe: + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + + with self.t_sampler: + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + self.generation_config, + logits, + input_meta_data.is_prompts, + batch_token_ids=input_meta_data.batch_token_ids, + ) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 0b4cfe37ab9f..36a1ee394946 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -1,4 +1,7 @@ import asyncio +import concurrent +import pickle +from contextlib import nullcontext from itertools import count from time import sleep from typing import List, Tuple, Union @@ -14,7 +17,7 @@ from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.executor.rpc_worker import rpcWorkerService -from colossalai.inference.utils import find_available_ports +from colossalai.inference.utils import Timer, find_available_ports from colossalai.logging import get_dist_logger from colossalai.shardformer.policies.base_policy import Policy @@ -119,8 +122,21 @@ def __init__( self.counter = count() self._verify_args() + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.timer = False + self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext() + self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext() + # self.t_sampler = Timer("[Timer] sampler time") + self.logger.info("engine init over ") + def __del__(self): + if self.timer: + del self.t_prepare + del self.t_exe + def _verify_args(self) -> None: """Verify the input args""" if not isinstance(self.inference_config, InferenceConfig): @@ -268,7 +284,7 @@ async def async_parallel_forward(self, async_f, *args, **kwargs): assert async_res.ready return async_res.value - async def step_(self, input_token_ids, input_meta_data: InputMetaData): + async def step_async(self, input_token_ids, input_meta_data: InputMetaData): assert len(self.workers) == self.tp_size, "init workers first" init_tasks = [] @@ -277,9 +293,9 @@ async def step_(self, input_token_ids, input_meta_data: InputMetaData): init_tasks.append( self.async_parallel_forward( async_forward, - input_token_ids, - input_meta_data.to_rpc_param(), - self.generation_config_dict, + pickle.dumps(input_token_ids), + pickle.dumps(input_meta_data.to_rpc_param()), + pickle.dumps(self.generation_config_dict), ) ) else: @@ -296,12 +312,45 @@ async def step_(self, input_token_ids, input_meta_data: InputMetaData): return ret[0] + def step_(self, input_token_ids, input_meta_data: InputMetaData): + assert len(self.workers) == self.tp_size, "init workers first" + init_tasks = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor: + for rank, worker in enumerate(self.workers): + if rank == 0: + init_tasks.append( + executor.submit( + worker.execute_model_forward, + pickle.dumps(input_token_ids), + pickle.dumps(input_meta_data.to_rpc_param()), + pickle.dumps(self.generation_config_dict), + ) + ) + else: + init_tasks.append( + executor.submit( + worker.execute_model_forward, + None, + None, + None, + ) + ) + + concurrent.futures.wait(init_tasks) + results = [future.result() for future in init_tasks] + return results[0] + def step(self) -> List[str]: - batch = self.request_handler.schedule() + with self.t_prepare: + batch = self.request_handler.schedule() + + input_token_ids, input_meta_data = self.prepare_input(batch) - input_token_ids, input_meta_data = self.prepare_input(batch) - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) + with self.t_exe: + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data)) + # with self.t_exe: + # next_tokens = self.step_(input_token_ids, input_meta_data) # update the request_handler self.request_handler.append_next_tokens(next_tokens) diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 4ac6026da434..a85801726376 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -1,3 +1,5 @@ +import pickle +from contextlib import nullcontext from typing import List, Optional, Tuple, Union import rpyc @@ -54,9 +56,29 @@ def exposed_init_dist_env(self, rank, world_size, master_address, master_port): self.rank = rank # profiling only, remove later - self.t_prepare = Timer("[Timer] prepare the data") - self.t_exe = Timer("[Timer] execute the model forward") - self.t_sampler = Timer("[Timer] sampler time") + self.timing = False + + self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext() + self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext() + self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext() + + self.profiling = False + self.profiler = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + # schedule=torch.profiler.schedule(wait=0, repeat=1, active=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode), + ) + if self.profiling + else nullcontext() + ) + logger.info(f"init process group done for rank {rank}") def exposed_init_model( @@ -109,28 +131,34 @@ def exposed_execute_model_forward( input_meta_data_param: Optional[dict] = None, generation_config_param: Optional[dict] = None, ): - # prepare the data for model forward - with self.t_prepare: - input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( - input_token_ids_param=input_token_ids_param, - input_meta_data_param=input_meta_data_param, - generation_config_param=generation_config_param, - ) + with self.profiler: + # prepare the data for model forward + with self.t_prepare: + input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( + input_token_ids_param=input_token_ids_param, + input_meta_data_param=input_meta_data_param, + generation_config_param=generation_config_param, + ) - if input_meta_data.is_prompts: - n_tokens = input_meta_data.sequence_lengths.sum().item() - else: - n_tokens = input_meta_data.batch_size - - # execute the model - with self.t_exe: - logits = self.model( - input_token_ids, - self.output_tensor[:n_tokens], - input_meta_data, - self.k_cache, - self.v_cache, - ) + if input_meta_data.is_prompts: + n_tokens = input_meta_data.sequence_lengths.sum().item() + else: + n_tokens = input_meta_data.batch_size + + # execute the model + with self.t_exe: + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) + + if self.profiling: + self.profiler.step() + + self.record() if self.rank == 0: with self.t_sampler: @@ -191,6 +219,10 @@ def _broadcast_param_to_all_workers( generation_config_param: Optional[dict] = None, ): if self.rank == 0: + input_token_ids_param = pickle.loads(input_token_ids_param) + input_meta_data_param = pickle.loads(input_meta_data_param) + generation_config_param = pickle.loads(generation_config_param) + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data.fd_inter_tensor = self.fd_inter_tensor input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) @@ -199,7 +231,7 @@ def _broadcast_param_to_all_workers( if dist.get_world_size() > 1: broadcast_list = {} for k, v in input_meta_data_param.items(): - if not isinstance(v, List): + if not isinstance(v, torch.Tensor): broadcast_list[k] = v # Pass the tensor shape and type in advance for @@ -248,7 +280,7 @@ def _broadcast_param_to_all_workers( async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True) input_meta_data_param["sequence_lengths"] = sequence_lengths - input_meta_data_param["blocktables"] = blocktables + input_meta_data_param["block_tables"] = blocktables input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data.fd_inter_tensor = self.fd_inter_tensor @@ -257,9 +289,6 @@ def _broadcast_param_to_all_workers( async2.wait() async3.wait() - input_meta_data.block_tables = blocktables - input_meta_data.sequence_lengths = sequence_lengths - return input_token_ids, input_meta_data, generation_config def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): @@ -408,3 +437,9 @@ def __del__(self): del self.t_prepare del self.t_exe del self.t_sampler + + def record(self): + if self.profiling: + file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json" + self.profiler.export_chrome_trace(file) + logger.info(f"trace has been saved into {file}") diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index aaf80181dad7..a241fd98c49e 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -149,8 +149,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): end_time = time.time() elapsed_time = end_time - self.start_time self.times.append(elapsed_time) - print(f"{self.name} took {elapsed_time:.6f} seconds") - self.print_info() + # print(f"{self.name} took {elapsed_time:.6f} seconds") + # self.print_info() def print_info(self): average_prefill_time = self.times[0] From 01ca9b813308e22d27ba87323f744a80c5668575 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Tue, 30 Jul 2024 07:48:39 +0000 Subject: [PATCH 3/3] remove timer --- colossalai/inference/core/rpc_engine.py | 46 +------------- colossalai/inference/executor/rpc_worker.py | 60 +++++++----------- colossalai/inference/utils.py | 67 --------------------- 3 files changed, 23 insertions(+), 150 deletions(-) diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 1fb27e6c85cd..4677418a350e 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -1,7 +1,5 @@ import asyncio -import concurrent import pickle -from contextlib import nullcontext from itertools import count from time import sleep from typing import List, Tuple, Union @@ -17,7 +15,7 @@ from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.executor.rpc_worker import rpcWorkerService -from colossalai.inference.utils import Timer, find_available_ports +from colossalai.inference.utils import find_available_ports from colossalai.logging import get_dist_logger from colossalai.shardformer.policies.base_policy import Policy @@ -126,18 +124,8 @@ def __init__( self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.timer = False - self.t_prepare = Timer("[Timer] prepare the data 2") if self.timer else nullcontext() - self.t_exe = Timer("[Timer] execute rpc worker") if self.timer else nullcontext() - # self.t_sampler = Timer("[Timer] sampler time") - self.logger.info("engine init over ") - def __del__(self): - if self.timer: - del self.t_prepare - del self.t_exe - def _verify_args(self) -> None: """Verify the input args""" if not isinstance(self.inference_config, InferenceConfig): @@ -313,34 +301,6 @@ async def step_async(self, input_token_ids, input_meta_data: InputMetaData): return ret[0] - def step_(self, input_token_ids, input_meta_data: InputMetaData): - assert len(self.workers) == self.tp_size, "init workers first" - init_tasks = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.workers)) as executor: - for rank, worker in enumerate(self.workers): - if rank == 0: - init_tasks.append( - executor.submit( - worker.execute_model_forward, - pickle.dumps(input_token_ids), - pickle.dumps(input_meta_data.to_rpc_param()), - pickle.dumps(self.generation_config_dict), - ) - ) - else: - init_tasks.append( - executor.submit( - worker.execute_model_forward, - None, - None, - None, - ) - ) - - concurrent.futures.wait(init_tasks) - results = [future.result() for future in init_tasks] - return results[0] - def step(self) -> List[str]: with self.t_prepare: batch = self.request_handler.schedule() @@ -350,8 +310,6 @@ def step(self) -> List[str]: with self.t_exe: # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data)) - # with self.t_exe: - # next_tokens = self.step_(input_token_ids, input_meta_data) # update the request_handler self.request_handler.append_next_tokens(next_tokens) @@ -360,7 +318,7 @@ def step(self) -> List[str]: def kill_workers(self): """ - I don't find a good way to implicit invoke self.kill_workers + NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers """ assert len(self.workers) != 0 for proc in self.worker_processes: diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 4e84ec8f0bcd..85f5758aebb1 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -55,13 +55,6 @@ def exposed_init_dist_env(self, rank, world_size, master_address, master_port): colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) self.rank = rank - # profiling only, remove later - self.timing = False - - self.t_prepare = Timer("[Timer] prepare the data 1") if self.timing else nullcontext() - self.t_exe = Timer("[Timer] execute the model forward") if self.timing else nullcontext() - self.t_sampler = Timer("[Timer] sampler time") if self.timing else nullcontext() - self.profiling = False self.profiler = ( torch.profiler.profile( @@ -133,12 +126,11 @@ def exposed_execute_model_forward( ): with self.profiler: # prepare the data for model forward - with self.t_prepare: - input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( - input_token_ids_param=input_token_ids_param, - input_meta_data_param=input_meta_data_param, - generation_config_param=generation_config_param, - ) + input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( + input_token_ids_param=input_token_ids_param, + input_meta_data_param=input_meta_data_param, + generation_config_param=generation_config_param, + ) if input_meta_data.is_prompts: n_tokens = input_meta_data.sequence_lengths.sum().item() @@ -146,14 +138,13 @@ def exposed_execute_model_forward( n_tokens = input_meta_data.batch_size # execute the model - with self.t_exe: - logits = self.model( - input_token_ids, - self.output_tensor[:n_tokens], - input_meta_data, - self.k_cache, - self.v_cache, - ) + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) if self.profiling: self.profiler.step() @@ -161,16 +152,15 @@ def exposed_execute_model_forward( self.record() if self.rank == 0: - with self.t_sampler: - # sampler - if self.inference_config.pad_input: - logits = logits[:, -1, :] - next_tokens = search_tokens( - generation_config, - logits, - input_meta_data.is_prompts, - input_meta_data.batch_token_ids, - ) + # sampler + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + generation_config, + logits, + input_meta_data.is_prompts, + input_meta_data.batch_token_ids, + ) # return the tokens generated to scheduler # only rank 0 need to pass the data back @@ -432,14 +422,6 @@ def exposed_compute_only_for_test(self): return data.item() - def __del__(self): - """ - profiling only, remove later - """ - del self.t_prepare - del self.t_exe - del self.t_sampler - def record(self): if self.profiling: file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json" diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index c7ff9a6a0437..d0851e362318 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -194,70 +194,3 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): """ else: return ModelType.UNKNOWN - - -""" -below just for profiling temporarily, will removed before merge -""" -import time -from contextlib import asynccontextmanager, contextmanager - - -@contextmanager -def timer(name=""): - # (@lry89757) will remove later - start_time = time.time() - try: - yield - finally: - end_time = time.time() - elapsed_time = end_time - start_time - print(f"{name} took {elapsed_time:.6f} seconds") - - -class Timer: - # (@lry89757) will remove later - def __init__(self, name=""): - print(f"init timer, {name}") - self.name = name - self.times = [] - - def __enter__(self): - self.start_time = time.time() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - end_time = time.time() - elapsed_time = end_time - self.start_time - self.times.append(elapsed_time) - # print(f"{self.name} took {elapsed_time:.6f} seconds") - # self.print_info() - - def print_info(self): - average_prefill_time = self.times[0] - print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds") - if len(self.times) > 1: - average_decoding_time = sum(self.times[1:]) / len(self.times[1:]) - print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds") - - def __del__(self): - if self.times: - average_prefill_time = self.times[0] - print(f"{self.name} prefill average time: {average_prefill_time:.6f} seconds") - if len(self.times) > 1: - average_decoding_time = sum(self.times[1:]) / len(self.times[1:]) - print(f"{self.name} decoding average time: {average_decoding_time:.6f} seconds") - else: - print(f"{self.name} no timings recorded") - - -@asynccontextmanager -async def async_timer(name=""): - # (@lry89757) will remove later - start_time = time.time() - try: - yield - finally: - end_time = time.time() - elapsed_time = end_time - start_time - print(f"{name} took {elapsed_time:.6f} seconds")