From 21f1fc56dfd5373c8aecca96ee239ec8f2215400 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 26 Apr 2023 23:52:33 -0700 Subject: [PATCH] Improved the processor I/O and enabled the question via QuestionExtBuilder Signed-off-by: Navneet Verma --- .../neuralsearch/ext/QuestionExtBuilder.java | 91 +++++++++++ .../ml/MLCommonsClientAccessor.java | 56 +++++-- .../neuralsearch/plugin/NeuralSearch.java | 14 +- .../processor/GenerativeTextLLMProcessor.java | 151 ++++++++++++++++++ .../processor/SummaryProcessor.java | 125 --------------- .../GenerativeTextLLMProcessorFactory.java | 53 ++++++ .../factory/SummaryProcessorFactory.java | 36 ----- .../search/summary/GeneratedText.java | 74 +++++++++ ...a => GenerativeTextLLMSearchResponse.java} | 42 +++-- .../search/summary/ResultsSummary.java | 39 ----- 10 files changed, 452 insertions(+), 229 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/ext/QuestionExtBuilder.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/GenerativeTextLLMProcessor.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/SummaryProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/GenerativeTextLLMProcessorFactory.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/SummaryProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/summary/GeneratedText.java rename src/main/java/org/opensearch/neuralsearch/search/summary/{SummarySearchResponse.java => GenerativeTextLLMSearchResponse.java} (58%) delete mode 100644 src/main/java/org/opensearch/neuralsearch/search/summary/ResultsSummary.java diff --git a/src/main/java/org/opensearch/neuralsearch/ext/QuestionExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/ext/QuestionExtBuilder.java new file mode 100644 index 000000000..4ee69284a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/ext/QuestionExtBuilder.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.ext; + +import java.io.IOException; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.common.ParsingException; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +/** + * An extension class which will be used to read the Question Extension Object from Search request. + * We will remove this extension when we have a way to create a Natural Language Question from OpenSearch Query DSL. + */ +@Log4j2 +@EqualsAndHashCode(callSuper = false) +public class QuestionExtBuilder extends SearchExtBuilder { + + public static String NAME = "question_extension"; + + private static final ParseField QUESTION_FIELD = new ParseField("question"); + + @Getter + @Setter + private String question; + + /** + * Returns the name of the writeable object + */ + @Override + public String getWriteableName() { + return NAME; + } + + /** + * Write this into the {@linkplain StreamOutput}. + * + * @param out {@link StreamOutput} + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(question); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(QUESTION_FIELD.getPreferredName(), question); + return builder; + } + + public static QuestionExtBuilder parse(XContentParser parser) throws IOException { + final QuestionExtBuilder questionExtBuilder = new QuestionExtBuilder(); + XContentParser.Token token = parser.currentToken(); + String currentFieldName = null; + if (token != XContentParser.Token.START_OBJECT && (token = parser.nextToken()) != XContentParser.Token.START_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + "Expected [" + XContentParser.Token.START_OBJECT + "] but found [" + token + "]", + parser.getTokenLocation() + ); + } + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (QUESTION_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + questionExtBuilder.setQuestion(parser.text()); + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "Unknown key for a " + token + " in [" + currentFieldName + "].", + parser.getTokenLocation() + ); + } + } + + return questionExtBuilder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 8413f8276..e0d8e01da 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -31,6 +31,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.neuralsearch.search.summary.GeneratedText; import org.opensearch.neuralsearch.util.RetryUtil; /** @@ -152,8 +153,17 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } - public String predict(final String prompt, String modelId) throws ExecutionException, InterruptedException { - final MLInput mlInput = buildMLInputForPredictCall(prompt, modelId); + /** + * Will be used to call predict API of ML Commons, to get the response for an input from a modelId. + * + * @param context to be passed to LLM + * @param modelId internal reference of OpenSearch to call LLM + * @return {@link GeneratedText} + * @throws ExecutionException + * @throws InterruptedException + */ + public GeneratedText predict(final String context, String modelId) throws ExecutionException, InterruptedException { + final MLInput mlInput = buildMLInputForPredictCall(context, modelId); final MLOutput output = mlClient.predict(modelId, mlInput).get(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) output; @@ -161,23 +171,39 @@ public String predict(final String prompt, String modelId) throws ExecutionExcep for (final ModelTensors tensors : tensorOutputList) { final List tensorsList = tensors.getMlModelTensors(); for (final ModelTensor tensor : tensorsList) { - final String error = (String) tensor.getDataAsMap().get("error"); - if (StringUtils.isNotEmpty(error)) { - log.error("Error happened during the Processing of the input. Error : {}", error); - return error; - } - final List> choices = (List>) tensor.getDataAsMap().get("choices"); - if (!CollectionUtils.isEmpty(choices)) { - for (Map choice : choices) { - if (StringUtils.isNotEmpty((String) choice.get("text"))) { - return (String) choice.get("text"); - } + return parseModelTensorResponseForDifferentModels(tensor); + } + } + log.error("Tensors Object List is empty : " + output); + return new GeneratedText(StringUtils.EMPTY, "No Text found hence not able to summarize"); + } + + private GeneratedText parseModelTensorResponseForDifferentModels(final ModelTensor tensor) { + log.info("Output from the model is : {}", tensor); + Map dataAsMap = tensor.getDataAsMap(); + if (dataAsMap.containsKey("error")) { + return new GeneratedText(StringUtils.EMPTY, "Error happened during the call. Error is : " + dataAsMap.get("error")); + } else if (tensor.getDataAsMap().containsKey("choices")) { + final List> choices = (List>) tensor.getDataAsMap().get("choices"); + // This is Open AI output + if (!CollectionUtils.isEmpty(choices)) { + for (Map choice : choices) { + if (StringUtils.isNotEmpty((String) choice.get("text"))) { + return new GeneratedText((String) choice.get("text"), StringUtils.EMPTY); } } } + return new GeneratedText(StringUtils.EMPTY, "There is no data present in the response from Open AI model"); + } else if (dataAsMap.containsKey("results")) { + // this is for bedrock + List> results = (List>) dataAsMap.get("results"); + for (Map result : results) { + if (StringUtils.isNotEmpty((String) result.get("outputText"))) { + return new GeneratedText((String) result.get("outputText"), StringUtils.EMPTY); + } + } } - log.error("No Choice object found as ML Output is : " + output); - return "No Text found hence not able to summarize"; + return new GeneratedText(StringUtils.EMPTY, "Not able to pase the response from model. Cannot find choices " + "object, "); } private MLInput buildMLInputForPredictCall(final String prompt, String modelId) { diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 90472c643..3aade54cf 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -21,11 +21,12 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.neuralsearch.ext.QuestionExtBuilder; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.AppendQueryResponseProcessor; -import org.opensearch.neuralsearch.processor.SummaryProcessor; +import org.opensearch.neuralsearch.processor.GenerativeTextLLMProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; -import org.opensearch.neuralsearch.processor.factory.SummaryProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.GenerativeTextLLMProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.node.Node; @@ -80,7 +81,7 @@ public Map getProcesso org.opensearch.search.pipeline.Processor.Parameters parameters ) { final Map processorsMap = new HashMap<>(); - processorsMap.put(SummaryProcessor.TYPE, new SummaryProcessorFactory(getClientAccessor(parameters.client))); + processorsMap.put(GenerativeTextLLMProcessor.TYPE, new GenerativeTextLLMProcessorFactory(getClientAccessor(parameters.client))); processorsMap.put(AppendQueryResponseProcessor.TYPE, new AppendQueryResponseProcessor.Factory()); return processorsMap; } @@ -101,4 +102,11 @@ private MLCommonsClientAccessor getClientAccessor(final Client client) { return clientAccessor; } + @Override + public List> getSearchExts() { + return Collections.singletonList( + new SearchExtSpec<>(QuestionExtBuilder.NAME, input -> new QuestionExtBuilder(), QuestionExtBuilder::parse) + ); + } + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/GenerativeTextLLMProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/GenerativeTextLLMProcessor.java new file mode 100644 index 000000000..475b2afb8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/GenerativeTextLLMProcessor.java @@ -0,0 +1,151 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.neuralsearch.ext.QuestionExtBuilder; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.search.summary.GeneratedText; +import org.opensearch.neuralsearch.search.summary.GenerativeTextLLMSearchResponse; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +@Log4j2 +public class GenerativeTextLLMProcessor extends AbstractProcessor implements SearchResponseProcessor { + + public static final String TYPE = "llm_processor"; + private final List fields; + private final MLCommonsClientAccessor clientAccessor; + private final String modelId; + private final ContextType contextType; + + public GenerativeTextLLMProcessor( + final String tag, + final String description, + final MLCommonsClientAccessor mlCommonsClientAccessor, + final List fields, + final String modelId, + final String usecase + ) { + super(description, tag); + this.clientAccessor = mlCommonsClientAccessor; + this.fields = fields; + this.modelId = modelId; + this.contextType = usecase == null ? ContextType.SUMMARY : ContextType.valueOf(usecase.toUpperCase(Locale.ROOT)); + } + + @Override + public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) { + final GeneratedText generatedText = generateTextFromLLM(searchRequest, searchResponse); + generatedText.setProcessorTag(getTag()); + generatedText.setUsecase(contextType.name); + List generatedTexts = new ArrayList<>(); + if (searchResponse instanceof GenerativeTextLLMSearchResponse) { + List generatedTextList = ((GenerativeTextLLMSearchResponse) searchResponse).getGeneratedTextList(); + generatedTexts.addAll(generatedTextList); + } + generatedTexts.add(generatedText); + + return new GenerativeTextLLMSearchResponse( + searchResponse.getInternalResponse(), + searchResponse.getScrollId(), + searchResponse.getTotalShards(), + searchResponse.getSuccessfulShards(), + searchResponse.getSkippedShards(), + searchResponse.getTook().millis(), + searchResponse.getShardFailures(), + searchResponse.getClusters(), + generatedTexts + ); + } + + /** + * Gets the type of processor + */ + @Override + public String getType() { + return TYPE; + } + + private GeneratedText generateTextFromLLM(SearchRequest searchRequest, SearchResponse searchResponse) { + final String context = createContextForLLM(searchRequest, searchResponse); + try { + log.info("Calling the Model {} with a context {}", modelId, context); + return clientAccessor.predict(context, modelId); + } catch (Exception e) { + log.error("Error while calling ML Commons Predict API for context: {}", context, e); + return new GeneratedText( + StringUtils.EMPTY, + String.format( + Locale.ROOT, + "Error Happened while calling the Predict API for model : %s with context: %s. Error is: %s", + modelId, + context, + e.getMessage() + ) + ); + } + } + + private String createContextForLLM(SearchRequest searchRequest, SearchResponse searchResponse) { + final StringBuilder contextBuilder = new StringBuilder(); + createContextForPromptUsingSearchResponse(contextBuilder, searchResponse); + return contextType.createContext(contextBuilder, searchRequest); + } + + private void createContextForPromptUsingSearchResponse(final StringBuilder promptBuilder, final SearchResponse searchResponse) { + for (final SearchHit hit : searchResponse.getInternalResponse().hits()) { + for (String field : fields) { + if (hit.getSourceAsMap().get(field) != null) { + promptBuilder.append(hit.getSourceAsMap().get(field)).append("\\n"); + } + } + } + } + + @AllArgsConstructor + @Getter + private enum ContextType { + SUMMARY("summary", "\\nSummarize the above input for me. \\n"), + QANDA("QandA", "By considering above input from me, answer the question: ${question}") { + public String createContext(final StringBuilder contextBuilder, SearchRequest searchRequest) { + final List extBuilders = searchRequest.source().ext(); + String questionString = ""; + for (SearchExtBuilder builder : extBuilders) { + if (builder instanceof QuestionExtBuilder) { + questionString = ((QuestionExtBuilder) builder).getQuestion(); + } + } + if (StringUtils.isEmpty(questionString)) { + throw new OpenSearchException("Not able to get question string from Ext Builder list: " + extBuilders); + } + + final String updatedPrompt = getContext().replace("${question}", questionString); + return contextBuilder.insert(0, "\"").append(updatedPrompt).append("\"").toString(); + } + }; + + private final String name; + private final String context; + + public String createContext(final StringBuilder contextBuilder, SearchRequest searchRequest) { + return contextBuilder.insert(0, "\"").append(this.context).append("\"").toString(); + } + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SummaryProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SummaryProcessor.java deleted file mode 100644 index 6555d9c05..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/SummaryProcessor.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor; - -import java.util.List; -import java.util.Locale; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; - -import org.apache.commons.lang.StringUtils; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.search.summary.ResultsSummary; -import org.opensearch.neuralsearch.search.summary.SummarySearchResponse; -import org.opensearch.search.SearchHit; -import org.opensearch.search.pipeline.SearchResponseProcessor; - -@Log4j2 -public class SummaryProcessor extends AbstractProcessor implements SearchResponseProcessor { - - public static final String TYPE = "summary_processor"; - private final List fields; - private final MLCommonsClientAccessor clientAccessor; - private final String modelId; - private final Prompt prompt; - - public SummaryProcessor( - final String tag, - final String description, - final MLCommonsClientAccessor mlCommonsClientAccessor, - final List fields, - final String modelId, - final String promptType - ) { - super(description, tag); - this.clientAccessor = mlCommonsClientAccessor; - this.fields = fields; - this.modelId = modelId; - if (promptType == null) { - prompt = Prompt.SUMMARY; - } else { - prompt = Prompt.valueOf(promptType.toUpperCase(Locale.ROOT)); - } - } - - @Override - public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) { - final ResultsSummary summary = createSummary(searchRequest, searchResponse); - return new SummarySearchResponse( - searchResponse.getInternalResponse(), - searchResponse.getScrollId(), - searchResponse.getTotalShards(), - searchResponse.getSuccessfulShards(), - searchResponse.getSkippedShards(), - searchResponse.getTook().millis(), - searchResponse.getShardFailures(), - searchResponse.getClusters(), - summary - ); - } - - /** - * Gets the type of processor - */ - @Override - public String getType() { - return TYPE; - } - - private ResultsSummary createSummary(SearchRequest searchRequest, SearchResponse searchResponse) { - final String prompt = createPromptForLLM(searchRequest, searchResponse); - try { - log.info("Calling the Model {} with a prompt {}", modelId, prompt); - String summary = clientAccessor.predict(prompt, modelId); - return new ResultsSummary(summary, StringUtils.EMPTY); - } catch (Exception e) { - log.error("Error while calling ML Commons Predict API, ", e); - return new ResultsSummary(StringUtils.EMPTY, "Error Happened while calling the Summary Response. " + e.getMessage()); - } - } - - private String createPromptForLLM(SearchRequest searchRequest, SearchResponse searchResponse) { - final StringBuilder promptBuilder = new StringBuilder(); - createContextForPromptUsingSearchResponse(promptBuilder, searchResponse); - return prompt.createPrompt(promptBuilder, searchRequest); - } - - private void createContextForPromptUsingSearchResponse(final StringBuilder promptBuilder, final SearchResponse searchResponse) { - for (final SearchHit hit : searchResponse.getInternalResponse().hits()) { - for (String field : fields) { - if (hit.getSourceAsMap().get(field) != null) { - promptBuilder.append(hit.getSourceAsMap().get(field)).append("\\n"); - } - } - } - } - - @AllArgsConstructor - @Getter - private enum Prompt { - SUMMARY("summary", "\\n Summarize the above input for me. \\n"), - QUESTION("question", "By considering " + "above input from me, answer the \\n question: ${question}") { - public String createPrompt(final StringBuilder context, SearchRequest searchRequest) { - // Find a way in which we can get the query produced by a user - final String queryString = searchRequest.source().query().toString().replace("\"", "\\\""); - final String updatedPrompt = getPrompt().replace("${question}", queryString); - return context.insert(0, "\"").append(updatedPrompt).append("\"").toString(); - } - }; - - private final String name; - private final String prompt; - - public String createPrompt(final StringBuilder context, SearchRequest searchRequest) { - return context.insert(0, "\"").append(prompt).append("\"").toString(); - } - } - -} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/GenerativeTextLLMProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/GenerativeTextLLMProcessorFactory.java new file mode 100644 index 000000000..fc66aee0e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/GenerativeTextLLMProcessorFactory.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; + +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.core.ParseField; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.GenerativeTextLLMProcessor; +import org.opensearch.search.pipeline.Processor; + +/** + * A Factory class for creating {@link GenerativeTextLLMProcessor} + */ +public class GenerativeTextLLMProcessorFactory implements Processor.Factory { + + private static final ParseField MODEL_ID = new ParseField("modelId"); + private static final ParseField USE_CASE = new ParseField("usecase"); + + private final MLCommonsClientAccessor clientAccessor; + + public GenerativeTextLLMProcessorFactory(final MLCommonsClientAccessor clientAccessor) { + this.clientAccessor = clientAccessor; + } + + @Override + public Processor create(Map registry, String processorTag, String description, Map config) { + final List fields = ConfigurationUtils.readList( + GenerativeTextLLMProcessor.TYPE, + processorTag, + config, + ParseField.CommonFields.FIELDS.getPreferredName() + ); + final String modelId = readStringProperty(GenerativeTextLLMProcessor.TYPE, processorTag, config, MODEL_ID.getPreferredName()); + final String usecase = readOptionalStringProperty( + GenerativeTextLLMProcessor.TYPE, + processorTag, + config, + USE_CASE.getPreferredName() + ); + final String tag = StringUtils.isEmpty(processorTag) ? modelId : processorTag; + return new GenerativeTextLLMProcessor(tag, description, clientAccessor, fields, modelId, usecase); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SummaryProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SummaryProcessorFactory.java deleted file mode 100644 index c9b1d2b91..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SummaryProcessorFactory.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor.factory; - -import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; -import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; - -import java.util.List; -import java.util.Map; - -import org.opensearch.ingest.ConfigurationUtils; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.SummaryProcessor; -import org.opensearch.search.pipeline.Processor; - -/** - * A Factory class for creating {@link SummaryProcessor} - */ -public class SummaryProcessorFactory implements Processor.Factory { - private final MLCommonsClientAccessor clientAccessor; - - public SummaryProcessorFactory(final MLCommonsClientAccessor clientAccessor) { - this.clientAccessor = clientAccessor; - } - - @Override - public Processor create(Map registry, String processorTag, String description, Map config) { - final List fields = ConfigurationUtils.readList(SummaryProcessor.TYPE, processorTag, config, "fields"); - final String modelId = readStringProperty(SummaryProcessor.TYPE, processorTag, config, "modelId"); - final String promptType = readOptionalStringProperty(SummaryProcessor.TYPE, processorTag, config, "prompt"); - return new SummaryProcessor(processorTag, description, clientAccessor, fields, modelId, promptType); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/search/summary/GeneratedText.java b/src/main/java/org/opensearch/neuralsearch/search/summary/GeneratedText.java new file mode 100644 index 000000000..08df6e89a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/summary/GeneratedText.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.summary; + +import java.io.IOException; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * This class holds the summary of the results which are returned by OpenSearch and will be sent back to the customer. + * The summary will be obtained from LLM summary models. + */ +@Getter +@AllArgsConstructor +public class GeneratedText implements ToXContentFragment, Writeable { + private final String value; + private final String error; + @Setter + private String processorTag; + + @Setter + private String usecase; + + public GeneratedText(StreamInput in) throws IOException { + processorTag = in.readString(); + usecase = in.readString(); + value = in.readOptionalString(); + error = in.readOptionalString(); + } + + public GeneratedText(final String value, final String error) { + this.value = value; + this.error = error; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (StringUtils.isNotEmpty(value)) { + builder.field("value", value); + } else if (StringUtils.isNotEmpty(error)) { + builder.field("error", error); + } + builder.field("processorTag", processorTag); + builder.field("usecase", usecase); + builder.endObject(); + return builder; + } + + /** + * Write this into the {@linkplain StreamOutput}. + * + * @param out + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(processorTag); + out.writeString(usecase); + out.writeOptionalString(value); + out.writeOptionalString(error); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/summary/SummarySearchResponse.java b/src/main/java/org/opensearch/neuralsearch/search/summary/GenerativeTextLLMSearchResponse.java similarity index 58% rename from src/main/java/org/opensearch/neuralsearch/search/summary/SummarySearchResponse.java rename to src/main/java/org/opensearch/neuralsearch/search/summary/GenerativeTextLLMSearchResponse.java index a1f0e4497..d35b7b6bc 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/summary/SummarySearchResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/search/summary/GenerativeTextLLMSearchResponse.java @@ -6,26 +6,42 @@ package org.opensearch.neuralsearch.search.summary; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import lombok.Getter; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.XContentBuilder; @Log4j2 -public class SummarySearchResponse extends SearchResponse { +public class GenerativeTextLLMSearchResponse extends SearchResponse { - private final ResultsSummary resultsSummary; + @Getter + @Setter + private List generatedTextList; - public SummarySearchResponse(StreamInput in) throws IOException { + private static final ParseField GENERATED_TEXT = new ParseField("generatedText"); + + public GenerativeTextLLMSearchResponse(StreamInput in) throws IOException { super(in); - resultsSummary = new ResultsSummary(); + generatedTextList = in.readList(GeneratedText::new); } - public SummarySearchResponse( + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeList(generatedTextList); + } + + public GenerativeTextLLMSearchResponse( SearchResponseSections internalResponse, String scrollId, int totalShards, @@ -34,14 +50,14 @@ public SummarySearchResponse( long tookInMillis, ShardSearchFailure[] shardFailures, Clusters clusters, - ResultsSummary summary + List generatedTextList ) { super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); - resultsSummary = summary; - + this.generatedTextList = new ArrayList<>(); + this.generatedTextList.addAll(generatedTextList); } - public SummarySearchResponse( + public GenerativeTextLLMSearchResponse( SearchResponseSections internalResponse, String scrollId, int totalShards, @@ -63,14 +79,18 @@ public SummarySearchResponse( clusters, pointInTimeId ); - resultsSummary = new ResultsSummary(); + this.generatedTextList = new ArrayList<>(); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + builder.startArray(GENERATED_TEXT.getPreferredName()); + for (GeneratedText generatedText : generatedTextList) { + generatedText.toXContent(builder, params); + } + builder.endArray(); innerToXContent(builder, params); - resultsSummary.toXContent(builder, params); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/summary/ResultsSummary.java b/src/main/java/org/opensearch/neuralsearch/search/summary/ResultsSummary.java deleted file mode 100644 index a5e0f8346..000000000 --- a/src/main/java/org/opensearch/neuralsearch/search/summary/ResultsSummary.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.search.summary; - -import java.io.IOException; - -import lombok.Getter; - -import org.opensearch.core.xcontent.ToXContentFragment; -import org.opensearch.core.xcontent.XContentBuilder; - -/** - * This class holds the summary of the results which are returned by OpenSearch and will be sent back to the customer. - * The summary will be obtained from LLM summary models. - */ -@Getter -public class ResultsSummary implements ToXContentFragment { - private final String summary; - private final String error; - - public ResultsSummary() { - summary = "This is my summary"; - error = null; - } - - public ResultsSummary(final String summary, final String error) { - this.summary = summary; - this.error = error; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("summary", summary); - return builder; - } -}