Skip to content

Commit

Permalink
RSDK-7189: add data wrappers (viamrobotics#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
purplenicole730 authored May 3, 2024
1 parent 7acf717 commit 7b34a10
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 8 deletions.
58 changes: 54 additions & 4 deletions src/viam/app/data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple

from google.protobuf.struct_pb2 import Struct
from grpclib.client import Channel, Stream
Expand All @@ -23,6 +23,7 @@
BoundingBoxLabelsByFilterRequest,
BoundingBoxLabelsByFilterResponse,
CaptureMetadata,
ConfigureDatabaseUserRequest,
DataRequest,
DataServiceStub,
DeleteBinaryDataByFilterRequest,
Expand All @@ -43,6 +44,10 @@
RemoveTagsFromBinaryDataByIDsResponse,
TabularDataByFilterRequest,
TabularDataByFilterResponse,
TabularDataByMQLRequest,
TabularDataByMQLResponse,
TabularDataBySQLRequest,
TabularDataBySQLResponse,
TagsByFilterRequest,
TagsByFilterResponse,
)
Expand Down Expand Up @@ -73,7 +78,7 @@
StreamingDataCaptureUploadResponse,
UploadMetadata,
)
from viam.utils import create_filter, datetime_to_timestamp, struct_to_dict
from viam.utils import ValueTypes, create_filter, datetime_to_timestamp, struct_to_dict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -248,6 +253,44 @@ async def tabular_data_by_filter(
LOGGER.error(f"Failed to write tabular data to file {dest}", exc_info=e)
return data, response.count, response.last

async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> List[Dict[str, ValueTypes]]:
"""Obtain unified tabular data and metadata, queried with SQL.
::
data = await data_client.tabular_data_by_sql(org_id="<your-org-id>", sql_query="<sql-query>")
Args:
organization_id (str): The ID of the organization that owns the data.
sql_query (str): The SQL query to run.
Returns:
List[Dict[str, ValueTypes]]: An array of data objects.
"""
request = TabularDataBySQLRequest(organization_id=organization_id, sql_query=sql_query)
response: TabularDataBySQLResponse = await self._data_client.TabularDataBySQL(request, metadata=self._metadata)
return [struct_to_dict(struct) for struct in response.data]

async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes]) -> List[Dict[str, ValueTypes]]:
"""Obtain unified tabular data and metadata, queried with MQL.
::
data = await data_client.tabular_data_by_mql(org_id="<your-org-id>", mql_binary=[<mql-bytes-1>, <mql-bytes-2>])
Args:
organization_id (str): The ID of the organization that owns the data.
mql_binary (List[bytes]):The MQL query to run as a list of BSON documents.
Returns:
List[Dict[str, ValueTypes]]: An array of data objects.
"""
request = TabularDataByMQLRequest(organization_id=organization_id, mql_binary=mql_binary)
response: TabularDataByMQLResponse = await self._data_client.TabularDataByMQL(request, metadata=self._metadata)
return [struct_to_dict(struct) for struct in response.data]

async def binary_data_by_filter(
self,
filter: Optional[Filter] = None,
Expand Down Expand Up @@ -733,9 +776,16 @@ async def get_database_connection(self, organization_id: str) -> str:
response: GetDatabaseConnectionResponse = await self._data_client.GetDatabaseConnection(request, metadata=self._metadata)
return response.hostname

# TODO(RSDK-5569): implement
async def configure_database_user(self, organization_id: str, password: str) -> None:
raise NotImplementedError()
"""Configure a database user for the Viam organization's MongoDB Atlas Data Federation instance. It can also be used to reset the
password of the existing database user.
Args:
organization_id (str): The ID of the organization.
password (str): The password of the user.
"""
request = ConfigureDatabaseUserRequest(organization_id=organization_id, password=password)
await self._data_client.ConfigureDatabaseUser(request, metadata=self._metadata)

async def create_dataset(self, name: str, organization_id: str) -> str:
"""Create a new dataset.
Expand Down
16 changes: 13 additions & 3 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,13 +753,15 @@ class MockData(DataServiceBase):
def __init__(
self,
tabular_response: List[DataClient.TabularData],
tabular_query_response: List[Dict[str, ValueTypes]],
binary_response: List[DataClient.BinaryData],
delete_remove_response: int,
tags_response: List[str],
bbox_labels_response: List[str],
hostname_response: str,
):
self.tabular_response = tabular_response
self.tabular_query_response = tabular_query_response
self.binary_response = binary_response
self.delete_remove_response = delete_remove_response
self.tags_response = tags_response
Expand Down Expand Up @@ -916,7 +918,11 @@ async def GetDatabaseConnection(self, stream: Stream[GetDatabaseConnectionReques
await stream.send_message(GetDatabaseConnectionResponse(hostname=self.hostname_response))

async def ConfigureDatabaseUser(self, stream: Stream[ConfigureDatabaseUserRequest, ConfigureDatabaseUserResponse]) -> None:
raise NotImplementedError()
request = await stream.recv_message()
assert request is not None
self.organization_id = request.organization_id
self.password = request.password
await stream.send_message(ConfigureDatabaseUserResponse())

async def AddBinaryDataToDatasetByIDs(
self, stream: Stream[AddBinaryDataToDatasetByIDsRequest, AddBinaryDataToDatasetByIDsResponse]
Expand All @@ -937,10 +943,14 @@ async def RemoveBinaryDataFromDatasetByIDs(
await stream.send_message(RemoveBinaryDataFromDatasetByIDsResponse())

async def TabularDataBySQL(self, stream: Stream[TabularDataBySQLRequest, TabularDataBySQLResponse]) -> None:
raise NotImplementedError()
request = await stream.recv_message()
assert request is not None
await stream.send_message(TabularDataBySQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))

async def TabularDataByMQL(self, stream: Stream[TabularDataByMQLRequest, TabularDataByMQLResponse]) -> None:
raise NotImplementedError()
request = await stream.recv_message()
assert request is not None
await stream.send_message(TabularDataByMQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))


class MockDataset(DatasetServiceBase):
Expand Down
27 changes: 26 additions & 1 deletion tests/test_data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LOCATION_IDS = [LOCATION_ID]
ORG_ID = "organization_id"
ORG_IDS = [ORG_ID]
PASSWORD = "password"
MIME_TYPE = "mime_type"
MIME_TYPES = [MIME_TYPE]
URI = "some.robot.uri"
Expand Down Expand Up @@ -70,6 +71,8 @@
y_max_normalized=0.3,
)
BBOXES = [BBOX]
SQL_QUERY = "sql_query"
MQL_BINARY = [b"mql_binary"]
TABULAR_DATA = {"key": "value"}
TABULAR_METADATA = CaptureMetadata(
organization_id=ORG_ID,
Expand Down Expand Up @@ -97,6 +100,9 @@
)

TABULAR_RESPONSE = [DataClient.TabularData(TABULAR_DATA, TABULAR_METADATA, START_DATETIME, END_DATETIME)]
TABULAR_QUERY_RESPONSE = [
{"key1": 1, "key2": "2", "key3": [1, 2, 3], "key4": {"key4sub1": 1}},
]
BINARY_RESPONSE = [DataClient.BinaryData(BINARY_DATA, BINARY_METADATA)]
DELETE_REMOVE_RESPONSE = 1
TAGS_RESPONSE = ["tag"]
Expand All @@ -110,6 +116,7 @@
def service() -> MockData:
return MockData(
tabular_response=TABULAR_RESPONSE,
tabular_query_response=TABULAR_QUERY_RESPONSE,
binary_response=BINARY_RESPONSE,
delete_remove_response=DELETE_REMOVE_RESPONSE,
tags_response=TAGS_RESPONSE,
Expand Down Expand Up @@ -143,6 +150,20 @@ async def test_tabular_data_by_filter(self, service: MockData):
assert last_response != ""
self.assert_filter(filter=service.filter)

@pytest.mark.asyncio
async def test_tabular_data_by_sql(self, service: MockData):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
response = await client.tabular_data_by_sql(ORG_ID, SQL_QUERY)
assert response == TABULAR_QUERY_RESPONSE

@pytest.mark.asyncio
async def test_tabular_data_by_mql(self, service: MockData):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
response = await client.tabular_data_by_mql(ORG_ID, MQL_BINARY)
assert response == TABULAR_QUERY_RESPONSE

@pytest.mark.asyncio
async def test_binary_data_by_filter(self, service: MockData):
async with ChannelFor([service]) as channel:
Expand Down Expand Up @@ -283,7 +304,11 @@ async def test_get_database_connection(self, service: MockData):

@pytest.mark.asyncio
async def test_configure_database_user(self, service: MockData):
assert True
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
await client.configure_database_user(ORG_ID, PASSWORD)
assert service.organization_id == ORG_ID
assert service.password == PASSWORD

@pytest.mark.asyncio
async def test_add_binary_data_to_dataset_by_ids(self, service: MockData):
Expand Down

0 comments on commit 7b34a10

Please sign in to comment.