diff --git a/src/viam/app/provisioning_client.py b/src/viam/app/provisioning_client.py new file mode 100644 index 000000000..2606f3664 --- /dev/null +++ b/src/viam/app/provisioning_client.py @@ -0,0 +1,95 @@ +from typing import Mapping, List, Optional + +from grpclib.client import Channel + +from viam import logging +from viam.proto.provisioning import ( + CloudConfig, + GetNetworkListRequest, + GetNetworkListResponse, + GetSmartMachineStatusRequest, + GetSmartMachineStatusResponse, + NetworkInfo, + ProvisioningServiceStub, + SetNetworkCredentialsRequest, + SetSmartMachineCredentialsRequest, +) + +LOGGER = logging.getLogger(__name__) + + +class ProvisioningClient: + """gRPC client for getting and setting smart machine info. + + Constructor is used by `ViamClient` to instantiate relevant service stubs. Calls to + `ProvisioningClient` methods should be made through `ViamClient`. + + Establish a connection:: + + import asyncio + + from viam.rpc.dial import DialOptions, Credentials + from viam.app.viam_client import ViamClient + + + async def connect() -> ViamClient: + # Replace "" (including brackets) with your API key and "" with your API key ID + dial_options = DialOptions.with_api_key("", "") + return await ViamClient.create_from_dial_options(dial_options) + + + async def main(): + + # Make a ViamClient + viam_client = await connect() + # Instantiate a ProvisioningClient to run provisioning client API methods on + provisioning_client = viam_client.provisioning_client + + viam_client.close() + + if __name__ == '__main__': + asyncio.run(main()) + + """ + + def __init__(self, channel: Channel, metadata: Mapping[str, str]): + """Create a `ProvisioningClient` that maintains a connection to app. + + Args: + channel (grpclib.client.Channel): Connection to app. + metadata (Mapping[str, str]): Required authorization token to send requests to app. + """ + self._metadata = metadata + self._provisioning_client = ProvisioningServiceStub(channel) + self._channel = channel + + _provisioning_client: ProvisioningServiceStub + _metadata: Mapping[str, str] + _channel: Channel + + async def get_network_list(self) -> List[NetworkInfo]: + """Returns list of networks that are visible to the Smart Machine.""" + request = GetNetworkListRequest() + resp: GetNetworkListResponse = await self._provisioning_client.GetNetworkList(request, metadata=self._metadata) + return list(resp.networks) + + async def get_smart_machine_status(self) -> GetSmartMachineStatusResponse: + """Returns the status of the smart machine.""" + request = GetSmartMachineStatusRequest() + return await self._provisioning_client.GetSmartMachineStatus(request, metadata=self._metadata) + + async def set_network_credentials(self, network_type: str, ssid: str, psk: str) -> None: + """Sets the network credentials of the Smart Machine. + + Args: + network_type (str): The type of the network. + ssid (str): The SSID of the network. + psk (str): The network's passkey. + """ + + request = SetNetworkCredentialsRequest(type=network_type, ssid=ssid, psk=psk) + await self._provisioning_client.SetNetworkCredentials(request, metadata=self._metadata) + + async def set_smart_machine_credentials(self, cloud_config: Optional[CloudConfig] = None) -> None: + request = SetSmartMachineCredentialsRequest(cloud=cloud_config) + await self._provisioning_client.SetSmartMachineCredentials(request, metadata=self._metadata) diff --git a/src/viam/app/viam_client.py b/src/viam/app/viam_client.py index ead858998..9e964c498 100644 --- a/src/viam/app/viam_client.py +++ b/src/viam/app/viam_client.py @@ -8,6 +8,7 @@ from viam.app.billing_client import BillingClient from viam.app.data_client import DataClient from viam.app.ml_training_client import MLTrainingClient +from viam.app.provisioning_client import ProvisioningClient from viam.rpc.dial import DialOptions, _dial_app, _get_access_token LOGGER = logging.getLogger(__name__) @@ -149,6 +150,27 @@ async def main(): return BillingClient(self._channel, self._metadata) + @property + def provisioning_client(self) -> ProvisioningClient: + """Instantiate and return a `ProvisioningClient` used to make `provisioning` method calls. + To use the `ProvisioningClient`, you must first instantiate a `ViamClient`. + + :: + + async def connect() -> ViamClient: + # Replace "" (including brackets) with your API key and "" with your API key ID + dial_options = DialOptions.with_api_key("", "") + return await ViamClient.create_from_dial_options(dial_options) + + + async def main(): + viam_client = await connect() + + # Instantiate a ProvisioningClient to run provisioning API methods on + provisioning_client = viam_client.provisioning_client + """ + return ProvisioningClient(self._channel, self._metadata) + def close(self): """Close opened channels used for the various service stubs initialized.""" if self._closed: diff --git a/tests/mocks/services.py b/tests/mocks/services.py index 895b61d7e..03f8150c7 100644 --- a/tests/mocks/services.py +++ b/tests/mocks/services.py @@ -279,6 +279,18 @@ FlatTensorDataUInt64, FlatTensors, ) +from viam.proto.provisioning import ( + NetworkInfo, + ProvisioningServiceBase, + GetNetworkListRequest, + GetNetworkListResponse, + GetSmartMachineStatusRequest, + GetSmartMachineStatusResponse, + SetNetworkCredentialsRequest, + SetNetworkCredentialsResponse, + SetSmartMachineCredentialsRequest, + SetSmartMachineCredentialsResponse, +) from viam.proto.service.motion import ( Constraints, GetPlanRequest, @@ -698,6 +710,43 @@ async def do_command(self, command: Mapping[str, ValueTypes], *, timeout: Option return {"command": command} +class MockProvisioning(ProvisioningServiceBase): + def __init__( + self, + smart_machine_status: GetSmartMachineStatusResponse, + network_info: List[NetworkInfo], + ): + self.smart_machine_status = smart_machine_status + self.network_info = network_info + + async def GetNetworkList(self, stream: Stream[GetNetworkListRequest, GetNetworkListResponse]) -> None: + request = await stream.recv_message() + assert request is not None + await stream.send_message(GetNetworkListResponse(networks=self.network_info)) + + async def GetSmartMachineStatus(self, stream: Stream[GetSmartMachineStatusRequest, GetSmartMachineStatusResponse]) -> None: + request = await stream.recv_message() + assert request is not None + await stream.send_message(self.smart_machine_status) + + async def SetNetworkCredentials(self, stream: Stream[SetNetworkCredentialsRequest, SetNetworkCredentialsResponse]) -> None: + request = await stream.recv_message() + assert request is not None + self.network_type = request.type + self.ssid = request.ssid + self.psk = request.psk + await stream.send_message(SetNetworkCredentialsResponse()) + + async def SetSmartMachineCredentials( + self, + stream: Stream[SetSmartMachineCredentialsRequest, SetSmartMachineCredentialsResponse], + ) -> None: + request = await stream.recv_message() + assert request is not None + self.cloud_config = request.cloud + await stream.send_message(SetSmartMachineCredentialsResponse()) + + class MockData(DataServiceBase): def __init__( self, diff --git a/tests/test_provisioning_client.py b/tests/test_provisioning_client.py new file mode 100644 index 000000000..56b235b0f --- /dev/null +++ b/tests/test_provisioning_client.py @@ -0,0 +1,80 @@ +import pytest + +from grpclib.testing import ChannelFor + +from viam.app.provisioning_client import ProvisioningClient + +from viam.proto.provisioning import GetSmartMachineStatusResponse, NetworkInfo, ProvisioningInfo, CloudConfig + +from .mocks.services import MockProvisioning + +ID = "id" +MODEL = "model" +MANUFACTURER = "acme" +PROVISIONING_INFO = ProvisioningInfo(fragment_id=ID, model=MODEL, manufacturer=MANUFACTURER) +HAS_CREDENTIALS = True +IS_ONLINE = True +NETWORK_TYPE = "type" +SSID = "ssid" +ERROR = "error" +ERRORS = [ERROR] +PSK = "psk" +SECRET = "secret" +APP_ADDRESS = "address" +NETWORK_INFO_LATEST = NetworkInfo( + type=NETWORK_TYPE, + ssid=SSID, + security="security", + signal=12, + connected=IS_ONLINE, + last_error=ERROR, +) +NETWORK_INFO = [NETWORK_INFO_LATEST] +SMART_MACHINE_STATUS_RESPONSE = GetSmartMachineStatusResponse( + provisioning_info=PROVISIONING_INFO, + has_smart_machine_credentials=HAS_CREDENTIALS, + is_online=IS_ONLINE, + latest_connection_attempt=NETWORK_INFO_LATEST, + errors=ERRORS +) +CLOUD_CONFIG = CloudConfig(id=ID, secret=SECRET, app_address=APP_ADDRESS) + +AUTH_TOKEN = "auth_token" +PROVISIONING_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"} + + +@pytest.fixture(scope="function") +def service() -> MockProvisioning: + return MockProvisioning(smart_machine_status=SMART_MACHINE_STATUS_RESPONSE, network_info=NETWORK_INFO) + + +class TestClient: + @pytest.mark.asyncio + async def test_get_network_list(self, service: MockProvisioning): + async with ChannelFor([service]) as channel: + client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) + network_info = await client.get_network_list() + assert network_info == NETWORK_INFO + + @pytest.mark.asyncio + async def test_get_smart_machine_status(self, service: MockProvisioning): + async with ChannelFor([service]) as channel: + client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) + smart_machine_status = await client.get_smart_machine_status() + assert smart_machine_status == SMART_MACHINE_STATUS_RESPONSE + + @pytest.mark.asyncio + async def test_set_network_credentials(self, service: MockProvisioning): + async with ChannelFor([service]) as channel: + client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) + await client.set_network_credentials(network_type=NETWORK_TYPE, ssid=SSID, psk=PSK) + assert service.network_type == NETWORK_TYPE + assert service.ssid == SSID + assert service.psk == PSK + + @pytest.mark.asyncio + async def test_set_smart_machine_credentials(self, service: MockProvisioning): + async with ChannelFor([service]) as channel: + client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) + await client.set_smart_machine_credentials(cloud_config=CLOUD_CONFIG) + assert service.cloud_config == CLOUD_CONFIG