-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added structured generation support to MlxLLM using Outlines #1108
base: develop
Are you sure you want to change the base?
Changes from all commits
626b222
e24470c
ad91b0b
6e5d545
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,16 +28,29 @@ | |
validate_call, | ||
) | ||
|
||
from distilabel.mixins.runtime_parameters import RuntimeParameter | ||
from distilabel.models.llms.base import LLM | ||
from distilabel.models.llms.utils import compute_tokens, prepare_output | ||
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin | ||
from distilabel.typing import GenerateOutput, StandardInput | ||
from distilabel.typing import ( | ||
GenerateOutput, | ||
OutlinesStructuredOutputType, | ||
StandardInput, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import mlx.nn as nn | ||
from mlx_lm.tokenizer_utils import TokenizerWrapper | ||
|
||
|
||
class MlxModel: | ||
"""Wrapper class providing a consistent interface for MLX models.""" | ||
|
||
def __init__(self, model: Any, tokenizer: Any): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
|
||
class MlxLLM(LLM, MagpieChatTemplateMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should also be able to pass structured output format during the class init. |
||
"""Apple MLX LLM implementation. | ||
|
||
|
@@ -52,6 +65,8 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): | |
sent to the LLM to generate an instruction or a follow up user message. Valid | ||
values are "llama3", "qwen2" or another pre-query template provided. Defaults | ||
to `None`. | ||
structured_output: a dictionary containing the structured output configuration or if more | ||
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. | ||
|
||
Icon: | ||
`:apple:` | ||
|
@@ -69,15 +84,41 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): | |
# Call the model | ||
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]]) | ||
``` | ||
|
||
Generate structured data: | ||
|
||
```python | ||
from pathlib import Path | ||
from distilabel.models.llms import MlxLLM | ||
from pydantic import BaseModel | ||
|
||
class User(BaseModel): | ||
first_name: str | ||
last_name: str | ||
|
||
llm = MlxLLM( | ||
path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", | ||
structured_output={"format": "json", "schema": User}, | ||
) | ||
|
||
llm.load() | ||
|
||
# Call the model | ||
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for John Smith"}]]) | ||
``` | ||
""" | ||
|
||
path_or_hf_repo: str | ||
tokenizer_config: Dict[str, Any] = Field(default_factory=dict) | ||
mlx_model_config: Dict[str, Any] = Field(default_factory=dict) | ||
adapter_path: Optional[str] = None | ||
|
||
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( | ||
default=None, | ||
description="The structured output format to use across all the generations.", | ||
) | ||
_model: Optional["nn.Module"] = PrivateAttr(None) | ||
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None) | ||
_wrapped_model: Optional[Any] = PrivateAttr(None) | ||
_mlx_generate: Optional[Callable] = PrivateAttr(None) | ||
_make_sampler: Optional[Callable] = PrivateAttr(None) | ||
|
||
|
@@ -99,6 +140,7 @@ def load(self) -> None: | |
model_config=self.mlx_model_config, | ||
adapter_path=self.adapter_path, | ||
) | ||
self._wrapped_model = MlxModel(self._model, self._tokenizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would create this during the load of the class. |
||
|
||
if self._tokenizer.pad_token is None: | ||
self._tokenizer.pad_token = self._tokenizer.eos_token | ||
|
@@ -193,22 +235,26 @@ def generate( # type: ignore | |
min_tokens_to_keep=min_tokens_to_keep, | ||
top_k=top_k, | ||
) | ||
structured_output = None | ||
result = [] | ||
for input in inputs: | ||
if isinstance(input, tuple): | ||
input, structured_output = input | ||
|
||
output: List[str] = [] | ||
for _ in range(num_generations): | ||
if structured_output: # will raise a NotImplementedError | ||
self._prepare_structured_output(structured_output) | ||
configured_processors = list(logits_processors or []) | ||
if self.structured_output: | ||
structured_processors = self._prepare_structured_output( | ||
self.structured_output | ||
) | ||
configured_processors.extend(structured_processors) | ||
|
||
prompt = self.prepare_input(input) | ||
generation = self._mlx_generate( # type: ignore | ||
prompt=prompt, | ||
model=self._model, | ||
tokenizer=self._tokenizer, | ||
logits_processors=logits_processors, | ||
logits_processors=configured_processors, | ||
max_tokens=max_tokens, | ||
sampler=sampler, | ||
max_kv_size=max_kv_size, | ||
|
@@ -219,7 +265,6 @@ def generate( # type: ignore | |
quantized_kv_start=quantized_kv_start, | ||
prompt_progress_callback=prompt_progress_callback, | ||
) | ||
|
||
output.append(generation) | ||
|
||
result.append( | ||
|
@@ -236,3 +281,34 @@ def generate( # type: ignore | |
) | ||
) | ||
return result | ||
|
||
def _prepare_structured_output( | ||
self, structured_output: Optional[OutlinesStructuredOutputType] = None | ||
) -> List[Callable]: | ||
"""Creates the appropriate function to filter tokens to generate structured outputs.""" | ||
if structured_output is None: | ||
return [] | ||
|
||
from distilabel.steps.tasks.structured_outputs.outlines import ( | ||
prepare_guided_output, | ||
) | ||
|
||
result = prepare_guided_output(structured_output, "mlx", self._wrapped_model) | ||
if (schema := result.get("schema")) and self.structured_output: | ||
self.structured_output["schema"] = schema | ||
|
||
base_processor = result["processor"] | ||
|
||
def mlx_processor(tokens: Any, logits: Any) -> Any: | ||
# Handle both single and batch inputs uniformly | ||
is_single = logits.shape[0] == 1 | ||
working_logits = logits[0, :] if is_single else logits[:, -1] | ||
|
||
# Process the logits | ||
logits_flat = working_logits.reshape(-1) | ||
processed_logits = base_processor(tokens, logits_flat) | ||
|
||
# Reshape back to original format | ||
return processed_logits.reshape(1, -1) | ||
|
||
return [mlx_processor] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,10 +37,11 @@ | |
from llama_cpp import Llama # noqa | ||
from transformers import Pipeline # noqa | ||
from vllm import LLM as _vLLM # noqa | ||
from distilabel.models.llms.mlx import MlxModel # noqa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would import this class from |
||
|
||
from distilabel.typing import OutlinesStructuredOutputType # noqa | ||
from distilabel.typing import OutlinesStructuredOutputType # noqa | ||
|
||
Frameworks = Literal["transformers", "llamacpp", "vllm"] | ||
Frameworks = Literal["transformers", "llamacpp", "vllm", "mlx"] | ||
|
||
|
||
def _is_outlines_version_below_0_1_0() -> bool: | ||
|
@@ -101,6 +102,11 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: | |
"JSONLogitsProcessor", | ||
"RegexLogitsProcessor", | ||
), | ||
"mlx": ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is mlx not implemented for outlines below 0.1? |
||
"outlines.processors", | ||
"JSONLogitsProcessor", | ||
"RegexLogitsProcessor", | ||
), | ||
} | ||
|
||
if framework not in processors: | ||
|
@@ -115,7 +121,7 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: | |
|
||
|
||
def _get_tokenizer_from_model( | ||
llm: Union["_vLLM", "Pipeline", "Llama"], | ||
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"], | ||
framework: Frameworks, | ||
) -> Callable: | ||
if framework == "llamacpp": | ||
|
@@ -130,12 +136,16 @@ def _get_tokenizer_from_model( | |
from outlines.models.vllm import adapt_tokenizer | ||
|
||
return adapt_tokenizer(llm.get_tokenizer()) | ||
if framework == "mlx": | ||
from outlines.models.transformers import TransformerTokenizer | ||
|
||
return TransformerTokenizer(llm.tokenizer) | ||
|
||
|
||
def prepare_guided_output( | ||
structured_output: "OutlinesStructuredOutputType", | ||
framework: Frameworks, | ||
llm: Union["_vLLM", "Pipeline", "Llama"], | ||
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"], | ||
) -> Dict[str, Any]: | ||
"""Prepares the `LLM` to generate guided output using `outlines`. | ||
|
||
|
@@ -156,7 +166,6 @@ def prepare_guided_output( | |
case of "json" will also include the schema as a dict, to simplify serialization | ||
and deserialization. | ||
""" | ||
|
||
json_processor, regex_processor = _get_logits_processor(framework) | ||
|
||
format = structured_output.get("format") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,12 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import platform | ||
from typing import Any, Dict, Generator | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from distilabel.models.llms.mlx import MlxLLM | ||
|
||
|
@@ -63,6 +64,47 @@ def test_generate(self, llm: MlxLLM) -> None: | |
assert "input_tokens" in statistics | ||
assert "output_tokens" in statistics | ||
|
||
def test_structured_generation_json(self, llm: MlxLLM) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have structured generation tests in |
||
class User(BaseModel): | ||
first_name: str | ||
last_name: str | ||
|
||
llm.structured_output = {"format": "json", "schema": User.model_json_schema()} | ||
|
||
responses = llm.generate( | ||
inputs=[ | ||
[{"role": "user", "content": "Create a user profile for John Smith"}], | ||
], | ||
num_generations=1, | ||
) | ||
|
||
assert len(responses) == 1 | ||
assert "generations" in responses[0] | ||
assert "statistics" in responses[0] | ||
generations = responses[0]["generations"] | ||
assert len(generations) == 1 | ||
|
||
# Clean and parse the generation | ||
for generation in generations: | ||
# Remove the <|im_end|> tokens and clean up the string | ||
cleaned_json = generation.replace("<|im_end|>", "").strip() | ||
try: | ||
user_data = json.loads(cleaned_json) | ||
parsed_user = User(**user_data) | ||
assert isinstance(parsed_user, User) | ||
assert parsed_user.first_name == "John" | ||
assert parsed_user.last_name == "Smith" | ||
except json.JSONDecodeError as e: | ||
print(f"JSON parsing error: {e}") | ||
print(f"Raw generation: {cleaned_json}") | ||
raise | ||
except ValueError as e: | ||
print(f"Validation error: {e}") | ||
raise | ||
statistics = responses[0]["statistics"] | ||
assert "input_tokens" in statistics | ||
assert "output_tokens" in statistics | ||
|
||
@pytest.mark.parametrize( | ||
"structured_output, dump", | ||
[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can import this from
outlines.models.mlxlm