diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d2d4db..4e46a21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,10 @@ def create_node() -> NodeConfig: - Updated dynamic flow examples to use the new `transition_callback` pattern. +### Fixed + +- Fixed an issue where multiple, consecutive function calls could result in two completions. + ## [0.0.11] - 2025-01-19 ### Changed diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index d9878ff..497af4e 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -98,6 +98,7 @@ def __init__( self.adapter = create_adapter(llm) self.initialized = False self._context_aggregator = context_aggregator + self._pending_function_calls = 0 # Set up static or dynamic mode if flow_config: @@ -258,6 +259,47 @@ async def _create_transition_func( is_edge_function = bool(transition_to) or bool(transition_callback) + def decrease_pending_function_calls() -> None: + """Decrease the pending function calls counter if greater than zero.""" + if self._pending_function_calls > 0: + self._pending_function_calls -= 1 + logger.debug( + f"Function call completed: {name} (remaining: {self._pending_function_calls})" + ) + + async def on_context_updated_edge(args: Dict[str, Any], result_callback: Callable) -> None: + """Handle context updates for edge functions with transitions.""" + try: + decrease_pending_function_calls() + + # Only process transition if this was the last pending call + if self._pending_function_calls == 0: + if transition_to: # Static flow + logger.debug(f"Static transition to: {transition_to}") + await self.set_node(transition_to, self.nodes[transition_to]) + elif transition_callback: # Dynamic flow + logger.debug(f"Dynamic transition for: {name}") + await transition_callback(args, self) + # Reset counter after transition completes + self._pending_function_calls = 0 + logger.debug("Reset pending function calls counter") + else: + logger.debug( + f"Skipping transition, {self._pending_function_calls} calls still pending" + ) + except Exception as e: + logger.error(f"Error in transition: {str(e)}") + self._pending_function_calls = 0 + await result_callback( + {"status": "error", "error": str(e)}, + properties=None, # Clear properties to prevent further callbacks + ) + raise # Re-raise to prevent further processing + + async def on_context_updated_node() -> None: + """Handle context updates for node functions without transitions.""" + decrease_pending_function_calls() + async def transition_func( function_name: str, tool_call_id: str, @@ -268,6 +310,12 @@ async def transition_func( ) -> None: """Inner function that handles the actual tool invocation.""" try: + # Track pending function call + self._pending_function_calls += 1 + logger.debug( + f"Function call pending: {name} (total: {self._pending_function_calls})" + ) + # Execute handler if present if handler: result = await self._call_handler(handler, args) @@ -276,33 +324,23 @@ async def transition_func( result = {"status": "acknowledged"} logger.debug(f"Function called without handler: {name}") - if is_edge_function: - - async def on_context_updated() -> None: - try: - if transition_to: # Static flow - logger.debug(f"Static transition to: {transition_to}") - await self.set_node(transition_to, self.nodes[transition_to]) - elif transition_callback: # Dynamic flow - logger.debug(f"Dynamic transition for: {name}") - await transition_callback(args, self) - except Exception as e: - logger.error(f"Error in transition: {str(e)}") - await result_callback( - {"status": "error", "error": str(e)}, - properties=None, # Clear properties to prevent further callbacks - ) - raise # Re-raise to prevent further processing - - properties = FunctionCallResultProperties( - run_llm=False, on_context_updated=on_context_updated - ) - await result_callback(result, properties=properties) - else: - await result_callback(result) + # For edge functions, prevent LLM completion until transition (run_llm=False) + # For node functions, allow immediate completion (run_llm=True) + async def on_context_updated() -> None: + if is_edge_function: + await on_context_updated_edge(args, result_callback) + else: + await on_context_updated_node() + + properties = FunctionCallResultProperties( + run_llm=not is_edge_function, + on_context_updated=on_context_updated, + ) + await result_callback(result, properties=properties) except Exception as e: logger.error(f"Error in transition function {name}: {str(e)}") + self._pending_function_calls = 0 error_result = {"status": "error", "error": str(e)} await result_callback(error_result) diff --git a/tests/test_manager.py b/tests/test_manager.py index 8629d6d..b7668f4 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -846,7 +846,7 @@ async def test_handler(args): # Get the registered function and test it name, func = self.mock_llm.register_function.call_args[0] - async def callback(result): + async def callback(result, properties=None): self.assertEqual(result["status"], "success") self.assertEqual(result["args"], {"test": "value"})