Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and dameikle committed Jan 22, 2025
1 parent 626b222 commit e24470c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
15 changes: 9 additions & 6 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import (
StandardInput,
GenerateOutput,
OutlinesStructuredOutputType,
StandardInput,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -218,10 +218,11 @@ def generate( # type: ignore

output: List[str] = []
for _ in range(num_generations):

configured_processors = list(logits_processors or [])
if self.structured_output:
structured_processors = self._prepare_structured_output(self.structured_output)
structured_processors = self._prepare_structured_output(
self.structured_output
)
configured_processors.extend(structured_processors)

prompt = self.prepare_input(input)
Expand Down Expand Up @@ -257,15 +258,17 @@ def generate( # type: ignore
)
return result


def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> List[Callable]:
"""Creates the appropriate function to filter tokens to generate structured outputs."""
if structured_output is None:
return []

from distilabel.steps.tasks.structured_outputs.outlines import prepare_guided_output
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(structured_output, "mlx", self._wrapped_model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema
Expand Down
29 changes: 24 additions & 5 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import importlib.util
import inspect
Expand All @@ -15,17 +29,18 @@
)

from pydantic import BaseModel
from distilabel.errors import DistilabelUserError

from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict

if TYPE_CHECKING: # noqa
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa
from distilabel.models.llms.mlx import MlxModel #noqa
from distilabel.models.llms.mlx import MlxModel # noqa

from distilabel.typing import OutlinesStructuredOutputType # noqa

Frameworks = Literal["transformers", "llamacpp", "vllm", "mlx"]


Expand Down Expand Up @@ -106,20 +121,24 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:


def _get_tokenizer_from_model(
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"],
framework: Frameworks,
llm: Union["_vLLM", "Pipeline", "Llama", "MlxModel"],
framework: Frameworks,
) -> Callable:
if framework == "llamacpp":
from outlines.models.llamacpp import LlamaCppTokenizer

return LlamaCppTokenizer(llm)
if framework == "transformers":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)
if framework == "vllm":
from outlines.models.vllm import adapt_tokenizer

return adapt_tokenizer(llm.get_tokenizer())
if framework == "mlx":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)


Expand Down Expand Up @@ -186,4 +205,4 @@ def prepare_guided_output(
raise DistilabelUserError(
f"Invalid format '{format}'. Must be either 'json' or 'regex'.",
page="sections/how_to_guides/advanced/structured_generation/",
)
)

0 comments on commit e24470c

Please sign in to comment.