From 6a6265e83213667dc809b259f407f2583ff17639 Mon Sep 17 00:00:00 2001 From: Dristy Srivastava <58721149+dristysrivastava@users.noreply.github.com> Date: Wed, 11 Sep 2024 04:27:39 +0530 Subject: [PATCH] Swagger changes in pebblo server APIs (#530) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Swagger changes in pebblo server APIs * Updating models --------- Co-authored-by: dristy.cd <“dristysrivastava91@gmail.com”> --- pebblo/app/api/api.py | 19 +-- pebblo/app/api/req_models.py | 94 ++++++++++++++ pebblo/app/service/discovery_service.py | 55 ++++---- tests/app/test_daemon.py | 55 ++++---- tests/app/test_prompt_api.py | 162 ++++++++---------------- 5 files changed, 221 insertions(+), 164 deletions(-) create mode 100644 pebblo/app/api/req_models.py diff --git a/pebblo/app/api/api.py b/pebblo/app/api/api.py index 2b1b3f77..ffd69cab 100644 --- a/pebblo/app/api/api.py +++ b/pebblo/app/api/api.py @@ -1,5 +1,6 @@ from fastapi import APIRouter, Depends +from pebblo.app.api.req_models import ReqDiscover, ReqLoaderDoc, ReqPrompt, ReqPromptGov from pebblo.app.config.config import var_server_config_dict from pebblo.app.service.prompt_gov import PromptGov from pebblo.app.utils.handler_mapper import get_handler @@ -17,34 +18,36 @@ def __init__(self, prefix: str): @staticmethod def discover( - data: dict, discover_obj=Depends(lambda: get_handler(handler_name="discover")) + data: ReqDiscover, + discover_obj=Depends(lambda: get_handler(handler_name="discover")), ): # "/app/discover" API entrypoint # Execute discover object based on a storage type - response = discover_obj.process_request(data) + response = discover_obj.process_request(data.model_dump()) return response @staticmethod def loader_doc( - data: dict, loader_doc_obj=Depends(lambda: get_handler(handler_name="loader")) + data: ReqLoaderDoc, + loader_doc_obj=Depends(lambda: get_handler(handler_name="loader")), ): # "/loader/doc" API entrypoint # Execute loader doc object based on a storage type - response = loader_doc_obj.process_request(data) + response = loader_doc_obj.process_request(data.model_dump()) return response @staticmethod def prompt( - data: dict, prompt_obj=Depends(lambda: get_handler(handler_name="prompt")) + data: ReqPrompt, prompt_obj=Depends(lambda: get_handler(handler_name="prompt")) ): # "/prompt" API entrypoint # Execute a prompt object based on a storage type - response = prompt_obj.process_request(data) + response = prompt_obj.process_request(data.model_dump()) return response @staticmethod - def promptgov(data: dict): + def promptgov(data: ReqPromptGov): # "/prompt/governance" API entrypoint - prompt_obj = PromptGov(data=data) + prompt_obj = PromptGov(data=data.model_dump()) response = prompt_obj.process_request() return response diff --git a/pebblo/app/api/req_models.py b/pebblo/app/api/req_models.py new file mode 100644 index 00000000..d7baba38 --- /dev/null +++ b/pebblo/app/api/req_models.py @@ -0,0 +1,94 @@ +"""API Request Model Class""" + +from typing import List, Optional, Union + +from pydantic import BaseModel + + +class Runtime(BaseModel): + type: str = "local" + host: str + path: str + ip: Optional[str] = None + platform: str + os: str + os_version: str + language: str + language_version: str + runtime: str = "local" + + +class Framework(BaseModel): + name: str + version: str + + +class VectorDB(BaseModel): + name: Optional[str] = None + version: Optional[str] = None + location: Optional[str] = None + embedding_model: Optional[str] = None + + +class Model(BaseModel): + vendor: Optional[str] = None + name: Optional[str] = None + + +class ChainInfo(BaseModel): + name: str + model: Optional[Model] = None + vector_dbs: Optional[List[VectorDB]] = None + + +class ReqDiscover(BaseModel): + name: str + owner: str + description: Optional[str] = None + load_id: Optional[str] = None + runtime: Runtime + framework: Framework + chains: Optional[List[ChainInfo]] = None + plugin_version: str + client_version: Framework + + +class ReqLoaderDoc(BaseModel): + name: str + owner: str + docs: list[dict] = None + plugin_version: str + load_id: str + loader_details: dict + loading_end: bool + source_owner: str + classifier_location: str + + +class Context(BaseModel): + retrieved_from: Optional[str] = None + doc: Optional[str] = None + vector_db: str + pb_checksum: Optional[str] = None + + +class Prompt(BaseModel): + data: Optional[Union[list, str]] = None + entityCount: Optional[int] = None + entities: Optional[dict] = None + prompt_gov_enabled: Optional[bool] = None + + +class ReqPrompt(BaseModel): + name: str + context: Optional[List[Context]] = None + prompt: Optional[Prompt] = None + response: Optional[Prompt] = None + prompt_time: str + user: str + user_identities: Optional[List[str]] = None + classifier_location: str + + +class ReqPromptGov(BaseModel): + prompt: str diff --git a/pebblo/app/service/discovery_service.py b/pebblo/app/service/discovery_service.py index 5f9022c6..7f66cda7 100644 --- a/pebblo/app/service/discovery_service.py +++ b/pebblo/app/service/discovery_service.py @@ -124,35 +124,36 @@ def _fetch_chain_details(self, app_metadata) -> list[Chain]: logger.debug(f"Existing Chains : {chains}") logger.debug(f"Input chains : {self.data.get('chains', [])}") - for chain in self.data.get("chains", []): - name = chain["name"] - model = chain["model"] - # vector db details - vector_db_details = [] - for vector_db in chain.get("vector_dbs", []): - vector_db_obj = VectorDB( - name=vector_db.get("name"), - version=vector_db.get("version"), - location=vector_db.get("location"), - embeddingModel=vector_db.get("embedding_model"), - pkgInfo=None, - ) - - package_info = vector_db.get("pkg_info") - if package_info: - pkg_info_obj = PackageInfo( - projectHomePage=package_info.get("project_home_page"), - documentationUrl=package_info.get("documentation_url"), - pypiUrl=package_info.get("pypi_url"), - licenceType=package_info.get("licence_type"), - installedVia=package_info.get("installed_via"), - location=package_info.get("location"), + if self.data.get("chains") not in [None, []]: + for chain in self.data.get("chains", []): + name = chain["name"] + model = chain["model"] + # vector db details + vector_db_details = [] + for vector_db in chain.get("vector_dbs", []): + vector_db_obj = VectorDB( + name=vector_db.get("name"), + version=vector_db.get("version"), + location=vector_db.get("location"), + embeddingModel=vector_db.get("embedding_model"), + pkgInfo=None, ) - vector_db_obj.pkgInfo = pkg_info_obj - vector_db_details.append(vector_db_obj) - chain_obj = Chain(name=name, model=model, vectorDbs=vector_db_details) - chains.append(chain_obj.model_dump()) + package_info = vector_db.get("pkg_info") + if package_info: + pkg_info_obj = PackageInfo( + projectHomePage=package_info.get("project_home_page"), + documentationUrl=package_info.get("documentation_url"), + pypiUrl=package_info.get("pypi_url"), + licenceType=package_info.get("licence_type"), + installedVia=package_info.get("installed_via"), + location=package_info.get("location"), + ) + vector_db_obj.pkgInfo = pkg_info_obj + + vector_db_details.append(vector_db_obj) + chain_obj = Chain(name=name, model=model, vectorDbs=vector_db_details) + chains.append(chain_obj.model_dump()) logger.debug(f"Application Name [{self.application_name}]: Chains: {chains}") return chains diff --git a/tests/app/test_daemon.py b/tests/app/test_daemon.py index 2c611905..4a4ffbdb 100644 --- a/tests/app/test_daemon.py +++ b/tests/app/test_daemon.py @@ -14,6 +14,28 @@ client = TestClient(app) +app_discover_payload = { + "name": "Test App", + "owner": "Test owner", + "description": "This is a test app.", + "runtime": { + "type": "desktop", + "host": "MacBook-Pro.local", + "path": "Test/Path", + "ip": "127.0.0.1", + "platform": "macOS-14.6.1-arm64-i386-64bit", + "os": "Darwin", + "os_version": "Darwin Kernel Version 23.6.0", + "language": "python", + "language_version": "3.11.9", + "runtime": "Mac OSX", + }, + "framework": {"name": "langchain", "version": "0.2.35"}, + "plugin_version": "0.1", + "client_version": {"name": "langchain_community", "version": "0.2.12"}, +} + + @pytest.fixture(scope="module") def mocked_objects(): with ( @@ -96,13 +118,7 @@ def test_app_discover_success(mock_write_json_to_file, mock_pebblo_server_versio Test the app discover endpoint. """ mock_write_json_to_file.return_value = None - app_payload = { - "name": "Test App", - "owner": "Test owner", - "description": "This is a test app.", - "plugin_version": "0.1", - } - response = client.post("/v1/app/discover", json=app_payload) + response = client.post("/v1/app/discover", json=app_discover_payload) # Assertions assert response.status_code == 200 @@ -115,14 +131,14 @@ def test_app_discover_validation_errors(mock_write_json_to_file): Test the app discover endpoint with validation errors. """ mock_write_json_to_file.return_value = None - app = { - "owner": "Test owner", - "description": "This is a test app.", - "plugin_version": "0.1", - } - response = client.post("/v1/app/discover", json=app) - assert response.status_code == 400 - assert "1 validation error for AiApp" in response.json()["message"] + app_payload = app_discover_payload.copy() + app_payload.pop("name") + + response = client.post("/v1/app/discover", json=app_payload) + assert response.status_code == 422 + assert "'type': 'missing', 'loc': ['body', 'name'], 'msg': 'Field required'" in str( + response.json()["detail"] + ) def test_app_discover_server_error(mock_write_json_to_file): @@ -130,13 +146,7 @@ def test_app_discover_server_error(mock_write_json_to_file): Test the app discover endpoint with server error. """ mock_write_json_to_file.side_effect = Exception("Mocked exception") - app_payload = { - "name": "Test App", - "owner": "Test owner", - "description": "This is a test app.", - "plugin_version": "0.1", - } - response = client.post("/v1/app/discover", json=app_payload) + response = client.post("/v1/app/discover", json=app_discover_payload) # Assertions assert response.status_code == 500 @@ -186,6 +196,7 @@ def test_loader_doc_success( "source_aggr_size": 306, }, "plugin_version": "0.1.0", + "classifier_location": "local", } response = client.post("/v1/loader/doc", json=loader_doc) assert response.status_code == 200 diff --git a/tests/app/test_prompt_api.py b/tests/app/test_prompt_api.py index 90aae200..c9756f98 100644 --- a/tests/app/test_prompt_api.py +++ b/tests/app/test_prompt_api.py @@ -10,6 +10,33 @@ app.include_router(router_instance.router) client = TestClient(app) +test_payload = { + "name": "Test App", + "context": [ + { + "retrieved_from": "test_data.pdf", + "doc": "Patient SSN: 222-85-4836", + "vector_db": "TestDB", + }, + { + "retrieved_from": "test_data1.pdf", + "doc": "Patient SSN: 222-85-4836", + "vector_db": "TestDB", + }, + ], + "prompt": { + "data": "What is John's SSN", + "entities": {}, + "topics": {}, + "entityCount": 0, + "prompt_gov_enabled": True, + }, + "response": {"data": "Patient SSN is 222-85-4836"}, + "prompt_time": "2024-04-17T15:03:18.177368", + "user": "Test Owner", + "user_identities": ["test_group@test.com"], + "classifier_location": "local", +} @pytest.fixture @@ -25,32 +52,6 @@ def test_app_prompt_success(mock_write_json_to_file): Test the app prompt endpoint. """ mock_write_json_to_file.return_value = None - test_payload = { - "name": "Test App", - "context": [ - { - "retrieved_from": "test_data.pdf", - "doc": "Patient SSN: 222-85-4836", - "vector_db": "TestDB", - }, - { - "retrieved_from": "test_data1.pdf", - "doc": "Patient SSN: 222-85-4836", - "vector_db": "TestDB", - }, - ], - "prompt": { - "data": "What is John's SSN", - "entities": {}, - "topics": {}, - "entityCount": 0, - "prompt_gov_enabled": True, - }, - "response": {"data": "Patient SSN is 222-85-4836"}, - "prompt_time": "2024-04-17T15:03:18.177368", - "user": "Test Owner", - "user_identities": ["test_group@test.com"], - } response = client.post("/v1/prompt", json=test_payload) assert response.status_code == 200 @@ -71,35 +72,28 @@ def test_app_prompt_validation_errors(mock_write_json_to_file): Test the app prompt endpoint with validation errors. """ mock_write_json_to_file.return_value = None - test_payload = { - "name": "Test App", - "context": [ - { - "retrieved_from": "test_data.pdf", - "doc": "This is test doc.", - }, - { - "retrieved_from": "test_data1.pdf", - "vector_db": "TestDB", - }, - ], - "prompt": { - "data": "What is Sachin's Passport ID?", - "entities": {}, - "entityCount": 0, - "prompt_gov_enabled": True, + test_error_payload = test_payload.copy() + test_error_payload["context"] = [ + { + "retrieved_from": "test_data.pdf", + "doc": "This is test doc.", }, - "response": {"data": "His passport ID is 5484880UA."}, - "user_identities": ["test_group@test.com"], - } - response = client.post("/v1/prompt", json=test_payload) - assert response.status_code == 400 - assert response.json()["message"] == ( - "1 validation error for RetrievalContext\n" - "vector_db\n" - " Input should be a valid string [type=string_type, input_value=None, input_type=NoneType]\n" - " For further information visit https://errors.pydantic.dev/2.8/v/string_type" - ) + { + "retrieved_from": "test_data1.pdf", + "vector_db": "TestDB", + }, + ] + response = client.post("/v1/prompt", json=test_error_payload) + assert response.status_code == 422 + assert response.json()["detail"] == [ + { + "type": "missing", + "loc": ["body", "context", 0, "vector_db"], + "msg": "Field required", + "input": {"retrieved_from": "test_data.pdf", "doc": "This is test doc."}, + "url": "https://errors.pydantic.dev/2.8/v/missing", + } + ] def test_app_prompt_validation_errors_single_missing_field(mock_write_json_to_file): @@ -107,35 +101,14 @@ def test_app_prompt_validation_errors_single_missing_field(mock_write_json_to_fi Test the app prompt endpoint with validation errors. """ mock_write_json_to_file.return_value = None - test_payload = { - "name": "Test App", - "context": [ - { - "retrieved_from": "test_data.pdf", - "doc": "This is test doc.", - "vector_db": "TestDB", - }, - { - "retrieved_from": "test_data1.pdf", - "doc": "This is test1 doc.", - "vector_db": "TestDB", - }, - ], - "prompt": { - "data": "What is Sachin's Passport ID?", - "entities": {}, - "entityCount": 0, - "prompt_gov_enabled": True, - }, - "response": {"data": "His passport ID is 5484880UA."}, - "user": "Test Owner", - "user_identities": ["test_group@test.com"], - } - response = client.post("/v1/prompt", json=test_payload) - assert response.status_code == 400 + test_error_payload = test_payload.copy() + test_error_payload.pop("prompt_time") + response = client.post("/v1/prompt", json=test_error_payload) + assert response.status_code == 422 assert ( - "1 validation error for RetrievalData\n" "prompt_time\n" - ) in response.json()["message"] + "'type': 'missing', 'loc': ['body', 'prompt_time'], 'msg': 'Field required'" + in str(response.json()["detail"]) + ) def test_app_prompt_server_error(mock_write_json_to_file): @@ -143,31 +116,6 @@ def test_app_prompt_server_error(mock_write_json_to_file): Test the app prompt endpoint with server error. """ mock_write_json_to_file.side_effect = Exception("Mocked exception") - test_payload = { - "name": "Test App", - "context": [ - { - "retrieved_from": "test_data.pdf", - "doc": "This is test doc.", - "vector_db": "TestDB", - }, - { - "retrieved_from": "test_data1.pdf", - "doc": "This is test1 doc.", - "vector_db": "TestDB", - }, - ], - "prompt": { - "data": "What is Sachin's Passport ID?", - "entities": {}, - "entityCount": 0, - "prompt_gov_enabled": True, - }, - "response": {"data": "His passport ID is 5484880UA."}, - "prompt_time": "2024-04-17T15:03:18.177368", - "user": "Test Owner", - "user_identities": ["test_group@test.com"], - } response = client.post("/v1/prompt", json=test_payload) # Assertions