From e16ef18596d81077cf46f0d623b67e7f02d58cd2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 30 Jul 2024 13:18:29 +0100 Subject: [PATCH] feat: Add `InMemoryChatMessageStore` and `ChatMessageStore` (#49) * Add InMemoryChatMessageStore and ChatMessageStore without filters * Minor fix * PR feedback, simplify, use ABC instead of Protocol * Simplify test files * Fix header --------- Co-authored-by: Julian Risch --- .../chat_message_stores/__init__.py | 7 ++ .../chat_message_stores/in_memory.py | 86 +++++++++++++++++++ .../chat_message_stores/types.py | 74 ++++++++++++++++ test/chat_message_stores/__init__.py | 3 + .../test_in_memory_chat_message_store.py | 86 +++++++++++++++++++ test/components/__init__.py | 2 +- 6 files changed, 257 insertions(+), 1 deletion(-) create mode 100644 haystack_experimental/chat_message_stores/__init__.py create mode 100644 haystack_experimental/chat_message_stores/in_memory.py create mode 100644 haystack_experimental/chat_message_stores/types.py create mode 100644 test/chat_message_stores/__init__.py create mode 100644 test/chat_message_stores/test_in_memory_chat_message_store.py diff --git a/haystack_experimental/chat_message_stores/__init__.py b/haystack_experimental/chat_message_stores/__init__.py new file mode 100644 index 00000000..a0a07a14 --- /dev/null +++ b/haystack_experimental/chat_message_stores/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore + +_all_ = ["InMemoryChatMessageStore"] diff --git a/haystack_experimental/chat_message_stores/in_memory.py b/haystack_experimental/chat_message_stores/in_memory.py new file mode 100644 index 00000000..2be43d7a --- /dev/null +++ b/haystack_experimental/chat_message_stores/in_memory.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Iterable, List + +from haystack import default_from_dict, default_to_dict, logging +from haystack.dataclasses import ChatMessage + +from haystack_experimental.chat_message_stores.types import ChatMessageStore + +logger = logging.getLogger(__name__) + + +class InMemoryChatMessageStore(ChatMessageStore): + """ + Stores chat messages in-memory. + """ + + def __init__( + self, + ): + """ + Initializes the InMemoryChatMessageStore. + """ + self.messages = [] + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) + + def count_messages(self) -> int: + """ + Returns the number of chat messages stored. + + :returns: The number of messages. + """ + return len(self.messages) + + def write_messages(self, messages: List[ChatMessage]) -> int: + """ + Writes chat messages to the ChatMessageStore. + + :param messages: A list of ChatMessages to write. + :returns: The number of messages written. + + :raises ValueError: If messages is not a list of ChatMessages. + """ + if not isinstance(messages, Iterable) or any(not isinstance(message, ChatMessage) for message in messages): + raise ValueError("Please provide a list of ChatMessages.") + + self.messages.extend(messages) + return len(messages) + + def delete_messages(self) -> None: + """ + Deletes all stored chat messages. + """ + self.messages = [] + + def retrieve(self) -> List[ChatMessage]: + """ + Retrieves all stored chat messages. + + :returns: A list of chat messages. + """ + return self.messages diff --git a/haystack_experimental/chat_message_stores/types.py b/haystack_experimental/chat_message_stores/types.py new file mode 100644 index 00000000..5300a1eb --- /dev/null +++ b/haystack_experimental/chat_message_stores/types.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from haystack import logging +from haystack.dataclasses import ChatMessage + +logger = logging.getLogger(__name__) + + +class ChatMessageStore(ABC): + """ + Stores ChatMessages to be used by the components of a Pipeline. + + Classes implementing this protocol might store ChatMessages either in durable storage or in memory. They might + allow specialized components (e.g. retrievers) to perform retrieval on them, either by embedding, by keyword, + hybrid, and so on, depending on the backend used. + + In order to write or retrieve chat messages, consider using a ChatMessageWriter or ChatMessageRetriever. + """ + + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this store to a dictionary. + + :returns: The serialized store as a dictionary. + """ + + @classmethod + @abstractmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageStore": + """ + Deserializes the store from a dictionary. + + :param data: The dictionary to deserialize from. + :returns: The deserialized store. + """ + + @abstractmethod + def count_messages(self) -> int: + """ + Returns the number of chat messages stored. + + :returns: The number of messages. + """ + + @abstractmethod + def write_messages(self, messages: List[ChatMessage]) -> int: + """ + Writes chat messages to the ChatMessageStore. + + :param messages: A list of ChatMessages to write. + :returns: The number of messages written. + + :raises ValueError: If messages is not a list of ChatMessages. + """ + + @abstractmethod + def delete_messages(self) -> None: + """ + Deletes all stored chat messages. + """ + + @abstractmethod + def retrieve(self) -> List[ChatMessage]: + """ + Retrieves all stored chat messages. + + :returns: A list of chat messages. + """ diff --git a/test/chat_message_stores/__init__.py b/test/chat_message_stores/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/chat_message_stores/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/chat_message_stores/test_in_memory_chat_message_store.py b/test/chat_message_stores/test_in_memory_chat_message_store.py new file mode 100644 index 00000000..47ef9e85 --- /dev/null +++ b/test/chat_message_stores/test_in_memory_chat_message_store.py @@ -0,0 +1,86 @@ +from haystack.dataclasses import ChatMessage + +from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore + + +class TestInMemoryChatMessageStore: + + def test_init(self): + """ + Test that the InMemoryChatMessageStore can be initialized and that it works as expected. + """ + store = InMemoryChatMessageStore() + assert store.count_messages() == 0 + assert store.retrieve() == [] + assert store.write_messages([]) == 0 + assert not store.delete_messages() + + def test_to_dict(self): + """ + Test that the InMemoryChatMessageStore can be serialized to a dictionary. + """ + store = InMemoryChatMessageStore() + assert store.to_dict() == { + "init_parameters": {}, + "type": "haystack_experimental.chat_message_stores.in_memory.InMemoryChatMessageStore" + } + + def test_from_dict(self): + """ + Test that the InMemoryChatMessageStore can be deserialized from a dictionary. + """ + data = { + "init_parameters": {}, + "type": "haystack_experimental.chat_message_stores.in_memory.InMemoryChatMessageStore" + } + store = InMemoryChatMessageStore.from_dict(data) + assert store.to_dict() == data + + def test_count_messages(self): + """ + Test that the InMemoryChatMessageStore can count the number of messages in the store correctly. + """ + store = InMemoryChatMessageStore() + assert store.count_messages() == 0 + store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + assert store.count_messages() == 1 + store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")]) + assert store.count_messages() == 2 + store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")]) + assert store.count_messages() == 3 + + def test_retrieve(self): + """ + Test that the InMemoryChatMessageStore can retrieve all messages from the store correctly. + """ + store = InMemoryChatMessageStore() + assert store.retrieve() == [] + store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + assert store.retrieve() == [ChatMessage.from_user(content="Hello, how can I help you?")] + store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")]) + assert store.retrieve() == [ + ChatMessage.from_user(content="Hello, how can I help you?"), + ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), + ] + store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")]) + assert store.retrieve() == [ + ChatMessage.from_user(content="Hello, how can I help you?"), + ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?"), + ] + + def test_delete_messages(self): + """ + Test that the InMemoryChatMessageStore can delete all messages from the store correctly. + """ + store = InMemoryChatMessageStore() + assert store.count_messages() == 0 + store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")]) + assert store.count_messages() == 1 + store.delete_messages() + assert store.count_messages() == 0 + store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")]) + store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")]) + assert store.count_messages() == 2 + store.delete_messages() + assert store.count_messages() == 0 diff --git a/test/components/__init__.py b/test/components/__init__.py index 3f4ac9d8..c1764a6e 100644 --- a/test/components/__init__.py +++ b/test/components/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0