-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RSDK-7200: add ml training wrapper (#272)
Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: viambot <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
- Loading branch information
1 parent
3544470
commit fd37ff3
Showing
11 changed files
with
432 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; | ||
import { beforeEach, describe, expect, it, vi } from 'vitest'; | ||
import { | ||
CancelTrainingJobRequest, | ||
DeleteCompletedTrainingJobRequest, | ||
GetTrainingJobRequest, | ||
GetTrainingJobResponse, | ||
ListTrainingJobsRequest, | ||
ListTrainingJobsResponse, | ||
ModelType, | ||
SubmitTrainingJobRequest, | ||
SubmitTrainingJobResponse, | ||
TrainingJobMetadata, | ||
TrainingStatus, | ||
} from '../gen/app/mltraining/v1/ml_training_pb'; | ||
import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service'; | ||
vi.mock('../gen/app/mltraining/v1/ml_training_pb_service'); | ||
import { MlTrainingClient } from './ml-training-client'; | ||
|
||
const subject = () => | ||
new MlTrainingClient('fakeServiceHoost', { | ||
transport: new FakeTransportBuilder().build(), | ||
}); | ||
|
||
describe('MlTrainingClient tests', () => { | ||
describe('submitTrainingJob tests', () => { | ||
const type = ModelType.MODEL_TYPE_UNSPECIFIED; | ||
beforeEach(() => { | ||
vi.spyOn(MLTrainingServiceClient.prototype, 'submitTrainingJob') | ||
// @ts-expect-error compiler is matching incorrect function signature | ||
.mockImplementationOnce((_req: SubmitTrainingJobRequest, _md, cb) => { | ||
const response = new SubmitTrainingJobResponse(); | ||
response.setId('fakeId'); | ||
cb(null, response); | ||
}); | ||
}); | ||
|
||
it('submit job training job', async () => { | ||
const response = await subject().submitTrainingJob( | ||
'org_id', | ||
'dataset_id', | ||
'model_name', | ||
'model_version', | ||
type, | ||
['tag1'] | ||
); | ||
expect(response).toEqual('fakeId'); | ||
}); | ||
}); | ||
|
||
describe('getTrainingJob tests', () => { | ||
const metadata: TrainingJobMetadata = new TrainingJobMetadata(); | ||
metadata.setId('id'); | ||
metadata.setDatasetId('dataset_id'); | ||
metadata.setOrganizationId('org_id'); | ||
metadata.setModelVersion('model_version'); | ||
metadata.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); | ||
metadata.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); | ||
metadata.setSyncedModelId('synced_model_id'); | ||
|
||
beforeEach(() => { | ||
vi.spyOn(MLTrainingServiceClient.prototype, 'getTrainingJob') | ||
// @ts-expect-error compiler is matching incorrect function signature | ||
.mockImplementationOnce((_req: GetTrainingJobRequest, _md, cb) => { | ||
const response = new GetTrainingJobResponse(); | ||
response.setMetadata(metadata); | ||
cb(null, response); | ||
}); | ||
}); | ||
|
||
it('get training job', async () => { | ||
const response = await subject().getTrainingJob('id'); | ||
expect(response).toEqual(metadata); | ||
}); | ||
}); | ||
|
||
describe('listTrainingJobs', () => { | ||
const status = TrainingStatus.TRAINING_STATUS_UNSPECIFIED; | ||
const md1 = new TrainingJobMetadata(); | ||
md1.setId('id1'); | ||
md1.setDatasetId('dataset_id1'); | ||
md1.setOrganizationId('org_id1'); | ||
md1.setModelVersion('model_version1'); | ||
md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); | ||
md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); | ||
md1.setSyncedModelId('synced_model_id1'); | ||
const md2 = new TrainingJobMetadata(); | ||
md1.setId('id2'); | ||
md1.setDatasetId('dataset_id2'); | ||
md1.setOrganizationId('org_id2'); | ||
md1.setModelVersion('model_version2'); | ||
md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); | ||
md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); | ||
md1.setSyncedModelId('synced_model_id2'); | ||
const jobs = [md1, md2]; | ||
|
||
beforeEach(() => { | ||
vi.spyOn(MLTrainingServiceClient.prototype, 'listTrainingJobs') | ||
// @ts-expect-error compiler is matching incorrect function signature | ||
.mockImplementationOnce((_req: ListTrainingJobsRequest, _md, cb) => { | ||
const response = new ListTrainingJobsResponse(); | ||
response.setJobsList(jobs); | ||
cb(null, response); | ||
}); | ||
}); | ||
|
||
it('list training jobs', async () => { | ||
const response = await subject().listTrainingJobs('org_id', status); | ||
expect(response).toEqual([md1.toObject(), md2.toObject()]); | ||
}); | ||
}); | ||
|
||
describe('cancelTrainingJob tests', () => { | ||
const id = 'id'; | ||
beforeEach(() => { | ||
vi.spyOn(MLTrainingServiceClient.prototype, 'cancelTrainingJob') | ||
// @ts-expect-error compiler is matching incorrect function signature | ||
.mockImplementationOnce((req: CancelTrainingJobRequest, _md, cb) => { | ||
expect(req.getId()).toStrictEqual(id); | ||
cb(null, {}); | ||
}); | ||
}); | ||
it('cancel training job', async () => { | ||
expect(await subject().cancelTrainingJob(id)).toStrictEqual(null); | ||
}); | ||
}); | ||
|
||
describe('deleteCompletedTrainingJob tests', () => { | ||
const id = 'id'; | ||
beforeEach(() => { | ||
vi.spyOn( | ||
MLTrainingServiceClient.prototype, | ||
'deleteCompletedTrainingJob' | ||
).mockImplementationOnce( | ||
// @ts-expect-error compiler is matching incorrect function signature | ||
(req: DeleteCompletedTrainingJobRequest, _md, cb) => { | ||
expect(req.getId()).toStrictEqual(id); | ||
cb(null, {}); | ||
} | ||
); | ||
}); | ||
it('delete completed training job', async () => { | ||
expect(await subject().deleteCompletedTrainingJob(id)).toEqual(null); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import { type RpcOptions } from '@improbable-eng/grpc-web/dist/typings/client.d'; | ||
import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service'; | ||
import pb from '../gen/app/mltraining/v1/ml_training_pb'; | ||
import { promisify } from '../utils'; | ||
|
||
type ValueOf<T> = T[keyof T]; | ||
export const { ModelType } = pb; | ||
export type ModelType = ValueOf<typeof pb.ModelType>; | ||
export const { TrainingStatus } = pb; | ||
export type TrainingStatus = ValueOf<typeof pb.TrainingStatus>; | ||
|
||
export class MlTrainingClient { | ||
private service: MLTrainingServiceClient; | ||
|
||
constructor(serviceHost: string, grpcOptions: RpcOptions) { | ||
this.service = new MLTrainingServiceClient(serviceHost, grpcOptions); | ||
} | ||
|
||
async submitTrainingJob( | ||
orgId: string, | ||
datasetId: string, | ||
modelName: string, | ||
modelVersion: string, | ||
modelType: ModelType, | ||
tagsList: string[] | ||
) { | ||
const { service } = this; | ||
|
||
const req = new pb.SubmitTrainingJobRequest(); | ||
req.setOrganizationId(orgId); | ||
req.setDatasetId(datasetId); | ||
req.setModelName(modelName); | ||
req.setModelVersion(modelVersion); | ||
req.setModelType(modelType); | ||
req.setTagsList(tagsList); | ||
|
||
const response = await promisify< | ||
pb.SubmitTrainingJobRequest, | ||
pb.SubmitTrainingJobResponse | ||
>(service.submitTrainingJob.bind(service), req); | ||
return response.getId(); | ||
} | ||
|
||
async getTrainingJob(id: string) { | ||
const { service } = this; | ||
|
||
const req = new pb.GetTrainingJobRequest(); | ||
req.setId(id); | ||
|
||
const response = await promisify< | ||
pb.GetTrainingJobRequest, | ||
pb.GetTrainingJobResponse | ||
>(service.getTrainingJob.bind(service), req); | ||
return response.getMetadata(); | ||
} | ||
|
||
async listTrainingJobs(orgId: string, status: TrainingStatus) { | ||
const { service } = this; | ||
|
||
const req = new pb.ListTrainingJobsRequest(); | ||
req.setOrganizationId(orgId); | ||
req.setStatus(status); | ||
|
||
const response = await promisify< | ||
pb.ListTrainingJobsRequest, | ||
pb.ListTrainingJobsResponse | ||
>(service.listTrainingJobs.bind(service), req); | ||
return response.toObject().jobsList; | ||
} | ||
|
||
async cancelTrainingJob(id: string) { | ||
const { service } = this; | ||
|
||
const req = new pb.CancelTrainingJobRequest(); | ||
req.setId(id); | ||
|
||
await promisify<pb.CancelTrainingJobRequest, pb.CancelTrainingJobResponse>( | ||
service.cancelTrainingJob.bind(service), | ||
req | ||
); | ||
return null; | ||
} | ||
|
||
async deleteCompletedTrainingJob(id: string) { | ||
const { service } = this; | ||
|
||
const req = new pb.DeleteCompletedTrainingJobRequest(); | ||
req.setId(id); | ||
|
||
await promisify< | ||
pb.DeleteCompletedTrainingJobRequest, | ||
pb.DeleteCompletedTrainingJobResponse | ||
>(service.deleteCompletedTrainingJob.bind(service), req); | ||
return null; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// @vitest-environment happy-dom | ||
|
||
import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; | ||
import { beforeEach, expect, it, vi } from 'vitest'; | ||
import { | ||
CloudConfig, | ||
SetNetworkCredentialsRequest, | ||
SetSmartMachineCredentialsRequest, | ||
} from '../gen/provisioning/v1/provisioning_pb'; | ||
import { ProvisioningServiceClient } from '../gen/provisioning/v1/provisioning_pb_service'; | ||
import { ProvisioningClient } from './provisioning-client'; | ||
|
||
const subject = () => | ||
new ProvisioningClient('fakeServiceHost', { | ||
transport: new FakeTransportBuilder().build(), | ||
}); | ||
|
||
const testProvisioningInfo = { | ||
fragmentId: 'id', | ||
model: 'model', | ||
manufacturer: 'manufacturer', | ||
}; | ||
const testNetworkInfo = { | ||
type: 'type', | ||
ssid: 'ssid', | ||
security: 'security', | ||
signal: 999, | ||
connected: 'true', | ||
lastError: 'last error', | ||
}; | ||
const testSmartMachineStatus = { | ||
provisioningInfo: testProvisioningInfo, | ||
hasSmartMachineCredentials: true, | ||
isOnline: true, | ||
latestConnectionAttempt: testNetworkInfo, | ||
errorsList: ['error', 'err'], | ||
}; | ||
const type = 'type'; | ||
const ssid = 'ssid'; | ||
const psk = 'psk'; | ||
const cloud = new CloudConfig(); | ||
cloud.setId('id'); | ||
cloud.setSecret('secret'); | ||
cloud.setAppAddress('app_address'); | ||
|
||
beforeEach(() => { | ||
ProvisioningServiceClient.prototype.getSmartMachineStatus = vi | ||
.fn() | ||
.mockImplementation((_req, _md, cb) => { | ||
cb(null, { | ||
toObject: () => testSmartMachineStatus, | ||
}); | ||
}); | ||
|
||
ProvisioningServiceClient.prototype.getNetworkList = vi | ||
.fn() | ||
.mockImplementation((_req, _md, cb) => { | ||
cb(null, { | ||
toObject: () => ({ networksList: [testNetworkInfo] }), | ||
}); | ||
}); | ||
|
||
ProvisioningServiceClient.prototype.setNetworkCredentials = vi | ||
.fn() | ||
.mockImplementation((req: SetNetworkCredentialsRequest, _md, cb) => { | ||
expect(req.getType()).toStrictEqual(type); | ||
expect(req.getSsid()).toStrictEqual(ssid); | ||
expect(req.getPsk()).toStrictEqual(psk); | ||
cb(null, {}); | ||
}); | ||
|
||
ProvisioningServiceClient.prototype.setSmartMachineCredentials = vi | ||
.fn() | ||
.mockImplementation((req: SetSmartMachineCredentialsRequest, _md, cb) => { | ||
expect(req.getCloud()).toStrictEqual(cloud); | ||
cb(null, {}); | ||
}); | ||
}); | ||
|
||
it('getSmartMachineStatus', async () => { | ||
await expect(subject().getSmartMachineStatus()).resolves.toStrictEqual( | ||
testSmartMachineStatus | ||
); | ||
}); | ||
|
||
it('getNetworkList', async () => { | ||
await expect(subject().getNetworkList()).resolves.toStrictEqual([ | ||
testNetworkInfo, | ||
]); | ||
}); | ||
|
||
it('setNetworkCredentials', async () => { | ||
await expect( | ||
subject().setNetworkCredentials(type, ssid, psk) | ||
).resolves.toStrictEqual(undefined); | ||
}); | ||
|
||
it('setSmartMachineCredentials', async () => { | ||
await expect( | ||
subject().setSmartMachineCredentials(cloud.toObject()) | ||
).resolves.toStrictEqual(undefined); | ||
}); |
Oops, something went wrong.