Skip to content

Commit

Permalink
Update model support (#429)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Jul 25, 2024
1 parent ff5e2fc commit bdf6bf2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 48 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ Under-the-hood MII is powered by [DeepSpeed-Inference](https://github.com/micros

# Supported Models

MII currently supports over 20,000 models across eight popular model architectures. We plan to add additional models in the near term, if there are specific model architectures you would like supported please [file an issue](https://github.com/microsoft/DeepSpeed-MII/issues) and let us know. All current models leverage Hugging Face in our backend to provide both the model weights and the model's corresponding tokenizer. For our current release we support the following model architectures:
MII currently supports over 37,000 models across eight popular model architectures. We plan to add additional models in the near term, if there are specific model architectures you would like supported please [file an issue](https://github.com/microsoft/DeepSpeed-MII/issues) and let us know. All current models leverage Hugging Face in our backend to provide both the model weights and the model's corresponding tokenizer. For our current release we support the following model architectures:

model family | size range | ~model count
------ | ------ | ------
[falcon](https://huggingface.co/models?other=falcon) | 7B - 180B | 300
[llama](https://huggingface.co/models?other=llama) | 7B - 65B | 19,000
[llama-2](https://huggingface.co/models?other=llama-2) | 7B - 70B | 900
[mistral](https://huggingface.co/models?other=mistral) | 7B | 6,000
[mixtral (MoE)](https://huggingface.co/models?other=mixtral) | 8x7B | 1,100
[opt](https://huggingface.co/models?other=opt) | 0.1B - 66B | 1,300
[phi-2](https://huggingface.co/models?other=phi) | 2.7B | 200
[qwen](https://huggingface.co/models?other=qwen) | 7B - 72B | 200
[falcon](https://huggingface.co/models?other=falcon) | 7B - 180B | 500
[llama](https://huggingface.co/models?other=llama) | 7B - 65B | 52,000
[llama-2](https://huggingface.co/models?other=llama-2) | 7B - 70B | 1,200
[mistral](https://huggingface.co/models?other=mistral) | 7B | 23,000
[mixtral (MoE)](https://huggingface.co/models?other=mixtral) | 8x7B | 2,900
[opt](https://huggingface.co/models?other=opt) | 0.1B - 66B | 2,100
[phi-2](https://huggingface.co/models?other=phi) | 2.7B | 1,500
[qwen](https://huggingface.co/models?other=qwen) | 7B - 72B | 500

## MII Legacy Model Support

Expand Down
3 changes: 3 additions & 0 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
clang-format==16.0.2
einops
pre-commit>=2.20.0
pytest
pytest-forked
sentencepiece
tiktoken
transformers-stream-generator
63 changes: 24 additions & 39 deletions tests/test_model_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,16 @@
CheckpointEngineBase,
HuggingFaceCheckpointEngine,
)
from transformers import AutoConfig, AutoModel, GenerationConfig
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
from typing import Iterable, Tuple


class RandomWeightsCheckpointEngine(CheckpointEngineBase):

# When using AutoModel.from_config() to load the model, the layer names are
# often missing a prefix. We default to adding "model." as the prefix, but
# others can be specified here.
layer_prefix_map = {"falcon": "transformer."}

# When using AutoModel.from_config() to load the model, the lm_head layer is
# not generated. We default to populating this with the
# "embed_tokens.weight" layer, but others can be specified here.
lm_head_layer_map = {"falcon": "word_embeddings.weight"}

class ZeroWeightsCheckpointEngine(CheckpointEngineBase):
""" Generates weight with all zeros for a given model for testing purposes. """
def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
self.model_name_or_path = model_name_or_path
self.model_config = AutoConfig.from_pretrained(self.model_name_or_path)
self.model_config = AutoConfig.from_pretrained(self.model_name_or_path,
trust_remote_code=True)
if hasattr(self.model_config, "max_position_embeddings"):
self.model_config.max_seq_length = self.model_config.max_position_embeddings
else:
Expand All @@ -40,37 +31,21 @@ def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
except OSError:
self.model_config.max_seq_length = 2048

def _get_layer_prefix(self) -> str:
for model_type, prefix in self.layer_prefix_map.items():
if model_type in self.model_name_or_path.lower():
return prefix
return "model."

def _get_lm_head_layer(self) -> str:
for model_type, layer in self.lm_head_layer_map.items():
if model_type in self.model_name_or_path.lower():
return layer
return "embed_tokens.weight"

def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
layer_prefix = self._get_layer_prefix()
lm_head_layer = self._get_lm_head_layer()

# Load with meta device is faster
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
model = AutoModel.from_config(self.model_config)
model = AutoModelForCausalLM.from_config(self.model_config,
trust_remote_code=True)

for param_name, param in model.state_dict().items():
yield layer_prefix + param_name, torch.zeros(param.shape)
if param_name == lm_head_layer:
yield "lm_head.weight", torch.zeros(param.shape)
yield param_name, torch.zeros(param.shape)


@pytest.fixture(scope="module", autouse=True)
def inject_checkpoint_engine():
# Inject the random weihts checkpoint engine
deepspeed.inference.v2.engine_factory.HuggingFaceCheckpointEngine = (
RandomWeightsCheckpointEngine)
ZeroWeightsCheckpointEngine)
yield None
# Restore the original checkpoint engine
deepspeed.inference.v2.engine_factory.HuggingFaceCheckpointEngine = (
Expand All @@ -81,16 +56,26 @@ def inject_checkpoint_engine():
"model_name",
[
"tiiuae/falcon-7b",
"huggyllama/llama-7b",
"NousResearch/Llama-2-7b-hf",
"NousResearch/Hermes-2-Pro-Mistral-7B",
"cloudyu/Mixtral_11Bx2_MoE_19B",
"facebook/opt-125m",
"microsoft/phi-2",
"Qwen/Qwen-7B-Chat",
"Qwen/Qwen1.5-0.5B",
],
ids=[
"falcon",
"llama",
"llama-2",
"mistral",
"mixtral",
"opt",
"phi-2",
"qwen",
"qwen-2"
],
ids=["falcon",
"llama",
"mistral",
"mixtral",
"opt"],
)
def test_model(pipeline, query):
outputs = pipeline(query, max_new_tokens=16)
Expand Down

0 comments on commit bdf6bf2

Please sign in to comment.