forked from viamrobotics/viam-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RSDK-7192 - Provisioning wrappers (viamrobotics#577)
- Loading branch information
Showing
4 changed files
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import Mapping, List, Optional | ||
|
||
from grpclib.client import Channel | ||
|
||
from viam import logging | ||
from viam.proto.provisioning import ( | ||
CloudConfig, | ||
GetNetworkListRequest, | ||
GetNetworkListResponse, | ||
GetSmartMachineStatusRequest, | ||
GetSmartMachineStatusResponse, | ||
NetworkInfo, | ||
ProvisioningServiceStub, | ||
SetNetworkCredentialsRequest, | ||
SetSmartMachineCredentialsRequest, | ||
) | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class ProvisioningClient: | ||
"""gRPC client for getting and setting smart machine info. | ||
Constructor is used by `ViamClient` to instantiate relevant service stubs. Calls to | ||
`ProvisioningClient` methods should be made through `ViamClient`. | ||
Establish a connection:: | ||
import asyncio | ||
from viam.rpc.dial import DialOptions, Credentials | ||
from viam.app.viam_client import ViamClient | ||
async def connect() -> ViamClient: | ||
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID | ||
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>") | ||
return await ViamClient.create_from_dial_options(dial_options) | ||
async def main(): | ||
# Make a ViamClient | ||
viam_client = await connect() | ||
# Instantiate a ProvisioningClient to run provisioning client API methods on | ||
provisioning_client = viam_client.provisioning_client | ||
viam_client.close() | ||
if __name__ == '__main__': | ||
asyncio.run(main()) | ||
""" | ||
|
||
def __init__(self, channel: Channel, metadata: Mapping[str, str]): | ||
"""Create a `ProvisioningClient` that maintains a connection to app. | ||
Args: | ||
channel (grpclib.client.Channel): Connection to app. | ||
metadata (Mapping[str, str]): Required authorization token to send requests to app. | ||
""" | ||
self._metadata = metadata | ||
self._provisioning_client = ProvisioningServiceStub(channel) | ||
self._channel = channel | ||
|
||
_provisioning_client: ProvisioningServiceStub | ||
_metadata: Mapping[str, str] | ||
_channel: Channel | ||
|
||
async def get_network_list(self) -> List[NetworkInfo]: | ||
"""Returns list of networks that are visible to the Smart Machine.""" | ||
request = GetNetworkListRequest() | ||
resp: GetNetworkListResponse = await self._provisioning_client.GetNetworkList(request, metadata=self._metadata) | ||
return list(resp.networks) | ||
|
||
async def get_smart_machine_status(self) -> GetSmartMachineStatusResponse: | ||
"""Returns the status of the smart machine.""" | ||
request = GetSmartMachineStatusRequest() | ||
return await self._provisioning_client.GetSmartMachineStatus(request, metadata=self._metadata) | ||
|
||
async def set_network_credentials(self, network_type: str, ssid: str, psk: str) -> None: | ||
"""Sets the network credentials of the Smart Machine. | ||
Args: | ||
network_type (str): The type of the network. | ||
ssid (str): The SSID of the network. | ||
psk (str): The network's passkey. | ||
""" | ||
|
||
request = SetNetworkCredentialsRequest(type=network_type, ssid=ssid, psk=psk) | ||
await self._provisioning_client.SetNetworkCredentials(request, metadata=self._metadata) | ||
|
||
async def set_smart_machine_credentials(self, cloud_config: Optional[CloudConfig] = None) -> None: | ||
request = SetSmartMachineCredentialsRequest(cloud=cloud_config) | ||
await self._provisioning_client.SetSmartMachineCredentials(request, metadata=self._metadata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import pytest | ||
|
||
from grpclib.testing import ChannelFor | ||
|
||
from viam.app.provisioning_client import ProvisioningClient | ||
|
||
from viam.proto.provisioning import GetSmartMachineStatusResponse, NetworkInfo, ProvisioningInfo, CloudConfig | ||
|
||
from .mocks.services import MockProvisioning | ||
|
||
ID = "id" | ||
MODEL = "model" | ||
MANUFACTURER = "acme" | ||
PROVISIONING_INFO = ProvisioningInfo(fragment_id=ID, model=MODEL, manufacturer=MANUFACTURER) | ||
HAS_CREDENTIALS = True | ||
IS_ONLINE = True | ||
NETWORK_TYPE = "type" | ||
SSID = "ssid" | ||
ERROR = "error" | ||
ERRORS = [ERROR] | ||
PSK = "psk" | ||
SECRET = "secret" | ||
APP_ADDRESS = "address" | ||
NETWORK_INFO_LATEST = NetworkInfo( | ||
type=NETWORK_TYPE, | ||
ssid=SSID, | ||
security="security", | ||
signal=12, | ||
connected=IS_ONLINE, | ||
last_error=ERROR, | ||
) | ||
NETWORK_INFO = [NETWORK_INFO_LATEST] | ||
SMART_MACHINE_STATUS_RESPONSE = GetSmartMachineStatusResponse( | ||
provisioning_info=PROVISIONING_INFO, | ||
has_smart_machine_credentials=HAS_CREDENTIALS, | ||
is_online=IS_ONLINE, | ||
latest_connection_attempt=NETWORK_INFO_LATEST, | ||
errors=ERRORS | ||
) | ||
CLOUD_CONFIG = CloudConfig(id=ID, secret=SECRET, app_address=APP_ADDRESS) | ||
|
||
AUTH_TOKEN = "auth_token" | ||
PROVISIONING_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"} | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def service() -> MockProvisioning: | ||
return MockProvisioning(smart_machine_status=SMART_MACHINE_STATUS_RESPONSE, network_info=NETWORK_INFO) | ||
|
||
|
||
class TestClient: | ||
@pytest.mark.asyncio | ||
async def test_get_network_list(self, service: MockProvisioning): | ||
async with ChannelFor([service]) as channel: | ||
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) | ||
network_info = await client.get_network_list() | ||
assert network_info == NETWORK_INFO | ||
|
||
@pytest.mark.asyncio | ||
async def test_get_smart_machine_status(self, service: MockProvisioning): | ||
async with ChannelFor([service]) as channel: | ||
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) | ||
smart_machine_status = await client.get_smart_machine_status() | ||
assert smart_machine_status == SMART_MACHINE_STATUS_RESPONSE | ||
|
||
@pytest.mark.asyncio | ||
async def test_set_network_credentials(self, service: MockProvisioning): | ||
async with ChannelFor([service]) as channel: | ||
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) | ||
await client.set_network_credentials(network_type=NETWORK_TYPE, ssid=SSID, psk=PSK) | ||
assert service.network_type == NETWORK_TYPE | ||
assert service.ssid == SSID | ||
assert service.psk == PSK | ||
|
||
@pytest.mark.asyncio | ||
async def test_set_smart_machine_credentials(self, service: MockProvisioning): | ||
async with ChannelFor([service]) as channel: | ||
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA) | ||
await client.set_smart_machine_credentials(cloud_config=CLOUD_CONFIG) | ||
assert service.cloud_config == CLOUD_CONFIG |