Skip to content

Commit

Permalink
Improved the processor I/O and enabled the question via QuestionExtBu…
Browse files Browse the repository at this point in the history
…ilder

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Apr 27, 2023
1 parent 453aeb6 commit 21f1fc5
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 229 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.ext;

import java.io.IOException;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.core.ParseField;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchExtBuilder;

/**
* An extension class which will be used to read the Question Extension Object from Search request.
* We will remove this extension when we have a way to create a Natural Language Question from OpenSearch Query DSL.
*/
@Log4j2
@EqualsAndHashCode(callSuper = false)
public class QuestionExtBuilder extends SearchExtBuilder {

public static String NAME = "question_extension";

private static final ParseField QUESTION_FIELD = new ParseField("question");

@Getter
@Setter
private String question;

/**
* Returns the name of the writeable object
*/
@Override
public String getWriteableName() {
return NAME;
}

/**
* Write this into the {@linkplain StreamOutput}.
*
* @param out {@link StreamOutput}
*/
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(question);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(QUESTION_FIELD.getPreferredName(), question);
return builder;
}

public static QuestionExtBuilder parse(XContentParser parser) throws IOException {
final QuestionExtBuilder questionExtBuilder = new QuestionExtBuilder();
XContentParser.Token token = parser.currentToken();
String currentFieldName = null;
if (token != XContentParser.Token.START_OBJECT && (token = parser.nextToken()) != XContentParser.Token.START_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"Expected [" + XContentParser.Token.START_OBJECT + "] but found [" + token + "]",
parser.getTokenLocation()
);
}
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token.isValue()) {
if (QUESTION_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
questionExtBuilder.setQuestion(parser.text());
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"Unknown key for a " + token + " in [" + currentFieldName + "].",
parser.getTokenLocation()
);
}
}

return questionExtBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.search.summary.GeneratedText;
import org.opensearch.neuralsearch.util.RetryUtil;

/**
Expand Down Expand Up @@ -152,32 +153,57 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

public String predict(final String prompt, String modelId) throws ExecutionException, InterruptedException {
final MLInput mlInput = buildMLInputForPredictCall(prompt, modelId);
/**
* Will be used to call predict API of ML Commons, to get the response for an input from a modelId.
*
* @param context to be passed to LLM
* @param modelId internal reference of OpenSearch to call LLM
* @return {@link GeneratedText}
* @throws ExecutionException
* @throws InterruptedException
*/
public GeneratedText predict(final String context, String modelId) throws ExecutionException, InterruptedException {
final MLInput mlInput = buildMLInputForPredictCall(context, modelId);
final MLOutput output = mlClient.predict(modelId, mlInput).get();

final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) output;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
final String error = (String) tensor.getDataAsMap().get("error");
if (StringUtils.isNotEmpty(error)) {
log.error("Error happened during the Processing of the input. Error : {}", error);
return error;
}
final List<Map<String, Object>> choices = (List<Map<String, Object>>) tensor.getDataAsMap().get("choices");
if (!CollectionUtils.isEmpty(choices)) {
for (Map<String, Object> choice : choices) {
if (StringUtils.isNotEmpty((String) choice.get("text"))) {
return (String) choice.get("text");
}
return parseModelTensorResponseForDifferentModels(tensor);
}
}
log.error("Tensors Object List is empty : " + output);
return new GeneratedText(StringUtils.EMPTY, "No Text found hence not able to summarize");
}

private GeneratedText parseModelTensorResponseForDifferentModels(final ModelTensor tensor) {
log.info("Output from the model is : {}", tensor);
Map<String, ?> dataAsMap = tensor.getDataAsMap();
if (dataAsMap.containsKey("error")) {
return new GeneratedText(StringUtils.EMPTY, "Error happened during the call. Error is : " + dataAsMap.get("error"));
} else if (tensor.getDataAsMap().containsKey("choices")) {
final List<Map<String, Object>> choices = (List<Map<String, Object>>) tensor.getDataAsMap().get("choices");
// This is Open AI output
if (!CollectionUtils.isEmpty(choices)) {
for (Map<String, Object> choice : choices) {
if (StringUtils.isNotEmpty((String) choice.get("text"))) {
return new GeneratedText((String) choice.get("text"), StringUtils.EMPTY);
}
}
}
return new GeneratedText(StringUtils.EMPTY, "There is no data present in the response from Open AI model");
} else if (dataAsMap.containsKey("results")) {
// this is for bedrock
List<Map<String, Object>> results = (List<Map<String, Object>>) dataAsMap.get("results");
for (Map<String, Object> result : results) {
if (StringUtils.isNotEmpty((String) result.get("outputText"))) {
return new GeneratedText((String) result.get("outputText"), StringUtils.EMPTY);
}
}
}
log.error("No Choice object found as ML Output is : " + output);
return "No Text found hence not able to summarize";
return new GeneratedText(StringUtils.EMPTY, "Not able to pase the response from model. Cannot find choices " + "object, ");
}

private MLInput buildMLInputForPredictCall(final String prompt, String modelId) {
Expand Down
14 changes: 11 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ext.QuestionExtBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.AppendQueryResponseProcessor;
import org.opensearch.neuralsearch.processor.SummaryProcessor;
import org.opensearch.neuralsearch.processor.GenerativeTextLLMProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.SummaryProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.GenerativeTextLLMProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.node.Node;
Expand Down Expand Up @@ -80,7 +81,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory> getProcesso
org.opensearch.search.pipeline.Processor.Parameters parameters
) {
final Map<String, org.opensearch.search.pipeline.Processor.Factory> processorsMap = new HashMap<>();
processorsMap.put(SummaryProcessor.TYPE, new SummaryProcessorFactory(getClientAccessor(parameters.client)));
processorsMap.put(GenerativeTextLLMProcessor.TYPE, new GenerativeTextLLMProcessorFactory(getClientAccessor(parameters.client)));
processorsMap.put(AppendQueryResponseProcessor.TYPE, new AppendQueryResponseProcessor.Factory());
return processorsMap;
}
Expand All @@ -101,4 +102,11 @@ private MLCommonsClientAccessor getClientAccessor(final Client client) {
return clientAccessor;
}

@Override
public List<SearchExtSpec<?>> getSearchExts() {
return Collections.singletonList(
new SearchExtSpec<>(QuestionExtBuilder.NAME, input -> new QuestionExtBuilder(), QuestionExtBuilder::parse)
);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang.StringUtils;
import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.ext.QuestionExtBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.search.summary.GeneratedText;
import org.opensearch.neuralsearch.search.summary.GenerativeTextLLMSearchResponse;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.SearchResponseProcessor;

@Log4j2
public class GenerativeTextLLMProcessor extends AbstractProcessor implements SearchResponseProcessor {

public static final String TYPE = "llm_processor";
private final List<String> fields;
private final MLCommonsClientAccessor clientAccessor;
private final String modelId;
private final ContextType contextType;

public GenerativeTextLLMProcessor(
final String tag,
final String description,
final MLCommonsClientAccessor mlCommonsClientAccessor,
final List<String> fields,
final String modelId,
final String usecase
) {
super(description, tag);
this.clientAccessor = mlCommonsClientAccessor;
this.fields = fields;
this.modelId = modelId;
this.contextType = usecase == null ? ContextType.SUMMARY : ContextType.valueOf(usecase.toUpperCase(Locale.ROOT));
}

@Override
public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) {
final GeneratedText generatedText = generateTextFromLLM(searchRequest, searchResponse);
generatedText.setProcessorTag(getTag());
generatedText.setUsecase(contextType.name);
List<GeneratedText> generatedTexts = new ArrayList<>();
if (searchResponse instanceof GenerativeTextLLMSearchResponse) {
List<GeneratedText> generatedTextList = ((GenerativeTextLLMSearchResponse) searchResponse).getGeneratedTextList();
generatedTexts.addAll(generatedTextList);
}
generatedTexts.add(generatedText);

return new GenerativeTextLLMSearchResponse(
searchResponse.getInternalResponse(),
searchResponse.getScrollId(),
searchResponse.getTotalShards(),
searchResponse.getSuccessfulShards(),
searchResponse.getSkippedShards(),
searchResponse.getTook().millis(),
searchResponse.getShardFailures(),
searchResponse.getClusters(),
generatedTexts
);
}

/**
* Gets the type of processor
*/
@Override
public String getType() {
return TYPE;
}

private GeneratedText generateTextFromLLM(SearchRequest searchRequest, SearchResponse searchResponse) {
final String context = createContextForLLM(searchRequest, searchResponse);
try {
log.info("Calling the Model {} with a context {}", modelId, context);
return clientAccessor.predict(context, modelId);
} catch (Exception e) {
log.error("Error while calling ML Commons Predict API for context: {}", context, e);
return new GeneratedText(
StringUtils.EMPTY,
String.format(
Locale.ROOT,
"Error Happened while calling the Predict API for model : %s with context: %s. Error is: %s",
modelId,
context,
e.getMessage()
)
);
}
}

private String createContextForLLM(SearchRequest searchRequest, SearchResponse searchResponse) {
final StringBuilder contextBuilder = new StringBuilder();
createContextForPromptUsingSearchResponse(contextBuilder, searchResponse);
return contextType.createContext(contextBuilder, searchRequest);
}

private void createContextForPromptUsingSearchResponse(final StringBuilder promptBuilder, final SearchResponse searchResponse) {
for (final SearchHit hit : searchResponse.getInternalResponse().hits()) {
for (String field : fields) {
if (hit.getSourceAsMap().get(field) != null) {
promptBuilder.append(hit.getSourceAsMap().get(field)).append("\\n");
}
}
}
}

@AllArgsConstructor
@Getter
private enum ContextType {
SUMMARY("summary", "\\nSummarize the above input for me. \\n"),
QANDA("QandA", "By considering above input from me, answer the question: ${question}") {
public String createContext(final StringBuilder contextBuilder, SearchRequest searchRequest) {
final List<SearchExtBuilder> extBuilders = searchRequest.source().ext();
String questionString = "";
for (SearchExtBuilder builder : extBuilders) {
if (builder instanceof QuestionExtBuilder) {
questionString = ((QuestionExtBuilder) builder).getQuestion();
}
}
if (StringUtils.isEmpty(questionString)) {
throw new OpenSearchException("Not able to get question string from Ext Builder list: " + extBuilders);
}

final String updatedPrompt = getContext().replace("${question}", questionString);
return contextBuilder.insert(0, "\"").append(updatedPrompt).append("\"").toString();
}
};

private final String name;
private final String context;

public String createContext(final StringBuilder contextBuilder, SearchRequest searchRequest) {
return contextBuilder.insert(0, "\"").append(this.context).append("\"").toString();
}
}

}
Loading

0 comments on commit 21f1fc5

Please sign in to comment.