Skip to content

Commit

Permalink
Swagger changes in pebblo server APIs (daxa-ai#530)
Browse files Browse the repository at this point in the history
* Swagger changes in pebblo server APIs

* Updating models

---------

Co-authored-by: dristy.cd <“[email protected]”>
  • Loading branch information
dristysrivastava and dristy.cd authored Sep 10, 2024
1 parent 7beb672 commit 6a6265e
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 164 deletions.
19 changes: 11 additions & 8 deletions pebblo/app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter, Depends

from pebblo.app.api.req_models import ReqDiscover, ReqLoaderDoc, ReqPrompt, ReqPromptGov
from pebblo.app.config.config import var_server_config_dict
from pebblo.app.service.prompt_gov import PromptGov
from pebblo.app.utils.handler_mapper import get_handler
Expand All @@ -17,34 +18,36 @@ def __init__(self, prefix: str):

@staticmethod
def discover(
data: dict, discover_obj=Depends(lambda: get_handler(handler_name="discover"))
data: ReqDiscover,
discover_obj=Depends(lambda: get_handler(handler_name="discover")),
):
# "/app/discover" API entrypoint
# Execute discover object based on a storage type
response = discover_obj.process_request(data)
response = discover_obj.process_request(data.model_dump())
return response

@staticmethod
def loader_doc(
data: dict, loader_doc_obj=Depends(lambda: get_handler(handler_name="loader"))
data: ReqLoaderDoc,
loader_doc_obj=Depends(lambda: get_handler(handler_name="loader")),
):
# "/loader/doc" API entrypoint
# Execute loader doc object based on a storage type
response = loader_doc_obj.process_request(data)
response = loader_doc_obj.process_request(data.model_dump())
return response

@staticmethod
def prompt(
data: dict, prompt_obj=Depends(lambda: get_handler(handler_name="prompt"))
data: ReqPrompt, prompt_obj=Depends(lambda: get_handler(handler_name="prompt"))
):
# "/prompt" API entrypoint
# Execute a prompt object based on a storage type
response = prompt_obj.process_request(data)
response = prompt_obj.process_request(data.model_dump())
return response

@staticmethod
def promptgov(data: dict):
def promptgov(data: ReqPromptGov):
# "/prompt/governance" API entrypoint
prompt_obj = PromptGov(data=data)
prompt_obj = PromptGov(data=data.model_dump())
response = prompt_obj.process_request()
return response
94 changes: 94 additions & 0 deletions pebblo/app/api/req_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""API Request Model Class"""

from typing import List, Optional, Union

from pydantic import BaseModel


class Runtime(BaseModel):
type: str = "local"
host: str
path: str
ip: Optional[str] = None
platform: str
os: str
os_version: str
language: str
language_version: str
runtime: str = "local"


class Framework(BaseModel):
name: str
version: str


class VectorDB(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
location: Optional[str] = None
embedding_model: Optional[str] = None


class Model(BaseModel):
vendor: Optional[str] = None
name: Optional[str] = None


class ChainInfo(BaseModel):
name: str
model: Optional[Model] = None
vector_dbs: Optional[List[VectorDB]] = None


class ReqDiscover(BaseModel):
name: str
owner: str
description: Optional[str] = None
load_id: Optional[str] = None
runtime: Runtime
framework: Framework
chains: Optional[List[ChainInfo]] = None
plugin_version: str
client_version: Framework


class ReqLoaderDoc(BaseModel):
name: str
owner: str
docs: list[dict] = None
plugin_version: str
load_id: str
loader_details: dict
loading_end: bool
source_owner: str
classifier_location: str


class Context(BaseModel):
retrieved_from: Optional[str] = None
doc: Optional[str] = None
vector_db: str
pb_checksum: Optional[str] = None


class Prompt(BaseModel):
data: Optional[Union[list, str]] = None
entityCount: Optional[int] = None
entities: Optional[dict] = None
prompt_gov_enabled: Optional[bool] = None


class ReqPrompt(BaseModel):
name: str
context: Optional[List[Context]] = None
prompt: Optional[Prompt] = None
response: Optional[Prompt] = None
prompt_time: str
user: str
user_identities: Optional[List[str]] = None
classifier_location: str


class ReqPromptGov(BaseModel):
prompt: str
55 changes: 28 additions & 27 deletions pebblo/app/service/discovery_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,35 +124,36 @@ def _fetch_chain_details(self, app_metadata) -> list[Chain]:
logger.debug(f"Existing Chains : {chains}")

logger.debug(f"Input chains : {self.data.get('chains', [])}")
for chain in self.data.get("chains", []):
name = chain["name"]
model = chain["model"]
# vector db details
vector_db_details = []
for vector_db in chain.get("vector_dbs", []):
vector_db_obj = VectorDB(
name=vector_db.get("name"),
version=vector_db.get("version"),
location=vector_db.get("location"),
embeddingModel=vector_db.get("embedding_model"),
pkgInfo=None,
)

package_info = vector_db.get("pkg_info")
if package_info:
pkg_info_obj = PackageInfo(
projectHomePage=package_info.get("project_home_page"),
documentationUrl=package_info.get("documentation_url"),
pypiUrl=package_info.get("pypi_url"),
licenceType=package_info.get("licence_type"),
installedVia=package_info.get("installed_via"),
location=package_info.get("location"),
if self.data.get("chains") not in [None, []]:
for chain in self.data.get("chains", []):
name = chain["name"]
model = chain["model"]
# vector db details
vector_db_details = []
for vector_db in chain.get("vector_dbs", []):
vector_db_obj = VectorDB(
name=vector_db.get("name"),
version=vector_db.get("version"),
location=vector_db.get("location"),
embeddingModel=vector_db.get("embedding_model"),
pkgInfo=None,
)
vector_db_obj.pkgInfo = pkg_info_obj

vector_db_details.append(vector_db_obj)
chain_obj = Chain(name=name, model=model, vectorDbs=vector_db_details)
chains.append(chain_obj.model_dump())
package_info = vector_db.get("pkg_info")
if package_info:
pkg_info_obj = PackageInfo(
projectHomePage=package_info.get("project_home_page"),
documentationUrl=package_info.get("documentation_url"),
pypiUrl=package_info.get("pypi_url"),
licenceType=package_info.get("licence_type"),
installedVia=package_info.get("installed_via"),
location=package_info.get("location"),
)
vector_db_obj.pkgInfo = pkg_info_obj

vector_db_details.append(vector_db_obj)
chain_obj = Chain(name=name, model=model, vectorDbs=vector_db_details)
chains.append(chain_obj.model_dump())

logger.debug(f"Application Name [{self.application_name}]: Chains: {chains}")
return chains
Expand Down
55 changes: 33 additions & 22 deletions tests/app/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@
client = TestClient(app)


app_discover_payload = {
"name": "Test App",
"owner": "Test owner",
"description": "This is a test app.",
"runtime": {
"type": "desktop",
"host": "MacBook-Pro.local",
"path": "Test/Path",
"ip": "127.0.0.1",
"platform": "macOS-14.6.1-arm64-i386-64bit",
"os": "Darwin",
"os_version": "Darwin Kernel Version 23.6.0",
"language": "python",
"language_version": "3.11.9",
"runtime": "Mac OSX",
},
"framework": {"name": "langchain", "version": "0.2.35"},
"plugin_version": "0.1",
"client_version": {"name": "langchain_community", "version": "0.2.12"},
}


@pytest.fixture(scope="module")
def mocked_objects():
with (
Expand Down Expand Up @@ -96,13 +118,7 @@ def test_app_discover_success(mock_write_json_to_file, mock_pebblo_server_versio
Test the app discover endpoint.
"""
mock_write_json_to_file.return_value = None
app_payload = {
"name": "Test App",
"owner": "Test owner",
"description": "This is a test app.",
"plugin_version": "0.1",
}
response = client.post("/v1/app/discover", json=app_payload)
response = client.post("/v1/app/discover", json=app_discover_payload)

# Assertions
assert response.status_code == 200
Expand All @@ -115,28 +131,22 @@ def test_app_discover_validation_errors(mock_write_json_to_file):
Test the app discover endpoint with validation errors.
"""
mock_write_json_to_file.return_value = None
app = {
"owner": "Test owner",
"description": "This is a test app.",
"plugin_version": "0.1",
}
response = client.post("/v1/app/discover", json=app)
assert response.status_code == 400
assert "1 validation error for AiApp" in response.json()["message"]
app_payload = app_discover_payload.copy()
app_payload.pop("name")

response = client.post("/v1/app/discover", json=app_payload)
assert response.status_code == 422
assert "'type': 'missing', 'loc': ['body', 'name'], 'msg': 'Field required'" in str(
response.json()["detail"]
)


def test_app_discover_server_error(mock_write_json_to_file):
"""
Test the app discover endpoint with server error.
"""
mock_write_json_to_file.side_effect = Exception("Mocked exception")
app_payload = {
"name": "Test App",
"owner": "Test owner",
"description": "This is a test app.",
"plugin_version": "0.1",
}
response = client.post("/v1/app/discover", json=app_payload)
response = client.post("/v1/app/discover", json=app_discover_payload)

# Assertions
assert response.status_code == 500
Expand Down Expand Up @@ -186,6 +196,7 @@ def test_loader_doc_success(
"source_aggr_size": 306,
},
"plugin_version": "0.1.0",
"classifier_location": "local",
}
response = client.post("/v1/loader/doc", json=loader_doc)
assert response.status_code == 200
Expand Down
Loading

0 comments on commit 6a6265e

Please sign in to comment.