From 2725fb4145b35857b0b550c8c7fc586eab88176a Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sun, 19 Jan 2025 08:21:16 -0500 Subject: [PATCH] Bump the minimum pipecat-ai version, update tests --- pyproject.toml | 2 +- tests/test_manager.py | 335 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 272 insertions(+), 65 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3a98910..3cfffa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ "Topic :: Multimedia :: Sound/Audio", ] dependencies = [ - "pipecat-ai>=0.0.50", + "pipecat-ai>=0.0.53", "loguru~=0.7.2", ] diff --git a/tests/test_manager.py b/tests/test_manager.py index f2b1724..c9d1b6e 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -46,6 +46,14 @@ async def asyncSetUp(self): self.mock_llm = MagicMock(spec=OpenAILLMService) self.mock_tts = AsyncMock() + # Create mock context aggregator + self.mock_context_aggregator = MagicMock() + self.mock_context_aggregator.user = MagicMock() + self.mock_context_aggregator.user.return_value = MagicMock() + self.mock_context_aggregator.user.return_value.get_context_frame = MagicMock( + return_value=MagicMock() + ) + # Sample node configurations self.sample_node = { "role_messages": [{"role": "system", "content": "You are a helpful test assistant."}], @@ -77,6 +85,7 @@ async def test_static_flow_initialization(self): flow_manager = FlowManager( task=self.mock_task, llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, tts=self.mock_tts, flow_config=self.static_flow_config, ) @@ -95,12 +104,12 @@ async def test_static_flow_initialization(self): # Verify the initial node was set self.assertEqual(flow_manager.current_node, "start") - # Verify the messages were queued with UpdateFrame + # Verify the messages were queued calls = self.mock_task.queue_frames.call_args_list - self.assertEqual(len(calls), 1) # Should be called once + self.assertEqual(len(calls), 2) # Should be called twice (context update and completion) - # Get the frames from the first call - frames = calls[0][0][0] # First call, first argument, which is the list of frames + # Get the frames from the first call (context update) + frames = calls[0][0][0] update_frames = [f for f in frames if isinstance(f, LLMMessagesUpdateFrame)] self.assertEqual(len(update_frames), 1) @@ -117,6 +126,7 @@ async def test_dynamic_flow_initialization(self): flow_manager = FlowManager( task=self.mock_task, llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, tts=self.mock_tts, transition_callback=mock_transition_callback, ) @@ -148,12 +158,15 @@ async def test_dynamic_flow_initialization(self): # Set initial node await flow_manager.set_node("initial", initial_node) - # Verify frames were queued - self.mock_task.queue_frames.assert_called_once() - frames = self.mock_task.queue_frames.call_args[0][0] + # Verify frames were queued twice (context update and completion trigger) + self.assertEqual(self.mock_task.queue_frames.call_count, 2) - # Should have exactly one UpdateFrame (since it's first node) - update_frames = [f for f in frames if isinstance(f, LLMMessagesUpdateFrame)] + # Get the first call (context update) + first_call = self.mock_task.queue_frames.call_args_list[0] + first_frames = first_call[0][0] + + # Verify UpdateFrame in first call + update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] self.assertEqual(len(update_frames), 1, "Should have exactly one UpdateFrame") # Verify message content @@ -162,17 +175,11 @@ async def test_dynamic_flow_initialization(self): self.assertEqual(actual_messages, expected_messages) async def test_static_flow_transitions(self): - """Test transitions in static flows. - - Verifies that: - 1. Static transitions correctly change the current node - 2. Node configuration is properly processed during transition - 3. Messages are sent using AppendFrame for non-initial nodes - """ - # Setup flow manager with static configuration + """Test transitions in static flows.""" flow_manager = FlowManager( task=self.mock_task, llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, tts=self.mock_tts, flow_config=self.static_flow_config, ) @@ -191,11 +198,14 @@ async def test_static_flow_transitions(self): self.assertEqual(flow_manager.current_node, "next_node") # Verify frame handling - self.mock_task.queue_frames.assert_called_once() - frames = self.mock_task.queue_frames.call_args[0][0] + self.assertEqual(self.mock_task.queue_frames.call_count, 2) + + # Get the first call (context update) + first_call = self.mock_task.queue_frames.call_args_list[0] + first_frames = first_call[0][0] - # Should have exactly one AppendFrame and one SetToolsFrame - append_frames = [f for f in frames if isinstance(f, LLMMessagesAppendFrame)] + # Should have exactly one AppendFrame + append_frames = [f for f in first_frames if isinstance(f, LLMMessagesAppendFrame)] self.assertEqual(len(append_frames), 1, "Should have exactly one AppendFrame") async def test_dynamic_flow_transitions(self): @@ -213,6 +223,7 @@ async def test_dynamic_flow_transitions(self): flow_manager = FlowManager( task=self.mock_task, llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, tts=self.mock_tts, transition_callback=mock_transition_callback, ) @@ -237,7 +248,11 @@ async def test_dynamic_flow_transitions(self): async def test_node_validation(self): """Test node configuration validation.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Test missing task_messages @@ -254,7 +269,11 @@ async def test_node_validation(self): async def test_function_registration(self): """Test function registration with LLM.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Reset mock to clear initialization calls @@ -271,7 +290,12 @@ async def test_function_registration(self): async def test_action_execution(self): """Test execution of pre and post actions.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm, tts=self.mock_tts) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + tts=self.mock_tts, + ) await flow_manager.initialize() # Create node config with actions @@ -302,7 +326,11 @@ async def test_error_handling(self): 2. Initialization fails properly when task queue fails 3. Node setting fails when task queue fails """ - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) # Test setting node before initialization with self.assertRaises(FlowTransitionError): @@ -322,7 +350,11 @@ async def test_error_handling(self): async def test_state_management(self): """Test state management across nodes.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Set state data @@ -338,7 +370,11 @@ async def test_state_management(self): async def test_multiple_function_registration(self): """Test registration of multiple functions.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Create node config with multiple functions @@ -365,7 +401,11 @@ async def test_multiple_function_registration(self): async def test_initialize_already_initialized(self): """Test initializing an already initialized flow manager.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Try to initialize again @@ -375,7 +415,11 @@ async def test_initialize_already_initialized(self): async def test_register_action(self): """Test registering custom actions.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) async def custom_action(action): pass @@ -385,7 +429,11 @@ async def custom_action(action): async def test_call_handler_variations(self): """Test different handler signature variations.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Test handler with args @@ -404,7 +452,11 @@ async def handler_no_args(): async def test_transition_func_error_handling(self): """Test error handling in transition functions.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() async def error_handler(args): @@ -430,7 +482,11 @@ async def result_callback(result): async def test_node_validation_edge_cases(self): """Test edge cases in node validation.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Test function with missing name @@ -479,7 +535,10 @@ async def failing_transition(function_name, args, flow_manager): raise ValueError("Transition error") flow_manager = FlowManager( - task=self.mock_task, llm=self.mock_llm, transition_callback=failing_transition + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + transition_callback=failing_transition, ) await flow_manager.initialize() @@ -494,7 +553,11 @@ async def result_callback(result): async def test_register_function_error_handling(self): """Test error handling in function registration.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Mock LLM to raise error on register_function @@ -506,7 +569,11 @@ async def test_register_function_error_handling(self): async def test_action_execution_error_handling(self): """Test error handling in action execution.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Create node config with actions that will fail @@ -530,7 +597,11 @@ async def test_action_execution_error_handling(self): async def test_update_llm_context_error_handling(self): """Test error handling in LLM context updates.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Mock task to raise error on queue_frames @@ -543,7 +614,11 @@ async def test_update_llm_context_error_handling(self): async def test_handler_callback_completion(self): """Test handler completion callback and logging.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() async def test_handler(args): @@ -564,7 +639,11 @@ async def result_callback(result): async def test_handler_removal_all_formats(self): """Test handler removal from different function configurations.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() async def dummy_handler(args): @@ -596,7 +675,11 @@ async def dummy_handler(args): async def test_function_declarations_processing(self): """Test processing of function declarations format.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() async def test_handler(args): @@ -636,7 +719,11 @@ async def test_handler(args): async def test_function_token_handling_main_module(self): """Test handling of __function__: tokens when function is in main module.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Define test handler in main module @@ -673,7 +760,11 @@ async def test_handler_main(args): async def test_function_token_handling_not_found(self): """Test error handling when function is not found in any module.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() node_config = { @@ -698,7 +789,11 @@ async def test_function_token_handling_not_found(self): async def test_function_token_execution(self): """Test that functions registered with __function__: token work when called.""" - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # Define and register test handler @@ -745,14 +840,12 @@ async def callback(result): delattr(sys.modules["__main__"], "test_handler") async def test_role_message_inheritance(self): - """Test that role messages are properly handled between nodes. - - Verifies: - 1. Role messages are included in first node - 2. Role messages are included in subsequent nodes - 3. Messages are combined correctly - """ - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + """Test that role messages are properly handled between nodes.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() # First node with role messages @@ -770,8 +863,8 @@ async def test_role_message_inheritance(self): # Set first node and verify UpdateFrame await flow_manager.set_node("first", first_node) - first_calls = self.mock_task.queue_frames.call_args_list[-1] - first_frames = first_calls[0][0] + first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call + first_frames = first_call[0][0] update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] self.assertEqual(len(update_frames), 1) @@ -784,8 +877,8 @@ async def test_role_message_inheritance(self): await flow_manager.set_node("second", second_node) # Verify AppendFrame for second node - second_calls = self.mock_task.queue_frames.call_args_list[-1] - second_frames = second_calls[0][0] + first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call + second_frames = first_call[0][0] append_frames = [f for f in second_frames if isinstance(f, LLMMessagesAppendFrame)] self.assertEqual(len(append_frames), 1) @@ -793,14 +886,12 @@ async def test_role_message_inheritance(self): self.assertEqual(append_frames[0].messages, second_node["task_messages"]) async def test_frame_type_selection(self): - """Test that correct frame types are used based on node order. - - Verifies: - 1. First node uses UpdateFrame - 2. Subsequent nodes use AppendFrame - 3. Frame content is correct - """ - flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) + """Test that correct frame types are used based on node order.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) await flow_manager.initialize() test_node = { @@ -810,7 +901,8 @@ async def test_frame_type_selection(self): # First node should use UpdateFrame await flow_manager.set_node("first", test_node) - first_frames = self.mock_task.queue_frames.call_args[0][0] + first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call + first_frames = first_call[0][0] self.assertTrue( any(isinstance(f, LLMMessagesUpdateFrame) for f in first_frames), "First node should use UpdateFrame", @@ -825,7 +917,8 @@ async def test_frame_type_selection(self): # Second node should use AppendFrame await flow_manager.set_node("second", test_node) - second_frames = self.mock_task.queue_frames.call_args[0][0] + first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call + second_frames = first_call[0][0] self.assertTrue( any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames), "Subsequent nodes should use AppendFrame", @@ -834,3 +927,117 @@ async def test_frame_type_selection(self): any(isinstance(f, LLMMessagesUpdateFrame) for f in second_frames), "Subsequent nodes should not use UpdateFrame", ) + + async def test_edge_vs_node_function_behavior(self): + """Test different completion behavior for edge and node functions.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) + await flow_manager.initialize() + + # Create test functions + async def test_handler(args): + return {"status": "success"} + + # Create node with both types of functions + node_config = { + "task_messages": [{"role": "system", "content": "Test"}], + "functions": [ + { + "type": "function", + "function": { + "name": "node_function", + "handler": test_handler, + "description": "Node function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "edge_function", + "handler": test_handler, + "description": "Edge function", + "parameters": {"type": "object", "properties": {}}, + "transition_to": "next_node", + }, + }, + ], + } + + await flow_manager.set_node("test", node_config) + + # Get the registered functions + node_func = None + edge_func = None + for args in self.mock_llm.register_function.call_args_list: + name = args[0][0] + func = args[0][1] + if name == "node_function": + node_func = func + elif name == "edge_function": + edge_func = func + + # Test node function + self.mock_task.queue_frames.reset_mock() + node_result = None + node_properties = None + + async def node_callback(result, *, properties=None): + nonlocal node_result, node_properties + node_result = result + node_properties = properties + + await node_func("node_function", "id1", {}, None, None, node_callback) + # Node function should not set run_llm=False + self.assertTrue(node_properties is None or node_properties.run_llm is not False) + + # Test edge function + self.mock_task.queue_frames.reset_mock() + edge_result = None + edge_properties = None + + async def edge_callback(result, *, properties=None): + nonlocal edge_result, edge_properties + edge_result = result + edge_properties = properties + + await edge_func("edge_function", "id2", {}, None, None, edge_callback) + # Edge function should set run_llm=False + self.assertTrue(edge_properties is not None and edge_properties.run_llm is False) + + async def test_completion_timing(self): + """Test that completions occur at the right time.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) + await flow_manager.initialize() + + # Test initial node setup + self.mock_task.queue_frames.reset_mock() + await flow_manager.set_node( + "initial", + { + "task_messages": [{"role": "system", "content": "Test"}], + "functions": [], + }, + ) + # Should see two calls: context update and completion trigger + self.assertEqual(self.mock_task.queue_frames.call_count, 2) + + # Add next node to flow manager's nodes + next_node = { + "task_messages": [{"role": "system", "content": "Next test"}], + "functions": [], + } + flow_manager.nodes["next"] = next_node + + # Test node transition + self.mock_task.queue_frames.reset_mock() + await flow_manager._handle_static_transition("next", {}, flow_manager) + # Should see two calls: context update and completion trigger + self.assertEqual(self.mock_task.queue_frames.call_count, 2)