Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds CachedToolCallingAgent to allow for caching of tools calls #558

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent", "CachedToolCallingAgent"]

import importlib.resources
import inspect
import json
import re
import textwrap
import time
Expand Down Expand Up @@ -837,6 +838,80 @@ def step(self, memory_step: ActionStep) -> Union[None, Any]:
return None


class ToolCallCacheMixin:
"""Mixin class that adds caching capabilities to tool calls."""

def __init__(self, cache_size: int = 128, cached_tools: Optional[List[str]] = None, *args, **kwargs):
"""
Initialize the caching mixin.
Args:
cache_size: Maximum number of entries to keep in cache
cached_tools: List of tool names to cache. If None, all tools are cached.
"""
self.tool_cache: Dict[str, Any] = {}
self.cache_size = cache_size
self.cached_tools = set(cached_tools) if cached_tools is not None else None
super().__init__(*args, **kwargs)

def _get_cache_key(self, tool_name: str, tool_arguments: Dict) -> str:
"""Create a unique cache key from tool name and arguments."""
sorted_args = json.dumps(tool_arguments, sort_keys=True)
return f"{tool_name}:{sorted_args}"

def _get_from_cache(self, tool_name: str, tool_arguments: Dict) -> Tuple[bool, Any]:
"""Try to get a result from cache. Returns (hit, result)."""
cache_key = self._get_cache_key(tool_name, tool_arguments)
if cache_key in self.tool_cache:
return True, self.tool_cache[cache_key]
return False, None

def _add_to_cache(self, tool_name: str, tool_arguments: Dict, result: Any):
"""Add a result to the cache."""
if len(self.tool_cache) >= self.cache_size:
# Remove oldest entry if cache is full (FIFO)
self.tool_cache.pop(next(iter(self.tool_cache)))
cache_key = self._get_cache_key(tool_name, tool_arguments)
self.tool_cache[cache_key] = result

def execute_tool_call(self, tool_name: str, tool_arguments: Dict) -> Any:
"""Override of execute_tool_call that implements caching."""
# Skip caching if tool is not in cached_tools (when specified)
if self.cached_tools is not None and tool_name not in self.cached_tools:
return super().execute_tool_call(tool_name, tool_arguments)

cache_hit, cached_result = self._get_from_cache(tool_name, tool_arguments)
if cache_hit:
self.logger.log(f"Retrieved result from cache for tool '{tool_name}'", level=LogLevel.INFO)
return cached_result

result = super().execute_tool_call(tool_name, tool_arguments)
self._add_to_cache(tool_name, tool_arguments, result)
return result


class CachedToolCallingAgent(ToolCallCacheMixin, ToolCallingAgent):
"""
A version of ToolCallingAgent that includes caching functionality.
Usage:
# Cache all tools
agent = CachedToolCallingAgent(
tools=tools,
model=model,
cache_size=128 # optional, defaults to 128
)
# Cache only specific tools
agent = CachedToolCallingAgent(
tools=tools,
model=model,
cached_tools=['search', 'fetch_data'] # only cache these tools
)
"""
pass


class CodeAgent(MultiStepAgent):
"""
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
Expand Down