Skip to content

Commit

Permalink
Add option to select and provision pretrained text embedding models (#…
Browse files Browse the repository at this point in the history
…137) (#138)

Signed-off-by: Tyler Ohlsen <[email protected]>
(cherry picked from commit 76584f4)

Co-authored-by: Tyler Ohlsen <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and ohltyler authored Apr 18, 2024
1 parent e3ec37f commit e805ea7
Show file tree
Hide file tree
Showing 18 changed files with 527 additions and 65 deletions.
43 changes: 42 additions & 1 deletion common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { TemplateNode, WORKFLOW_STATE } from './interfaces';
import {
MODEL_ALGORITHM,
PRETRAINED_MODEL_FORMAT,
PretrainedSentenceTransformer,
WORKFLOW_STATE,
} from './interfaces';

export const PLUGIN_ID = 'flow-framework';

Expand Down Expand Up @@ -52,6 +57,42 @@ export const SEARCH_MODELS_NODE_API_PATH = `${BASE_MODEL_NODE_API_PATH}/search`;
*/
export const CREATE_INGEST_PIPELINE_STEP_TYPE = 'create_ingest_pipeline';
export const CREATE_INDEX_STEP_TYPE = 'create_index';
export const REGISTER_LOCAL_PRETRAINED_MODEL_STEP_TYPE =
'register_local_pretrained_model';

/**
* ML PLUGIN PRETRAINED MODELS
* (based off of https://opensearch.org/docs/latest/ml-commons-plugin/pretrained-models/#sentence-transformers)
*/
export const ROBERTA_SENTENCE_TRANSFORMER = {
name: 'huggingface/sentence-transformers/all-distilroberta-v1',
shortenedName: 'all-distilroberta-v1',
description: 'A sentence transformer from Hugging Face',
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
version: '1.0.1',
vectorDimensions: 768,
} as PretrainedSentenceTransformer;

export const MPNET_SENTENCE_TRANSFORMER = {
name: 'huggingface/sentence-transformers/all-mpnet-base-v2',
shortenedName: 'all-mpnet-base-v2',
description: 'A sentence transformer from Hugging Face',
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
version: '1.0.1',
vectorDimensions: 768,
} as PretrainedSentenceTransformer;

export const BERT_SENTENCE_TRANSFORMER = {
name: 'huggingface/sentence-transformers/msmarco-distilbert-base-tas-b',
shortenedName: 'msmarco-distilbert-base-tas-b',
description: 'A sentence transformer from Hugging Face',
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
version: '1.0.2',
vectorDimensions: 768,
} as PretrainedSentenceTransformer;

/**
* MISCELLANEOUS
Expand Down
90 changes: 85 additions & 5 deletions common/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ export type CreateIndexNode = TemplateNode & {
};
};

export type RegisterPretrainedModelNode = TemplateNode & {
user_inputs: {
name: string;
description: string;
model_format: string;
version: string;
deploy: boolean;
};
};

export type TemplateEdge = {
source: string;
dest: string;
Expand Down Expand Up @@ -130,9 +140,83 @@ export enum USE_CASE {
/**
********** ML PLUGIN TYPES/INTERFACES **********
*/

// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java
export enum MODEL_STATE {
REGISTERED = 'Registered',
REGISTERING = 'Registering',
DEPLOYING = 'Deploying',
DEPLOYED = 'Deployed',
PARTIALLY_DEPLOYED = 'Partially deployed',
UNDEPLOYED = 'Undeployed',
DEPLOY_FAILED = 'Deploy failed',
}

// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/FunctionName.java
export enum MODEL_ALGORITHM {
LINEAR_REGRESSION = 'Linear regression',
KMEANS = 'K-means',
AD_LIBSVM = 'AD LIBSVM',
SAMPLE_ALGO = 'Sample algorithm',
LOCAL_SAMPLE_CALCULATOR = 'Local sample calculator',
FIT_RCF = 'Fit RCF',
BATCH_RCF = 'Batch RCF',
ANOMALY_LOCALIZATION = 'Anomaly localization',
RCF_SUMMARIZE = 'RCF summarize',
LOGISTIC_REGRESSION = 'Logistic regression',
TEXT_EMBEDDING = 'Text embedding',
METRICS_CORRELATION = 'Metrics correlation',
REMOTE = 'Remote',
SPARSE_ENCODING = 'Sparse encoding',
SPARSE_TOKENIZE = 'Sparse tokenize',
TEXT_SIMILARITY = 'Text similarity',
QUESTION_ANSWERING = 'Question answering',
AGENT = 'Agent',
}

export enum MODEL_CATEGORY {
DEPLOYED = 'Deployed',
PRETRAINED = 'Pretrained',
}

export enum PRETRAINED_MODEL_FORMAT {
TORCH_SCRIPT = 'TORCH_SCRIPT',
}

export type PretrainedModel = {
name: string;
shortenedName: string;
description: string;
format: PRETRAINED_MODEL_FORMAT;
algorithm: MODEL_ALGORITHM;
version: string;
};

export type PretrainedSentenceTransformer = PretrainedModel & {
vectorDimensions: number;
};

export type ModelConfig = {
modelType?: string;
embeddingDimension?: number;
};

export type Model = {
id: string;
algorithm: string;
name: string;
algorithm: MODEL_ALGORITHM;
state: MODEL_STATE;
modelConfig?: ModelConfig;
};

export type ModelDict = {
[modelId: string]: Model;
};

export type ModelFormValue = {
id: string;
category?: MODEL_CATEGORY;
algorithm?: MODEL_ALGORITHM;
};

/**
Expand Down Expand Up @@ -171,7 +255,3 @@ export enum WORKFLOW_RESOURCE_TYPE {
export type WorkflowDict = {
[workflowId: string]: Workflow;
};

export type ModelDict = {
[modelId: string]: Model;
};
17 changes: 13 additions & 4 deletions public/app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,21 @@ export const FlowFrameworkDashboardsApp = (props: Props) => {
<Workflows {...routeProps} />
)}
/>
{/* Defaulting to Workflows page */}
{/*
Defaulting to Workflows page. The pathname will need to be updated
to handle the redirection and get the router props consistent.
*/}
<Route
path={`${APP_PATH.HOME}`}
render={(routeProps: RouteComponentProps<WorkflowsRouterProps>) => (
<Workflows {...routeProps} />
)}
render={(routeProps: RouteComponentProps<WorkflowsRouterProps>) => {
if (props.history.location.pathname !== APP_PATH.WORKFLOWS) {
props.history.replace({
...history,
pathname: APP_PATH.WORKFLOWS,
});
}
return <Workflows {...routeProps} />;
}}
/>
</Switch>
</EuiPageTemplate>
Expand Down
2 changes: 1 addition & 1 deletion public/component_types/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { COMPONENT_CATEGORY, COMPONENT_CLASS } from '../utils';
/**
* ************ Types *************************
*/
export type FieldType = 'string' | 'json' | 'select';
export type FieldType = 'string' | 'json' | 'select' | 'model';
export type SelectType = 'model';
export type FieldValue = string | {};
export type ComponentFormValues = FormikValues;
Expand Down
10 changes: 4 additions & 6 deletions public/component_types/transformer/text_embedding_transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ export class TextEmbeddingTransformer extends MLTransformer {
this.inputs = [];
this.createFields = [
{
label: 'Model ID',
id: 'modelId',
type: 'select',
selectType: 'model',
helpText: 'The deployed text embedding model to use for embedding.',
label: 'Text Embedding Model',
id: 'model',
type: 'model',
helpText: 'A text embedding model for embedding text.',
helpLink:
'https://opensearch.org/docs/latest/ml-commons-plugin/integrating-ml-models/#choosing-a-model',
},
Expand All @@ -36,7 +35,6 @@ export class TextEmbeddingTransformer extends MLTransformer {
helpLink:
'https://opensearch.org/docs/latest/ingest-pipelines/processors/text-embedding/',
},

{
label: 'Vector Field',
id: 'vectorField',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ interface ComponentDetailsProps {
export function ComponentDetails(props: ComponentDetailsProps) {
return (
<EuiPanel paddingSize="m">
{props.isDeprovisionable ? (
{/* TODO: determine if we need this view if we want the workspace to remain
readonly once provisioned */}
{/* {props.isDeprovisionable ? (
<ProvisionedComponentInputs />
) : props.selectedComponent ? (
) : */}
{props.selectedComponent ? (
<ComponentInputs
selectedComponent={props.selectedComponent}
onFormChange={props.onFormChange}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ export function ComponentInputs(props: ComponentInputsProps) {
<EuiTitle size="m">
<h2>{props.selectedComponent.data.label || ''}</h2>
</EuiTitle>
<EuiText color="subdued">
{props.selectedComponent.data.description}
</EuiText>
<NewOrExistingTabs
selectedTabId={selectedTabId}
setSelectedTabId={setSelectedTabId}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import React from 'react';
import { EuiFlexItem, EuiSpacer } from '@elastic/eui';
import { TextField, JsonField, SelectField } from './input_fields';
import { TextField, JsonField, SelectField, ModelField } from './input_fields';
import { IComponentField } from '../../../../common';

/**
Expand Down Expand Up @@ -54,6 +54,19 @@ export function InputFieldList(props: InputFieldListProps) {
);
break;
}
case 'model': {
el = (
<EuiFlexItem key={idx}>
<ModelField
field={field}
componentId={props.componentId}
onFormChange={props.onFormChange}
/>
<EuiSpacer size={INPUT_FIELD_SPACER_SIZE} />
</EuiFlexItem>
);
break;
}
case 'json': {
el = (
<EuiFlexItem key={idx}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
export { TextField } from './text_field';
export { JsonField } from './json_field';
export { SelectField } from './select_field';
export { ModelField } from './model_field';
Loading

0 comments on commit e805ea7

Please sign in to comment.