Skip to content

Commit

Permalink
RSDK-7200: add ml training wrapper (#272)
Browse files Browse the repository at this point in the history
Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: viambot <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 16, 2024
1 parent 3544470 commit fd37ff3
Show file tree
Hide file tree
Showing 11 changed files with 432 additions and 135 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ update-buf: $(node_modules)
.PHONY: build-buf
build-buf: $(node_modules) clean-buf
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/googleapis/googleapis)
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/api) --path common,component,robot,service,app,provisioning
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/api) --path common,component,robot,service,app,provisioning,tagger
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/erdaniels/gostream)
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/goutils)

Expand Down
146 changes: 146 additions & 0 deletions src/app/ml-training-client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import {
CancelTrainingJobRequest,
DeleteCompletedTrainingJobRequest,
GetTrainingJobRequest,
GetTrainingJobResponse,
ListTrainingJobsRequest,
ListTrainingJobsResponse,
ModelType,
SubmitTrainingJobRequest,
SubmitTrainingJobResponse,
TrainingJobMetadata,
TrainingStatus,
} from '../gen/app/mltraining/v1/ml_training_pb';
import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service';
vi.mock('../gen/app/mltraining/v1/ml_training_pb_service');
import { MlTrainingClient } from './ml-training-client';

const subject = () =>
new MlTrainingClient('fakeServiceHoost', {
transport: new FakeTransportBuilder().build(),
});

describe('MlTrainingClient tests', () => {
describe('submitTrainingJob tests', () => {
const type = ModelType.MODEL_TYPE_UNSPECIFIED;
beforeEach(() => {
vi.spyOn(MLTrainingServiceClient.prototype, 'submitTrainingJob')
// @ts-expect-error compiler is matching incorrect function signature
.mockImplementationOnce((_req: SubmitTrainingJobRequest, _md, cb) => {
const response = new SubmitTrainingJobResponse();
response.setId('fakeId');
cb(null, response);
});
});

it('submit job training job', async () => {
const response = await subject().submitTrainingJob(
'org_id',
'dataset_id',
'model_name',
'model_version',
type,
['tag1']
);
expect(response).toEqual('fakeId');
});
});

describe('getTrainingJob tests', () => {
const metadata: TrainingJobMetadata = new TrainingJobMetadata();
metadata.setId('id');
metadata.setDatasetId('dataset_id');
metadata.setOrganizationId('org_id');
metadata.setModelVersion('model_version');
metadata.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED);
metadata.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED);
metadata.setSyncedModelId('synced_model_id');

beforeEach(() => {
vi.spyOn(MLTrainingServiceClient.prototype, 'getTrainingJob')
// @ts-expect-error compiler is matching incorrect function signature
.mockImplementationOnce((_req: GetTrainingJobRequest, _md, cb) => {
const response = new GetTrainingJobResponse();
response.setMetadata(metadata);
cb(null, response);
});
});

it('get training job', async () => {
const response = await subject().getTrainingJob('id');
expect(response).toEqual(metadata);
});
});

describe('listTrainingJobs', () => {
const status = TrainingStatus.TRAINING_STATUS_UNSPECIFIED;
const md1 = new TrainingJobMetadata();
md1.setId('id1');
md1.setDatasetId('dataset_id1');
md1.setOrganizationId('org_id1');
md1.setModelVersion('model_version1');
md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED);
md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED);
md1.setSyncedModelId('synced_model_id1');
const md2 = new TrainingJobMetadata();
md1.setId('id2');
md1.setDatasetId('dataset_id2');
md1.setOrganizationId('org_id2');
md1.setModelVersion('model_version2');
md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED);
md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED);
md1.setSyncedModelId('synced_model_id2');
const jobs = [md1, md2];

beforeEach(() => {
vi.spyOn(MLTrainingServiceClient.prototype, 'listTrainingJobs')
// @ts-expect-error compiler is matching incorrect function signature
.mockImplementationOnce((_req: ListTrainingJobsRequest, _md, cb) => {
const response = new ListTrainingJobsResponse();
response.setJobsList(jobs);
cb(null, response);
});
});

it('list training jobs', async () => {
const response = await subject().listTrainingJobs('org_id', status);
expect(response).toEqual([md1.toObject(), md2.toObject()]);
});
});

describe('cancelTrainingJob tests', () => {
const id = 'id';
beforeEach(() => {
vi.spyOn(MLTrainingServiceClient.prototype, 'cancelTrainingJob')
// @ts-expect-error compiler is matching incorrect function signature
.mockImplementationOnce((req: CancelTrainingJobRequest, _md, cb) => {
expect(req.getId()).toStrictEqual(id);
cb(null, {});
});
});
it('cancel training job', async () => {
expect(await subject().cancelTrainingJob(id)).toStrictEqual(null);
});
});

describe('deleteCompletedTrainingJob tests', () => {
const id = 'id';
beforeEach(() => {
vi.spyOn(
MLTrainingServiceClient.prototype,
'deleteCompletedTrainingJob'
).mockImplementationOnce(
// @ts-expect-error compiler is matching incorrect function signature
(req: DeleteCompletedTrainingJobRequest, _md, cb) => {
expect(req.getId()).toStrictEqual(id);
cb(null, {});
}
);
});
it('delete completed training job', async () => {
expect(await subject().deleteCompletedTrainingJob(id)).toEqual(null);
});
});
});
96 changes: 96 additions & 0 deletions src/app/ml-training-client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { type RpcOptions } from '@improbable-eng/grpc-web/dist/typings/client.d';
import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service';
import pb from '../gen/app/mltraining/v1/ml_training_pb';
import { promisify } from '../utils';

type ValueOf<T> = T[keyof T];
export const { ModelType } = pb;
export type ModelType = ValueOf<typeof pb.ModelType>;
export const { TrainingStatus } = pb;
export type TrainingStatus = ValueOf<typeof pb.TrainingStatus>;

export class MlTrainingClient {
private service: MLTrainingServiceClient;

constructor(serviceHost: string, grpcOptions: RpcOptions) {
this.service = new MLTrainingServiceClient(serviceHost, grpcOptions);
}

async submitTrainingJob(
orgId: string,
datasetId: string,
modelName: string,
modelVersion: string,
modelType: ModelType,
tagsList: string[]
) {
const { service } = this;

const req = new pb.SubmitTrainingJobRequest();
req.setOrganizationId(orgId);
req.setDatasetId(datasetId);
req.setModelName(modelName);
req.setModelVersion(modelVersion);
req.setModelType(modelType);
req.setTagsList(tagsList);

const response = await promisify<
pb.SubmitTrainingJobRequest,
pb.SubmitTrainingJobResponse
>(service.submitTrainingJob.bind(service), req);
return response.getId();
}

async getTrainingJob(id: string) {
const { service } = this;

const req = new pb.GetTrainingJobRequest();
req.setId(id);

const response = await promisify<
pb.GetTrainingJobRequest,
pb.GetTrainingJobResponse
>(service.getTrainingJob.bind(service), req);
return response.getMetadata();
}

async listTrainingJobs(orgId: string, status: TrainingStatus) {
const { service } = this;

const req = new pb.ListTrainingJobsRequest();
req.setOrganizationId(orgId);
req.setStatus(status);

const response = await promisify<
pb.ListTrainingJobsRequest,
pb.ListTrainingJobsResponse
>(service.listTrainingJobs.bind(service), req);
return response.toObject().jobsList;
}

async cancelTrainingJob(id: string) {
const { service } = this;

const req = new pb.CancelTrainingJobRequest();
req.setId(id);

await promisify<pb.CancelTrainingJobRequest, pb.CancelTrainingJobResponse>(
service.cancelTrainingJob.bind(service),
req
);
return null;
}

async deleteCompletedTrainingJob(id: string) {
const { service } = this;

const req = new pb.DeleteCompletedTrainingJobRequest();
req.setId(id);

await promisify<
pb.DeleteCompletedTrainingJobRequest,
pb.DeleteCompletedTrainingJobResponse
>(service.deleteCompletedTrainingJob.bind(service), req);
return null;
}
}
102 changes: 102 additions & 0 deletions src/app/provisioning-client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// @vitest-environment happy-dom

import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport';
import { beforeEach, expect, it, vi } from 'vitest';
import {
CloudConfig,
SetNetworkCredentialsRequest,
SetSmartMachineCredentialsRequest,
} from '../gen/provisioning/v1/provisioning_pb';
import { ProvisioningServiceClient } from '../gen/provisioning/v1/provisioning_pb_service';
import { ProvisioningClient } from './provisioning-client';

const subject = () =>
new ProvisioningClient('fakeServiceHost', {
transport: new FakeTransportBuilder().build(),
});

const testProvisioningInfo = {
fragmentId: 'id',
model: 'model',
manufacturer: 'manufacturer',
};
const testNetworkInfo = {
type: 'type',
ssid: 'ssid',
security: 'security',
signal: 999,
connected: 'true',
lastError: 'last error',
};
const testSmartMachineStatus = {
provisioningInfo: testProvisioningInfo,
hasSmartMachineCredentials: true,
isOnline: true,
latestConnectionAttempt: testNetworkInfo,
errorsList: ['error', 'err'],
};
const type = 'type';
const ssid = 'ssid';
const psk = 'psk';
const cloud = new CloudConfig();
cloud.setId('id');
cloud.setSecret('secret');
cloud.setAppAddress('app_address');

beforeEach(() => {
ProvisioningServiceClient.prototype.getSmartMachineStatus = vi
.fn()
.mockImplementation((_req, _md, cb) => {
cb(null, {
toObject: () => testSmartMachineStatus,
});
});

ProvisioningServiceClient.prototype.getNetworkList = vi
.fn()
.mockImplementation((_req, _md, cb) => {
cb(null, {
toObject: () => ({ networksList: [testNetworkInfo] }),
});
});

ProvisioningServiceClient.prototype.setNetworkCredentials = vi
.fn()
.mockImplementation((req: SetNetworkCredentialsRequest, _md, cb) => {
expect(req.getType()).toStrictEqual(type);
expect(req.getSsid()).toStrictEqual(ssid);
expect(req.getPsk()).toStrictEqual(psk);
cb(null, {});
});

ProvisioningServiceClient.prototype.setSmartMachineCredentials = vi
.fn()
.mockImplementation((req: SetSmartMachineCredentialsRequest, _md, cb) => {
expect(req.getCloud()).toStrictEqual(cloud);
cb(null, {});
});
});

it('getSmartMachineStatus', async () => {
await expect(subject().getSmartMachineStatus()).resolves.toStrictEqual(
testSmartMachineStatus
);
});

it('getNetworkList', async () => {
await expect(subject().getNetworkList()).resolves.toStrictEqual([
testNetworkInfo,
]);
});

it('setNetworkCredentials', async () => {
await expect(
subject().setNetworkCredentials(type, ssid, psk)
).resolves.toStrictEqual(undefined);
});

it('setSmartMachineCredentials', async () => {
await expect(
subject().setSmartMachineCredentials(cloud.toObject())
).resolves.toStrictEqual(undefined);
});
Loading

0 comments on commit fd37ff3

Please sign in to comment.