From e5841a71ef364eddeeedb0f3e0fd13ca4a454d2a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 18 Jul 2024 17:57:30 +0200 Subject: [PATCH] Improve OpenAPITool corner cases handling (missing operationId, servers under paths, etc) (#37) * Improve corner cases handling * Add unit tests for server order resolution and missing operationId --- .../tools/openapi/_schema_conversion.py | 19 +++- .../components/tools/openapi/types.py | 107 ++++++++---------- .../openapi/test_openapi_client_edge_cases.py | 29 ++++- .../tools/openapi/test_openapi_tool.py | 21 ++++ test/test_files/yaml/openapi_edge_cases.yml | 38 +++++++ 5 files changed, 149 insertions(+), 65 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/_schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py index 1d0766a5..4a986752 100644 --- a/haystack_experimental/components/tools/openapi/_schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -7,7 +7,11 @@ from haystack.lazy_imports import LazyImport -from haystack_experimental.components.tools.openapi.types import OpenAPISpecification +from haystack_experimental.components.tools.openapi.types import ( + VALID_HTTP_METHODS, + OpenAPISpecification, + path_to_operation_id, +) with LazyImport("Run 'pip install jsonref'") as jsonref_import: # pylint: disable=import-error @@ -96,11 +100,14 @@ def _openapi_to_functions( f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}." ) functions: List[Dict[str, Any]] = [] - for paths in service_openapi_spec["paths"].values(): - for path_spec in paths.values(): - function_dict = parse_endpoint_fn(path_spec, parameters_name) - if function_dict: - functions.append(function_dict) + for path, path_value in service_openapi_spec["paths"].items(): + for path_key, operation_spec in path_value.items(): + if path_key.lower() in VALID_HTTP_METHODS: + if "operationId" not in operation_spec: + operation_spec["operationId"] = path_to_operation_id(path, path_key) + function_dict = parse_endpoint_fn(operation_spec, parameters_name) + if function_dict: + functions.append(function_dict) return functions diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index 2562daa2..77e22421 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -23,10 +23,24 @@ ] +def path_to_operation_id(path: str, http_method: str = "get") -> str: + """ + Converts a path to an operationId. + + :param path: The path to convert. + :param http_method: The HTTP method to use for the operationId. + :returns: The operationId. + """ + if http_method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {http_method}") + return path.replace("/", "_").lstrip("_").rstrip("_") + "_" + http_method.lower() + + class LLMProvider(Enum): """ LLM providers supported by `OpenAPITool`. """ + OPENAI = "openai" ANTHROPIC = "anthropic" COHERE = "cohere" @@ -50,18 +64,18 @@ def from_str(string: str) -> "LLMProvider": @dataclass class Operation: """ - Represents an operation in an OpenAPI specification - - See https://spec.openapis.org/oas/latest.html#paths-object for details. - Path objects can contain multiple operations, each with a unique combination of path and method. - - :param path: Path of the operation. - :param method: HTTP method of the operation. - :param operation_dict: Operation details from OpenAPI spec - :param spec_dict: The encompassing OpenAPI specification. - :param security_requirements: A list of security requirements for the operation. - :param request_body: Request body details. - :param parameters: Parameters for the operation. + Represents an operation in an OpenAPI specification + + See https://spec.openapis.org/oas/latest.html#paths-object for details. + Path objects can contain multiple operations, each with a unique combination of path and method. + + :param path: Path of the operation. + :param method: HTTP method of the operation. + :param operation_dict: Operation details from OpenAPI spec + :param spec_dict: The encompassing OpenAPI specification. + :param security_requirements: A list of security requirements for the operation. + :param request_body: Request body details. + :param parameters: Parameters for the operation. """ path: str @@ -105,8 +119,12 @@ def get_server(self, server_index: int = 0) -> str: :returns: The server URL. :raises ValueError: If no servers are found in the specification. """ - servers = self.operation_dict.get("servers", []) or self.spec_dict.get( - "servers", [] + # servers can be defined at the operation level, path level, or at the root level + # search for servers in the following order: operation, path, root + servers = ( + self.operation_dict.get("servers", []) + or self.spec_dict.get("paths", {}).get(self.path, {}).get("servers", []) + or self.spec_dict.get("servers", []) ) if not servers: raise ValueError("No servers found in the provided specification.") @@ -136,11 +154,7 @@ def __init__(self, spec_dict: Dict[str, Any]): f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}" ) # just a crude sanity check, by no means a full validation - if ( - "openapi" not in spec_dict - or "paths" not in spec_dict - or "servers" not in spec_dict - ): + if "openapi" not in spec_dict or "paths" not in spec_dict: raise ValueError( "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", spec_dict, @@ -201,51 +215,30 @@ def from_url(cls, url: str) -> "OpenAPISpecification": ) from e return cls.from_str(content) - def find_operation_by_id( - self, op_id: str, method: Optional[str] = None - ) -> Operation: + def find_operation_by_id(self, op_id: str) -> Operation: """ Find an Operation by operationId. :param op_id: The operationId of the operation. - :param method: The HTTP method of the operation. :returns: The matching operation :raises ValueError: If no operation is found with the given operationId. """ - for path, path_item in self.spec_dict.get("paths", {}).items(): - op: Operation = self.get_operation_item(path, path_item, method) - if op_id in op.operation_dict.get("operationId", ""): - return self.get_operation_item(path, path_item, method) - raise ValueError( - f"No operation found with operationId {op_id}, method {method}" - ) - - def get_operation_item( - self, path: str, path_item: Dict[str, Any], method: Optional[str] = None - ) -> Operation: - """ - Gets a particular Operation item from the OpenAPI specification given the path and method. - - :param path: The path of the operation. - :param path_item: The path item from the OpenAPI specification. - :param method: The HTTP method of the operation. - :returns: The operation - """ - if method: - operation_dict = path_item.get(method.lower(), {}) - if not operation_dict: - raise ValueError( - f"No operation found for method {method} at path {path}" - ) - return Operation(path, method.lower(), operation_dict, self.spec_dict) - if len(path_item) == 1: - method, operation_dict = next(iter(path_item.items())) - return Operation(path, method, operation_dict, self.spec_dict) - if len(path_item) > 1: - raise ValueError( - f"Multiple operations found at path {path}, method parameter is required." - ) - raise ValueError(f"No operations found at path {path} and method {method}") + for path, path_value in self.spec_dict.get("paths", {}).items(): + operations = { + method: operation_dict + for method, operation_dict in path_value.items() + if method.lower() in VALID_HTTP_METHODS + } + + for method, operation_dict in operations.items(): + if ( + operation_dict.get( + "operationId", path_to_operation_id(path, method) + ) + == op_id + ): + return Operation(path, method, operation_dict, self.spec_dict) + raise ValueError(f"No operation found with operationId {op_id}") def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: """ diff --git a/test/components/tools/openapi/test_openapi_client_edge_cases.py b/test/components/tools/openapi/test_openapi_client_edge_cases.py index f6272baa..d8ced45d 100644 --- a/test/components/tools/openapi/test_openapi_client_edge_cases.py +++ b/test/components/tools/openapi/test_openapi_client_edge_cases.py @@ -5,7 +5,7 @@ import pytest -from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec @@ -26,4 +26,29 @@ def test_missing_operation_id(self, test_files_path): with pytest.raises(ValueError, match="No operation found with operationId"): client.invoke(payload) - # TODO: Add more tests for edge cases + def test_missing_operation_id_in_operation(self, test_files_path): + """ + Test that the tool definition is generated correctly when the operationId is missing in the specification. + """ + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"), + request_sender=FastAPITestClient(None)) + + tools = config.get_tools_definitions(), + tool_def = tools[0][0] + assert tool_def["type"] == "function" + assert tool_def["function"]["name"] == "missing-operation-id_get" + + def test_servers_order(self, test_files_path): + """ + Test that servers defined in different locations in the specification are used correctly. + """ + + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"), + request_sender=FastAPITestClient(None)) + + op = config.openapi_spec.find_operation_by_id("servers-order-path") + assert op.get_server() == "https://inpath.example.com" + op = config.openapi_spec.find_operation_by_id("servers-order-operation") + assert op.get_server() == "https://inoperation.example.com" + op = config.openapi_spec.find_operation_by_id("missing-operation-id_get") + assert op.get_server() == "http://localhost" diff --git a/test/components/tools/openapi/test_openapi_tool.py b/test/components/tools/openapi/test_openapi_tool.py index 5119ba61..b8274199 100644 --- a/test/components/tools/openapi/test_openapi_tool.py +++ b/test/components/tools/openapi/test_openapi_tool.py @@ -201,3 +201,24 @@ def test_run_live_cohere(self): assert isinstance(json_response, dict) except json.JSONDecodeError: pytest.fail("Response content is not valid JSON") + + @pytest.mark.integration + @pytest.mark.parametrize("provider", ["openai", "anthropic", "cohere"]) + def test_run_live_meteo_forecast(self, provider: str): + tool = OpenAPITool( + generator_api=LLMProvider.from_str(provider), + spec="https://raw.githubusercontent.com/open-meteo/open-meteo/main/openapi.yml" + ) + results = tool.run(messages=[ChatMessage.from_user( + "weather forecast for latitude 52.52 and longitude 13.41 and set hourly=temperature_2m")]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + assert "hourly" in json_response + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") diff --git a/test/test_files/yaml/openapi_edge_cases.yml b/test/test_files/yaml/openapi_edge_cases.yml index cef304c5..e49e43b1 100644 --- a/test/test_files/yaml/openapi_edge_cases.yml +++ b/test/test_files/yaml/openapi_edge_cases.yml @@ -8,6 +8,44 @@ paths: /missing-operation-id: get: summary: Missing operationId + parameters: + - name: name + in: path + required: true + schema: + type: string responses: '200': description: OK + + /servers-order-in-path: + servers: + - url: https://inpath.example.com + get: + summary: Servers order + operationId: servers-order-path + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: OK + + /servers-order-in-operation: + get: + summary: Servers order + operationId: servers-order-operation + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: OK + servers: + - url: https://inoperation.example.com