Skip to content

Commit

Permalink
solid: add option to enable selective_batching
Browse files Browse the repository at this point in the history
  • Loading branch information
gc-fu committed Dec 21, 2023
1 parent 861c072 commit 29b2e60
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .utils import logger
from typing import Union
import numpy as np
import os
from bigdl.llm.utils.common import invalidInputError


Expand Down Expand Up @@ -375,8 +376,6 @@ def _optimize_post(model, lightweight_bmm=False):
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from bigdl.llm.transformers.models.llama import llama_mlp_forward
from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
from transformers.modeling_utils import PreTrainedModel

# All huggingface format models are inherited from `PreTrainedModel`
Expand All @@ -385,24 +384,36 @@ def _optimize_post(model, lightweight_bmm=False):
"supported for further optimizations")
return model

enable_vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
enable_vllm_selective_batching = True if enable_vllm_selective_batching is not None \
and enable_vllm_selective_batching.lower()=="true" \
else False
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"):
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31,)
llama_attention_forward_4_31,)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaRMSNorm,
llama_rms_norm_forward,)
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
if enable_vllm_selective_batching:
from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31,
)
else:
# todo implement 4.28.0 ~ 4.30.2
pass
Expand Down

0 comments on commit 29b2e60

Please sign in to comment.