Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add option to select and provision pretrained text embedding models #138

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { TemplateNode, WORKFLOW_STATE } from './interfaces';
import {
MODEL_ALGORITHM,
PRETRAINED_MODEL_FORMAT,
PretrainedSentenceTransformer,
WORKFLOW_STATE,
} from './interfaces';

export const PLUGIN_ID = 'flow-framework';

Expand Down Expand Up @@ -52,6 +57,42 @@ export const SEARCH_MODELS_NODE_API_PATH = `${BASE_MODEL_NODE_API_PATH}/search`;
*/
export const CREATE_INGEST_PIPELINE_STEP_TYPE = 'create_ingest_pipeline';
export const CREATE_INDEX_STEP_TYPE = 'create_index';
export const REGISTER_LOCAL_PRETRAINED_MODEL_STEP_TYPE =
'register_local_pretrained_model';

/**
* ML PLUGIN PRETRAINED MODELS
* (based off of https://opensearch.org/docs/latest/ml-commons-plugin/pretrained-models/#sentence-transformers)
*/
export const ROBERTA_SENTENCE_TRANSFORMER = {
name: 'huggingface/sentence-transformers/all-distilroberta-v1',
shortenedName: 'all-distilroberta-v1',
description: 'A sentence transformer from Hugging Face',
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
version: '1.0.1',
vectorDimensions: 768,
} as PretrainedSentenceTransformer;

export const MPNET_SENTENCE_TRANSFORMER = {
name: 'huggingface/sentence-transformers/all-mpnet-base-v2',
shortenedName: 'all-mpnet-base-v2',
description: 'A sentence transformer from Hugging Face',
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
version: '1.0.1',
vectorDimensions: 768,
} as PretrainedSentenceTransformer;

export const BERT_SENTENCE_TRANSFORMER = {
name: 'huggingface/sentence-transformers/msmarco-distilbert-base-tas-b',
shortenedName: 'msmarco-distilbert-base-tas-b',
description: 'A sentence transformer from Hugging Face',
format: PRETRAINED_MODEL_FORMAT.TORCH_SCRIPT,
algorithm: MODEL_ALGORITHM.TEXT_EMBEDDING,
version: '1.0.2',
vectorDimensions: 768,
} as PretrainedSentenceTransformer;

/**
* MISCELLANEOUS
Expand Down
90 changes: 85 additions & 5 deletions common/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ export type CreateIndexNode = TemplateNode & {
};
};

export type RegisterPretrainedModelNode = TemplateNode & {
user_inputs: {
name: string;
description: string;
model_format: string;
version: string;
deploy: boolean;
};
};

export type TemplateEdge = {
source: string;
dest: string;
Expand Down Expand Up @@ -130,9 +140,83 @@ export enum USE_CASE {
/**
********** ML PLUGIN TYPES/INTERFACES **********
*/

// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java
export enum MODEL_STATE {
REGISTERED = 'Registered',
REGISTERING = 'Registering',
DEPLOYING = 'Deploying',
DEPLOYED = 'Deployed',
PARTIALLY_DEPLOYED = 'Partially deployed',
UNDEPLOYED = 'Undeployed',
DEPLOY_FAILED = 'Deploy failed',
}

// Based off of https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/FunctionName.java
export enum MODEL_ALGORITHM {
LINEAR_REGRESSION = 'Linear regression',
KMEANS = 'K-means',
AD_LIBSVM = 'AD LIBSVM',
SAMPLE_ALGO = 'Sample algorithm',
LOCAL_SAMPLE_CALCULATOR = 'Local sample calculator',
FIT_RCF = 'Fit RCF',
BATCH_RCF = 'Batch RCF',
ANOMALY_LOCALIZATION = 'Anomaly localization',
RCF_SUMMARIZE = 'RCF summarize',
LOGISTIC_REGRESSION = 'Logistic regression',
TEXT_EMBEDDING = 'Text embedding',
METRICS_CORRELATION = 'Metrics correlation',
REMOTE = 'Remote',
SPARSE_ENCODING = 'Sparse encoding',
SPARSE_TOKENIZE = 'Sparse tokenize',
TEXT_SIMILARITY = 'Text similarity',
QUESTION_ANSWERING = 'Question answering',
AGENT = 'Agent',
}

export enum MODEL_CATEGORY {
DEPLOYED = 'Deployed',
PRETRAINED = 'Pretrained',
}

export enum PRETRAINED_MODEL_FORMAT {
TORCH_SCRIPT = 'TORCH_SCRIPT',
}

export type PretrainedModel = {
name: string;
shortenedName: string;
description: string;
format: PRETRAINED_MODEL_FORMAT;
algorithm: MODEL_ALGORITHM;
version: string;
};

export type PretrainedSentenceTransformer = PretrainedModel & {
vectorDimensions: number;
};

export type ModelConfig = {
modelType?: string;
embeddingDimension?: number;
};

export type Model = {
id: string;
algorithm: string;
name: string;
algorithm: MODEL_ALGORITHM;
state: MODEL_STATE;
modelConfig?: ModelConfig;
};

export type ModelDict = {
[modelId: string]: Model;
};

export type ModelFormValue = {
id: string;
category?: MODEL_CATEGORY;
algorithm?: MODEL_ALGORITHM;
};

/**
Expand Down Expand Up @@ -171,7 +255,3 @@ export enum WORKFLOW_RESOURCE_TYPE {
export type WorkflowDict = {
[workflowId: string]: Workflow;
};

export type ModelDict = {
[modelId: string]: Model;
};
17 changes: 13 additions & 4 deletions public/app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,21 @@ export const FlowFrameworkDashboardsApp = (props: Props) => {
<Workflows {...routeProps} />
)}
/>
{/* Defaulting to Workflows page */}
{/*
Defaulting to Workflows page. The pathname will need to be updated
to handle the redirection and get the router props consistent.
*/}
<Route
path={`${APP_PATH.HOME}`}
render={(routeProps: RouteComponentProps<WorkflowsRouterProps>) => (
<Workflows {...routeProps} />
)}
render={(routeProps: RouteComponentProps<WorkflowsRouterProps>) => {
if (props.history.location.pathname !== APP_PATH.WORKFLOWS) {
props.history.replace({
...history,
pathname: APP_PATH.WORKFLOWS,
});
}
return <Workflows {...routeProps} />;
}}
/>
</Switch>
</EuiPageTemplate>
Expand Down
2 changes: 1 addition & 1 deletion public/component_types/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { COMPONENT_CATEGORY, COMPONENT_CLASS } from '../utils';
/**
* ************ Types *************************
*/
export type FieldType = 'string' | 'json' | 'select';
export type FieldType = 'string' | 'json' | 'select' | 'model';
export type SelectType = 'model';
export type FieldValue = string | {};
export type ComponentFormValues = FormikValues;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ export class TextEmbeddingTransformer extends MLTransformer {
this.inputs = [];
this.createFields = [
{
label: 'Model ID',
id: 'modelId',
type: 'select',
selectType: 'model',
helpText: 'The deployed text embedding model to use for embedding.',
label: 'Text Embedding Model',
id: 'model',
type: 'model',
helpText: 'A text embedding model for embedding text.',
helpLink:
'https://opensearch.org/docs/latest/ml-commons-plugin/integrating-ml-models/#choosing-a-model',
},
Expand All @@ -36,7 +35,6 @@ export class TextEmbeddingTransformer extends MLTransformer {
helpLink:
'https://opensearch.org/docs/latest/ingest-pipelines/processors/text-embedding/',
},

{
label: 'Vector Field',
id: 'vectorField',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ interface ComponentDetailsProps {
export function ComponentDetails(props: ComponentDetailsProps) {
return (
<EuiPanel paddingSize="m">
{props.isDeprovisionable ? (
{/* TODO: determine if we need this view if we want the workspace to remain
readonly once provisioned */}
{/* {props.isDeprovisionable ? (
<ProvisionedComponentInputs />
) : props.selectedComponent ? (
) : */}
{props.selectedComponent ? (
<ComponentInputs
selectedComponent={props.selectedComponent}
onFormChange={props.onFormChange}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ export function ComponentInputs(props: ComponentInputsProps) {
<EuiTitle size="m">
<h2>{props.selectedComponent.data.label || ''}</h2>
</EuiTitle>
<EuiText color="subdued">
{props.selectedComponent.data.description}
</EuiText>
<NewOrExistingTabs
selectedTabId={selectedTabId}
setSelectedTabId={setSelectedTabId}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import React from 'react';
import { EuiFlexItem, EuiSpacer } from '@elastic/eui';
import { TextField, JsonField, SelectField } from './input_fields';
import { TextField, JsonField, SelectField, ModelField } from './input_fields';
import { IComponentField } from '../../../../common';

/**
Expand Down Expand Up @@ -54,6 +54,19 @@ export function InputFieldList(props: InputFieldListProps) {
);
break;
}
case 'model': {
el = (
<EuiFlexItem key={idx}>
<ModelField
field={field}
componentId={props.componentId}
onFormChange={props.onFormChange}
/>
<EuiSpacer size={INPUT_FIELD_SPACER_SIZE} />
</EuiFlexItem>
);
break;
}
case 'json': {
el = (
<EuiFlexItem key={idx}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
export { TextField } from './text_field';
export { JsonField } from './json_field';
export { SelectField } from './select_field';
export { ModelField } from './model_field';
Loading
Loading