Skip to content

Commit

Permalink
Added structured generation support to MlxLLM using Outlines
Browse files Browse the repository at this point in the history
  • Loading branch information
dameikle committed Jan 22, 2025
1 parent 43d1bb1 commit 626b222
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 33 deletions.
63 changes: 56 additions & 7 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,29 @@
validate_call,
)

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import GenerateOutput, StandardInput
from distilabel.typing import (
StandardInput,
GenerateOutput,
OutlinesStructuredOutputType,
)

if TYPE_CHECKING:
import mlx.nn as nn
from mlx_lm.tokenizer_utils import TokenizerWrapper


class MlxModel:
"""Wrapper class providing a consistent interface for MLX models."""

def __init__(self, model: Any, tokenizer: Any):
self.model = model
self.tokenizer = tokenizer


class MlxLLM(LLM, MagpieChatTemplateMixin):
"""Apple MLX LLM implementation.
Expand Down Expand Up @@ -75,9 +88,13 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
tokenizer_config: Dict[str, Any] = Field(default_factory=dict)
mlx_model_config: Dict[str, Any] = Field(default_factory=dict)
adapter_path: Optional[str] = None

structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)
_model: Optional["nn.Module"] = PrivateAttr(None)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None)
_wrapped_model: Optional[Any] = PrivateAttr(None)
_mlx_generate: Optional[Callable] = PrivateAttr(None)
_make_sampler: Optional[Callable] = PrivateAttr(None)

Expand All @@ -99,6 +116,7 @@ def load(self) -> None:
model_config=self.mlx_model_config,
adapter_path=self.adapter_path,
)
self._wrapped_model = MlxModel(self._model, self._tokenizer)

if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
Expand Down Expand Up @@ -193,22 +211,25 @@ def generate( # type: ignore
min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
)
structured_output = None
result = []
for input in inputs:
if isinstance(input, tuple):
input, structured_output = input

output: List[str] = []
for _ in range(num_generations):
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)

configured_processors = list(logits_processors or [])
if self.structured_output:
structured_processors = self._prepare_structured_output(self.structured_output)
configured_processors.extend(structured_processors)

prompt = self.prepare_input(input)
generation = self._mlx_generate( # type: ignore
prompt=prompt,
model=self._model,
tokenizer=self._tokenizer,
logits_processors=logits_processors,
logits_processors=configured_processors,
max_tokens=max_tokens,
sampler=sampler,
max_kv_size=max_kv_size,
Expand All @@ -219,7 +240,6 @@ def generate( # type: ignore
quantized_kv_start=quantized_kv_start,
prompt_progress_callback=prompt_progress_callback,
)

output.append(generation)

result.append(
Expand All @@ -236,3 +256,32 @@ def generate( # type: ignore
)
)
return result


def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> List[Callable]:
"""Creates the appropriate function to filter tokens to generate structured outputs."""
if structured_output is None:
return []

from distilabel.steps.tasks.structured_outputs.outlines import prepare_guided_output
result = prepare_guided_output(structured_output, "mlx", self._wrapped_model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema

base_processor = result["processor"]

def mlx_processor(tokens: Any, logits: Any) -> Any:
# Handle both single and batch inputs uniformly
is_single = logits.shape[0] == 1
working_logits = logits[0, :] if is_single else logits[:, -1]

# Process the logits
logits_flat = working_logits.reshape(-1)
processed_logits = base_processor(tokens, logits_flat)

# Reshape back to original format
return processed_logits.reshape(1, -1)

return [mlx_processor]
42 changes: 16 additions & 26 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import importlib.util
import inspect
Expand All @@ -29,18 +15,18 @@
)

from pydantic import BaseModel

from distilabel.errors import DistilabelUserError

from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict

if TYPE_CHECKING: # noqa
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa
from distilabel.models.llms.mlx import MlxModel #noqa

from distilabel.typing import OutlinesStructuredOutputType # noqa

Frameworks = Literal["transformers", "llamacpp", "vllm"]
from distilabel.typing import OutlinesStructuredOutputType # noqa
Frameworks = Literal["transformers", "llamacpp", "vllm", "mlx"]


def _is_outlines_version_below_0_1_0() -> bool:
Expand Down Expand Up @@ -101,6 +87,11 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"mlx": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}

if framework not in processors:
Expand All @@ -115,27 +106,27 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:


def _get_tokenizer_from_model(
llm: Union["_vLLM", "Pipeline", "Llama"],
framework: Frameworks,
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"],
framework: Frameworks,
) -> Callable:
if framework == "llamacpp":
from outlines.models.llamacpp import LlamaCppTokenizer

return LlamaCppTokenizer(llm)
if framework == "transformers":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)
if framework == "vllm":
from outlines.models.vllm import adapt_tokenizer

return adapt_tokenizer(llm.get_tokenizer())
if framework == "mlx":
from outlines.models.transformers import TransformerTokenizer
return TransformerTokenizer(llm.tokenizer)


def prepare_guided_output(
structured_output: "OutlinesStructuredOutputType",
framework: Frameworks,
llm: Union["_vLLM", "Pipeline", "Llama"],
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"],
) -> Dict[str, Any]:
"""Prepares the `LLM` to generate guided output using `outlines`.
Expand All @@ -156,7 +147,6 @@ def prepare_guided_output(
case of "json" will also include the schema as a dict, to simplify serialization
and deserialization.
"""

json_processor, regex_processor = _get_logits_processor(framework)

format = structured_output.get("format")
Expand Down Expand Up @@ -196,4 +186,4 @@ def prepare_guided_output(
raise DistilabelUserError(
f"Invalid format '{format}'. Must be either 'json' or 'regex'.",
page="sections/how_to_guides/advanced/structured_generation/",
)
)

0 comments on commit 626b222

Please sign in to comment.