diff --git a/Makefile b/Makefile index 857a9c998..1dac2ace7 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ update-buf: $(node_modules) .PHONY: build-buf build-buf: $(node_modules) clean-buf $(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/googleapis/googleapis) - $(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/api) --path common,component,robot,service,app,provisioning + $(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/api) --path common,component,robot,service,app,provisioning,tagger $(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/erdaniels/gostream) $(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/goutils) diff --git a/src/app/ml-training-client.test.ts b/src/app/ml-training-client.test.ts new file mode 100644 index 000000000..11e66cd3b --- /dev/null +++ b/src/app/ml-training-client.test.ts @@ -0,0 +1,146 @@ +import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { + CancelTrainingJobRequest, + DeleteCompletedTrainingJobRequest, + GetTrainingJobRequest, + GetTrainingJobResponse, + ListTrainingJobsRequest, + ListTrainingJobsResponse, + ModelType, + SubmitTrainingJobRequest, + SubmitTrainingJobResponse, + TrainingJobMetadata, + TrainingStatus, +} from '../gen/app/mltraining/v1/ml_training_pb'; +import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service'; +vi.mock('../gen/app/mltraining/v1/ml_training_pb_service'); +import { MlTrainingClient } from './ml-training-client'; + +const subject = () => + new MlTrainingClient('fakeServiceHoost', { + transport: new FakeTransportBuilder().build(), + }); + +describe('MlTrainingClient tests', () => { + describe('submitTrainingJob tests', () => { + const type = ModelType.MODEL_TYPE_UNSPECIFIED; + beforeEach(() => { + vi.spyOn(MLTrainingServiceClient.prototype, 'submitTrainingJob') + // @ts-expect-error compiler is matching incorrect function signature + .mockImplementationOnce((_req: SubmitTrainingJobRequest, _md, cb) => { + const response = new SubmitTrainingJobResponse(); + response.setId('fakeId'); + cb(null, response); + }); + }); + + it('submit job training job', async () => { + const response = await subject().submitTrainingJob( + 'org_id', + 'dataset_id', + 'model_name', + 'model_version', + type, + ['tag1'] + ); + expect(response).toEqual('fakeId'); + }); + }); + + describe('getTrainingJob tests', () => { + const metadata: TrainingJobMetadata = new TrainingJobMetadata(); + metadata.setId('id'); + metadata.setDatasetId('dataset_id'); + metadata.setOrganizationId('org_id'); + metadata.setModelVersion('model_version'); + metadata.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); + metadata.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); + metadata.setSyncedModelId('synced_model_id'); + + beforeEach(() => { + vi.spyOn(MLTrainingServiceClient.prototype, 'getTrainingJob') + // @ts-expect-error compiler is matching incorrect function signature + .mockImplementationOnce((_req: GetTrainingJobRequest, _md, cb) => { + const response = new GetTrainingJobResponse(); + response.setMetadata(metadata); + cb(null, response); + }); + }); + + it('get training job', async () => { + const response = await subject().getTrainingJob('id'); + expect(response).toEqual(metadata); + }); + }); + + describe('listTrainingJobs', () => { + const status = TrainingStatus.TRAINING_STATUS_UNSPECIFIED; + const md1 = new TrainingJobMetadata(); + md1.setId('id1'); + md1.setDatasetId('dataset_id1'); + md1.setOrganizationId('org_id1'); + md1.setModelVersion('model_version1'); + md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); + md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); + md1.setSyncedModelId('synced_model_id1'); + const md2 = new TrainingJobMetadata(); + md1.setId('id2'); + md1.setDatasetId('dataset_id2'); + md1.setOrganizationId('org_id2'); + md1.setModelVersion('model_version2'); + md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); + md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); + md1.setSyncedModelId('synced_model_id2'); + const jobs = [md1, md2]; + + beforeEach(() => { + vi.spyOn(MLTrainingServiceClient.prototype, 'listTrainingJobs') + // @ts-expect-error compiler is matching incorrect function signature + .mockImplementationOnce((_req: ListTrainingJobsRequest, _md, cb) => { + const response = new ListTrainingJobsResponse(); + response.setJobsList(jobs); + cb(null, response); + }); + }); + + it('list training jobs', async () => { + const response = await subject().listTrainingJobs('org_id', status); + expect(response).toEqual([md1.toObject(), md2.toObject()]); + }); + }); + + describe('cancelTrainingJob tests', () => { + const id = 'id'; + beforeEach(() => { + vi.spyOn(MLTrainingServiceClient.prototype, 'cancelTrainingJob') + // @ts-expect-error compiler is matching incorrect function signature + .mockImplementationOnce((req: CancelTrainingJobRequest, _md, cb) => { + expect(req.getId()).toStrictEqual(id); + cb(null, {}); + }); + }); + it('cancel training job', async () => { + expect(await subject().cancelTrainingJob(id)).toStrictEqual(null); + }); + }); + + describe('deleteCompletedTrainingJob tests', () => { + const id = 'id'; + beforeEach(() => { + vi.spyOn( + MLTrainingServiceClient.prototype, + 'deleteCompletedTrainingJob' + ).mockImplementationOnce( + // @ts-expect-error compiler is matching incorrect function signature + (req: DeleteCompletedTrainingJobRequest, _md, cb) => { + expect(req.getId()).toStrictEqual(id); + cb(null, {}); + } + ); + }); + it('delete completed training job', async () => { + expect(await subject().deleteCompletedTrainingJob(id)).toEqual(null); + }); + }); +}); diff --git a/src/app/ml-training-client.ts b/src/app/ml-training-client.ts new file mode 100644 index 000000000..38d02cab3 --- /dev/null +++ b/src/app/ml-training-client.ts @@ -0,0 +1,96 @@ +import { type RpcOptions } from '@improbable-eng/grpc-web/dist/typings/client.d'; +import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service'; +import pb from '../gen/app/mltraining/v1/ml_training_pb'; +import { promisify } from '../utils'; + +type ValueOf = T[keyof T]; +export const { ModelType } = pb; +export type ModelType = ValueOf; +export const { TrainingStatus } = pb; +export type TrainingStatus = ValueOf; + +export class MlTrainingClient { + private service: MLTrainingServiceClient; + + constructor(serviceHost: string, grpcOptions: RpcOptions) { + this.service = new MLTrainingServiceClient(serviceHost, grpcOptions); + } + + async submitTrainingJob( + orgId: string, + datasetId: string, + modelName: string, + modelVersion: string, + modelType: ModelType, + tagsList: string[] + ) { + const { service } = this; + + const req = new pb.SubmitTrainingJobRequest(); + req.setOrganizationId(orgId); + req.setDatasetId(datasetId); + req.setModelName(modelName); + req.setModelVersion(modelVersion); + req.setModelType(modelType); + req.setTagsList(tagsList); + + const response = await promisify< + pb.SubmitTrainingJobRequest, + pb.SubmitTrainingJobResponse + >(service.submitTrainingJob.bind(service), req); + return response.getId(); + } + + async getTrainingJob(id: string) { + const { service } = this; + + const req = new pb.GetTrainingJobRequest(); + req.setId(id); + + const response = await promisify< + pb.GetTrainingJobRequest, + pb.GetTrainingJobResponse + >(service.getTrainingJob.bind(service), req); + return response.getMetadata(); + } + + async listTrainingJobs(orgId: string, status: TrainingStatus) { + const { service } = this; + + const req = new pb.ListTrainingJobsRequest(); + req.setOrganizationId(orgId); + req.setStatus(status); + + const response = await promisify< + pb.ListTrainingJobsRequest, + pb.ListTrainingJobsResponse + >(service.listTrainingJobs.bind(service), req); + return response.toObject().jobsList; + } + + async cancelTrainingJob(id: string) { + const { service } = this; + + const req = new pb.CancelTrainingJobRequest(); + req.setId(id); + + await promisify( + service.cancelTrainingJob.bind(service), + req + ); + return null; + } + + async deleteCompletedTrainingJob(id: string) { + const { service } = this; + + const req = new pb.DeleteCompletedTrainingJobRequest(); + req.setId(id); + + await promisify< + pb.DeleteCompletedTrainingJobRequest, + pb.DeleteCompletedTrainingJobResponse + >(service.deleteCompletedTrainingJob.bind(service), req); + return null; + } +} diff --git a/src/app/provisioning-client.test.ts b/src/app/provisioning-client.test.ts new file mode 100644 index 000000000..ac42dddbc --- /dev/null +++ b/src/app/provisioning-client.test.ts @@ -0,0 +1,102 @@ +// @vitest-environment happy-dom + +import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; +import { beforeEach, expect, it, vi } from 'vitest'; +import { + CloudConfig, + SetNetworkCredentialsRequest, + SetSmartMachineCredentialsRequest, +} from '../gen/provisioning/v1/provisioning_pb'; +import { ProvisioningServiceClient } from '../gen/provisioning/v1/provisioning_pb_service'; +import { ProvisioningClient } from './provisioning-client'; + +const subject = () => + new ProvisioningClient('fakeServiceHost', { + transport: new FakeTransportBuilder().build(), + }); + +const testProvisioningInfo = { + fragmentId: 'id', + model: 'model', + manufacturer: 'manufacturer', +}; +const testNetworkInfo = { + type: 'type', + ssid: 'ssid', + security: 'security', + signal: 999, + connected: 'true', + lastError: 'last error', +}; +const testSmartMachineStatus = { + provisioningInfo: testProvisioningInfo, + hasSmartMachineCredentials: true, + isOnline: true, + latestConnectionAttempt: testNetworkInfo, + errorsList: ['error', 'err'], +}; +const type = 'type'; +const ssid = 'ssid'; +const psk = 'psk'; +const cloud = new CloudConfig(); +cloud.setId('id'); +cloud.setSecret('secret'); +cloud.setAppAddress('app_address'); + +beforeEach(() => { + ProvisioningServiceClient.prototype.getSmartMachineStatus = vi + .fn() + .mockImplementation((_req, _md, cb) => { + cb(null, { + toObject: () => testSmartMachineStatus, + }); + }); + + ProvisioningServiceClient.prototype.getNetworkList = vi + .fn() + .mockImplementation((_req, _md, cb) => { + cb(null, { + toObject: () => ({ networksList: [testNetworkInfo] }), + }); + }); + + ProvisioningServiceClient.prototype.setNetworkCredentials = vi + .fn() + .mockImplementation((req: SetNetworkCredentialsRequest, _md, cb) => { + expect(req.getType()).toStrictEqual(type); + expect(req.getSsid()).toStrictEqual(ssid); + expect(req.getPsk()).toStrictEqual(psk); + cb(null, {}); + }); + + ProvisioningServiceClient.prototype.setSmartMachineCredentials = vi + .fn() + .mockImplementation((req: SetSmartMachineCredentialsRequest, _md, cb) => { + expect(req.getCloud()).toStrictEqual(cloud); + cb(null, {}); + }); +}); + +it('getSmartMachineStatus', async () => { + await expect(subject().getSmartMachineStatus()).resolves.toStrictEqual( + testSmartMachineStatus + ); +}); + +it('getNetworkList', async () => { + await expect(subject().getNetworkList()).resolves.toStrictEqual([ + testNetworkInfo, + ]); +}); + +it('setNetworkCredentials', async () => { + await expect( + subject().setNetworkCredentials(type, ssid, psk) + ).resolves.toStrictEqual(undefined); +}); + +it('setSmartMachineCredentials', async () => { + await expect( + subject().setSmartMachineCredentials(cloud.toObject()) + ).resolves.toStrictEqual(undefined); +}); diff --git a/src/provisioning/client.ts b/src/app/provisioning-client.ts similarity index 58% rename from src/provisioning/client.ts rename to src/app/provisioning-client.ts index fae683c37..e693558a4 100644 --- a/src/provisioning/client.ts +++ b/src/app/provisioning-client.ts @@ -1,28 +1,33 @@ -import { type Options } from '../types'; +import { type RpcOptions } from '@improbable-eng/grpc-web/dist/typings/client.d'; import pb from '../gen/provisioning/v1/provisioning_pb'; import { ProvisioningServiceClient } from '../gen/provisioning/v1/provisioning_pb_service'; -import { RobotClient } from '../robot'; -import type { Provisioning } from './provisioning'; import { promisify } from '../utils'; -import { type CloudConfig, encodeCloudConfig } from './types'; -export class ProvisioningClient implements Provisioning { - private client: ProvisioningServiceClient; - private readonly options: Options; +export type CloudConfig = pb.CloudConfig.AsObject; - constructor(client: RobotClient, options: Options = {}) { - this.client = client.createServiceClient(ProvisioningServiceClient); - this.options = options; - } +const encodeCloudConfig = (obj: CloudConfig): pb.CloudConfig => { + const result = new pb.CloudConfig(); + result.setId(obj.id); + result.setSecret(obj.secret); + result.setAppAddress(obj.appAddress); + return result; +}; + +export class ProvisioningClient { + private service: ProvisioningServiceClient; - private get service() { - return this.client; + constructor(serviceHost: string, grpcOptions: RpcOptions = {}) { + this.service = new ProvisioningServiceClient(serviceHost, grpcOptions); } + /** + * Get the status of the Smart Machine. + * + * @returns The Smart Machine status + */ async getSmartMachineStatus() { const { service } = this; const request = new pb.GetSmartMachineStatusRequest(); - this.options.requestLogger?.(request); const response = await promisify< pb.GetSmartMachineStatusRequest, @@ -31,6 +36,14 @@ export class ProvisioningClient implements Provisioning { return response.toObject(); } + /** + * Set the network credentials of the Smart Machine, so it can connect to the + * internet. + * + * @param type - The type of network + * @param ssid - The SSID of the network + * @param psk - The network's passkey + */ async setNetworkCredentials(type: string, ssid: string, psk: string) { const { service } = this; const request = new pb.SetNetworkCredentialsRequest(); @@ -38,14 +51,18 @@ export class ProvisioningClient implements Provisioning { request.setSsid(ssid); request.setPsk(psk); - this.options.requestLogger?.(request); - await promisify< pb.SetNetworkCredentialsRequest, pb.SetNetworkCredentialsResponse >(service.setNetworkCredentials.bind(service), request); } + /** + * Set the Viam credentials of the smart machine credentials, so it connect to + * the Cloud. + * + * @param cloud - The configuration of the Cloud + */ async setSmartMachineCredentials(cloud?: CloudConfig) { const { service } = this; const request = new pb.SetSmartMachineCredentialsRequest(); @@ -53,20 +70,21 @@ export class ProvisioningClient implements Provisioning { request.setCloud(encodeCloudConfig(cloud)); } - this.options.requestLogger?.(request); - await promisify< pb.SetSmartMachineCredentialsRequest, pb.SetSmartMachineCredentialsResponse >(service.setSmartMachineCredentials.bind(service), request); } + /** + * Get the networks that are visible to the Smart Machine. + * + * @returns A list of networks + */ async getNetworkList() { const { service } = this; const request = new pb.GetNetworkListRequest(); - this.options.requestLogger?.(request); - const response = await promisify< pb.GetNetworkListRequest, pb.GetNetworkListResponse diff --git a/src/app/viam-client.test.ts b/src/app/viam-client.test.ts index 17a839361..dedd2b147 100644 --- a/src/app/viam-client.test.ts +++ b/src/app/viam-client.test.ts @@ -14,6 +14,8 @@ vi.mock('./viam-transport', () => { }); import { DataClient } from './data-client'; import { createViamClient, type ViamClientOptions } from './viam-client'; +import { MlTrainingClient } from './ml-training-client'; +import { ProvisioningClient } from './provisioning-client'; describe('ViamClient', () => { let options: ViamClientOptions | undefined; @@ -43,6 +45,8 @@ describe('ViamClient', () => { testCredential ); expect(client.dataClient).toBeInstanceOf(DataClient); + expect(client.mlTrainingClient).toBeInstanceOf(MlTrainingClient); + expect(client.provisioningClient).toBeInstanceOf(ProvisioningClient); }); it('create client with an api key credential and a custom service host', async () => { @@ -55,6 +59,8 @@ describe('ViamClient', () => { testCredential ); expect(client.dataClient).toBeInstanceOf(DataClient); + expect(client.mlTrainingClient).toBeInstanceOf(MlTrainingClient); + expect(client.provisioningClient).toBeInstanceOf(ProvisioningClient); }); it('create client with an access token', async () => { @@ -66,5 +72,7 @@ describe('ViamClient', () => { testAccessToken ); expect(client.dataClient).toBeInstanceOf(DataClient); + expect(client.mlTrainingClient).toBeInstanceOf(MlTrainingClient); + expect(client.provisioningClient).toBeInstanceOf(ProvisioningClient); }); }); diff --git a/src/app/viam-client.ts b/src/app/viam-client.ts index dc7dcd850..61bc23440 100644 --- a/src/app/viam-client.ts +++ b/src/app/viam-client.ts @@ -5,6 +5,8 @@ import { type AccessToken, } from './viam-transport'; import { DataClient } from './data-client'; +import { MlTrainingClient } from './ml-training-client'; +import { ProvisioningClient } from './provisioning-client'; export interface ViamClientOptions { serviceHost?: string; @@ -31,6 +33,8 @@ export class ViamClient { private serviceHost: string; public dataClient: DataClient | undefined; + public mlTrainingClient: MlTrainingClient | undefined; + public provisioningClient: ProvisioningClient | undefined; constructor(transportFactory: grpc.TransportFactory, serviceHost: string) { this.transportFactory = transportFactory; @@ -40,5 +44,10 @@ export class ViamClient { public connect() { const grpcOptions = { transport: this.transportFactory }; this.dataClient = new DataClient(this.serviceHost, grpcOptions); + this.mlTrainingClient = new MlTrainingClient(this.serviceHost, grpcOptions); + this.provisioningClient = new ProvisioningClient( + this.serviceHost, + grpcOptions + ); } } diff --git a/src/main.ts b/src/main.ts index c47552f6c..80500acf9 100644 --- a/src/main.ts +++ b/src/main.ts @@ -40,7 +40,7 @@ export type { } from './app/viam-transport'; /** - * Raw Protobuf interfaces for a Data service. + * Raw Protobuf interfaces for Data. * * Generated with https://github.com/improbable-eng/grpc-web * @@ -55,6 +55,37 @@ export { type FilterOptions, } from './app/data-client'; +/** + * Raw Protobuf interfaces for ML Training. + * + * Generated with https://github.com/improbable-eng/grpc-web + * + * @deprecated Use {@link MlTrainingClient} instead. + * @alpha + * @group Raw Protobufs + */ +export { default as mlTrainingApi } from './gen/app/mltraining/v1/ml_training_pb'; +export { + type MlTrainingClient, + ModelType, + TrainingStatus, +} from './app/ml-training-client'; + +/** + * Raw Protobuf interfaces for Provisioning. + * + * Generated with https://github.com/improbable-eng/grpc-web + * + * @deprecated Use {@link ProvisioningClient} instead. + * @alpha + * @group Raw Protobufs + */ +export { default as provisioningApi } from './gen/provisioning/v1/provisioning_pb'; +export { + type CloudConfig, + type ProvisioningClient, +} from './app/provisioning-client'; + /** * Raw Protobuf interfaces for an Arm component. * diff --git a/src/provisioning/client.test.ts b/src/provisioning/client.test.ts deleted file mode 100644 index b8c8a7faf..000000000 --- a/src/provisioning/client.test.ts +++ /dev/null @@ -1,67 +0,0 @@ -// @vitest-environment happy-dom - -import { beforeEach, expect, it, vi } from 'vitest'; -import { ProvisioningServiceClient } from '../gen/provisioning/v1/provisioning_pb_service'; -import { RobotClient } from '../robot'; -import { ProvisioningClient } from './client'; - -let provisioning: ProvisioningClient; - -const testProvisioningInfo = { - fragmentId: 'id', - model: 'model', - manufacturer: 'manufacturer', -}; -const testNetworkInfo = { - type: 'type', - ssid: 'ssid', - security: 'security', - signal: 999, - connected: 'true', - lastError: 'last error', -}; -const testSmartMachineStatus = { - provisioningInfo: testProvisioningInfo, - hasSmartMachineCredentials: true, - isOnline: true, - latestConnectionAttempt: testNetworkInfo, - errorsList: ['error', 'err'], -}; - -beforeEach(() => { - RobotClient.prototype.createServiceClient = vi - .fn() - .mockImplementation( - () => new ProvisioningServiceClient('test-provisioning') - ); - - ProvisioningServiceClient.prototype.getSmartMachineStatus = vi - .fn() - .mockImplementation((_req, _md, cb) => { - cb(null, { - toObject: () => testSmartMachineStatus, - }); - }); - - ProvisioningServiceClient.prototype.getNetworkList = vi - .fn() - .mockImplementation((_req, _md, cb) => { - cb(null, { - toObject: () => ({ networksList: [testNetworkInfo] }), - }); - }); - - provisioning = new ProvisioningClient(new RobotClient('host')); -}); - -it('getSmartMachineStatus', async () => { - await expect(provisioning.getSmartMachineStatus()).resolves.toStrictEqual( - testSmartMachineStatus - ); -}); - -it('getNetworkList', async () => { - await expect(provisioning.getNetworkList()).resolves.toStrictEqual([ - testNetworkInfo, - ]); -}); diff --git a/src/provisioning/provisioning.ts b/src/provisioning/provisioning.ts deleted file mode 100644 index c43f8c96e..000000000 --- a/src/provisioning/provisioning.ts +++ /dev/null @@ -1,31 +0,0 @@ -import type { CloudConfig, NetworkInfo, SmartMachineStatus } from './types'; - -export interface Provisioning { - /** Get the status of the Smart Machine */ - getSmartMachineStatus: () => Promise; - - /** - * Set the network credentials of the Smart Machine, so it can connect to the - * internet. - * - * @param type - The type of network. - * @param ssid - The SSID of the network. - * @param psk - The network's passkey. - */ - setNetworkCredentials: ( - type: string, - ssid: string, - psk: string - ) => Promise; - - /** - * Set the Viam credentials of the smart machine credentials, so it connect to - * the Cloud. - * - * @param cloud - The configuration of the Cloud. - */ - setSmartMachineCredentials: (cloud?: CloudConfig) => Promise; - - /** Get the networks that are visible to the Smart Machine. */ - getNetworkList: () => Promise; -} diff --git a/src/provisioning/types.ts b/src/provisioning/types.ts deleted file mode 100644 index b04fd266a..000000000 --- a/src/provisioning/types.ts +++ /dev/null @@ -1,15 +0,0 @@ -import pb from '../gen/provisioning/v1/provisioning_pb'; - -export type SmartMachineStatus = pb.GetSmartMachineStatusResponse.AsObject; -export type NetworkInfo = pb.NetworkInfo.AsObject; -export type CloudConfig = pb.CloudConfig.AsObject; - -export const encodeCloudConfig = ( - obj: pb.CloudConfig.AsObject -): pb.CloudConfig => { - const result = new pb.CloudConfig(); - result.setId(obj.id); - result.setSecret(obj.secret); - result.setAppAddress(obj.appAddress); - return result; -};