Skip to content

Commit

Permalink
RSDK-7192 - Provisioning wrappers (viamrobotics#577)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuqdog authored Apr 11, 2024
1 parent 313ca7d commit 1e31db8
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 0 deletions.
95 changes: 95 additions & 0 deletions src/viam/app/provisioning_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Mapping, List, Optional

from grpclib.client import Channel

from viam import logging
from viam.proto.provisioning import (
CloudConfig,
GetNetworkListRequest,
GetNetworkListResponse,
GetSmartMachineStatusRequest,
GetSmartMachineStatusResponse,
NetworkInfo,
ProvisioningServiceStub,
SetNetworkCredentialsRequest,
SetSmartMachineCredentialsRequest,
)

LOGGER = logging.getLogger(__name__)


class ProvisioningClient:
"""gRPC client for getting and setting smart machine info.
Constructor is used by `ViamClient` to instantiate relevant service stubs. Calls to
`ProvisioningClient` methods should be made through `ViamClient`.
Establish a connection::
import asyncio
from viam.rpc.dial import DialOptions, Credentials
from viam.app.viam_client import ViamClient
async def connect() -> ViamClient:
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
return await ViamClient.create_from_dial_options(dial_options)
async def main():
# Make a ViamClient
viam_client = await connect()
# Instantiate a ProvisioningClient to run provisioning client API methods on
provisioning_client = viam_client.provisioning_client
viam_client.close()
if __name__ == '__main__':
asyncio.run(main())
"""

def __init__(self, channel: Channel, metadata: Mapping[str, str]):
"""Create a `ProvisioningClient` that maintains a connection to app.
Args:
channel (grpclib.client.Channel): Connection to app.
metadata (Mapping[str, str]): Required authorization token to send requests to app.
"""
self._metadata = metadata
self._provisioning_client = ProvisioningServiceStub(channel)
self._channel = channel

_provisioning_client: ProvisioningServiceStub
_metadata: Mapping[str, str]
_channel: Channel

async def get_network_list(self) -> List[NetworkInfo]:
"""Returns list of networks that are visible to the Smart Machine."""
request = GetNetworkListRequest()
resp: GetNetworkListResponse = await self._provisioning_client.GetNetworkList(request, metadata=self._metadata)
return list(resp.networks)

async def get_smart_machine_status(self) -> GetSmartMachineStatusResponse:
"""Returns the status of the smart machine."""
request = GetSmartMachineStatusRequest()
return await self._provisioning_client.GetSmartMachineStatus(request, metadata=self._metadata)

async def set_network_credentials(self, network_type: str, ssid: str, psk: str) -> None:
"""Sets the network credentials of the Smart Machine.
Args:
network_type (str): The type of the network.
ssid (str): The SSID of the network.
psk (str): The network's passkey.
"""

request = SetNetworkCredentialsRequest(type=network_type, ssid=ssid, psk=psk)
await self._provisioning_client.SetNetworkCredentials(request, metadata=self._metadata)

async def set_smart_machine_credentials(self, cloud_config: Optional[CloudConfig] = None) -> None:
request = SetSmartMachineCredentialsRequest(cloud=cloud_config)
await self._provisioning_client.SetSmartMachineCredentials(request, metadata=self._metadata)
22 changes: 22 additions & 0 deletions src/viam/app/viam_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from viam.app.billing_client import BillingClient
from viam.app.data_client import DataClient
from viam.app.ml_training_client import MLTrainingClient
from viam.app.provisioning_client import ProvisioningClient
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,6 +150,27 @@ async def main():

return BillingClient(self._channel, self._metadata)

@property
def provisioning_client(self) -> ProvisioningClient:
"""Instantiate and return a `ProvisioningClient` used to make `provisioning` method calls.
To use the `ProvisioningClient`, you must first instantiate a `ViamClient`.
::
async def connect() -> ViamClient:
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
return await ViamClient.create_from_dial_options(dial_options)
async def main():
viam_client = await connect()
# Instantiate a ProvisioningClient to run provisioning API methods on
provisioning_client = viam_client.provisioning_client
"""
return ProvisioningClient(self._channel, self._metadata)

def close(self):
"""Close opened channels used for the various service stubs initialized."""
if self._closed:
Expand Down
49 changes: 49 additions & 0 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@
FlatTensorDataUInt64,
FlatTensors,
)
from viam.proto.provisioning import (
NetworkInfo,
ProvisioningServiceBase,
GetNetworkListRequest,
GetNetworkListResponse,
GetSmartMachineStatusRequest,
GetSmartMachineStatusResponse,
SetNetworkCredentialsRequest,
SetNetworkCredentialsResponse,
SetSmartMachineCredentialsRequest,
SetSmartMachineCredentialsResponse,
)
from viam.proto.service.motion import (
Constraints,
GetPlanRequest,
Expand Down Expand Up @@ -698,6 +710,43 @@ async def do_command(self, command: Mapping[str, ValueTypes], *, timeout: Option
return {"command": command}


class MockProvisioning(ProvisioningServiceBase):
def __init__(
self,
smart_machine_status: GetSmartMachineStatusResponse,
network_info: List[NetworkInfo],
):
self.smart_machine_status = smart_machine_status
self.network_info = network_info

async def GetNetworkList(self, stream: Stream[GetNetworkListRequest, GetNetworkListResponse]) -> None:
request = await stream.recv_message()
assert request is not None
await stream.send_message(GetNetworkListResponse(networks=self.network_info))

async def GetSmartMachineStatus(self, stream: Stream[GetSmartMachineStatusRequest, GetSmartMachineStatusResponse]) -> None:
request = await stream.recv_message()
assert request is not None
await stream.send_message(self.smart_machine_status)

async def SetNetworkCredentials(self, stream: Stream[SetNetworkCredentialsRequest, SetNetworkCredentialsResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.network_type = request.type
self.ssid = request.ssid
self.psk = request.psk
await stream.send_message(SetNetworkCredentialsResponse())

async def SetSmartMachineCredentials(
self,
stream: Stream[SetSmartMachineCredentialsRequest, SetSmartMachineCredentialsResponse],
) -> None:
request = await stream.recv_message()
assert request is not None
self.cloud_config = request.cloud
await stream.send_message(SetSmartMachineCredentialsResponse())


class MockData(DataServiceBase):
def __init__(
self,
Expand Down
80 changes: 80 additions & 0 deletions tests/test_provisioning_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest

from grpclib.testing import ChannelFor

from viam.app.provisioning_client import ProvisioningClient

from viam.proto.provisioning import GetSmartMachineStatusResponse, NetworkInfo, ProvisioningInfo, CloudConfig

from .mocks.services import MockProvisioning

ID = "id"
MODEL = "model"
MANUFACTURER = "acme"
PROVISIONING_INFO = ProvisioningInfo(fragment_id=ID, model=MODEL, manufacturer=MANUFACTURER)
HAS_CREDENTIALS = True
IS_ONLINE = True
NETWORK_TYPE = "type"
SSID = "ssid"
ERROR = "error"
ERRORS = [ERROR]
PSK = "psk"
SECRET = "secret"
APP_ADDRESS = "address"
NETWORK_INFO_LATEST = NetworkInfo(
type=NETWORK_TYPE,
ssid=SSID,
security="security",
signal=12,
connected=IS_ONLINE,
last_error=ERROR,
)
NETWORK_INFO = [NETWORK_INFO_LATEST]
SMART_MACHINE_STATUS_RESPONSE = GetSmartMachineStatusResponse(
provisioning_info=PROVISIONING_INFO,
has_smart_machine_credentials=HAS_CREDENTIALS,
is_online=IS_ONLINE,
latest_connection_attempt=NETWORK_INFO_LATEST,
errors=ERRORS
)
CLOUD_CONFIG = CloudConfig(id=ID, secret=SECRET, app_address=APP_ADDRESS)

AUTH_TOKEN = "auth_token"
PROVISIONING_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"}


@pytest.fixture(scope="function")
def service() -> MockProvisioning:
return MockProvisioning(smart_machine_status=SMART_MACHINE_STATUS_RESPONSE, network_info=NETWORK_INFO)


class TestClient:
@pytest.mark.asyncio
async def test_get_network_list(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
network_info = await client.get_network_list()
assert network_info == NETWORK_INFO

@pytest.mark.asyncio
async def test_get_smart_machine_status(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
smart_machine_status = await client.get_smart_machine_status()
assert smart_machine_status == SMART_MACHINE_STATUS_RESPONSE

@pytest.mark.asyncio
async def test_set_network_credentials(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
await client.set_network_credentials(network_type=NETWORK_TYPE, ssid=SSID, psk=PSK)
assert service.network_type == NETWORK_TYPE
assert service.ssid == SSID
assert service.psk == PSK

@pytest.mark.asyncio
async def test_set_smart_machine_credentials(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
await client.set_smart_machine_credentials(cloud_config=CLOUD_CONFIG)
assert service.cloud_config == CLOUD_CONFIG

0 comments on commit 1e31db8

Please sign in to comment.