Skip to content

Commit

Permalink
[3.0] Update neural-search for OpenSearch 3.0 compatibility (#1141)
Browse files Browse the repository at this point in the history
* Make java compile task pass for main code and tests

Signed-off-by: Martin Gaievski <[email protected]>

* Adopting change in client class package name

Signed-off-by: Martin Gaievski <[email protected]>

* Fixed unit tests

Signed-off-by: Martin Gaievski <[email protected]>

* Fixed integ tests for: hq explain

Signed-off-by: Martin Gaievski <[email protected]>

* Fixed wrapped bool queries for latest core, fixed failing tests

Signed-off-by: Martin Gaievski <[email protected]>

* Adjust bwc versions after 2.19 release

Signed-off-by: Martin Gaievski <[email protected]>

* Reverting back the secure testing flag, refactor BWC tests

Signed-off-by: Martin Gaievski <[email protected]>

* Multiple minor changes: drop precommit for win CI, refator BWC, added changelog

Signed-off-by: Martin Gaievski <[email protected]>

---------

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Feb 17, 2025
1 parent 0769ad7 commit c36ca15
Show file tree
Hide file tree
Showing 49 changed files with 381 additions and 344 deletions.
22 changes: 0 additions & 22 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,6 @@ jobs:
name: coverage-report-${{ matrix.os }}-${{ matrix.java }}
path: ./jacocoTestReport.xml

Precommit-neural-search-windows:
strategy:
matrix:
java: [21, 23]
os: [windows-latest]

name: Pre-commit Windows
runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v4

- name: Setup Java ${{ matrix.java }}
uses: actions/setup-java@v4
with:
distribution: 'temurin'
java-version: ${{ matrix.java }}

- name: Run build
run: |
./gradlew precommit --parallel
Precommit-codecov:
needs: Precommit-neural-search-linux
strategy:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/backwards_compatibility_tests_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ jobs:
Restart-Upgrade-BWCTests-NeuralSearch:
strategy:
matrix:
java: [21, 23]
os: [ubuntu-latest,windows-latest]
bwc_version : ["2.9.0","2.10.0","2.11.0","2.12.0","2.13.0","2.14.0","2.15.0","2.16.0","2.17.0","2.18.0","2.19.0-SNAPSHOT"]
java: [ 21, 23 ]
os: [ubuntu-latest]
bwc_version : ["2.9.0","2.10.0","2.11.0","2.12.0","2.13.0","2.14.0","2.15.0","2.16.0","2.17.0","2.18.0","2.19.0","2.20.0-SNAPSHOT"]
opensearch_version : [ "3.0.0-SNAPSHOT" ]

name: NeuralSearch Restart-Upgrade BWC Tests
Expand All @@ -42,7 +42,7 @@ jobs:
matrix:
java: [21, 23]
os: [ubuntu-latest,windows-latest]
bwc_version: [ "2.19.0-SNAPSHOT" ]
bwc_version: [ "2.20.0-SNAPSHOT" ]
opensearch_version: [ "3.0.0-SNAPSHOT" ]

name: NeuralSearch Rolling-Upgrade BWC Tests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_aggregations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
Check-neural-search-windows:
strategy:
matrix:
java: [21, 23]
java: [23]
os: [windows-latest]

name: Integ Tests Windows
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-2 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
### Bug Fixes
### Infrastructure
- [3.0] Update neural-search for OpenSearch 3.0 compatibility ([#1141](https://github.com/opensearch-project/neural-search/pull/1141))
### Documentation
### Maintenance
### Refactoring
Expand Down
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import java.util.concurrent.Callable

buildscript {
ext {
opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")
opensearch_version = System.getProperty("opensearch.version", "3.0.0-alpha1-SNAPSHOT")
buildVersionQualifier = System.getProperty("build.version_qualifier", "alpha1")
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
version_tokens = opensearch_version.tokenize('-')
opensearch_build = version_tokens[0] + '.0'
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# https://github.com/opensearch-project/OpenSearch/blob/main/libs/core/src/main/java/org/opensearch/Version.java .
# Wired compatibility of OpenSearch works like 3.x version is compatible with 2.(latest-major) version.
# Therefore, to run rolling-upgrade BWC Test on local machine the BWC version here should be set 2.(latest-major).
systemProp.bwc.version=2.19.0-SNAPSHOT
systemProp.bwc.version=2.20.0-SNAPSHOT
systemProp.bwc.bundle.version=2.19.0

# For fixing Spotless check with Java 17
Expand Down
5 changes: 5 additions & 0 deletions gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#

distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionSha256Sum=2ab88d6de2c23e6adae7363ae6e29cbdd2a709e992929b48b6530fd0c7133bd6
Expand Down
4 changes: 4 additions & 0 deletions gradlew
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#!/bin/sh
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#

#
# Copyright © 2015-2021 the original authors.
Expand Down
3 changes: 3 additions & 0 deletions gradlew.bat
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
@rem
@rem Copyright OpenSearch Contributors
@rem SPDX-License-Identifier: Apache-2.0
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import java.nio.file.Path;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;

import org.junit.Before;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX;
import static org.opensearch.neuralsearch.util.TestUtils.OLD_CLUSTER;
Expand All @@ -23,6 +26,8 @@

public abstract class AbstractRollingUpgradeTestCase extends BaseNeuralSearchIT {

private static final Set<MLModelState> READY_FOR_INFERENCE_STATES = Set.of(MLModelState.LOADED, MLModelState.DEPLOYED);

@Before
protected String getIndexNameForTest() {
// Creating index name by concatenating "neural-bwc-" prefix with test method name
Expand Down Expand Up @@ -159,4 +164,24 @@ protected void createPipelineForTextChunkingProcessor(String pipelineName) throw
);
createPipelineProcessor(requestBody, pipelineName, "", null);
}

protected boolean isModelReadyForInference(final MLModelState mlModelState) throws Exception {
return READY_FOR_INFERENCE_STATES.contains(mlModelState);
}

protected void waitForModelToLoad(String modelId) throws Exception {
int maxAttempts = 30; // Maximum number of attempts
int waitTimeInSeconds = 2; // Time to wait between attempts

for (int attempt = 0; attempt < maxAttempts; attempt++) {
MLModelState state = getModelState(modelId);
if (isModelReadyForInference(state)) {
logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1);
return;
}
logger.info("Waiting for model {} to load. Current state: {}. Attempt {}/{}", modelId, state, attempt + 1, maxAttempts);
Thread.sleep(waitTimeInSeconds * 1000);
}
throw new RuntimeException("Model " + modelId + " failed to load after " + maxAttempts + " attempts");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,4 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
throw new IllegalStateException("Unexpected value: " + getClusterType());
}
}

private void waitForModelToLoad(String modelId) throws Exception {
int maxAttempts = 30; // Maximum number of attempts
int waitTimeInSeconds = 2; // Time to wait between attempts

for (int attempt = 0; attempt < maxAttempts; attempt++) {
MLModelState state = getModelState(modelId);
if (state == MLModelState.LOADED) {
logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1);
return;
}
logger.info("Waiting for model {} to load. Current state: {}. Attempt {}/{}", modelId, state, attempt + 1, maxAttempts);
Thread.sleep(waitTimeInSeconds * 1000);
}
throw new RuntimeException("Model " + modelId + " failed to load after " + maxAttempts + " attempts");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.Optional;
import java.util.function.Supplier;

import org.opensearch.client.Client;
import org.opensearch.transport.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresD
* @return max score
*/
private float maxScoreForShard(CompoundTopDocs updatedTopDocs, boolean isSortEnabled) {
if (updatedTopDocs.getTotalHits().value == 0 || updatedTopDocs.getScoreDocs().isEmpty()) {
if (updatedTopDocs.getTotalHits().value() == 0 || updatedTopDocs.getScoreDocs().isEmpty()) {
return MAX_SCORE_WHEN_NO_HITS_FOUND;
}
if (isSortEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private void combineShardScores(
final CompoundTopDocs compoundQueryTopDocs,
final Sort sort
) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value() == 0) {
return;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
Expand Down Expand Up @@ -292,7 +292,7 @@ private void updateQueryTopDocsWithCombinedScores(
boolean isSortingEnabled
) {
// - max number of hits will be the same which are passed from QueryPhase
long maxHits = compoundQueryTopDocs.getTotalHits().value;
long maxHits = compoundQueryTopDocs.getTotalHits().value();
// - update query search results with normalized scores
compoundQueryTopDocs.setScoreDocs(
getCombinedScoreDocs(
Expand All @@ -309,7 +309,7 @@ private void updateQueryTopDocsWithCombinedScores(

private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final long maxHits) {
TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation() == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
}
return new TotalHits(maxHits, totalHits);
Expand Down Expand Up @@ -343,7 +343,7 @@ private List<ExplanationDetails> explainByShard(
final CompoundTopDocs compoundQueryTopDocs,
final Sort sort
) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value() == 0) {
return List.of();
}
// create map of normalized scores results returned from the single shard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void rerank(
final ActionListener<SearchResponse> listener
) {
try {
if (searchResponse.getHits().getTotalHits().value == 0) {
if (searchResponse.getHits().getTotalHits().value() == 0) {
listener.onResponse(searchResponse);
return;
}
Expand Down
25 changes: 23 additions & 2 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
Expand Down Expand Up @@ -42,6 +43,26 @@ public final class HybridQuery extends Query implements Iterable<Query> {
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final HybridQueryContext hybridQueryContext) {
this(
subQueries,
hybridQueryContext,
filterQueries == null
? null
: filterQueries.stream().map(query -> new BooleanClause(query, BooleanClause.Occur.FILTER)).collect(Collectors.toList())
);
}

/**
* Create new instance of hybrid query object based on collection of sub queries and boolean clauses that are used as filters for each sub-query
* @param subQueries
* @param hybridQueryContext
* @param booleanClauses
*/
public HybridQuery(
final Collection<Query> subQueries,
final HybridQueryContext hybridQueryContext,
final List<BooleanClause> booleanClauses
) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
Expand All @@ -50,14 +71,14 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
if (Objects.nonNull(paginationDepth) && paginationDepth == 0) {
throw new IllegalArgumentException("pagination_depth must not be zero");
}
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
if (Objects.isNull(booleanClauses) || booleanClauses.isEmpty()) {
this.subQueries = new ArrayList<>(subQueries);
} else {
List<Query> modifiedSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(subQuery, BooleanClause.Occur.MUST);
filterQueries.forEach(filterQuery -> builder.add(filterQuery, BooleanClause.Occur.FILTER));
booleanClauses.forEach(filterQuery -> builder.add(booleanClauses));
modifiedSubQueries.add(builder.build());
}
this.subQueries = modifiedSubQueries;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public HybridQueryScorer(final Weight weight, final List<Scorer> subScorers) thr
}

HybridQueryScorer(final Weight weight, final List<Scorer> subScorers, final ScoreMode scoreMode) throws IOException {
super(weight);
super();
this.subScorers = Collections.unmodifiableList(subScorers);
this.numSubqueries = subScorers.size();
this.subScorersPQ = initializeSubScorersPQ();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,6 @@ private Void addScoreSupplier(Weight weight, HybridQueryExecutorCollector<LeafRe
return null;
}

/**
* Create the scorer used to score our associated Query
*
* @param context the {@link LeafReaderContext} for which to return the
* {@link Scorer}.
* @return scorer of hybrid query that contains scorers of each sub-query, null if there are no matches in any sub-query
* @throws IOException
*/
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
ScorerSupplier supplier = scorerSupplier(context);
if (supplier == null) {
return null;
}
supplier.setTopLevelScoringClause();
return supplier.get(Long.MAX_VALUE);
}

/**
* Check if weight object can be cached
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package org.opensearch.neuralsearch.query;

import lombok.Getter;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
Expand Down Expand Up @@ -45,8 +44,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = knnQuery.rewrite(reader);
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
Query rewritten = knnQuery.rewrite(indexSearcher);
if (rewritten == knnQuery) {
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.client.Client;
import org.opensearch.transport.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.ParseField;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class HybridDisiWrapper extends DisiWrapper {
private final int subQueryIndex;

public HybridDisiWrapper(Scorer scorer, int subQueryIndex) {
super(scorer);
super(scorer, false);
this.subQueryIndex = subQueryIndex;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE
return (HybridQueryScorer) scorer;
}
for (Scorable.ChildScorable childScorable : scorer.getChildren()) {
HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child);
HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child());
if (Objects.nonNull(hybridQueryScorer)) {
log.debug(
String.format(
Locale.ROOT,
"found hybrid query scorer, it's child of scorer %s",
childScorable.child.getClass().getSimpleName()
childScorable.child().getClass().getSimpleName()
)
);
return hybridQueryScorer;
Expand Down Expand Up @@ -289,7 +289,7 @@ private void initializeLeafFieldComparators(LeafReaderContext context, int subQu
private void initializeComparators(LeafReaderContext context, int subQueryNumber) throws IOException {
// as all segments are sorted in the same way, enough to check only the 1st segment for indexSort
if (searchSortPartOfIndexSort == null) {
Sort indexSort = context.reader().getMetaData().getSort();
Sort indexSort = context.reader().getMetaData().sort();
searchSortPartOfIndexSort = canEarlyTerminate(sort, indexSort);
if (searchSortPartOfIndexSort) {
firstComparator.disableSkipping();
Expand Down
Loading

0 comments on commit c36ca15

Please sign in to comment.