Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added structured generation support to MlxLLM using Outlines #1108

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 83 additions & 7 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,29 @@
validate_call,
)

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import GenerateOutput, StandardInput
from distilabel.typing import (
GenerateOutput,
OutlinesStructuredOutputType,
StandardInput,
)

if TYPE_CHECKING:
import mlx.nn as nn
from mlx_lm.tokenizer_utils import TokenizerWrapper


class MlxModel:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can import this from outlines.models.mlxlm

"""Wrapper class providing a consistent interface for MLX models."""

def __init__(self, model: Any, tokenizer: Any):
self.model = model
self.tokenizer = tokenizer


class MlxLLM(LLM, MagpieChatTemplateMixin):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also be able to pass structured output format during the class init.

"""Apple MLX LLM implementation.

Expand All @@ -52,6 +65,8 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.

Icon:
`:apple:`
Expand All @@ -69,15 +84,41 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```

Generate structured data:

```python
from pathlib import Path
from distilabel.models.llms import MlxLLM
from pydantic import BaseModel

class User(BaseModel):
first_name: str
last_name: str

llm = MlxLLM(
path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
structured_output={"format": "json", "schema": User},
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for John Smith"}]])
```
"""

path_or_hf_repo: str
tokenizer_config: Dict[str, Any] = Field(default_factory=dict)
mlx_model_config: Dict[str, Any] = Field(default_factory=dict)
adapter_path: Optional[str] = None

structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)
_model: Optional["nn.Module"] = PrivateAttr(None)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None)
_wrapped_model: Optional[Any] = PrivateAttr(None)
_mlx_generate: Optional[Callable] = PrivateAttr(None)
_make_sampler: Optional[Callable] = PrivateAttr(None)

Expand All @@ -99,6 +140,7 @@ def load(self) -> None:
model_config=self.mlx_model_config,
adapter_path=self.adapter_path,
)
self._wrapped_model = MlxModel(self._model, self._tokenizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would create this during the load of the class.


if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
Expand Down Expand Up @@ -193,22 +235,26 @@ def generate( # type: ignore
min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
)
structured_output = None
result = []
for input in inputs:
if isinstance(input, tuple):
input, structured_output = input

output: List[str] = []
for _ in range(num_generations):
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)
configured_processors = list(logits_processors or [])
if self.structured_output:
structured_processors = self._prepare_structured_output(
self.structured_output
)
configured_processors.extend(structured_processors)

prompt = self.prepare_input(input)
generation = self._mlx_generate( # type: ignore
prompt=prompt,
model=self._model,
tokenizer=self._tokenizer,
logits_processors=logits_processors,
logits_processors=configured_processors,
max_tokens=max_tokens,
sampler=sampler,
max_kv_size=max_kv_size,
Expand All @@ -219,7 +265,6 @@ def generate( # type: ignore
quantized_kv_start=quantized_kv_start,
prompt_progress_callback=prompt_progress_callback,
)

output.append(generation)

result.append(
Expand All @@ -236,3 +281,34 @@ def generate( # type: ignore
)
)
return result

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> List[Callable]:
"""Creates the appropriate function to filter tokens to generate structured outputs."""
if structured_output is None:
return []

from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(structured_output, "mlx", self._wrapped_model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema

base_processor = result["processor"]

def mlx_processor(tokens: Any, logits: Any) -> Any:
# Handle both single and batch inputs uniformly
is_single = logits.shape[0] == 1
working_logits = logits[0, :] if is_single else logits[:, -1]

# Process the logits
logits_flat = working_logits.reshape(-1)
processed_logits = base_processor(tokens, logits_flat)

# Reshape back to original format
return processed_logits.reshape(1, -1)

return [mlx_processor]
19 changes: 14 additions & 5 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa
from distilabel.models.llms.mlx import MlxModel # noqa

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would import this class from outlines.model.mlxlm to avoid code duplication


from distilabel.typing import OutlinesStructuredOutputType # noqa
from distilabel.typing import OutlinesStructuredOutputType # noqa

Frameworks = Literal["transformers", "llamacpp", "vllm"]
Frameworks = Literal["transformers", "llamacpp", "vllm", "mlx"]


def _is_outlines_version_below_0_1_0() -> bool:
Expand Down Expand Up @@ -101,6 +102,11 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"mlx": (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is mlx not implemented for outlines below 0.1?

"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}

if framework not in processors:
Expand All @@ -115,7 +121,7 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:


def _get_tokenizer_from_model(
llm: Union["_vLLM", "Pipeline", "Llama"],
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"],
framework: Frameworks,
) -> Callable:
if framework == "llamacpp":
Expand All @@ -130,12 +136,16 @@ def _get_tokenizer_from_model(
from outlines.models.vllm import adapt_tokenizer

return adapt_tokenizer(llm.get_tokenizer())
if framework == "mlx":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)


def prepare_guided_output(
structured_output: "OutlinesStructuredOutputType",
framework: Frameworks,
llm: Union["_vLLM", "Pipeline", "Llama"],
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"],
) -> Dict[str, Any]:
"""Prepares the `LLM` to generate guided output using `outlines`.

Expand All @@ -156,7 +166,6 @@ def prepare_guided_output(
case of "json" will also include the schema as a dict, to simplify serialization
and deserialization.
"""

json_processor, regex_processor = _get_logits_processor(framework)

format = structured_output.get("format")
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/models/llms/test_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import platform
from typing import Any, Dict, Generator

import pytest
from pydantic import BaseModel

from distilabel.models.llms.mlx import MlxLLM

Expand Down Expand Up @@ -63,6 +64,47 @@ def test_generate(self, llm: MlxLLM) -> None:
assert "input_tokens" in statistics
assert "output_tokens" in statistics

def test_structured_generation_json(self, llm: MlxLLM) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have structured generation tests in tests/unit/steps/tasks/structured_outputs/test_outlines.py could you add/integrate this test there?

class User(BaseModel):
first_name: str
last_name: str

llm.structured_output = {"format": "json", "schema": User.model_json_schema()}

responses = llm.generate(
inputs=[
[{"role": "user", "content": "Create a user profile for John Smith"}],
],
num_generations=1,
)

assert len(responses) == 1
assert "generations" in responses[0]
assert "statistics" in responses[0]
generations = responses[0]["generations"]
assert len(generations) == 1

# Clean and parse the generation
for generation in generations:
# Remove the <|im_end|> tokens and clean up the string
cleaned_json = generation.replace("<|im_end|>", "").strip()
try:
user_data = json.loads(cleaned_json)
parsed_user = User(**user_data)
assert isinstance(parsed_user, User)
assert parsed_user.first_name == "John"
assert parsed_user.last_name == "Smith"
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw generation: {cleaned_json}")
raise
except ValueError as e:
print(f"Validation error: {e}")
raise
statistics = responses[0]["statistics"]
assert "input_tokens" in statistics
assert "output_tokens" in statistics

@pytest.mark.parametrize(
"structured_output, dump",
[
Expand Down