Skip to content

Commit

Permalink
Improve OpenAPITool corner cases handling (missing operationId, serve…
Browse files Browse the repository at this point in the history
…rs under paths, etc) (#37)

* Improve corner cases handling

* Add unit tests for server order resolution and missing operationId
  • Loading branch information
vblagoje authored Jul 18, 2024
1 parent 6c62db7 commit e5841a7
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from haystack.lazy_imports import LazyImport

from haystack_experimental.components.tools.openapi.types import OpenAPISpecification
from haystack_experimental.components.tools.openapi.types import (
VALID_HTTP_METHODS,
OpenAPISpecification,
path_to_operation_id,
)

with LazyImport("Run 'pip install jsonref'") as jsonref_import:
# pylint: disable=import-error
Expand Down Expand Up @@ -96,11 +100,14 @@ def _openapi_to_functions(
f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
)
functions: List[Dict[str, Any]] = []
for paths in service_openapi_spec["paths"].values():
for path_spec in paths.values():
function_dict = parse_endpoint_fn(path_spec, parameters_name)
if function_dict:
functions.append(function_dict)
for path, path_value in service_openapi_spec["paths"].items():
for path_key, operation_spec in path_value.items():
if path_key.lower() in VALID_HTTP_METHODS:
if "operationId" not in operation_spec:
operation_spec["operationId"] = path_to_operation_id(path, path_key)
function_dict = parse_endpoint_fn(operation_spec, parameters_name)
if function_dict:
functions.append(function_dict)
return functions


Expand Down
107 changes: 50 additions & 57 deletions haystack_experimental/components/tools/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,24 @@
]


def path_to_operation_id(path: str, http_method: str = "get") -> str:
"""
Converts a path to an operationId.
:param path: The path to convert.
:param http_method: The HTTP method to use for the operationId.
:returns: The operationId.
"""
if http_method.lower() not in VALID_HTTP_METHODS:
raise ValueError(f"Invalid HTTP method: {http_method}")
return path.replace("/", "_").lstrip("_").rstrip("_") + "_" + http_method.lower()


class LLMProvider(Enum):
"""
LLM providers supported by `OpenAPITool`.
"""

OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"
Expand All @@ -50,18 +64,18 @@ def from_str(string: str) -> "LLMProvider":
@dataclass
class Operation:
"""
Represents an operation in an OpenAPI specification
See https://spec.openapis.org/oas/latest.html#paths-object for details.
Path objects can contain multiple operations, each with a unique combination of path and method.
:param path: Path of the operation.
:param method: HTTP method of the operation.
:param operation_dict: Operation details from OpenAPI spec
:param spec_dict: The encompassing OpenAPI specification.
:param security_requirements: A list of security requirements for the operation.
:param request_body: Request body details.
:param parameters: Parameters for the operation.
Represents an operation in an OpenAPI specification
See https://spec.openapis.org/oas/latest.html#paths-object for details.
Path objects can contain multiple operations, each with a unique combination of path and method.
:param path: Path of the operation.
:param method: HTTP method of the operation.
:param operation_dict: Operation details from OpenAPI spec
:param spec_dict: The encompassing OpenAPI specification.
:param security_requirements: A list of security requirements for the operation.
:param request_body: Request body details.
:param parameters: Parameters for the operation.
"""

path: str
Expand Down Expand Up @@ -105,8 +119,12 @@ def get_server(self, server_index: int = 0) -> str:
:returns: The server URL.
:raises ValueError: If no servers are found in the specification.
"""
servers = self.operation_dict.get("servers", []) or self.spec_dict.get(
"servers", []
# servers can be defined at the operation level, path level, or at the root level
# search for servers in the following order: operation, path, root
servers = (
self.operation_dict.get("servers", [])
or self.spec_dict.get("paths", {}).get(self.path, {}).get("servers", [])
or self.spec_dict.get("servers", [])
)
if not servers:
raise ValueError("No servers found in the provided specification.")
Expand Down Expand Up @@ -136,11 +154,7 @@ def __init__(self, spec_dict: Dict[str, Any]):
f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}"
)
# just a crude sanity check, by no means a full validation
if (
"openapi" not in spec_dict
or "paths" not in spec_dict
or "servers" not in spec_dict
):
if "openapi" not in spec_dict or "paths" not in spec_dict:
raise ValueError(
"Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.",
spec_dict,
Expand Down Expand Up @@ -201,51 +215,30 @@ def from_url(cls, url: str) -> "OpenAPISpecification":
) from e
return cls.from_str(content)

def find_operation_by_id(
self, op_id: str, method: Optional[str] = None
) -> Operation:
def find_operation_by_id(self, op_id: str) -> Operation:
"""
Find an Operation by operationId.
:param op_id: The operationId of the operation.
:param method: The HTTP method of the operation.
:returns: The matching operation
:raises ValueError: If no operation is found with the given operationId.
"""
for path, path_item in self.spec_dict.get("paths", {}).items():
op: Operation = self.get_operation_item(path, path_item, method)
if op_id in op.operation_dict.get("operationId", ""):
return self.get_operation_item(path, path_item, method)
raise ValueError(
f"No operation found with operationId {op_id}, method {method}"
)

def get_operation_item(
self, path: str, path_item: Dict[str, Any], method: Optional[str] = None
) -> Operation:
"""
Gets a particular Operation item from the OpenAPI specification given the path and method.
:param path: The path of the operation.
:param path_item: The path item from the OpenAPI specification.
:param method: The HTTP method of the operation.
:returns: The operation
"""
if method:
operation_dict = path_item.get(method.lower(), {})
if not operation_dict:
raise ValueError(
f"No operation found for method {method} at path {path}"
)
return Operation(path, method.lower(), operation_dict, self.spec_dict)
if len(path_item) == 1:
method, operation_dict = next(iter(path_item.items()))
return Operation(path, method, operation_dict, self.spec_dict)
if len(path_item) > 1:
raise ValueError(
f"Multiple operations found at path {path}, method parameter is required."
)
raise ValueError(f"No operations found at path {path} and method {method}")
for path, path_value in self.spec_dict.get("paths", {}).items():
operations = {
method: operation_dict
for method, operation_dict in path_value.items()
if method.lower() in VALID_HTTP_METHODS
}

for method, operation_dict in operations.items():
if (
operation_dict.get(
"operationId", path_to_operation_id(path, method)
)
== op_id
):
return Operation(path, method, operation_dict, self.spec_dict)
raise ValueError(f"No operation found with operationId {op_id}")

def get_security_schemes(self) -> Dict[str, Dict[str, Any]]:
"""
Expand Down
29 changes: 27 additions & 2 deletions test/components/tools/openapi/test_openapi_client_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration
from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient
from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec


Expand All @@ -26,4 +26,29 @@ def test_missing_operation_id(self, test_files_path):
with pytest.raises(ValueError, match="No operation found with operationId"):
client.invoke(payload)

# TODO: Add more tests for edge cases
def test_missing_operation_id_in_operation(self, test_files_path):
"""
Test that the tool definition is generated correctly when the operationId is missing in the specification.
"""
config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"),
request_sender=FastAPITestClient(None))

tools = config.get_tools_definitions(),
tool_def = tools[0][0]
assert tool_def["type"] == "function"
assert tool_def["function"]["name"] == "missing-operation-id_get"

def test_servers_order(self, test_files_path):
"""
Test that servers defined in different locations in the specification are used correctly.
"""

config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"),
request_sender=FastAPITestClient(None))

op = config.openapi_spec.find_operation_by_id("servers-order-path")
assert op.get_server() == "https://inpath.example.com"
op = config.openapi_spec.find_operation_by_id("servers-order-operation")
assert op.get_server() == "https://inoperation.example.com"
op = config.openapi_spec.find_operation_by_id("missing-operation-id_get")
assert op.get_server() == "http://localhost"
21 changes: 21 additions & 0 deletions test/components/tools/openapi/test_openapi_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,24 @@ def test_run_live_cohere(self):
assert isinstance(json_response, dict)
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")

@pytest.mark.integration
@pytest.mark.parametrize("provider", ["openai", "anthropic", "cohere"])
def test_run_live_meteo_forecast(self, provider: str):
tool = OpenAPITool(
generator_api=LLMProvider.from_str(provider),
spec="https://raw.githubusercontent.com/open-meteo/open-meteo/main/openapi.yml"
)
results = tool.run(messages=[ChatMessage.from_user(
"weather forecast for latitude 52.52 and longitude 13.41 and set hourly=temperature_2m")])

assert isinstance(results["service_response"], list)
assert len(results["service_response"]) == 1
assert isinstance(results["service_response"][0], ChatMessage)

try:
json_response = json.loads(results["service_response"][0].content)
assert isinstance(json_response, dict)
assert "hourly" in json_response
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")
38 changes: 38 additions & 0 deletions test/test_files/yaml/openapi_edge_cases.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@ paths:
/missing-operation-id:
get:
summary: Missing operationId
parameters:
- name: name
in: path
required: true
schema:
type: string
responses:
'200':
description: OK

/servers-order-in-path:
servers:
- url: https://inpath.example.com
get:
summary: Servers order
operationId: servers-order-path
parameters:
- name: name
in: path
required: true
schema:
type: string
responses:
'200':
description: OK

/servers-order-in-operation:
get:
summary: Servers order
operationId: servers-order-operation
parameters:
- name: name
in: path
required: true
schema:
type: string
responses:
'200':
description: OK
servers:
- url: https://inoperation.example.com

0 comments on commit e5841a7

Please sign in to comment.