Skip to content

Commit

Permalink
Merge pull request #87 from pipecat-ai/mb/fix-consecutive-function-calls
Browse files Browse the repository at this point in the history
Fix an issue where multiple, consecutive function calls result in repeat completions
  • Loading branch information
markbackman authored Jan 29, 2025
2 parents 036ec7e + 7123cf8 commit e82dafc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 25 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def create_node() -> NodeConfig:

- Updated dynamic flow examples to use the new `transition_callback` pattern.

### Fixed

- Fixed an issue where multiple, consecutive function calls could result in two completions.

## [0.0.11] - 2025-01-19

### Changed
Expand Down
86 changes: 62 additions & 24 deletions src/pipecat_flows/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
self.adapter = create_adapter(llm)
self.initialized = False
self._context_aggregator = context_aggregator
self._pending_function_calls = 0

# Set up static or dynamic mode
if flow_config:
Expand Down Expand Up @@ -258,6 +259,47 @@ async def _create_transition_func(

is_edge_function = bool(transition_to) or bool(transition_callback)

def decrease_pending_function_calls() -> None:
"""Decrease the pending function calls counter if greater than zero."""
if self._pending_function_calls > 0:
self._pending_function_calls -= 1
logger.debug(
f"Function call completed: {name} (remaining: {self._pending_function_calls})"
)

async def on_context_updated_edge(args: Dict[str, Any], result_callback: Callable) -> None:
"""Handle context updates for edge functions with transitions."""
try:
decrease_pending_function_calls()

# Only process transition if this was the last pending call
if self._pending_function_calls == 0:
if transition_to: # Static flow
logger.debug(f"Static transition to: {transition_to}")
await self.set_node(transition_to, self.nodes[transition_to])
elif transition_callback: # Dynamic flow
logger.debug(f"Dynamic transition for: {name}")
await transition_callback(args, self)
# Reset counter after transition completes
self._pending_function_calls = 0
logger.debug("Reset pending function calls counter")
else:
logger.debug(
f"Skipping transition, {self._pending_function_calls} calls still pending"
)
except Exception as e:
logger.error(f"Error in transition: {str(e)}")
self._pending_function_calls = 0
await result_callback(
{"status": "error", "error": str(e)},
properties=None, # Clear properties to prevent further callbacks
)
raise # Re-raise to prevent further processing

async def on_context_updated_node() -> None:
"""Handle context updates for node functions without transitions."""
decrease_pending_function_calls()

async def transition_func(
function_name: str,
tool_call_id: str,
Expand All @@ -268,6 +310,12 @@ async def transition_func(
) -> None:
"""Inner function that handles the actual tool invocation."""
try:
# Track pending function call
self._pending_function_calls += 1
logger.debug(
f"Function call pending: {name} (total: {self._pending_function_calls})"
)

# Execute handler if present
if handler:
result = await self._call_handler(handler, args)
Expand All @@ -276,33 +324,23 @@ async def transition_func(
result = {"status": "acknowledged"}
logger.debug(f"Function called without handler: {name}")

if is_edge_function:

async def on_context_updated() -> None:
try:
if transition_to: # Static flow
logger.debug(f"Static transition to: {transition_to}")
await self.set_node(transition_to, self.nodes[transition_to])
elif transition_callback: # Dynamic flow
logger.debug(f"Dynamic transition for: {name}")
await transition_callback(args, self)
except Exception as e:
logger.error(f"Error in transition: {str(e)}")
await result_callback(
{"status": "error", "error": str(e)},
properties=None, # Clear properties to prevent further callbacks
)
raise # Re-raise to prevent further processing

properties = FunctionCallResultProperties(
run_llm=False, on_context_updated=on_context_updated
)
await result_callback(result, properties=properties)
else:
await result_callback(result)
# For edge functions, prevent LLM completion until transition (run_llm=False)
# For node functions, allow immediate completion (run_llm=True)
async def on_context_updated() -> None:
if is_edge_function:
await on_context_updated_edge(args, result_callback)
else:
await on_context_updated_node()

properties = FunctionCallResultProperties(
run_llm=not is_edge_function,
on_context_updated=on_context_updated,
)
await result_callback(result, properties=properties)

except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
self._pending_function_calls = 0
error_result = {"status": "error", "error": str(e)}
await result_callback(error_result)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ async def test_handler(args):
# Get the registered function and test it
name, func = self.mock_llm.register_function.call_args[0]

async def callback(result):
async def callback(result, properties=None):
self.assertEqual(result["status"], "success")
self.assertEqual(result["args"], {"test": "value"})

Expand Down

0 comments on commit e82dafc

Please sign in to comment.