Skip to content

Commit

Permalink
Add support for radial search on Neural query (#1235)
Browse files Browse the repository at this point in the history
* Add support for radial search on k-NN and Neural query types

Signed-off-by: Thomas Farr <[email protected]>

* spotless

Signed-off-by: Thomas Farr <[email protected]>

* Fix compile errors

Signed-off-by: Thomas Farr <[email protected]>

* Fix test

Signed-off-by: Thomas Farr <[email protected]>

---------

Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia authored Oct 22, 2024
1 parent 96f1688 commit a3a3e54
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 20 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Added
- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166))
- Added `minScore` and `maxDistance` to `NeuralQuery` ([#1235](https://github.com/opensearch-project/opensearch-java/pull/1235))

### Dependencies

Expand Down Expand Up @@ -568,4 +569,4 @@ This section is for maintaining a changelog for all breaking changes for the cli
[2.5.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.4.0...v2.5.0
[2.4.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.3.0...v2.4.0
[2.3.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.2.0...v2.3.0
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public final float[] vector() {
* Optional - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
@Nullable
public final Integer k() {
return this.k;
}
Expand All @@ -84,6 +85,7 @@ public final Integer k() {
* Optional - The minimum score allowed for the returned search results.
* @return The minimum score allowed for the returned search results.
*/
@Nullable
private final Float minScore() {
return this.minScore;
}
Expand All @@ -92,6 +94,7 @@ private final Float minScore() {
* Optional - The maximum distance allowed between the vector and each of the returned search results.
* @return The maximum distance allowed between the vector and each ofthe returned search results.
*/
@Nullable
private final Float maxDistance() {
return this.maxDistance;
}
Expand All @@ -111,8 +114,6 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);

// TODO: Implement the rest of the serialization.

generator.writeKey("vector");
generator.writeStartArray();
for (float value : this.vector) {
Expand Down Expand Up @@ -183,7 +184,7 @@ public Builder vector(@Nullable float[] vector) {
}

/**
* Required - The number of neighbors the search of each graph will return.
* Optional - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ public class NeuralQuery extends QueryBase implements QueryVariant {
private final String field;
private final String queryText;
private final String queryImage;
private final int k;
@Nullable
private final Integer k;
@Nullable
private final Float minScore;
@Nullable
private final Float maxDistance;
@Nullable
private final String modelId;
@Nullable
Expand All @@ -41,7 +46,9 @@ private NeuralQuery(NeuralQuery.Builder builder) {
}
this.queryText = builder.queryText;
this.queryImage = builder.queryImage;
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.k = builder.k;
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.modelId = builder.modelId;
this.filter = builder.filter;
}
Expand Down Expand Up @@ -90,17 +97,34 @@ public final String queryImage() {
}

/**
* Required - The number of neighbors to return.
* Optional - The number of neighbors to return.
*
* @return The number of neighbors to return.
*/
public final int k() {
@Nullable
public final Integer k() {
return this.k;
}

/**
* Builder for {@link NeuralQuery}.
* Optional - The minimum score threshold for the search results
*
* @return The minimum score threshold for the search results
*/
@Nullable
public final Float minScore() {
return this.minScore;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @return The maximum distance threshold for the search results
*/
@Nullable
public final Float maxDistance() {
return this.maxDistance;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
Expand Down Expand Up @@ -141,7 +165,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.write("model_id", this.modelId);
}

generator.write("k", this.k);
if (this.k != null) {
generator.write("k", this.k);
}

if (this.minScore != null) {
generator.write("min_score", this.minScore);
}

if (this.maxDistance != null) {
generator.write("max_distance", this.maxDistance);
}

if (this.filter != null) {
generator.writeKey("filter");
Expand All @@ -152,7 +186,14 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
return toBuilder(new Builder()).field(field)
.queryText(queryText)
.queryImage(queryImage)
.k(k)
.minScore(minScore)
.maxDistance(maxDistance)
.modelId(modelId)
.filter(filter);
}

/**
Expand All @@ -162,8 +203,13 @@ public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builde
private String field;
private String queryText;
private String queryImage;
@Nullable
private Integer k;
@Nullable
private Float minScore;
@Nullable
private Float maxDistance;
@Nullable
private String modelId;
@Nullable
private Query filter;
Expand Down Expand Up @@ -216,7 +262,7 @@ public NeuralQuery.Builder modelId(@Nullable String modelId) {
}

/**
* Required - The number of neighbors to return.
* Optional - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
Expand All @@ -226,6 +272,28 @@ public NeuralQuery.Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - The minimum score threshold for the search results
*
* @param minScore The minimum score threshold for the search results
* @return This builder.
*/
public NeuralQuery.Builder minScore(@Nullable Float minScore) {
this.minScore = minScore;
return this;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @param maxDistance The maximum distance threshold for the search results
* @return This builder.
*/
public NeuralQuery.Builder maxDistance(@Nullable Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
Expand Down Expand Up @@ -267,6 +335,8 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(NeuralQuery.Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public void testNeuralQuery() {
assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
assertEquals((Integer) 100, searchRequest.query().neural().k());
}

@Test
Expand Down Expand Up @@ -251,7 +251,7 @@ public void testNeuralQueryFromJson() {
searchRequest.query().neural().queryImage()
);
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
assertEquals((Integer) 100, searchRequest.query().neural().k());
}

@Test
Expand Down Expand Up @@ -279,10 +279,10 @@ public void testHybridQuery() {
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals((Integer) 100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals((Integer) 2, searchRequest.query().hybrid().queries().get(2).knn().k());
}

@Test
Expand All @@ -301,9 +301,9 @@ public void testHybridQueryFromJson() {
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals((Integer) 100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals((Integer) 2, searchRequest.query().hybrid().queries().get(2).knn().k());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch._types.Refresh;
import org.opensearch.client.opensearch._types.mapping.Property;
Expand Down Expand Up @@ -89,6 +90,14 @@ public void testTemplateSearchAggregations() throws Exception {

@Test
public void testMultiSearchTemplate() throws Exception {
Integer expectedSuccessStatus = null;
Integer expectedFailureStatus = null;

if (getServerVersion().onOrAfter(Version.V_2_18_0)) {
expectedSuccessStatus = 200;
expectedFailureStatus = 404;
}

var index = "test-msearch-template";
createDocuments(index);

Expand Down Expand Up @@ -120,11 +129,11 @@ public void testMultiSearchTemplate() throws Exception {
assertEquals(2, searchResponse.responses().size());
var response = searchResponse.responses().get(0);
assertTrue(response.isResult());
assertNull(response.result().status());
assertEquals(expectedSuccessStatus, response.result().status());
assertEquals(4, response.result().hits().hits().size());
var failureResponse = searchResponse.responses().get(1);
assertTrue(failureResponse.isFailure());
assertNull(failureResponse.failure().status());
assertEquals(expectedFailureStatus, failureResponse.failure().status());
}

private SearchTemplateResponse<SimpleDoc> sendTemplateRequest(String index, String title, boolean suggs, boolean aggs)
Expand Down

0 comments on commit a3a3e54

Please sign in to comment.