From 8461369ebb94c0abeff02c6f05d865373be1f1d1 Mon Sep 17 00:00:00 2001 From: Tyler Ohlsen Date: Fri, 30 Aug 2024 09:46:53 -0700 Subject: [PATCH] Prefill dimension if known remote model found; onboard search connectors API (#330) * onboard search connectors api; add auto-fetching of dimensions Signed-off-by: Tyler Ohlsen * Reset to undefined if no values found Signed-off-by: Tyler Ohlsen * add comment Signed-off-by: Tyler Ohlsen * minor tuning of select autofill in maps Signed-off-by: Tyler Ohlsen --------- Signed-off-by: Tyler Ohlsen --- common/constants.ts | 27 ++++++++ common/interfaces.ts | 16 +++++ .../input_fields/map_field.tsx | 7 ++ .../input_fields/model_field.tsx | 2 +- .../select_with_custom_options.tsx | 16 +++-- .../processor_inputs/ml_processor_inputs.tsx | 2 +- .../workflow_inputs/workflow_inputs.tsx | 11 ---- .../workflows/new_workflow/new_workflow.tsx | 4 +- .../new_workflow/quick_configure_inputs.tsx | 45 ++++++++++++- public/route_service.ts | 19 ++++++ public/store/reducers/index.ts | 2 +- .../{models_reducer.ts => ml_reducer.ts} | 49 ++++++++++++-- public/store/store.ts | 4 +- server/cluster/ml_plugin.ts | 13 +++- server/routes/helpers.ts | 24 +++++++ server/routes/ml_routes_service.ts | 65 ++++++++++++++++++- 16 files changed, 272 insertions(+), 34 deletions(-) rename public/store/reducers/{models_reducer.ts => ml_reducer.ts} (52%) diff --git a/common/constants.ts b/common/constants.ts index 8167231f..7027e3ce 100644 --- a/common/constants.ts +++ b/common/constants.ts @@ -22,7 +22,9 @@ export const FLOW_FRAMEWORK_SEARCH_WORKFLOW_STATE_ROUTE = `${FLOW_FRAMEWORK_WORK */ export const ML_API_ROUTE_PREFIX = '/_plugins/_ml'; export const ML_MODEL_ROUTE_PREFIX = `${ML_API_ROUTE_PREFIX}/models`; +export const ML_CONNECTOR_ROUTE_PREFIX = `${ML_API_ROUTE_PREFIX}/connectors`; export const ML_SEARCH_MODELS_ROUTE = `${ML_MODEL_ROUTE_PREFIX}/_search`; +export const ML_SEARCH_CONNECTORS_ROUTE = `${ML_CONNECTOR_ROUTE_PREFIX}/_search`; /** * NODE APIs @@ -51,7 +53,32 @@ export const GET_PRESET_WORKFLOWS_NODE_API_PATH = `${BASE_WORKFLOW_NODE_API_PATH // ML Plugin node APIs export const BASE_MODEL_NODE_API_PATH = `${BASE_NODE_API_PATH}/model`; +export const BASE_CONNECTOR_NODE_API_PATH = `${BASE_NODE_API_PATH}/connector`; export const SEARCH_MODELS_NODE_API_PATH = `${BASE_MODEL_NODE_API_PATH}/search`; +export const SEARCH_CONNECTORS_NODE_API_PATH = `${BASE_CONNECTOR_NODE_API_PATH}/search`; + +/** + * Remote model dimensions. Used for attempting to pre-fill dimension size + * based on the specified remote model from a remote service, if found + */ + +// Cohere +export const COHERE_DIMENSIONS = { + [`embed-english-v3.0`]: 1024, + [`embed-english-light-v3.0`]: 384, + [`embed-multilingual-v3.0`]: 1024, + [`embed-multilingual-light-v3.0`]: 384, + [`embed-english-v2.0`]: 4096, + [`embed-english-light-v2.0`]: 1024, + [`embed-multilingual-v2.0`]: 768, +}; + +// OpenAI +export const OPENAI_DIMENSIONS = { + [`text-embedding-3-small`]: 1536, + [`text-embedding-3-large`]: 3072, + [`text-embedding-ada-002`]: 1536, +}; /** * Various constants pertaining to Workflow configs diff --git a/common/interfaces.ts b/common/interfaces.ts index ee247f47..f674e5c9 100644 --- a/common/interfaces.ts +++ b/common/interfaces.ts @@ -393,6 +393,11 @@ export type ModelInterface = { output: { [key: string]: ModelOutput }; }; +export type ConnectorParameters = { + model?: string; + dimensions?: number; +}; + export type Model = { id: string; name: string; @@ -400,12 +405,23 @@ export type Model = { state: MODEL_STATE; modelConfig?: ModelConfig; interface?: ModelInterface; + connectorId?: string; +}; + +export type Connector = { + id: string; + name: string; + parameters?: ConnectorParameters; }; export type ModelDict = { [modelId: string]: Model; }; +export type ConnectorDict = { + [connectorId: string]: Connector; +}; + export type ModelFormValue = { id: string; algorithm?: MODEL_ALGORITHM; diff --git a/public/pages/workflow_detail/workflow_inputs/input_fields/map_field.tsx b/public/pages/workflow_detail/workflow_inputs/input_fields/map_field.tsx index 1cdc3aef..058d51e5 100644 --- a/public/pages/workflow_detail/workflow_inputs/input_fields/map_field.tsx +++ b/public/pages/workflow_detail/workflow_inputs/input_fields/map_field.tsx @@ -108,6 +108,9 @@ export function MapField(props: MapFieldProps) { fieldPath={`${props.fieldPath}.${idx}.key`} options={props.keyOptions as any[]} placeholder={props.keyPlaceholder || 'Input'} + autofill={ + props.keyOptions?.length === 1 && idx === 0 + } /> ) : ( ) : ( page. We don't // re-fetch here as it could overload client-side if user clicks back and forth / // keeps re-rendering this component (and subsequently re-fetching data) as they're building flows - const models = useSelector((state: AppState) => state.models.models); + const models = useSelector((state: AppState) => state.ml.models); const { errors, touched } = useFormikContext(); diff --git a/public/pages/workflow_detail/workflow_inputs/input_fields/select_with_custom_options.tsx b/public/pages/workflow_detail/workflow_inputs/input_fields/select_with_custom_options.tsx index e4b65aa5..b4ff7c09 100644 --- a/public/pages/workflow_detail/workflow_inputs/input_fields/select_with_custom_options.tsx +++ b/public/pages/workflow_detail/workflow_inputs/input_fields/select_with_custom_options.tsx @@ -13,6 +13,7 @@ interface SelectWithCustomOptionsProps { fieldPath: string; placeholder: string; options: any[]; + autofill: boolean; } /** @@ -30,13 +31,14 @@ export function SelectWithCustomOptions(props: SelectWithCustomOptionsProps) { // default to the top option. by default, this will re-trigger this hook with a populated // value, to then finally update the displayed option. useEffect(() => { - const formValue = getIn(values, props.fieldPath); - if (!isEmpty(formValue)) { - setSelectedOption([{ label: getIn(values, props.fieldPath) }]); - } else { - if (props.options.length > 0) { - setFieldTouched(props.fieldPath, true); - setFieldValue(props.fieldPath, props.options[0].label); + if (props.autofill) { + const formValue = getIn(values, props.fieldPath); + if (!isEmpty(formValue)) { + setSelectedOption([{ label: getIn(values, props.fieldPath) }]); + } else { + if (props.options.length > 0) { + setFieldValue(props.fieldPath, props.options[0].label); + } } } }, [getIn(values, props.fieldPath)]); diff --git a/public/pages/workflow_detail/workflow_inputs/processor_inputs/ml_processor_inputs.tsx b/public/pages/workflow_detail/workflow_inputs/processor_inputs/ml_processor_inputs.tsx index 52ce7017..9a4dc9f6 100644 --- a/public/pages/workflow_detail/workflow_inputs/processor_inputs/ml_processor_inputs.tsx +++ b/public/pages/workflow_detail/workflow_inputs/processor_inputs/ml_processor_inputs.tsx @@ -53,7 +53,7 @@ interface MLProcessorInputsProps { * output map configuration forms, respectively. */ export function MLProcessorInputs(props: MLProcessorInputsProps) { - const models = useSelector((state: AppState) => state.models.models); + const models = useSelector((state: AppState) => state.ml.models); const { values, setFieldValue, setFieldTouched } = useFormikContext< WorkflowFormValues >(); diff --git a/public/pages/workflow_detail/workflow_inputs/workflow_inputs.tsx b/public/pages/workflow_detail/workflow_inputs/workflow_inputs.tsx index 2279ccb9..6b074f31 100644 --- a/public/pages/workflow_detail/workflow_inputs/workflow_inputs.tsx +++ b/public/pages/workflow_detail/workflow_inputs/workflow_inputs.tsx @@ -9,7 +9,6 @@ import { debounce, isEmpty, isEqual } from 'lodash'; import { EuiButton, EuiButtonEmpty, - EuiCallOut, EuiFlexGroup, EuiFlexItem, EuiHorizontalRule, @@ -100,7 +99,6 @@ export function WorkflowInputs(props: WorkflowInputsProps) { setFieldValue, values, touched, - dirty, } = useFormikContext(); const dispatch = useAppDispatch(); const dataSourceId = getDataSourceId(); @@ -650,15 +648,6 @@ export function WorkflowInputs(props: WorkflowInputsProps) { )} - {onIngest && - dirty && - hasProvisionedSearchResources(props.workflow) && ( - - )} {onIngestAndUnprovisioned && ( <> diff --git a/public/pages/workflows/new_workflow/new_workflow.tsx b/public/pages/workflows/new_workflow/new_workflow.tsx index 2ae8208f..f0c2a537 100644 --- a/public/pages/workflows/new_workflow/new_workflow.tsx +++ b/public/pages/workflows/new_workflow/new_workflow.tsx @@ -24,6 +24,7 @@ import { useAppDispatch, getWorkflowPresets, searchModels, + searchConnectors, } from '../../../store'; import { enrichPresetWorkflowWithUiMetadata } from './utils'; import { getDataSourceId } from '../../../utils'; @@ -56,11 +57,12 @@ export function NewWorkflow(props: NewWorkflowProps) { // on initial load: // 1. fetch the workflow presets persisted on server-side - // 2. fetch the ML models. these may be used in quick-create views when selecting a preset, + // 2. fetch the ML models and connectors. these may be used in quick-create views when selecting a preset, // so we optimize by fetching once at the top-level here. useEffect(() => { dispatch(getWorkflowPresets()); dispatch(searchModels({ apiBody: FETCH_ALL_QUERY, dataSourceId })); + dispatch(searchConnectors({ apiBody: FETCH_ALL_QUERY, dataSourceId })); }, []); // initial hook to populate all workflows diff --git a/public/pages/workflows/new_workflow/quick_configure_inputs.tsx b/public/pages/workflows/new_workflow/quick_configure_inputs.tsx index 0bd09472..1ce8fd68 100644 --- a/public/pages/workflows/new_workflow/quick_configure_inputs.tsx +++ b/public/pages/workflows/new_workflow/quick_configure_inputs.tsx @@ -16,8 +16,10 @@ import { EuiCompressedFieldNumber, } from '@elastic/eui'; import { + COHERE_DIMENSIONS, MODEL_STATE, Model, + OPENAI_DIMENSIONS, QuickConfigureFields, WORKFLOW_TYPE, } from '../../../../common'; @@ -35,7 +37,7 @@ const DEFAULT_IMAGE_FIELD = 'my_image'; // Dynamic component to allow optional input configuration fields for different use cases. // Hooks back to the parent component with such field values export function QuickConfigureInputs(props: QuickConfigureInputsProps) { - const models = useSelector((state: AppState) => state.models.models); + const { models, connectors } = useSelector((state: AppState) => state.ml); // Deployed models state const [deployedModels, setDeployedModels] = useState([]); @@ -88,6 +90,45 @@ export function QuickConfigureInputs(props: QuickConfigureInputsProps) { props.setFields(fieldValues); }, [fieldValues]); + // Try to pre-fill the dimensions based on the chosen model + useEffect(() => { + const selectedModel = deployedModels.find( + (model) => model.id === fieldValues.embeddingModelId + ); + if (selectedModel?.connectorId !== undefined) { + const connector = connectors[selectedModel.connectorId]; + if (connector !== undefined) { + // some APIs allow specifically setting the dimensions at runtime, + // so we check for that first. + if (connector.parameters?.dimensions !== undefined) { + setFieldValues({ + ...fieldValues, + embeddingLength: connector.parameters?.dimensions, + }); + } else if (connector.parameters?.model !== undefined) { + const dimensions = + // @ts-ignore + COHERE_DIMENSIONS[connector.parameters?.model] || + // @ts-ignore + (OPENAI_DIMENSIONS[connector.parameters?.model] as + | number + | undefined); + if (dimensions !== undefined) { + setFieldValues({ + ...fieldValues, + embeddingLength: dimensions, + }); + } + } else { + setFieldValues({ + ...fieldValues, + embeddingLength: undefined, + }); + } + } + } + }, [fieldValues.embeddingModelId, deployedModels, connectors]); + return ( <> {(props.workflowType === WORKFLOW_TYPE.SEMANTIC_SEARCH || @@ -196,7 +237,7 @@ export function QuickConfigureInputs(props: QuickConfigureInputsProps) { Promise; + searchConnectors: ( + body: {}, + dataSourceId?: string + ) => Promise; + simulatePipeline: ( body: { pipeline: IngestPipelineConfig; @@ -342,6 +348,19 @@ export function configureRoutes(core: CoreStart): RouteService { return e as HttpFetchError; } }, + searchConnectors: async (body: {}, dataSourceId?: string) => { + try { + const url = dataSourceId + ? `${BASE_NODE_API_PATH}/${dataSourceId}/connector/search` + : SEARCH_CONNECTORS_NODE_API_PATH; + const response = await core.http.post<{ respString: string }>(url, { + body: JSON.stringify(body), + }); + return response; + } catch (e: any) { + return e as HttpFetchError; + } + }, simulatePipeline: async ( body: { pipeline: IngestPipelineConfig; diff --git a/public/store/reducers/index.ts b/public/store/reducers/index.ts index 9aeff271..d28b3193 100644 --- a/public/store/reducers/index.ts +++ b/public/store/reducers/index.ts @@ -6,4 +6,4 @@ export * from './opensearch_reducer'; export * from './workflows_reducer'; export * from './presets_reducer'; -export * from './models_reducer'; +export * from './ml_reducer'; diff --git a/public/store/reducers/models_reducer.ts b/public/store/reducers/ml_reducer.ts similarity index 52% rename from public/store/reducers/models_reducer.ts rename to public/store/reducers/ml_reducer.ts index 600cbf9d..53791869 100644 --- a/public/store/reducers/models_reducer.ts +++ b/public/store/reducers/ml_reducer.ts @@ -4,7 +4,7 @@ */ import { createAsyncThunk, createSlice } from '@reduxjs/toolkit'; -import { ModelDict } from '../../../common'; +import { ConnectorDict, ModelDict } from '../../../common'; import { HttpFetchError } from '../../../../../src/core/public'; import { getRouteService } from '../../services'; @@ -12,10 +12,13 @@ const initialState = { loading: false, errorMessage: '', models: {} as ModelDict, + connectors: {} as ConnectorDict, }; const MODELS_ACTION_PREFIX = 'models'; -const SEARCH_MODELS_ACTION = `${MODELS_ACTION_PREFIX}/searchModels`; +const CONNECTORS_ACTION_PREFIX = 'connectors'; +const SEARCH_MODELS_ACTION = `${MODELS_ACTION_PREFIX}/search`; +const SEARCH_CONNECTORS_ACTION = `${CONNECTORS_ACTION_PREFIX}/search`; export const searchModels = createAsyncThunk( SEARCH_MODELS_ACTION, @@ -37,8 +40,30 @@ export const searchModels = createAsyncThunk( } ); -const modelsSlice = createSlice({ - name: 'models', +export const searchConnectors = createAsyncThunk( + SEARCH_CONNECTORS_ACTION, + async ( + { apiBody, dataSourceId }: { apiBody: {}; dataSourceId?: string }, + { rejectWithValue } + ) => { + const response: + | any + | HttpFetchError = await getRouteService().searchConnectors( + apiBody, + dataSourceId + ); + if (response instanceof HttpFetchError) { + return rejectWithValue( + 'Error searching connectors: ' + response.body.message + ); + } else { + return response; + } + } +); + +const mlSlice = createSlice({ + name: 'ml', initialState, reducers: {}, extraReducers: (builder) => { @@ -48,6 +73,10 @@ const modelsSlice = createSlice({ state.loading = true; state.errorMessage = ''; }) + .addCase(searchConnectors.pending, (state, action) => { + state.loading = true; + state.errorMessage = ''; + }) // Fulfilled states .addCase(searchModels.fulfilled, (state, action) => { const { models } = action.payload as { models: ModelDict }; @@ -55,12 +84,22 @@ const modelsSlice = createSlice({ state.loading = false; state.errorMessage = ''; }) + .addCase(searchConnectors.fulfilled, (state, action) => { + const { connectors } = action.payload as { connectors: ConnectorDict }; + state.connectors = connectors; + state.loading = false; + state.errorMessage = ''; + }) // Rejected states .addCase(searchModels.rejected, (state, action) => { state.errorMessage = action.payload as string; state.loading = false; + }) + .addCase(searchConnectors.rejected, (state, action) => { + state.errorMessage = action.payload as string; + state.loading = false; }); }, }); -export const modelsReducer = modelsSlice.reducer; +export const mlReducer = mlSlice.reducer; diff --git a/public/store/store.ts b/public/store/store.ts index 3d0c4214..c3a61818 100644 --- a/public/store/store.ts +++ b/public/store/store.ts @@ -10,13 +10,13 @@ import { opensearchReducer, workflowsReducer, presetsReducer, - modelsReducer, + mlReducer, } from './reducers'; const rootReducer = combineReducers({ workflows: workflowsReducer, presets: presetsReducer, - models: modelsReducer, + ml: mlReducer, opensearch: opensearchReducer, }); diff --git a/server/cluster/ml_plugin.ts b/server/cluster/ml_plugin.ts index 4601a9dd..2d13e749 100644 --- a/server/cluster/ml_plugin.ts +++ b/server/cluster/ml_plugin.ts @@ -3,7 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ML_SEARCH_MODELS_ROUTE } from '../../common'; +import { + ML_SEARCH_CONNECTORS_ROUTE, + ML_SEARCH_MODELS_ROUTE, +} from '../../common'; /** * Used during the plugin's setup() lifecycle phase to register various client actions @@ -24,4 +27,12 @@ export function mlPlugin(Client: any, config: any, components: any) { needBody: true, method: 'POST', }); + + mlClient.searchConnectors = ca({ + url: { + fmt: ML_SEARCH_CONNECTORS_ROUTE, + }, + needBody: true, + method: 'POST', + }); } diff --git a/server/routes/helpers.ts b/server/routes/helpers.ts index 0c5735ba..b66c64dd 100644 --- a/server/routes/helpers.ts +++ b/server/routes/helpers.ts @@ -4,6 +4,8 @@ */ import { + Connector, + ConnectorDict, DEFAULT_NEW_WORKFLOW_STATE_TYPE, INDEX_NOT_FOUND_EXCEPTION, MODEL_ALGORITHM, @@ -126,12 +128,34 @@ export function getModelsFromResponses(modelHits: SearchHit[]): ModelDict { modelHit._source?.model_config?.embedding_dimension, }, interface: modelInterface, + connectorId: modelHit._source?.connector_id, } as Model; } }); return modelDict; } +export function getConnectorsFromResponses( + modelHits: SearchHit[] +): ConnectorDict { + const connectorDict = {} as ConnectorDict; + modelHits.forEach((connectorHit: SearchHit) => { + const connectorId = connectorHit._id; + + // in case of schema changes from ML plugin, this may crash. That is ok, as the error + // produced will help expose the root cause + connectorDict[connectorId] = { + id: connectorId, + name: connectorHit._source?.name, + parameters: { + model: connectorHit._source?.parameters?.model, + dimensions: connectorHit._source?.parameters.dimensions, + }, + } as Connector; + }); + return connectorDict; +} + // Convert the workflow state into a readable/presentable state on frontend export function getWorkflowStateFromResponse( state: typeof WORKFLOW_STATE | undefined diff --git a/server/routes/ml_routes_service.ts b/server/routes/ml_routes_service.ts index 9f6e6639..c1f7bc5b 100644 --- a/server/routes/ml_routes_service.ts +++ b/server/routes/ml_routes_service.ts @@ -11,8 +11,17 @@ import { OpenSearchDashboardsRequest, OpenSearchDashboardsResponseFactory, } from '../../../../src/core/server'; -import { SEARCH_MODELS_NODE_API_PATH, BASE_NODE_API_PATH, SearchHit } from '../../common'; -import { generateCustomError, getModelsFromResponses } from './helpers'; +import { + SEARCH_MODELS_NODE_API_PATH, + BASE_NODE_API_PATH, + SearchHit, + SEARCH_CONNECTORS_NODE_API_PATH, +} from '../../common'; +import { + generateCustomError, + getConnectorsFromResponses, + getModelsFromResponses, +} from './helpers'; import { getClientBasedOnDataSource } from '../utils/helpers'; /** @@ -44,6 +53,27 @@ export function registerMLRoutes( }, mlRoutesService.searchModels ); + router.post( + { + path: SEARCH_CONNECTORS_NODE_API_PATH, + validate: { + body: schema.any(), + }, + }, + mlRoutesService.searchConnectors + ); + router.post( + { + path: `${BASE_NODE_API_PATH}/{data_source_id}/connector/search`, + validate: { + body: schema.any(), + params: schema.object({ + data_source_id: schema.string(), + }), + }, + }, + mlRoutesService.searchConnectors + ); } export class MLRoutesService { @@ -82,4 +112,35 @@ export class MLRoutesService { return generateCustomError(res, err); } }; + + searchConnectors = async ( + context: RequestHandlerContext, + req: OpenSearchDashboardsRequest, + res: OpenSearchDashboardsResponseFactory + ): Promise> => { + const body = req.body; + try { + const { data_source_id = '' } = req.params as { data_source_id?: string }; + const callWithRequest = getClientBasedOnDataSource( + context, + this.dataSourceEnabled, + req, + data_source_id, + this.client + ); + const connectorsResponse = await callWithRequest( + 'mlClient.searchConnectors', + { + body, + } + ); + + const connectorHits = connectorsResponse.hits.hits as SearchHit[]; + const connectorDict = getConnectorsFromResponses(connectorHits); + + return res.ok({ body: { connectors: connectorDict } }); + } catch (err: any) { + return generateCustomError(res, err); + } + }; }