From e3b27a0750b7ca29ca9b4babf3b5bbb4b33484d7 Mon Sep 17 00:00:00 2001 From: Juntao Wang <37624318+DaoDaoNoCode@users.noreply.github.com> Date: Thu, 27 Feb 2025 10:50:53 -0500 Subject: [PATCH] Add teacher/judge section to ilab form (#3777) * Add teacher/judge section to ilab form * Resolve conflicts, address feedback * update zod validation for URL * Address comments * address feedback * change run name --- .../pages/pipelines/modelCustomizationForm.ts | 22 ++ .../pipelines/modelCustomizationForm.cy.ts | 10 +- frontend/src/api/k8s/secrets.ts | 18 ++ .../content/createRun/submitUtils.ts | 11 +- .../__tests__/validationUtils.spec.ts | 48 ++++ .../validationUtils.ts | 58 ++-- .../modelCustomizationForm/useIlabPipeline.ts | 13 +- .../modelCustomization/FineTunePage.tsx | 27 +- .../modelCustomization/FineTunePageFooter.tsx | 247 ++++++++++++------ .../ModelCustomizationForm.tsx | 26 +- .../global/modelCustomization/const.ts | 4 + .../teacherJudgeSection/JudgeModelSection.tsx | 95 +++++++ .../TeacherJudgeInputComponents.tsx | 103 ++++++++ .../TeacherModelSection.tsx | 95 +++++++ .../global/modelCustomization/utils.ts | 62 +++++ 15 files changed, 719 insertions(+), 120 deletions(-) create mode 100644 frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/__tests__/validationUtils.spec.ts create mode 100644 frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/JudgeModelSection.tsx create mode 100644 frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherJudgeInputComponents.tsx create mode 100644 frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherModelSection.tsx create mode 100644 frontend/src/pages/pipelines/global/modelCustomization/utils.ts diff --git a/frontend/src/__tests__/cypress/cypress/pages/pipelines/modelCustomizationForm.ts b/frontend/src/__tests__/cypress/cypress/pages/pipelines/modelCustomizationForm.ts index 7c709bc6b3..ed647bedd6 100644 --- a/frontend/src/__tests__/cypress/cypress/pages/pipelines/modelCustomizationForm.ts +++ b/frontend/src/__tests__/cypress/cypress/pages/pipelines/modelCustomizationForm.ts @@ -40,4 +40,26 @@ class ModelCustomizationFormGlobal { } } +class TeacherModelSection { + findEndpointInput() { + return cy.findByTestId('teacher-endpoint-input'); + } + + findModelNameInput() { + return cy.findByTestId('teacher-model-name-input'); + } +} + +class JudgeModelSection { + findEndpointInput() { + return cy.findByTestId('judge-endpoint-input'); + } + + findModelNameInput() { + return cy.findByTestId('judge-model-name-input'); + } +} + export const modelCustomizationFormGlobal = new ModelCustomizationFormGlobal(); +export const teacherModelSection = new TeacherModelSection(); +export const judgeModelSection = new JudgeModelSection(); diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/modelCustomizationForm.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/modelCustomizationForm.cy.ts index 4a17d5a06b..3b6bf3afc8 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/modelCustomizationForm.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/pipelines/modelCustomizationForm.cy.ts @@ -1,5 +1,9 @@ /* eslint-disable camelcase */ -import { modelCustomizationFormGlobal } from '~/__tests__/cypress/cypress/pages/pipelines/modelCustomizationForm'; +import { + judgeModelSection, + modelCustomizationFormGlobal, + teacherModelSection, +} from '~/__tests__/cypress/cypress/pages/pipelines/modelCustomizationForm'; import { buildMockPipeline, buildMockPipelines, @@ -46,6 +50,10 @@ describe('Model Customization Form', () => { modelCustomizationFormGlobal.visit(projectName); cy.wait('@getAllPipelines'); cy.wait('@getAllPipelineVersions'); + teacherModelSection.findEndpointInput().type('http://test.com'); + teacherModelSection.findModelNameInput().type('test'); + judgeModelSection.findEndpointInput().type('http://test.com'); + judgeModelSection.findModelNameInput().type('test'); modelCustomizationFormGlobal.findSubmitButton().should('not.be.disabled'); }); it('Should not submit', () => { diff --git a/frontend/src/api/k8s/secrets.ts b/frontend/src/api/k8s/secrets.ts index c6e9ea17c9..b5992aaf69 100644 --- a/frontend/src/api/k8s/secrets.ts +++ b/frontend/src/api/k8s/secrets.ts @@ -55,6 +55,24 @@ export const assembleSecret = ( }; }; +export const assembleSecretTeacher = ( + projectName: string, + data: Record, + secretName?: string, +): SecretKind => { + const k8sName = secretName || `teacher-secret-${genRandomChars()}`; + return assembleSecret(projectName, data, 'generic', k8sName); +}; + +export const assembleSecretJudge = ( + projectName: string, + data: Record, + secretName?: string, +): SecretKind => { + const k8sName = secretName || `judge-secret-${genRandomChars()}`; + return assembleSecret(projectName, data, 'generic', k8sName); +}; + export const assembleISSecretBody = ( assignableData: Record, ): [Record, string] => { diff --git a/frontend/src/concepts/pipelines/content/createRun/submitUtils.ts b/frontend/src/concepts/pipelines/content/createRun/submitUtils.ts index 13d6b95338..c5f6ccbcb1 100644 --- a/frontend/src/concepts/pipelines/content/createRun/submitUtils.ts +++ b/frontend/src/concepts/pipelines/content/createRun/submitUtils.ts @@ -27,6 +27,7 @@ import { convertPeriodicTimeToSeconds, convertToDate } from '~/utilities/time'; const createRun = async ( formData: SafeRunFormData, createPipelineRun: PipelineAPIs['createPipelineRun'], + dryRun?: boolean, ): Promise => { /* eslint-disable camelcase */ const data: CreatePipelineRunKFData = { @@ -44,7 +45,7 @@ const createRun = async ( }; /* eslint-enable camelcase */ - return createPipelineRun({}, data); + return createPipelineRun({ dryRun }, data); }; export const convertDateDataToKFDateTime = (dateData?: RunDateTime): DateTimeKF | null => { @@ -58,6 +59,7 @@ export const convertDateDataToKFDateTime = (dateData?: RunDateTime): DateTimeKF const createRecurringRun = async ( formData: SafeRunFormData, createPipelineRecurringRun: PipelineAPIs['createPipelineRecurringRun'], + dryRun?: boolean, ): Promise => { if (formData.runType.type !== RunTypeOption.SCHEDULED) { return Promise.reject(new Error('Cannot create a schedule with incomplete data.')); @@ -109,13 +111,14 @@ const createRecurringRun = async ( }; /* eslint-enable camelcase */ - return createPipelineRecurringRun({}, data); + return createPipelineRecurringRun({ dryRun }, data); }; /** Returns the relative path to navigate to from the namespace qualified route */ export const handleSubmit = ( formData: RunFormData, api: PipelineAPIs, + dryRun?: boolean, ): Promise => { if (!isFilledRunFormData(formData)) { throw new Error('Form data was incomplete.'); @@ -123,9 +126,9 @@ export const handleSubmit = ( switch (formData.runType.type) { case RunTypeOption.ONE_TRIGGER: - return createRun(formData, api.createPipelineRun); + return createRun(formData, api.createPipelineRun, dryRun); case RunTypeOption.SCHEDULED: - return createRecurringRun(formData, api.createPipelineRecurringRun); + return createRecurringRun(formData, api.createPipelineRecurringRun, dryRun); default: // eslint-disable-next-line no-console console.error('Unknown run type', formData.runType); diff --git a/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/__tests__/validationUtils.spec.ts b/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/__tests__/validationUtils.spec.ts new file mode 100644 index 0000000000..059baa45ec --- /dev/null +++ b/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/__tests__/validationUtils.spec.ts @@ -0,0 +1,48 @@ +import { ModelCustomizationEndpointType } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/types'; +import { + TeacherJudgeFormData, + teacherJudgeModel, +} from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils'; + +describe('TeacherJudgeSchema', () => { + it('should validate when it is public without token', () => { + const field: TeacherJudgeFormData = { + endpointType: ModelCustomizationEndpointType.PUBLIC, + apiToken: '', + modelName: 'test', + endpoint: 'http://test.com', + }; + const result = teacherJudgeModel.safeParse(field); + expect(result.success).toBe(true); + }); + it('should error when it is private without token', () => { + const field: TeacherJudgeFormData = { + endpointType: ModelCustomizationEndpointType.PRIVATE, + apiToken: '', + modelName: 'test', + endpoint: 'http://test.com', + }; + const result = teacherJudgeModel.safeParse(field); + expect(result.success).toBe(false); + }); + it('should validate when it is private with token', () => { + const field: TeacherJudgeFormData = { + endpointType: ModelCustomizationEndpointType.PRIVATE, + apiToken: 'test', + modelName: 'test', + endpoint: 'http://test.com', + }; + const result = teacherJudgeModel.safeParse(field); + expect(result.success).toBe(true); + }); + it('should error when the endpoint is not a uri', () => { + const field: TeacherJudgeFormData = { + endpointType: ModelCustomizationEndpointType.PRIVATE, + apiToken: 'test', + modelName: 'test', + endpoint: 'not a uri', + }; + const result = teacherJudgeModel.safeParse(field); + expect(result.success).toBe(false); + }); +}); diff --git a/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils.ts b/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils.ts index d7ba134be7..38a6296831 100644 --- a/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils.ts +++ b/frontend/src/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils.ts @@ -1,36 +1,47 @@ import { z } from 'zod'; import { ModelCustomizationEndpointType, ModelCustomizationRunType } from './types'; -export const uriFieldSchema = z.string().refine( - (value) => { - if (!value) { - return true; - } - try { - return !!new URL(value); - } catch (e) { - return false; - } - }, - { message: 'Invalid URI' }, -); +export const uriFieldSchemaBase = ( + isOptional: boolean, +): z.ZodEffects => + z.string().refine( + (value) => { + if (!value) { + return !!isOptional; + } + try { + return !!new URL(value); + } catch (e) { + return false; + } + }, + { message: 'Invalid URI' }, + ); export const baseModelSchema = z.object({ registryName: z.string(), name: z.string(), version: z.string(), - inputStorageLocationUri: uriFieldSchema, + inputStorageLocationUri: uriFieldSchemaBase(true), }); -export const teacherJudgeModel = z.object({ - endpointType: z.enum([ - ModelCustomizationEndpointType.PUBLIC, - ModelCustomizationEndpointType.PRIVATE, - ]), - endpoint: uriFieldSchema, - username: z.string().min(1, 'Username is required'), - password: z.string().min(1, 'Password is required'), +const teacherJudgeBaseSchema = z.object({ + endpoint: uriFieldSchemaBase(false), + modelName: z.string().trim().min(1, 'Model name is required'), +}); +const teacherJudgePublicSchema = teacherJudgeBaseSchema.extend({ + endpointType: z.literal(ModelCustomizationEndpointType.PUBLIC), + apiToken: z.string(), }); +const teacherJudgePrivateSchema = teacherJudgeBaseSchema.extend({ + endpointType: z.literal(ModelCustomizationEndpointType.PRIVATE), + apiToken: z.string().trim().min(1, 'Token is required'), +}); + +export const teacherJudgeModel = z.discriminatedUnion('endpointType', [ + teacherJudgePrivateSchema, + teacherJudgePublicSchema, +]); export const numericFieldSchema = z .object({ @@ -113,8 +124,11 @@ export const fineTunedModelDetailsSchema = z.object({ export const modelCustomizationFormSchema = z.object({ projectName: z.object({ value: z.string().min(1, { message: 'Project is required' }) }), baseModel: baseModelSchema, + teacher: teacherJudgeModel, + judge: teacherJudgeModel, }); export type ModelCustomizationFormData = z.infer; export type BaseModelFormData = z.infer; +export type TeacherJudgeFormData = z.infer; diff --git a/frontend/src/concepts/pipelines/content/modelCustomizationForm/useIlabPipeline.ts b/frontend/src/concepts/pipelines/content/modelCustomizationForm/useIlabPipeline.ts index 7759754470..863dfdd740 100644 --- a/frontend/src/concepts/pipelines/content/modelCustomizationForm/useIlabPipeline.ts +++ b/frontend/src/concepts/pipelines/content/modelCustomizationForm/useIlabPipeline.ts @@ -1,11 +1,16 @@ import React from 'react'; import { ILAB_PIPELINE_NAME } from '~/pages/pipelines/global/modelCustomization/const'; -import { FetchState } from '~/utilities/useFetchState'; import { useLatestPipelineVersion } from '~/concepts/pipelines/apiHooks/useLatestPipelineVersion'; import { usePipelineByName } from '~/concepts/pipelines/apiHooks/usePipelineByName'; -import { PipelineVersionKF } from '~/concepts/pipelines/kfTypes'; +import { PipelineKF, PipelineVersionKF } from '~/concepts/pipelines/kfTypes'; -export const useIlabPipeline = (): FetchState => { +export const useIlabPipeline = (): { + ilabPipeline: PipelineKF | null; + ilabPipelineVersion: PipelineVersionKF | null; + loaded: boolean; + loadError: Error | undefined; + refresh: () => Promise; +} => { const [ilabPipeline, ilabPipelineLoaded, ilabPipelineLoadError, refreshIlabPipeline] = usePipelineByName(ILAB_PIPELINE_NAME); const [ @@ -21,5 +26,5 @@ export const useIlabPipeline = (): FetchState => { return refreshIlabPipelineVersion(); }, [refreshIlabPipeline, refreshIlabPipelineVersion]); - return [ilabPipelineVersion, loaded, loadError, refresh]; + return { ilabPipeline, ilabPipelineVersion, loaded, loadError, refresh }; }; diff --git a/frontend/src/pages/pipelines/global/modelCustomization/FineTunePage.tsx b/frontend/src/pages/pipelines/global/modelCustomization/FineTunePage.tsx index 89e656a3a9..e428518025 100644 --- a/frontend/src/pages/pipelines/global/modelCustomization/FineTunePage.tsx +++ b/frontend/src/pages/pipelines/global/modelCustomization/FineTunePage.tsx @@ -10,15 +10,27 @@ import { ModelCustomizationFormData } from '~/concepts/pipelines/content/modelCu import { UpdateObjectAtPropAndValue } from '~/pages/projects/types'; import FineTunePageFooter from '~/pages/pipelines/global/modelCustomization/FineTunePageFooter'; import BaseModelSection from '~/pages/pipelines/global/modelCustomization/baseModelSection/BaseModelSection'; +import TeacherModelSection from '~/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherModelSection'; +import JudgeModelSection from '~/pages/pipelines/global/modelCustomization/teacherJudgeSection/JudgeModelSection'; +import { PipelineKF, PipelineVersionKF } from '~/concepts/pipelines/kfTypes'; type FineTunePageProps = { isInvalid: boolean; onSuccess: () => void; data: ModelCustomizationFormData; setData: UpdateObjectAtPropAndValue; + ilabPipeline: PipelineKF | null; + ilabPipelineVersion: PipelineVersionKF | null; }; -const FineTunePage: React.FC = ({ isInvalid, onSuccess, data, setData }) => { +const FineTunePage: React.FC = ({ + isInvalid, + onSuccess, + data, + setData, + ilabPipeline, + ilabPipelineVersion, +}) => { const projectDetailsDescription = 'This project is used for running your pipeline'; const { project } = usePipelinesAPI(); @@ -41,8 +53,19 @@ const FineTunePage: React.FC = ({ isInvalid, onSuccess, data, data={data.baseModel} setData={(baseModelData) => setData('baseModel', baseModelData)} /> + setData('teacher', teacherData)} + /> + setData('judge', judgeData)} /> - + ); diff --git a/frontend/src/pages/pipelines/global/modelCustomization/FineTunePageFooter.tsx b/frontend/src/pages/pipelines/global/modelCustomization/FineTunePageFooter.tsx index fb1ffa3a6c..226f24538a 100644 --- a/frontend/src/pages/pipelines/global/modelCustomization/FineTunePageFooter.tsx +++ b/frontend/src/pages/pipelines/global/modelCustomization/FineTunePageFooter.tsx @@ -1,4 +1,11 @@ -import { ActionList, ActionListItem, Button, Stack, StackItem } from '@patternfly/react-core'; +import { + ActionList, + ActionListItem, + Alert, + Button, + Stack, + StackItem, +} from '@patternfly/react-core'; import * as React from 'react'; import { useNavigate } from 'react-router-dom'; import { ModelCustomizationFormData } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils'; @@ -12,16 +19,41 @@ import { NotificationWatcherContext, NotificationWatcherResponse, } from '~/concepts/notificationWatcher/NotificationWatcherContext'; -import { RuntimeStateKF } from '~/concepts/pipelines/kfTypes'; +import { + PipelineKF, + PipelineRecurringRunKF, + PipelineRunKF, + PipelineVersionKF, + RuntimeStateKF, +} from '~/concepts/pipelines/kfTypes'; +import { + createTeacherJudgeSecrets, + translateIlabFormToTeacherJudge, +} from '~/pages/pipelines/global/modelCustomization/utils'; +import { genRandomChars } from '~/utilities/string'; +import { RunTypeOption } from '~/concepts/pipelines/content/createRun/types'; type FineTunePageFooterProps = { isInvalid: boolean; onSuccess: () => void; data: ModelCustomizationFormData; + ilabPipeline: PipelineKF | null; + ilabPipelineVersion: PipelineVersionKF | null; }; -// eslint-disable-next-line @typescript-eslint/no-unused-vars -- TODO remove this line when start using `data` -const FineTunePageFooter: React.FC = ({ isInvalid, onSuccess, data }) => { +type FineTunePageFooterSubmitPresetValues = { + teacherSecretName?: string; + judgeSecretName?: string; +}; + +const FineTunePageFooter: React.FC = ({ + isInvalid, + onSuccess, + data, + ilabPipeline, + ilabPipelineVersion, +}) => { + const [error, setError] = React.useState(); const [isSubmitting, setIsSubmitting] = React.useState(false); const { api } = usePipelinesAPI(); const { registerNotification } = React.useContext(NotificationWatcherContext); @@ -31,10 +63,119 @@ const FineTunePageFooter: React.FC = ({ isInvalid, onSu const contextPath = globalPipelineRunsRoute(namespace); // TODO: translate data to `RunFormData` - const [runFormData] = useRunFormData(null, {}); + const [runFormData] = useRunFormData(null, { + nameDesc: { + name: `lab-${genRandomChars()}`, + description: '', + }, + runType: { type: RunTypeOption.ONE_TRIGGER }, + pipeline: ilabPipeline, + version: ilabPipelineVersion, + }); + + const onSubmit = async (dryRun: boolean, presetValues?: FineTunePageFooterSubmitPresetValues) => { + const { teacherSecretName, judgeSecretName } = presetValues || {}; + const [teacherSecret, judgeSecret] = await createTeacherJudgeSecrets( + namespace, + data.teacher, + data.judge, + dryRun, + teacherSecretName, + judgeSecretName, + ); + const run = await handleSubmit( + { + ...runFormData, + params: { + ...runFormData.params, + ...translateIlabFormToTeacherJudge( + teacherSecret.metadata.name, + judgeSecret.metadata.name, + ), + }, + }, + api, + dryRun, + ); + return { run, teacherSecret, judgeSecret }; + }; + + const afterSubmit = (resource: PipelineRunKF | PipelineRecurringRunKF) => { + const runId = isRunSchedule(resource) ? resource.recurring_run_id : resource.run_id; + notification.info('InstructLab run started', `Run for ${resource.display_name} started`, [ + { + title: 'View run details', + onClick: () => { + navigate(`${contextPath}/${runId}`); + }, + }, + ]); + registerNotification({ + callback: (signal: AbortSignal) => + api + .getPipelineRun({ signal }, runId) + .then((response): NotificationWatcherResponse => { + if (response.state === RuntimeStateKF.SUCCEEDED) { + return { + status: 'success', + title: `${resource.display_name} successfully completed`, + message: `Your new model, ${resource.display_name}, is within the model registry`, + actions: [ + { + title: 'View in model registry', + onClick: () => { + // TODO: navigate to model registry + }, + }, + ], + }; + } + if (response.state === RuntimeStateKF.FAILED) { + return { + status: 'error', + title: `${resource.display_name} has failed`, + message: `Your run ${resource.display_name} has failed`, + actions: [ + { + title: 'View run details', + onClick: () => { + navigate(`${contextPath}/${runId}`); + }, + }, + ], + }; + } + if ( + response.state === RuntimeStateKF.RUNNING || + response.state === RuntimeStateKF.PENDING + ) { + return { status: 'repoll' }; + } + // Stop on any other state + return { status: 'stop' }; + }) + .catch((e) => { + // eslint-disable-next-line no-console + console.error('Error calling api.getPipelineRun', e); + return { status: 'stop' }; + }), + }); + }; + + const handleError = (e: Error) => { + setIsSubmitting(false); + setError(e); + }; return ( + {error && ( + + + {error.message} + + + )} @@ -43,86 +184,24 @@ const FineTunePageFooter: React.FC = ({ isInvalid, onSu data-testid="model-customization-submit-button" isDisabled={isInvalid || isSubmitting} onClick={() => { + setError(undefined); setIsSubmitting(true); - - handleSubmit(runFormData, api) - .then((resource) => { - const runId = isRunSchedule(resource) - ? resource.recurring_run_id - : resource.run_id; - - notification.info( - 'InstructLab run started', - `Run for ${resource.display_name} started`, - [ - { - title: 'View run details', - onClick: () => { - navigate(`${contextPath}/${runId}`); - }, - }, - ], - ); - - registerNotification({ - callback: (signal: AbortSignal) => - api - .getPipelineRun({ signal }, runId) - .then((response): NotificationWatcherResponse => { - if (response.state === RuntimeStateKF.SUCCEEDED) { - return { - status: 'success', - title: `${resource.display_name} successfully completed`, - message: `Your new model, ${resource.display_name}, is within the model registry`, - actions: [ - { - title: 'View in model registry', - onClick: () => { - // TODO: navigate to model registry - }, - }, - ], - }; - } - if (response.state === RuntimeStateKF.FAILED) { - return { - status: 'error', - title: `${resource.display_name} has failed`, - message: `Your run ${resource.display_name} has failed`, - actions: [ - { - title: 'View run details', - onClick: () => { - navigate(`${contextPath}/${runId}`); - }, - }, - ], - }; - } - if ( - response.state === RuntimeStateKF.RUNNING || - response.state === RuntimeStateKF.PENDING - ) { - return { status: 'repoll' }; - } - // Stop on any other state - return { status: 'stop' }; - }) - .catch((e) => { - // eslint-disable-next-line no-console - console.error('Error calling api.getPipelineRun', e); - return { status: 'stop' }; - }), - }); - - onSuccess(); - }) - .catch(() => { - // TODO: show error in the form - }) - .finally(() => { - setIsSubmitting(false); - }); + // dry-run network calls first + onSubmit(true) + .then(({ teacherSecret, judgeSecret }) => + // get the dry-run values and do the real network calls + onSubmit(false, { + teacherSecretName: teacherSecret.metadata.name, + judgeSecretName: judgeSecret.metadata.name, + }) + .then(({ run }) => { + afterSubmit(run); + setIsSubmitting(false); + onSuccess(); + }) + .catch(handleError), + ) + .catch(handleError); }} isLoading={isSubmitting} > diff --git a/frontend/src/pages/pipelines/global/modelCustomization/ModelCustomizationForm.tsx b/frontend/src/pages/pipelines/global/modelCustomization/ModelCustomizationForm.tsx index db8d39738f..108e891092 100644 --- a/frontend/src/pages/pipelines/global/modelCustomization/ModelCustomizationForm.tsx +++ b/frontend/src/pages/pipelines/global/modelCustomization/ModelCustomizationForm.tsx @@ -18,6 +18,7 @@ import { import { usePipelinesAPI } from '~/concepts/pipelines/context'; import { modelCustomizationRootPath } from '~/routes'; import { useIlabPipeline } from '~/concepts/pipelines/content/modelCustomizationForm/useIlabPipeline'; +import { ModelCustomizationEndpointType } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/types'; import FineTunePage from './FineTunePage'; import { BASE_MODEL_INPUT_STORAGE_LOCATION_URI_KEY, @@ -27,7 +28,12 @@ import { const ModelCustomizationForm: React.FC = () => { const { project } = usePipelinesAPI(); - const [ilabPipeline, ilabPipelineLoaded, ilabPipelineLoadError] = useIlabPipeline(); + const { + ilabPipeline, + ilabPipelineVersion, + loaded: ilabPipelineLoaded, + loadError: ilabPipelineLoadError, + } = useIlabPipeline(); const [searchParams] = useSearchParams(); @@ -40,6 +46,18 @@ const ModelCustomizationForm: React.FC = () => { version: 'myModel-v0.0.2', inputStorageLocationUri: searchParams.get(BASE_MODEL_INPUT_STORAGE_LOCATION_URI_KEY) ?? '', }, + teacher: { + endpointType: ModelCustomizationEndpointType.PUBLIC, + apiToken: '', + endpoint: '', + modelName: '', + }, + judge: { + endpointType: ModelCustomizationEndpointType.PUBLIC, + apiToken: '', + endpoint: '', + modelName: '', + }, }); const validation = useValidation(data, modelCustomizationFormSchema); @@ -73,12 +91,14 @@ const ModelCustomizationForm: React.FC = () => { onSuccess={() => navigate( `/pipelines/${encodeURIComponent(project.metadata.name)}/${encodeURIComponent( - ilabPipeline?.pipeline_id ?? '', - )}/${encodeURIComponent(ilabPipeline?.pipeline_version_id ?? '')}/view`, + ilabPipelineVersion?.pipeline_id ?? '', + )}/${encodeURIComponent(ilabPipelineVersion?.pipeline_version_id ?? '')}/view`, ) } data={data} setData={setData} + ilabPipeline={ilabPipeline} + ilabPipelineVersion={ilabPipelineVersion} /> diff --git a/frontend/src/pages/pipelines/global/modelCustomization/const.ts b/frontend/src/pages/pipelines/global/modelCustomization/const.ts index b7f58f9bdb..7b7ccfbe6c 100644 --- a/frontend/src/pages/pipelines/global/modelCustomization/const.ts +++ b/frontend/src/pages/pipelines/global/modelCustomization/const.ts @@ -1,11 +1,15 @@ export enum FineTunePageSections { PROJECT_DETAILS = 'fine-tune-section-project-details', BASE_MODEL = 'fine-tune-section-base-model', + TEACHER_MODEL = 'fine-tune-section-teacher-model', + JUDGE_MODEL = 'fine-tune-section-judge-model', } export const fineTunePageSectionTitles: Record = { [FineTunePageSections.PROJECT_DETAILS]: 'Project details', [FineTunePageSections.BASE_MODEL]: 'Base model', + [FineTunePageSections.TEACHER_MODEL]: 'Teacher model', + [FineTunePageSections.JUDGE_MODEL]: 'Judge model', }; export const ILAB_PIPELINE_NAME = 'instructlab'; diff --git a/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/JudgeModelSection.tsx b/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/JudgeModelSection.tsx new file mode 100644 index 0000000000..f91cb91ee5 --- /dev/null +++ b/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/JudgeModelSection.tsx @@ -0,0 +1,95 @@ +import React from 'react'; +import { Button, FormGroup, FormSection, Radio } from '@patternfly/react-core'; +import { + FineTunePageSections, + fineTunePageSectionTitles, +} from '~/pages/pipelines/global/modelCustomization/const'; +import { ModelCustomizationEndpointType } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/types'; +import { + JudgeEndpointInput, + JudgeModelNameInput, + JudgeTokenInput, +} from '~/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherJudgeInputComponents'; +import { TeacherJudgeFormData } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils'; + +type JudgeModelSectionProps = { + data: TeacherJudgeFormData; + setData: (data: TeacherJudgeFormData) => void; +}; + +const JudgeModelSection: React.FC = ({ data, setData }) => ( + + {/* TODO: add link to judge model */} +
+ Select or create a connection to specify the judge model to deploy for use in model + evaluation.{' '} + +
+ + { + setData({ + ...data, + endpointType: ModelCustomizationEndpointType.PUBLIC, + }); + }} + body={ + data.endpointType === ModelCustomizationEndpointType.PUBLIC && ( + <> + setData({ ...data, endpoint: value })} + /> + setData({ ...data, modelName: value })} + /> + + ) + } + /> + { + setData({ + ...data, + endpointType: ModelCustomizationEndpointType.PRIVATE, + }); + }} + body={ + data.endpointType === ModelCustomizationEndpointType.PRIVATE && ( + <> + setData({ ...data, endpoint: value })} + /> + setData({ ...data, apiToken: value })} + /> + setData({ ...data, modelName: value })} + /> + + ) + } + /> + +
+); + +export default JudgeModelSection; diff --git a/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherJudgeInputComponents.tsx b/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherJudgeInputComponents.tsx new file mode 100644 index 0000000000..bffa6b78f5 --- /dev/null +++ b/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherJudgeInputComponents.tsx @@ -0,0 +1,103 @@ +import React from 'react'; +import { FormGroup, TextInput } from '@patternfly/react-core'; +import PasswordInput from '~/components/PasswordInput'; + +type TeacherJudgeInputBaseProps = TeacherJudgeInputProps & { + label: string; + fieldId: string; + isPasswordType?: boolean; +}; + +type TeacherJudgeInputProps = { + value: string; + setValue: (value: string) => void; +}; + +export const TeacherEndpointInput: React.FC = ({ value, setValue }) => ( + +); + +export const TeacherModelNameInput: React.FC = ({ value, setValue }) => ( + +); + +export const TeacherTokenInput: React.FC = ({ value, setValue }) => ( + +); + +export const JudgeEndpointInput: React.FC = ({ value, setValue }) => ( + +); + +export const JudgeModelNameInput: React.FC = ({ value, setValue }) => ( + +); + +export const JudgeTokenInput: React.FC = ({ value, setValue }) => ( + +); + +const TeacherJudgeInputBase: React.FC = ({ + label, + fieldId, + value, + setValue, + isPasswordType, +}) => ( + <> + + {isPasswordType ? ( + setValue(newValue)} + /> + ) : ( + setValue(newValue)} + /> + )} + + +); diff --git a/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherModelSection.tsx b/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherModelSection.tsx new file mode 100644 index 0000000000..7854f78086 --- /dev/null +++ b/frontend/src/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherModelSection.tsx @@ -0,0 +1,95 @@ +import React from 'react'; +import { Button, FormGroup, FormSection, Radio } from '@patternfly/react-core'; +import { + FineTunePageSections, + fineTunePageSectionTitles, +} from '~/pages/pipelines/global/modelCustomization/const'; +import { ModelCustomizationEndpointType } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/types'; +import { + TeacherEndpointInput, + TeacherModelNameInput, + TeacherTokenInput, +} from '~/pages/pipelines/global/modelCustomization/teacherJudgeSection/TeacherJudgeInputComponents'; +import { TeacherJudgeFormData } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils'; + +type TeacherModelSectionProps = { + data: TeacherJudgeFormData; + setData: (data: TeacherJudgeFormData) => void; +}; + +const TeacherModelSection: React.FC = ({ data, setData }) => ( + + {/* TODO: add link to teacher model */} +
+ Select or create a connection to specify the teacher model to deploy for use in synthetic data + generation (SDG).{' '} + +
+ + { + setData({ + ...data, + endpointType: ModelCustomizationEndpointType.PUBLIC, + }); + }} + body={ + data.endpointType === ModelCustomizationEndpointType.PUBLIC && ( + <> + setData({ ...data, endpoint: value })} + /> + setData({ ...data, modelName: value })} + /> + + ) + } + /> + { + setData({ + ...data, + endpointType: ModelCustomizationEndpointType.PRIVATE, + }); + }} + body={ + data.endpointType === ModelCustomizationEndpointType.PRIVATE && ( + <> + setData({ ...data, endpoint: value })} + /> + setData({ ...data, apiToken: value })} + /> + setData({ ...data, modelName: value })} + /> + + ) + } + /> + +
+); + +export default TeacherModelSection; diff --git a/frontend/src/pages/pipelines/global/modelCustomization/utils.ts b/frontend/src/pages/pipelines/global/modelCustomization/utils.ts new file mode 100644 index 0000000000..e4f36cb0bc --- /dev/null +++ b/frontend/src/pages/pipelines/global/modelCustomization/utils.ts @@ -0,0 +1,62 @@ +import { assembleSecretJudge, assembleSecretTeacher, createSecret } from '~/api'; +import { ModelCustomizationEndpointType } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/types'; +import { TeacherJudgeFormData } from '~/concepts/pipelines/content/modelCustomizationForm/modelCustomizationFormSchema/validationUtils'; +import { SecretKind } from '~/k8sTypes'; + +export const createTeacherJudgeSecrets = ( + projectName: string, + teacherData: TeacherJudgeFormData, + judgeData: TeacherJudgeFormData, + dryRun: boolean, + teacherSecretName?: string, + judgeSecretName?: string, +): Promise => + Promise.all([ + createSecret( + assembleSecretTeacher( + projectName, + { + /* eslint-disable camelcase */ + api_token: + teacherData.endpointType === ModelCustomizationEndpointType.PRIVATE + ? teacherData.apiToken.trim() + : '', + endpoint: teacherData.endpoint.trim(), + model_name: teacherData.modelName.trim(), + /* eslint-enable camelcase */ + }, + teacherSecretName, + ), + { dryRun }, + ), + createSecret( + assembleSecretJudge( + projectName, + { + /* eslint-disable camelcase */ + api_token: + judgeData.endpointType === ModelCustomizationEndpointType.PRIVATE + ? judgeData.apiToken.trim() + : '', + endpoint: judgeData.endpoint.trim(), + model_name: judgeData.modelName.trim(), + /* eslint-enable camelcase */ + }, + judgeSecretName, + ), + { dryRun }, + ), + ]); + +export const translateIlabFormToTeacherJudge = ( + teacherSecretName: string, + judgeSecretName: string, +): { + teacher_secret: string; + judge_secret: string; +} => ({ + /* eslint-disable camelcase */ + teacher_secret: teacherSecretName, + judge_secret: judgeSecretName, + /* eslint-enable camelcase */ +});