From 768dec022df60f677728c604661751d29afbddf3 Mon Sep 17 00:00:00 2001 From: Stephan Fitzpatrick Date: Thu, 2 Jan 2025 22:39:29 -0800 Subject: [PATCH] Ensure system messages are always passed first (#28) * ensure system messages are always passed first * address regression with memory and streaming --- promptic.py | 32 ++++++++-- tests/test_promptic.py | 135 ++++++++++++++++++++++++++++++++++++++++- uv.lock | 6 +- 3 files changed, 162 insertions(+), 11 deletions(-) diff --git a/promptic.py b/promptic.py index 16d089b..d2f97da 100644 --- a/promptic.py +++ b/promptic.py @@ -10,7 +10,7 @@ from jsonschema import validate as validate_json_schema from pydantic import BaseModel -__version__ = "3.0.0" +__version__ = "4.0.0" SystemPrompt = Optional[Union[str, List[str], List[Dict[str, str]]]] @@ -286,7 +286,9 @@ def wrapper(*args, **kwargs): self.logger.debug(f"{return_type = }") - messages = [{"content": prompt_text, "role": "user"}] + # Create the user message + user_message = {"content": prompt_text, "role": "user"} + messages = [user_message] if self.system: if isinstance(self.system, str): @@ -301,12 +303,30 @@ def wrapper(*args, **kwargs): elif isinstance(self.system[0], dict): messages = self.system + messages - # Store the user message in state before making the API call + # Store messages in state if enabled if self.state: + if self.system: + # Add system messages to state if they're not already there + state_messages = self.state.get_messages() + if not state_messages or state_messages[0]["role"] != "system": + if isinstance(self.system, str): + self.state.add_message( + {"content": self.system, "role": "system"} + ) + elif isinstance(self.system, list): + if isinstance(self.system[0], str): + for msg in self.system: + self.state.add_message( + {"content": msg, "role": "system"} + ) + elif isinstance(self.system[0], dict): + for msg in self.system: + self.state.add_message(msg) + + # Store user message before starting stream or regular completion + self.state.add_message(user_message) history = self.state.get_messages() - self.state.add_message(messages[-1]) - if history: # Add previous history if it exists - messages = history + messages + messages = history # Add tools if any are registered tools = None diff --git a/tests/test_promptic.py b/tests/test_promptic.py index d0d4567..d9c3b20 100644 --- a/tests/test_promptic.py +++ b/tests/test_promptic.py @@ -480,7 +480,6 @@ def test_memory_with_streaming(model): memory=True, state=state, stream=True, - debug=True, temperature=0, timeout=5, ) @@ -500,7 +499,7 @@ def simple_conversation(input_text: str) -> str: response = "".join(list(response_stream)) # Verify first exchange is stored (both user and assistant messages) - assert len(state.get_messages()) == 2 + assert len(state.get_messages()) == 2, f"Messages: {state.get_messages()}" assert state.get_messages()[0]["role"] == "user" assert state.get_messages()[1]["role"] == "assistant" assert state.get_messages()[1]["content"] == response @@ -974,3 +973,135 @@ def chat(message): long_message = "Please analyze: " + ("lorem ipsum " * 100) result = chat(long_message) assert isinstance(result, str) + + +@pytest.mark.parametrize("model", CHEAP_MODELS) +def test_system_prompt_order(model): + """Test that system prompts are always first in the message list""" + state = State() + system_prompt = "You are a helpful test assistant" + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(ERRORS), + ) + @llm( + model=model, + system=system_prompt, + state=state, + memory=True, + temperature=0, + timeout=5, + ) + def chat(message): + """Chat: {message}""" + + # First interaction + chat("Hello") + messages = state.get_messages() + + # Verify system message is first + assert messages[0]["role"] == "system" + assert messages[0]["content"] == system_prompt + + # Second interaction should still have system message first + chat("How are you?") + messages = state.get_messages() + assert messages[0]["role"] == "system" + assert messages[0]["content"] == system_prompt + + # Test with list of system prompts + state.clear() + system_prompts = [ + "You are a helpful assistant", + "You always provide concise answers", + ] + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(ERRORS), + ) + @llm( + model=model, + system=system_prompts, + state=state, + memory=True, + temperature=0, + timeout=5, + ) + def chat2(message): + """Chat: {message}""" + + chat2("Hello") + messages = state.get_messages() + + # Verify both system messages are first, in order + assert messages[0]["role"] == "system" + assert messages[0]["content"] == system_prompts[0] + assert messages[1]["role"] == "system" + assert messages[1]["content"] == system_prompts[1] + + +@pytest.mark.parametrize("model", CHEAP_MODELS) +def test_message_order_with_memory(model): + """Test that messages maintain correct order with memory enabled""" + state = State() + system_prompts = [ + "You are a helpful assistant", + "You always provide concise answers", + "You speak in a formal tone", + ] + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(ERRORS), + ) + @llm( + model=model, + system=system_prompts, + state=state, + memory=True, + temperature=0, + timeout=5, + debug=True, + ) + def chat(message): + """Chat: {message}""" + + # First interaction + chat("Hello") + messages = state.get_messages() + + # Check initial message order + assert len(messages) == 5 # 3 system + 1 user + 1 assistant + assert messages[0]["role"] == "system" + assert messages[0]["content"] == system_prompts[0] + assert messages[1]["role"] == "system" + assert messages[1]["content"] == system_prompts[1] + assert messages[2]["role"] == "system" + assert messages[2]["content"] == system_prompts[2] + assert messages[3]["role"] == "user" + assert "Hello" in messages[3]["content"] + assert messages[4]["role"] == "assistant" + + # Second interaction + chat("How are you?") + messages = state.get_messages() + + # Check message order after second interaction + assert len(messages) == 7 # 3 system + 2 user + 2 assistant + # System messages should still be first + assert messages[0]["role"] == "system" + assert messages[0]["content"] == system_prompts[0] + assert messages[1]["role"] == "system" + assert messages[1]["content"] == system_prompts[1] + assert messages[2]["role"] == "system" + assert messages[2]["content"] == system_prompts[2] + # First interaction messages + assert messages[3]["role"] == "user" + assert "Hello" in messages[3]["content"] + assert messages[4]["role"] == "assistant" + # Second interaction messages + assert messages[5]["role"] == "user" + assert "How are you?" in messages[5]["content"] + assert messages[6]["role"] == "assistant" diff --git a/uv.lock b/uv.lock index fd32036..d8a04d8 100644 --- a/uv.lock +++ b/uv.lock @@ -196,7 +196,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -739,7 +739,7 @@ wheels = [ [[package]] name = "promptic" -version = "2.3.1" +version = "4.0.0" source = { editable = "." } dependencies = [ { name = "jsonschema" }, @@ -1212,7 +1212,7 @@ name = "tqdm" version = "4.67.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e8/4f/0153c21dc5779a49a0598c445b1978126b1344bab9ee71e53e44877e14e0/tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a", size = 169739 } wheels = [