Skip to content

Commit

Permalink
Update LLMs to support prompt logprobs use-case (#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb authored Jan 17, 2025
1 parent 5257600 commit 74cc09e
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 70 deletions.
5 changes: 2 additions & 3 deletions src/distilabel/models/base_clients/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import (
TYPE_CHECKING,
Optional,
Union,
)

from pydantic import (
Expand Down Expand Up @@ -143,9 +142,9 @@ def load(self) -> None: # noqa: C901
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)

@property
def model_name(self) -> Union[str, None]: # type: ignore
def model_name(self) -> str:
"""Returns the model name used for the model."""
return (
return ( # type: ignore
self.model_display_name
or self._model_name
or self.model_id
Expand Down
63 changes: 35 additions & 28 deletions src/distilabel/models/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _get_structured_output(

async def _generate_with_text_generation(
self,
input: FormattedInput,
input: str,
max_new_tokens: int = 128,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
Expand All @@ -287,13 +287,12 @@ async def _generate_with_text_generation(
return_full_text: bool = False,
seed: Optional[int] = None,
watermark: bool = False,
structured_output: Union[Dict[str, Any], None] = None,
) -> GenerateOutput:
input, structured_output = self._get_structured_output(input)
prompt = self.prepare_input(input)
generation: Union["TextGenerationOutput", None] = None
try:
generation = await self._aclient.text_generation( # type: ignore
prompt=prompt,
prompt=input,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
Expand All @@ -319,7 +318,9 @@ async def _generate_with_text_generation(
)
return prepare_output(
generations=[generation.generated_text] if generation else [None],
input_tokens=[compute_tokens(prompt, self._tokenizer.encode)], # type: ignore
input_tokens=[
compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1
],
output_tokens=[
generation.details.generated_tokens
if generation and generation.details
Expand Down Expand Up @@ -544,37 +545,43 @@ async def agenerate( # type: ignore
"""
stop_sequences = self._check_stop_sequences(stop_sequences)

if self.tokenizer_id is None:
return await self._generate_with_chat_completion(
input=input, # type: ignore
if isinstance(input, str) or self.tokenizer_id is not None:
structured_output = None
if not isinstance(input, str):
input, structured_output = self._get_structured_output(input)
input = self.prepare_input(input)

return await self._generate_with_text_generation(
input=input,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
presence_penalty=presence_penalty,
seed=seed,
stop_sequences=stop_sequences,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_n_tokens=top_n_tokens,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
return_full_text=return_full_text,
seed=seed,
watermark=watermark,
structured_output=structured_output,
)

return await self._generate_with_text_generation(
input=input,
return await self._generate_with_chat_completion(
input=input, # type: ignore
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
presence_penalty=presence_penalty,
seed=seed,
stop_sequences=stop_sequences,
temperature=temperature,
top_n_tokens=top_n_tokens,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
return_full_text=return_full_text,
seed=seed,
watermark=watermark,
)
143 changes: 130 additions & 13 deletions src/distilabel/models/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import orjson
from pydantic import PositiveInt, validate_call
from pydantic import NonNegativeInt, PositiveInt, validate_call

from distilabel import envs
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
Expand All @@ -29,10 +29,18 @@
from openai.types import Batch as OpenAIBatch
from openai.types import FileObject as OpenAIFileObject
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from openai.types.chat.chat_completion import Choice as OpenAIChoice
from openai.types.chat.chat_completion import Choice as OpenAIChatCompletionChoice
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import (
CompletionChoice as OpenAICompletionChoice,
)

from distilabel.typing import LLMStatistics, Logprob
from distilabel.typing.models import (
LLMStatistics,
Logprob,
StandardInput,
StructuredInput,
)


_OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
Expand Down Expand Up @@ -148,15 +156,17 @@ async def agenerate( # type: ignore
self,
input: FormattedInput,
num_generations: int = 1,
max_new_tokens: int = 128,
max_new_tokens: NonNegativeInt = 128,
logprobs: bool = False,
top_logprobs: Optional[PositiveInt] = None,
echo: bool = False,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
stop: Optional[Union[str, List[str]]] = None,
response_format: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the OpenAI async
client.
Expand All @@ -170,6 +180,8 @@ async def agenerate( # type: ignore
logprobs: whether to return the log probabilities or not. Defaults to `False`.
top_logprobs: the number of top log probabilities to return per output token
generated. Defaults to `None`.
echo: whether to echo the input in the response or not. It's only used if the
`input` argument is an `str`. Defaults to `False`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`.
presence_penalty: the presence penalty to use for the generation. Defaults to
Expand All @@ -182,14 +194,115 @@ async def agenerate( # type: ignore
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
for more information on how to use the JSON model from OpenAI. Defaults to None
which returns text. To return JSON, use {"type": "json_object"}.
Note:
If response_format
extra_body: an optional dictionary containing extra body parameters that will
be sent to the OpenAI API endpoint. Defaults to `None`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""

if isinstance(input, str):
return await self._generate_completion(
input=input,
num_generations=num_generations,
max_new_tokens=max_new_tokens,
echo=echo,
top_logprobs=top_logprobs,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
extra_body=extra_body,
)

return await self._generate_chat_completion(
input=input,
num_generations=num_generations,
max_new_tokens=max_new_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
stop=stop,
response_format=response_format,
extra_body=extra_body,
)

async def _generate_completion(
self,
input: str,
num_generations: int = 1,
max_new_tokens: int = 128,
echo: bool = False,
top_logprobs: Optional[PositiveInt] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
completion = await self._aclient.completions.create(
prompt=input,
echo=echo,
model=self.model,
n=num_generations,
max_tokens=max_new_tokens,
logprobs=top_logprobs,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
extra_body=extra_body,
)

generations = []
logprobs = []
for choice in completion.choices:
generations.append(choice.text)
if choice_logprobs := self._get_logprobs_from_completion_choice(choice):
logprobs.append(choice_logprobs)

statistics = self._get_llm_statistics(completion)
return prepare_output(
generations=generations,
input_tokens=statistics["input_tokens"],
output_tokens=statistics["output_tokens"],
logprobs=logprobs,
)

def _get_logprobs_from_completion_choice(
self, choice: "OpenAICompletionChoice"
) -> Union[List[Union[List["Logprob"], None]], None]:
if choice.logprobs is None or choice.logprobs.top_logprobs is None:
return None

return [
[
{"token": token, "logprob": token_logprob}
for token, token_logprob in logprobs.items()
]
if logprobs is not None
else None
for logprobs in choice.logprobs.top_logprobs
]

async def _generate_chat_completion(
self,
input: Union["StandardInput", "StructuredInput"],
num_generations: int = 1,
max_new_tokens: int = 128,
logprobs: bool = False,
top_logprobs: Optional[PositiveInt] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
stop: Optional[Union[str, List[str]]] = None,
response_format: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
Expand All @@ -215,9 +328,11 @@ async def agenerate( # type: ignore
"temperature": temperature,
"top_p": top_p,
"stop": stop,
"extra_body": extra_body,
}
# Check if it's a vision generation task, in that case "stop" cannot be used or raises
# an error in the API.

# Checks if any message contains an image, in that case "stop" cannot be used or
# raises an error in the API.
if isinstance(
[row for row in input if row["role"] == "user"][0]["content"], list
):
Expand All @@ -235,7 +350,7 @@ async def agenerate( # type: ignore
# NOTE: `instructor` doesn't work with `n` parameter, so it will always return
# only 1 choice.
statistics = self._get_llm_statistics(completion._raw_response)
if choice_logprobs := self._get_logprobs_from_choice(
if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
completion._raw_response.choices[0]
):
output_logprobs = [choice_logprobs]
Expand Down Expand Up @@ -270,7 +385,9 @@ def _generations_from_openai_completion(
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
if choice_logprobs := self._get_logprobs_from_choice(choice):
if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
choice
):
logprobs.append(choice_logprobs)

statistics = self._get_llm_statistics(completion)
Expand All @@ -281,8 +398,8 @@ def _generations_from_openai_completion(
logprobs=logprobs,
)

def _get_logprobs_from_choice(
self, choice: "OpenAIChoice"
def _get_logprobs_from_chat_completion_choice(
self, choice: "OpenAIChatCompletionChoice"
) -> Union[List[List["Logprob"]], None]:
if choice.logprobs is None or choice.logprobs.content is None:
return None
Expand Down
Loading

0 comments on commit 74cc09e

Please sign in to comment.