diff --git a/haystack_experimental/components/generators/chat/__init__.py b/haystack_experimental/components/generators/chat/__init__.py new file mode 100644 index 00000000..594bd56e --- /dev/null +++ b/haystack_experimental/components/generators/chat/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.generators.chat.openai import ( # noqa: I001 (otherwise we end up with partial imports) + OpenAIChatGenerator, +) + +__all__ = [ + "OpenAIChatGenerator", +] diff --git a/haystack_experimental/components/generators/chat/openai.py b/haystack_experimental/components/generators/chat/openai.py new file mode 100644 index 00000000..445421ba --- /dev/null +++ b/haystack_experimental/components/generators/chat/openai.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +from haystack import component, default_from_dict, logging +from haystack.components.generators.chat.openai import OpenAIChatGenerator as OpenAIChatGeneratorBase +from haystack.dataclasses import StreamingChunk +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace +from openai import Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice + +from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall + +logger = logging.getLogger(__name__) + + +def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a message to the format expected by OpenAI's Chat API. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + elif len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + openai_msg: Dict[str, Any] = {"role": message._role.value} + + if tool_call_results: + result = tool_call_results[0] + if result.origin.id is None: + raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.") + openai_msg["content"] = result.result + openai_msg["tool_call_id"] = result.origin.id + return openai_msg + + if text_contents: + openai_msg["content"] = text_contents[0] + if tool_calls: + openai_tool_calls = [] + for tc in tool_calls: + if tc.id is None: + raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.") + openai_tool_calls.append( + { + "id": tc.id, + "type": "function", + "function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments)}, + } + ) + openai_msg["tool_calls"] = openai_tool_calls + return openai_msg + + +@component +class OpenAIChatGenerator(OpenAIChatGeneratorBase): + """ + Completes chats using OpenAI's large language models (LLMs). + + It works with the gpt-4 and gpt-3.5-turbo models and supports streaming responses + from OpenAI API. It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) + format in input and output. + + You can customize how the text is generated by passing parameters to the + OpenAI API. Use the `**generation_kwargs` argument when you initialize + the component or when you run it. Any parameter that works with + `openai.ChatCompletion.create` will work here too. + + For details on OpenAI API parameters, see + [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat). + + ### Usage example + + ```python + from haystack_experimental.components.generators.chat import OpenAIChatGenerator + from haystack_experimental.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + + client = OpenAIChatGenerator() + response = client.run(messages) + print(response) + ``` + Output: + ``` + {'replies': [ + ChatMessage(_role=, + _content=[TextContent(text='Natural Language Processing (NLP) is a field of artificial ...')], + _meta={'model': 'gpt-3.5-turbo-0125', 'index': 0, 'finish_reason': 'stop', + 'usage': {'completion_tokens': 71, 'prompt_tokens': 13, 'total_tokens': 84}} + ) + ] + } + ``` + """ + + def __init__( # noqa: PLR0913 + self, + api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), + model: str = "gpt-3.5-turbo", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + api_base_url: Optional[str] = None, + organization: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + max_retries: Optional[int] = None, + tools: Optional[List[Tool]] = None, + tools_strict: bool = False, + ): + """ + Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's GPT-3.5. + + Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' + environment variables to override the `timeout` and `max_retries` parameters respectively + in the OpenAI client. + + :param api_key: The OpenAI API key. + You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter + during initialization. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) + as an argument. + :param api_base_url: An optional base URL. + :param organization: Your organization ID, defaults to `None`. See + [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). + :param generation_kwargs: Other parameters to use for the model. These parameters are sent directly to + the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for + more details. + Some of the supported parameters: + - `max_tokens`: The maximum number of tokens the output text can have. + - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. + Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. + - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. For example, 0.1 means only the tokens + comprising the top 10% probability mass are considered. + - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, + it will generate two completions for each of the three prompts, ending up with 6 completions in total. + - `stop`: One or more sequences after which the LLM should stop generating tokens. + - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the + values are the bias to add to that token. + :param timeout: + Timeout for OpenAI client calls. If not set, it defaults to either the + `OPENAI_TIMEOUT` environment variable, or 30 seconds. + :param max_retries: + Maximum number of retries to contact OpenAI after an internal error. + If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5. + :param tools: + A list of tools for which the model can prepare calls. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly + the schema provided in the `parameters` field of the tool definition, but this may increase latency. + """ + self.tools = tools + self.tools_strict = tools_strict + + super(OpenAIChatGenerator, self).__init__( + api_key=api_key, + model=model, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + organization=organization, + generation_kwargs=generation_kwargs, + timeout=timeout, + max_retries=max_retries, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + serialized = super(OpenAIChatGenerator, self).to_dict() + serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None + serialized["init_parameters"]["tools_strict"] = self.tools_strict + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + + tools = init_params.get("tools") + if tools: + init_params["tools"] = [Tool.from_dict(tool) for tool in tools] + + return default_from_dict(cls, data) + + @component.output_types(replies=List[ChatMessage]) + def run( # noqa: PLR0913 + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + tools_strict: Optional[bool] = None, + ): + """ + Invokes chat completion based on the provided messages and generation parameters. + + :param messages: A list of ChatMessage instances representing the input messages. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + override the parameters passed during component initialization. + For details on OpenAI API parameters, see + [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly + the schema provided in the `parameters` field of the tool definition, but this may increase latency. + If set, it will override the `tools_strict` parameter set during component initialization. + + :returns: + A list containing the generated responses as ChatMessage instances. + """ + + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + # check if streaming_callback is passed + streaming_callback = streaming_callback or self.streaming_callback + + # adapt ChatMessage(s) to the format expected by the OpenAI API + openai_formatted_messages = [_convert_message_to_openai_format(message) for message in messages] + + tools = tools or self.tools + tools_strict = tools_strict if tools_strict is not None else self.tools_strict + + openai_tools = None + if tools: + openai_tools = [{"type": "function", "function": {**t.tool_spec, "strict": tools_strict}} for t in tools] + + chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( + model=self.model, + messages=openai_formatted_messages, # type: ignore[arg-type] # openai expects list of specific message types + stream=streaming_callback is not None, + tools=openai_tools, # type: ignore[arg-type] + **generation_kwargs, + ) + + completions: List[ChatMessage] = [] + # if streaming is enabled, the completion is a Stream of ChatCompletionChunk + if isinstance(chat_completion, Stream): + num_responses = generation_kwargs.pop("n", 1) + if num_responses > 1: + raise ValueError("Cannot stream multiple responses, please set n=1.") + chunks: List[StreamingChunk] = [] + chunk = None + + # pylint: disable=not-an-iterable + for chunk in chat_completion: + if chunk.choices and streaming_callback: + chunk_delta: StreamingChunk = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) + chunks.append(chunk_delta) + streaming_callback(chunk_delta) # invoke callback with the chunk_delta + completions = [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + # if streaming is disabled, the completion is a ChatCompletion + elif isinstance(chat_completion, ChatCompletion): + completions = [ + self._convert_chat_completion_to_chat_message(chat_completion, choice) + for choice in chat_completion.choices + ] + + # before returning, do post-processing of the completions + for message in completions: + self._check_finish_reason(message) + + return {"replies": completions} + + def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: + """ + Connects the streaming chunks into a single ChatMessage. + + :param chunk: The last chunk returned by the OpenAI API. + :param chunks: The list of all `StreamingChunk` objects. + """ + + text = "".join([chunk.content for chunk in chunks]) + tool_calls = [] + + # if it's a tool call , we need to build the payload dict from all the chunks + if bool(chunks[0].meta.get("tool_calls")): + tools_len = len(chunks[0].meta.get("tool_calls", [])) + + payloads = [{"arguments": "", "name": ""} for _ in range(tools_len)] + for chunk_payload in chunks: + deltas = chunk_payload.meta.get("tool_calls") or [] + + # deltas is a list of ChoiceDeltaToolCall or ChoiceDeltaFunctionCall + for i, delta in enumerate(deltas): + payloads[i]["id"] = delta.id or payloads[i].get("id", "") + if delta.function: + payloads[i]["name"] += delta.function.name or "" + payloads[i]["arguments"] += delta.function.arguments or "" + + for payload in payloads: + arguments_str = payload["arguments"] + try: + arguments = json.loads(arguments_str) + tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "OpenAI returned a malformed JSON string for tool call arguments. This tool call " + "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=payload["id"], + _name=payload["name"], + _arguments=arguments_str, + ) + + meta = { + "model": chunk.model, + "index": 0, + "finish_reason": chunk.choices[0].finish_reason, + "usage": {}, # we don't have usage data for streaming responses + } + + return ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + + def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage: + """ + Converts the non-streaming response from the OpenAI API to a ChatMessage. + + :param completion: The completion returned by the OpenAI API. + :param choice: The choice returned by the OpenAI API. + :return: The ChatMessage. + """ + message: ChatCompletionMessage = choice.message + text = message.content or "" + tool_calls = [] + if openai_tool_calls := message.tool_calls: + for openai_tc in openai_tool_calls: + arguments_str = openai_tc.function.arguments + try: + arguments = json.loads(arguments_str) + tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "OpenAI returned a malformed JSON string for tool call arguments. This tool call " + "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=openai_tc.id, + _name=openai_tc.function.name, + _arguments=arguments_str, + ) + + chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) + chat_message._meta.update( + { + "model": completion.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "usage": dict(completion.usage or {}), + } + ) + return chat_message + + def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: + """ + Converts the streaming response chunk from the OpenAI API to a StreamingChunk. + + :param chunk: The chunk returned by the OpenAI API. + :param choice: The choice returned by the OpenAI API. + :return: The StreamingChunk. + """ + # we stream the content of the chunk if it's not a tool or function call + choice: ChunkChoice = chunk.choices[0] + content = choice.delta.content or "" + chunk_message = StreamingChunk(content) + # but save the tool calls and function call in the meta if they are present + # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method + chunk_message.meta.update( + { + "model": chunk.model, + "index": choice.index, + "tool_calls": choice.delta.tool_calls, + "finish_reason": choice.finish_reason, + } + ) + return chunk_message diff --git a/haystack_experimental/dataclasses/chat_message.py b/haystack_experimental/dataclasses/chat_message.py index ed1d833a..9eec35b2 100644 --- a/haystack_experimental/dataclasses/chat_message.py +++ b/haystack_experimental/dataclasses/chat_message.py @@ -190,9 +190,6 @@ def from_assistant( :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ - if not text and not tool_calls: - raise ValueError("At least one of `text` or `tool_calls` must be provided.") - content: List[ChatMessageContentT] = [] if text: content.append(TextContent(text=text)) diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py new file mode 100644 index 00000000..afb5410b --- /dev/null +++ b/test/components/generators/chat/test_openai.py @@ -0,0 +1,642 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import MagicMock, patch +import pytest + +from typing import Iterator +import logging +import os +import json +from datetime import datetime + +from openai import OpenAIError +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_tool_call import Function +from openai.types.chat import chat_completion_chunk +from openai import Stream + +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent +from haystack_experimental.components.generators.chat.openai import OpenAIChatGenerator, _convert_message_to_openai_format, OpenAIChatGeneratorBase + + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France"), + ] + +class MockStream(Stream[ChatCompletionChunk]): + def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): + client = client or MagicMock() + super().__init__(client=client, *args, **kwargs) + self.mock_chunk = mock_chunk + + def __stream__(self) -> Iterator[ChatCompletionChunk]: + # Yielding only one ChatCompletionChunk object + yield self.mock_chunk + +@pytest.fixture +def mock_chat_completion_chunk(): + """ + Mock the OpenAI API completion chunk response and reuse it for tests + """ + + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletionChunk( + id="foo", + model="gpt-4", + object="chat.completion.chunk", + choices=[ + chat_completion_chunk.Choice( + finish_reason="stop", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant") + ) + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + mock_chat_completion_create.return_value = MockStream(completion, cast_to=None, response=None, client=None) + yield mock_chat_completion_create + +@pytest.fixture +def mock_chat_completion(): + """ + Mock the OpenAI API completion response and reuse it for tests + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + logprobs=None, + index=0, + message=ChatCompletionMessage(content="Hello world!", role="assistant"), + ) + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + +@pytest.fixture +def mock_chat_completion_chunk_with_tools(): + """ + Mock the OpenAI API completion chunk response and reuse it for tests + """ + + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletionChunk( + id="foo", + model="gpt-4", + object="chat.completion.chunk", + choices=[ + chat_completion_chunk.Choice( + finish_reason="tool_calls", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta( + role="assistant", + tool_calls=[chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id="123", type="function", function=chat_completion_chunk.ChoiceDeltaToolCallFunction(name="weather", arguments='{"city": "Paris"}') + )]) + ) + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + mock_chat_completion_create.return_value = MockStream(completion, cast_to=None, response=None, client=None) + yield mock_chat_completion_create + +@pytest.fixture +def tools(): + tool_parameters = { + "type": "object", + "properties": { + "city": {"type": "string"} + }, + "required": ["city"] +} + tool = Tool(name="weather", description="useful to determine the weather in a given location", + parameters=tool_parameters, function=lambda x:x) + + return [tool] + + + +class TestOpenAIChatGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + component = OpenAIChatGenerator() + assert component.client.api_key == "test-api-key" + assert component.model == "gpt-3.5-turbo" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.client.timeout == 30 + assert component.client.max_retries == 5 + assert component.tools is None + assert not component.tools_strict + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with pytest.raises(ValueError): + OpenAIChatGenerator() + + def test_init_with_parameters(self, monkeypatch): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) + + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), + model="gpt-4", + streaming_callback=print_streaming_chunk, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + timeout=40.0, + max_retries=1, + tools=[tool], + tools_strict=True, + ) + assert component.client.api_key == "test-api-key" + assert component.model == "gpt-4" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.client.timeout == 40.0 + assert component.client.max_retries == 1 + assert component.tools == [tool] + assert component.tools_strict + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), + model="gpt-4", + streaming_callback=print_streaming_chunk, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + assert component.client.api_key == "test-api-key" + assert component.model == "gpt-4" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.client.timeout == 100.0 + assert component.client.max_retries == 10 + + def test_to_dict_default(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + component = OpenAIChatGenerator() + data = component.to_dict() + assert data == { + "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-3.5-turbo", + "organization": None, + "streaming_callback": None, + "api_base_url": None, + "generation_kwargs": {}, + "tools": None, + "tools_strict": False, + }, + } + + def test_to_dict_with_parameters(self, monkeypatch): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + monkeypatch.setenv("ENV_VAR", "test-api-key") + component = OpenAIChatGenerator( + api_key=Secret.from_env_var("ENV_VAR"), + model="gpt-4", + streaming_callback=print_streaming_chunk, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools = [tool], + tools_strict=True, + ) + data = component.to_dict() + + assert data == { + "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "model": "gpt-4", + "organization": None, + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + 'tools': [ + { + 'description': 'description', + 'function': 'builtins.print', + 'name': 'name', + 'parameters': { + 'x': { + 'type': 'string', + }, + }, + }, + ], + 'tools_strict': True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + component = OpenAIChatGenerator( + model="gpt-4", + streaming_callback=lambda x: x, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-4", + "organization": None, + "api_base_url": "test-base-url", + "streaming_callback": "test_openai.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": None, + "tools_strict": False, + + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + data = { + "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-4", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + 'tools': [ + { + 'description': 'description', + 'function': 'builtins.print', + 'name': 'name', + 'parameters': { + 'x': { + 'type': 'string', + }, + }, + }, + ], + 'tools_strict': True, + }, + } + component = OpenAIChatGenerator.from_dict(data) + + assert isinstance(component, OpenAIChatGenerator) + assert component.model == "gpt-4" + assert component.streaming_callback is print_streaming_chunk + assert component.api_base_url == "test-base-url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") + assert component.tools == [Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)] + assert component.tools_strict + + def test_from_dict_fail_wo_env_var(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + data = { + "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-4", + "organization": None, + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + with pytest.raises(ValueError): + OpenAIChatGenerator.from_dict(data) + + def test_run(self, chat_messages, mock_chat_completion): + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the OpenAI API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params_streaming(self, chat_messages, mock_chat_completion_chunk): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback + ) + response = component.run(chat_messages) + + # check we called the streaming callback + assert streaming_callback_called + + # check that the component still returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk + + def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + response = component.run(chat_messages, streaming_callback=streaming_callback) + + # check we called the streaming callback + assert streaming_callback_called + + # check that the component still returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk + + def test_check_abnormal_completions(self, caplog): + caplog.set_level(logging.INFO) + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + messages = [ + ChatMessage.from_assistant( + "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} + ) + for i, _ in enumerate(range(4)) + ] + + for m in messages: + component._check_finish_reason(m) + + # check truncation warning + message_template = ( + "The completion for index {index} has been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions." + ) + + for index in [1, 3]: + assert caplog.records[index].message == message_template.format(index=index) + + # check content filter warning + message_template = "The completion for index {index} has been truncated due to the content filter." + for index in [0, 2]: + assert caplog.records[index].message == message_template.format(index=index) + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + chat_messages = [ChatMessage.from_user("What's the capital of France")] + component = OpenAIChatGenerator(generation_kwargs={"n": 1}) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + assert "gpt-3.5" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = OpenAIChatGenerator(model="something-obviously-wrong") + with pytest.raises(OpenAIError): + component.run(chat_messages) + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 + + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + + callback = Callback() + component = OpenAIChatGenerator(streaming_callback=callback) + results = component.run([ChatMessage.from_user("What's the capital of France?")]) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + + assert "gpt-3.5" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + assert callback.counter > 1 + assert "Paris" in callback.responses + + + + + def test_convert_message_to_openai_format(self): + message = ChatMessage.from_system("You are good assistant") + assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert _convert_message_to_openai_format(message) == {"role": "assistant", "content": "I have an answer"} + + message = ChatMessage.from_assistant(tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})]) + assert _convert_message_to_openai_format(message) == {"role": "assistant", "tool_calls": [{"id": "123", "type": "function", "function": {"name": "weather", "arguments": '{"city": "Paris"}'}}]} + + tool_result=json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool(tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})) + assert _convert_message_to_openai_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + + def test_convert_message_to_openai_invalid(self): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_message_to_openai_format(message) + + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")]) + with pytest.raises(ValueError): + _convert_message_to_openai_format(message) + + tool_call_null_id = ToolCall(id=None, tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call_null_id]) + with pytest.raises(ValueError): + _convert_message_to_openai_format(message) + + message = ChatMessage.from_tool(tool_result="result", origin=tool_call_null_id) + with pytest.raises(ValueError): + _convert_message_to_openai_format(message) + + + def test_run_with_tools(self, tools): + + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4", + object="chat.completion", + choices=[ + Choice( + finish_reason="tool_calls", + logprobs=None, + index=0, + message=ChatCompletionMessage(role="assistant", + tool_calls=[ChatCompletionMessageToolCall( + id="123", type="function", function=Function(name="weather", arguments='{"city": "Paris"}'))]) + ) + ], + created=int(datetime.now().timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) + response = component.run([ChatMessage.from_user("What's the weather like in Paris?")]) + + + assert len(response["replies"]) == 1 + message = response["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + + def test_run_with_tools_streaming(self, mock_chat_completion_chunk_with_tools, tools): + + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback + ) + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + response = component.run(chat_messages, tools=tools) + + # check we called the streaming callback + assert streaming_callback_called + + # check that the component still returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + message = response["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + + def test_invalid_tool_call_json(self, tools, caplog): + caplog.set_level(logging.WARNING) + + with patch("openai.resources.chat.completions.Completions.create") as mock_create: + mock_create.return_value = ChatCompletion( + id="test", + model="gpt-4", + object="chat.completion", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall(id="1", type="function", function=Function(name="weather", arguments='"invalid": "json"')), + ] + ) + ) + ], + created=1234567890, + usage={"prompt_tokens": 50, "completion_tokens": 30, "total_tokens": 80} + ) + + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) + response = component.run([ChatMessage.from_user("What's the weather in Paris?")]) + + assert len(response["replies"]) == 1 + message = response["replies"][0] + assert len(message.tool_calls) == 0 + assert "OpenAI returned a malformed JSON string for tool call arguments" in caplog.text + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = OpenAIChatGenerator(tools=tools) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls"