Skip to content

Commit

Permalink
After further review round
Browse files Browse the repository at this point in the history
Signed-off-by: br3no <[email protected]>
  • Loading branch information
br3no committed Nov 6, 2024
1 parent a5977b2 commit 390d04c
Show file tree
Hide file tree
Showing 15 changed files with 504 additions and 326 deletions.
430 changes: 252 additions & 178 deletions src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -48,17 +49,19 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
this.modelId,
inferenceList,
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,23 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
* This processor is used for user input data text embedding processing, model_id can be used to
* indicate which model user use, and field_map can be used to indicate which fields needs text
* embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
public final class TextEmbeddingProcessor extends InferenceProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";

private static final AsymmetricTextEmbeddingParameters PASSAGE_PARAMETERS = AsymmetricTextEmbeddingParameters.builder()
.embeddingContentType(EmbeddingContentType.PASSAGE)
.build();

public TextEmbeddingProcessor(
String tag,
String description,
Expand All @@ -50,9 +56,7 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(
this.modelId,
inferenceList,
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
Expand All @@ -62,6 +66,9 @@ public void doExecute(

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException));
mlCommonsClientAccessor.inferenceSentences(
new InferenceRequest.Builder(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
ActionListener.wrap(handler::accept, onException)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/**
Expand Down Expand Up @@ -113,10 +114,13 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
if (inferenceMap.isEmpty()) {
handler.accept(ingestDocument, null);
} else {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentencesMap(
new InferenceRequest.Builder(this.modelId).inputObjects(inferenceMap).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}
} catch (Exception e) {
handler.accept(null, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
Expand Down Expand Up @@ -73,9 +74,9 @@ public void rescoreSearchResponse(
List<?> ctxList = (List<?>) ctxObj;
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
mlCommonsClientAccessor.inferenceSimilarity(
modelId,
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
contexts,
new InferenceRequest.Builder(modelId).queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD))
.inputTexts(contexts)
.build(),
listener
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;

/**
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a
* k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as
* the query vector for the k-NN search.
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a
* wrapper around a k-NN vector query. It uses a ML language model to produce a dense vector from a
* query string that is then used as the query vector for the k-NN search.
*/

@Log4j2
Expand All @@ -86,6 +87,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
static final ParseField K_FIELD = new ParseField("k");

private static final int DEFAULT_K = 10;
private static final AsymmetricTextEmbeddingParameters QUERY_PARAMETERS = AsymmetricTextEmbeddingParameters.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.build();

private static MLCommonsClientAccessor ML_CLIENT;

Expand Down Expand Up @@ -335,10 +339,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
inferenceInput.put(INPUT_IMAGE, queryImage());
}
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentences(
modelId(),
inferenceInput,
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(),
((client, actionListener) -> ML_CLIENT.inferenceSentencesMap(
new InferenceRequest.Builder(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(),
ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
Expand Down Expand Up @@ -368,8 +370,12 @@ protected Query doToQuery(QueryShardContext queryShardContext) {

@Override
protected boolean doEquals(NeuralQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queryText, obj.queryText);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

Expand Down Expand Up @@ -341,8 +342,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
// it splits the tokens using a threshold defined by a ratio of the maximum score of tokens, updating the token set
// accordingly.
return ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult(
modelId(),
List.of(queryText),
new InferenceRequest.Builder(modelId()).inputTexts(List.of(queryText)).build(),
ActionListener.wrap(mapResultList -> {
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
Expand Down
Loading

0 comments on commit 390d04c

Please sign in to comment.