From e144e265e869407fa78976e5a48f8af7396d1ea6 Mon Sep 17 00:00:00 2001 From: Xuye Qin Date: Tue, 10 Dec 2024 20:35:12 +0800 Subject: [PATCH] BUG: use stream_generate in MLX (#2635) --- xinference/model/llm/llm_family.json | 96 +++++++++---------- .../model/llm/llm_family_modelscope.json | 34 +++---- xinference/model/llm/mlx/core.py | 33 +++---- xinference/model/llm/mlx/tests/test_mlx.py | 2 +- 4 files changed, 80 insertions(+), 85 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 9f53f45ae8..4003d46be5 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -952,7 +952,7 @@ "model_format": "mlx", "model_size_in_billions": 8, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Meta-Llama-3-8B-Instruct-4bit" }, @@ -960,7 +960,7 @@ "model_format": "mlx", "model_size_in_billions": 8, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Meta-Llama-3-8B-Instruct-8bit" }, @@ -976,7 +976,7 @@ "model_format": "mlx", "model_size_in_billions": 70, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Meta-Llama-3-70B-Instruct-4bit-mlx" }, @@ -984,7 +984,7 @@ "model_format": "mlx", "model_size_in_billions": 70, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Meta-Llama-3-70B-Instruct-8bit" }, @@ -1229,7 +1229,7 @@ "model_format": "mlx", "model_size_in_billions": 8, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit" }, @@ -1237,7 +1237,7 @@ "model_format": "mlx", "model_size_in_billions": 8, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit" }, @@ -1253,7 +1253,7 @@ "model_format": "mlx", "model_size_in_billions": 70, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit" }, @@ -1261,7 +1261,7 @@ "model_format": "mlx", "model_size_in_billions": 70, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Meta-Llama-3.1-70B-Instruct-8bit" }, @@ -2199,7 +2199,7 @@ "model_format": "mlx", "model_size_in_billions": "0_5", "quantizations": [ - "4-bit" + "4bit" ], "model_id": "Qwen/Qwen2-0.5B-Instruct-MLX" }, @@ -2207,7 +2207,7 @@ "model_format": "mlx", "model_size_in_billions": "1_5", "quantizations": [ - "4-bit" + "4bit" ], "model_id": "Qwen/Qwen2-1.5B-Instruct-MLX" }, @@ -2215,7 +2215,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "Qwen/Qwen2-7B-Instruct-MLX" }, @@ -2223,7 +2223,7 @@ "model_format": "mlx", "model_size_in_billions": 72, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2-72B-Instruct-4bit" }, @@ -3222,7 +3222,7 @@ "model_format": "mlx", "model_size_in_billions": 12, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Mistral-Nemo-Instruct-2407-4bit" }, @@ -3230,7 +3230,7 @@ "model_format": "mlx", "model_size_in_billions": 12, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Mistral-Nemo-Instruct-2407-8bit" } @@ -3370,7 +3370,7 @@ "model_format": "mlx", "model_size_in_billions": 123, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Mistral-Large-Instruct-2407-4bit" }, @@ -3378,7 +3378,7 @@ "model_format": "mlx", "model_size_in_billions": 123, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Mistral-Large-Instruct-2407-8bit" } @@ -3436,7 +3436,7 @@ "model_format": "mlx", "model_size_in_billions": 22, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Codestral-22B-v0.1-4bit", "model_revision": "544626b38eb1c9524f0fa570ec7b29550c26b78d" @@ -3445,7 +3445,7 @@ "model_format": "mlx", "model_size_in_billions": 22, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Codestral-22B-v0.1-8bit", "model_revision": "0399a53970663950d57010e61a2796af524a1588" @@ -4170,7 +4170,7 @@ "model_format": "mlx", "model_size_in_billions": 6, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Yi-1.5-6B-Chat-4bit", "model_revision": "0177c9a12b869d6bc73f772b5a1981a7c966adb6" @@ -4179,7 +4179,7 @@ "model_format": "mlx", "model_size_in_billions": 6, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Yi-1.5-6B-Chat-8bit", "model_revision": "7756e65d1bf1e2e6e97aef6bc9484307225f536b" @@ -4188,7 +4188,7 @@ "model_format": "mlx", "model_size_in_billions": 9, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Yi-1.5-9B-Chat-4bit", "model_revision": "e15f886479c44e7d90f0ac13ace69b2319b71c2f" @@ -4197,7 +4197,7 @@ "model_format": "mlx", "model_size_in_billions": 9, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Yi-1.5-9B-Chat-8bit", "model_revision": "c1f742fcf3683edbe2d2c2fd1ad7ac2bb6c5ca36" @@ -4206,7 +4206,7 @@ "model_format": "mlx", "model_size_in_billions": 34, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Yi-1.5-34B-Chat-4bit", "model_revision": "945e3b306ef37c46ab444fdc857d1f3ea7247374" @@ -4215,7 +4215,7 @@ "model_format": "mlx", "model_size_in_billions": 34, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Yi-1.5-34B-Chat-8bit", "model_revision": "3c12761a2c6663f216caab6dff84b0dd29b472ac" @@ -5266,7 +5266,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/internlm2_5-7b-chat-4bit", "model_revision": "d12097a867721978142a6048399f470a3d18beee" @@ -5275,7 +5275,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/internlm2_5-7b-chat-8bit", "model_revision": "0ec94d61d30ab161b49c69f9bf92ec2b9986d234" @@ -5803,7 +5803,7 @@ "model_format": "mlx", "model_size_in_billions": 2, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/gemma-2-2b-it-4bit" }, @@ -5811,7 +5811,7 @@ "model_format": "mlx", "model_size_in_billions": 2, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/gemma-2-2b-it-8bit" }, @@ -5827,7 +5827,7 @@ "model_format": "mlx", "model_size_in_billions": 9, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/gemma-2-9b-it-4bit" }, @@ -5835,7 +5835,7 @@ "model_format": "mlx", "model_size_in_billions": 9, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/gemma-2-9b-it-8bit" }, @@ -5851,7 +5851,7 @@ "model_format": "mlx", "model_size_in_billions": 27, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/gemma-2-27b-it-4bit" }, @@ -5859,7 +5859,7 @@ "model_format": "mlx", "model_size_in_billions": 27, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/gemma-2-27b-it-8bit" }, @@ -8015,7 +8015,7 @@ "model_format": "mlx", "model_size_in_billions": "0_5", "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-0.5B-Instruct-4bit" }, @@ -8023,7 +8023,7 @@ "model_format": "mlx", "model_size_in_billions": "0_5", "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-0.5B-Instruct-8bit" }, @@ -8039,7 +8039,7 @@ "model_format": "mlx", "model_size_in_billions": "1_5", "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-1.5B-Instruct-4bit" }, @@ -8047,7 +8047,7 @@ "model_format": "mlx", "model_size_in_billions": "1_5", "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-1.5B-Instruct-8bit" }, @@ -8063,7 +8063,7 @@ "model_format": "mlx", "model_size_in_billions": 3, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-3B-Instruct-4bit" }, @@ -8071,7 +8071,7 @@ "model_format": "mlx", "model_size_in_billions": 3, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-3B-Instruct-8bit" }, @@ -8087,7 +8087,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-7B-Instruct-4bit" }, @@ -8095,7 +8095,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-7B-Instruct-8bit" }, @@ -8111,7 +8111,7 @@ "model_format": "mlx", "model_size_in_billions": 14, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-14B-Instruct-4bit" }, @@ -8119,7 +8119,7 @@ "model_format": "mlx", "model_size_in_billions": 14, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-14B-Instruct-8bit" }, @@ -8135,7 +8135,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-32B-Instruct-4bit" }, @@ -8143,7 +8143,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-32B-Instruct-8bit" }, @@ -8159,7 +8159,7 @@ "model_format": "mlx", "model_size_in_billions": 72, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen2.5-72B-Instruct-4bit" }, @@ -8167,7 +8167,7 @@ "model_format": "mlx", "model_size_in_billions": 72, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen2.5-72B-Instruct-8bit" }, @@ -8564,7 +8564,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "mlx-community/Qwen_QwQ-32B-Preview_MLX-4bit" }, @@ -8572,7 +8572,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "mlx-community/Qwen_QwQ-32B-Preview_MLX-8bit" }, diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 40e9e8a9e8..f4438263e6 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2837,7 +2837,7 @@ "model_format": "mlx", "model_size_in_billions": "0_5", "quantizations": [ - "4-bit" + "4bit" ], "model_id": "qwen/Qwen2-0.5B-Instruct-MLX", "model_hub": "modelscope" @@ -2846,7 +2846,7 @@ "model_format": "mlx", "model_size_in_billions": "1_5", "quantizations": [ - "4-bit" + "4bit" ], "model_id": "qwen/Qwen2-1.5B-Instruct-MLX", "model_hub": "modelscope" @@ -2855,7 +2855,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "qwen/Qwen2-7B-Instruct-MLX", "model_hub": "modelscope" @@ -5777,7 +5777,7 @@ "model_format": "mlx", "model_size_in_billions": 3, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "okwinds/Qwen2.5-3B-Instruct-MLX-4bit", "model_hub": "modelscope" @@ -5786,7 +5786,7 @@ "model_format": "mlx", "model_size_in_billions": 3, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "okwinds/Qwen2.5-3B-Instruct-MLX-8bit", "model_hub": "modelscope" @@ -5795,7 +5795,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "okwinds/Qwen2.5-7B-Instruct-MLX-4bit", "model_hub": "modelscope" @@ -5804,7 +5804,7 @@ "model_format": "mlx", "model_size_in_billions": 7, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "okwinds/Qwen2.5-7B-Instruct-MLX-8bit", "model_hub": "modelscope" @@ -5813,7 +5813,7 @@ "model_format": "mlx", "model_size_in_billions": 14, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "okwinds/Qwen2.5-14B-Instruct-MLX-4bit", "model_hub": "modelscope" @@ -5822,7 +5822,7 @@ "model_format": "mlx", "model_size_in_billions": 14, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "okwinds/Qwen2.5-14B-Instruct-MLX-8bit", "model_hub": "modelscope" @@ -5831,7 +5831,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "2-bit" + "2bit" ], "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-2bit", "model_hub": "modelscope" @@ -5840,7 +5840,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-4bit", "model_hub": "modelscope" @@ -5849,7 +5849,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-8bit", "model_hub": "modelscope" @@ -5858,7 +5858,7 @@ "model_format": "mlx", "model_size_in_billions": 72, "quantizations": [ - "2-bit" + "2bit" ], "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-2bit", "model_hub": "modelscope" @@ -5867,7 +5867,7 @@ "model_format": "mlx", "model_size_in_billions": 72, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "okwinds/Qwen2.5-72B-Instruct-MLX-4bit", "model_hub": "modelscope" @@ -5876,7 +5876,7 @@ "model_format": "mlx", "model_size_in_billions": 72, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "okwinds/Qwen2.5-72B-Instruct-MLX-8bit", "model_hub": "modelscope" @@ -6296,7 +6296,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "4-bit" + "4bit" ], "model_id": "okwinds/QwQ-32B-Preview-MLX-4bit", "model_hub": "modelscope" @@ -6305,7 +6305,7 @@ "model_format": "mlx", "model_size_in_billions": 32, "quantizations": [ - "8-bit" + "8bit" ], "model_id": "okwinds/QwQ-32B-Preview-MLX-8bit", "model_hub": "modelscope" diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 6cb22eafbc..0fb70bf088 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -192,8 +192,7 @@ def _get_prompt_cache(self, prompt, lora_name: Optional[str] = None): return prompt def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): - import mlx.core as mx - from mlx_lm.utils import generate_step + from mlx_lm.utils import make_sampler, stream_generate model = self._model model_uid = self.model_uid @@ -212,37 +211,30 @@ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name) - prompt_tokens = mx.array(prompt_token_ids) - input_echo_len = len(prompt_tokens) + input_echo_len = len(prompt_token_ids) i = 0 start = time.time() output = "" tokens = [] - for (token, _), i in zip( - generate_step( - prompt_tokens, + sampler = make_sampler(temp=kwargs["temperature"], top_p=kwargs["top_p"]) + for chunk_resp, i in zip( + stream_generate( model, - temp=kwargs["temperature"], + tokenizer, + prompt_token_ids, + max_tokens=max_tokens, + sampler=sampler, repetition_penalty=kwargs["repetition_penalty"], repetition_context_size=kwargs["repetition_context_size"], - top_p=kwargs["top_p"], prompt_cache=self._prompt_cache.cache, # type: ignore ), range(max_tokens), ): + token = chunk_resp.token tokens.append(token) - if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore - break - - # Yield the last segment if streaming - out = tokenizer.decode( - token, - skip_special_tokens=True, - spaces_between_special_tokens=False, - clean_up_tokenization_spaces=True, - ) + out = chunk_resp.text if stream: # this special character is mainly for qwen out = out.strip("�") @@ -266,6 +258,9 @@ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): total_tokens=(input_echo_len + i), ), completion_usage + if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore + break + logger.info( f"Average generation speed: {i / (time.time() - start):.2f} tokens/s." ) diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py index 2807dabb04..cae7fb494f 100644 --- a/xinference/model/llm/mlx/tests/test_mlx.py +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -32,7 +32,7 @@ def test_load_mlx(setup): model_engine="MLX", model_size_in_billions="0_5", model_format="mlx", - quantization="4-bit", + quantization="4bit", ) assert len(client.list_models()) == 1 model = client.get_model(model_uid)