Skip to content

Commit

Permalink
Enable support for default model id on HybridQueryBuilder (#541)
Browse files Browse the repository at this point in the history
* Enable support for default model id on HybridQueryBuilder

Signed-off-by: Varun Jain <[email protected]>

* Adding tests and updating changelog.md

Signed-off-by: Varun Jain <[email protected]>

* Optimizing code

Signed-off-by: Varun Jain <[email protected]>

* modyfing the tests4

Signed-off-by: Varun Jain <[email protected]>

* Addressing Heemin comment

Signed-off-by: Varun Jain <[email protected]>

---------

Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun authored Jan 23, 2024
1 parent 7320cd3 commit 98e5534
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
- Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533))
- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541))
### Infrastructure
- BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515))
- Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.stream.Collectors;

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.Query;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.core.ParseField;
Expand All @@ -27,6 +28,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -295,4 +297,18 @@ private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, Quer
}).filter(Objects::nonNull).collect(Collectors.toList());
return queries;
}

/**
* visit method to parse the HybridQueryBuilder by a visitor
*/
@Override
public void visit(QueryBuilderVisitor visitor) {
visitor.accept(this);
// getChildVisitor of NeuralSearchQueryVisitor return this.
// therefore any argument can be passed. Here we have used Occcur.MUST as an argument.
QueryBuilderVisitor subVisitor = visitor.getChildVisitor(Occur.MUST);
for (QueryBuilder subQueryBuilder : queries) {
subQueryBuilder.visit(subVisitor);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.junit.Before;
import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

import com.google.common.primitives.Floats;
Expand Down Expand Up @@ -62,6 +63,25 @@ public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() {

}

@SneakyThrows
public void testNeuralQueryEnricherProcessor_whenHybridQueryBuilderAndNoModelIdPassed_thenSuccess() {
initializeIndexIfNotExist();
String modelId = getDeployedModelId();
createSearchRequestProcessor(modelId, search_pipeline);
createPipelineProcessor(modelId, ingest_pipeline);
updateIndexSettings(index, Settings.builder().put("index.search.default_pipeline", search_pipeline));
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1);
neuralQueryBuilder.queryText("Hello World");
neuralQueryBuilder.k(1);
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
Map<String, Object> response = search(index, hybridQueryBuilder, 2);

assertFalse(response.isEmpty());

}

@SneakyThrows
private void initializeIndexIfNotExist() {
if (index.equals(NeuralQueryEnricherProcessorIT.index) && !indexExists(index)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ArrayList;
import java.util.function.Supplier;

import org.apache.lucene.search.MatchNoDocsQuery;
Expand Down Expand Up @@ -708,6 +709,13 @@ public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() {
assertNotNull(hybridQueryBuilder);
}

public void testVisit() {
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(new NeuralQueryBuilder()).add(new NeuralSparseQueryBuilder());
List<QueryBuilder> visitedQueries = new ArrayList<>();
hybridQueryBuilder.visit(createTestVisitor(visitedQueries));
assertEquals(3, visitedQueries.size());
}

private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
if (!(innerObject instanceof Map)) {
fail("field name does not map to nested object");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonMap;
import static java.util.stream.Collectors.toList;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.BooleanClause;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilderVisitor;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;

import java.io.IOException;
Expand All @@ -20,11 +28,6 @@
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.common.CheckedConsumer;
Expand Down Expand Up @@ -230,4 +233,18 @@ public float getMaxScore(int upTo) {
protected static void initFeatureFlags() {
System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey(), "true");
}

protected static QueryBuilderVisitor createTestVisitor(List<QueryBuilder> visitedQueries) {
return new QueryBuilderVisitor() {
@Override
public void accept(QueryBuilder qb) {
visitedQueries.add(qb);
}

@Override
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) {
return this;
}
};
}
}

0 comments on commit 98e5534

Please sign in to comment.