Skip to content

Commit

Permalink
Ensure system messages are always passed first (#28)
Browse files Browse the repository at this point in the history
* ensure system messages are always passed first

* address regression with memory and streaming
  • Loading branch information
knowsuchagency authored Jan 3, 2025
1 parent 62ec197 commit 768dec0
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 11 deletions.
32 changes: 26 additions & 6 deletions promptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jsonschema import validate as validate_json_schema
from pydantic import BaseModel

__version__ = "3.0.0"
__version__ = "4.0.0"

SystemPrompt = Optional[Union[str, List[str], List[Dict[str, str]]]]

Expand Down Expand Up @@ -286,7 +286,9 @@ def wrapper(*args, **kwargs):

self.logger.debug(f"{return_type = }")

messages = [{"content": prompt_text, "role": "user"}]
# Create the user message
user_message = {"content": prompt_text, "role": "user"}
messages = [user_message]

if self.system:
if isinstance(self.system, str):
Expand All @@ -301,12 +303,30 @@ def wrapper(*args, **kwargs):
elif isinstance(self.system[0], dict):
messages = self.system + messages

# Store the user message in state before making the API call
# Store messages in state if enabled
if self.state:
if self.system:
# Add system messages to state if they're not already there
state_messages = self.state.get_messages()
if not state_messages or state_messages[0]["role"] != "system":
if isinstance(self.system, str):
self.state.add_message(
{"content": self.system, "role": "system"}
)
elif isinstance(self.system, list):
if isinstance(self.system[0], str):
for msg in self.system:
self.state.add_message(
{"content": msg, "role": "system"}
)
elif isinstance(self.system[0], dict):
for msg in self.system:
self.state.add_message(msg)

# Store user message before starting stream or regular completion
self.state.add_message(user_message)
history = self.state.get_messages()
self.state.add_message(messages[-1])
if history: # Add previous history if it exists
messages = history + messages
messages = history

# Add tools if any are registered
tools = None
Expand Down
135 changes: 133 additions & 2 deletions tests/test_promptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def test_memory_with_streaming(model):
memory=True,
state=state,
stream=True,
debug=True,
temperature=0,
timeout=5,
)
Expand All @@ -500,7 +499,7 @@ def simple_conversation(input_text: str) -> str:
response = "".join(list(response_stream))

# Verify first exchange is stored (both user and assistant messages)
assert len(state.get_messages()) == 2
assert len(state.get_messages()) == 2, f"Messages: {state.get_messages()}"
assert state.get_messages()[0]["role"] == "user"
assert state.get_messages()[1]["role"] == "assistant"
assert state.get_messages()[1]["content"] == response
Expand Down Expand Up @@ -974,3 +973,135 @@ def chat(message):
long_message = "Please analyze: " + ("lorem ipsum " * 100)
result = chat(long_message)
assert isinstance(result, str)


@pytest.mark.parametrize("model", CHEAP_MODELS)
def test_system_prompt_order(model):
"""Test that system prompts are always first in the message list"""
state = State()
system_prompt = "You are a helpful test assistant"

@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(ERRORS),
)
@llm(
model=model,
system=system_prompt,
state=state,
memory=True,
temperature=0,
timeout=5,
)
def chat(message):
"""Chat: {message}"""

# First interaction
chat("Hello")
messages = state.get_messages()

# Verify system message is first
assert messages[0]["role"] == "system"
assert messages[0]["content"] == system_prompt

# Second interaction should still have system message first
chat("How are you?")
messages = state.get_messages()
assert messages[0]["role"] == "system"
assert messages[0]["content"] == system_prompt

# Test with list of system prompts
state.clear()
system_prompts = [
"You are a helpful assistant",
"You always provide concise answers",
]

@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(ERRORS),
)
@llm(
model=model,
system=system_prompts,
state=state,
memory=True,
temperature=0,
timeout=5,
)
def chat2(message):
"""Chat: {message}"""

chat2("Hello")
messages = state.get_messages()

# Verify both system messages are first, in order
assert messages[0]["role"] == "system"
assert messages[0]["content"] == system_prompts[0]
assert messages[1]["role"] == "system"
assert messages[1]["content"] == system_prompts[1]


@pytest.mark.parametrize("model", CHEAP_MODELS)
def test_message_order_with_memory(model):
"""Test that messages maintain correct order with memory enabled"""
state = State()
system_prompts = [
"You are a helpful assistant",
"You always provide concise answers",
"You speak in a formal tone",
]

@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(ERRORS),
)
@llm(
model=model,
system=system_prompts,
state=state,
memory=True,
temperature=0,
timeout=5,
debug=True,
)
def chat(message):
"""Chat: {message}"""

# First interaction
chat("Hello")
messages = state.get_messages()

# Check initial message order
assert len(messages) == 5 # 3 system + 1 user + 1 assistant
assert messages[0]["role"] == "system"
assert messages[0]["content"] == system_prompts[0]
assert messages[1]["role"] == "system"
assert messages[1]["content"] == system_prompts[1]
assert messages[2]["role"] == "system"
assert messages[2]["content"] == system_prompts[2]
assert messages[3]["role"] == "user"
assert "Hello" in messages[3]["content"]
assert messages[4]["role"] == "assistant"

# Second interaction
chat("How are you?")
messages = state.get_messages()

# Check message order after second interaction
assert len(messages) == 7 # 3 system + 2 user + 2 assistant
# System messages should still be first
assert messages[0]["role"] == "system"
assert messages[0]["content"] == system_prompts[0]
assert messages[1]["role"] == "system"
assert messages[1]["content"] == system_prompts[1]
assert messages[2]["role"] == "system"
assert messages[2]["content"] == system_prompts[2]
# First interaction messages
assert messages[3]["role"] == "user"
assert "Hello" in messages[3]["content"]
assert messages[4]["role"] == "assistant"
# Second interaction messages
assert messages[5]["role"] == "user"
assert "How are you?" in messages[5]["content"]
assert messages[6]["role"] == "assistant"
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 768dec0

Please sign in to comment.