From 1fb2b276919c861d1b86b668d52011888e84982d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 22 Oct 2024 20:34:45 +0000 Subject: [PATCH] Add support for radial search on Neural query (#1235) * Add support for radial search on k-NN and Neural query types Signed-off-by: Thomas Farr * spotless Signed-off-by: Thomas Farr * Fix compile errors Signed-off-by: Thomas Farr * Fix test Signed-off-by: Thomas Farr --------- Signed-off-by: Thomas Farr (cherry picked from commit a3a3e541d76b5f39005e0fbdf7af20c403177828) Signed-off-by: github-actions[bot] --- CHANGELOG.md | 1 + .../opensearch/_types/query_dsl/KnnQuery.java | 7 +- .../_types/query_dsl/NeuralQuery.java | 86 +++++++++++++++++-- .../client/opensearch/model/VariantsTest.java | 12 +-- .../AbstractSearchTemplateRequestIT.java | 13 ++- 5 files changed, 100 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11b480499a..1c88e34886 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased 2.x] ### Added - Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166)) +- Added `minScore` and `maxDistance` to `NeuralQuery` ([#1235](https://github.com/opensearch-project/opensearch-java/pull/1235)) ### Dependencies - Bumps `org.ajoberstar.grgit:grgit-gradle` from 5.2.2 to 5.3.0 diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java index 59f03cc1ab..48ee5e353e 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java @@ -76,6 +76,7 @@ public final float[] vector() { * Optional - The number of neighbors the search of each graph will return. * @return The number of neighbors to return. */ + @Nullable public final Integer k() { return this.k; } @@ -84,6 +85,7 @@ public final Integer k() { * Optional - The minimum score allowed for the returned search results. * @return The minimum score allowed for the returned search results. */ + @Nullable private final Float minScore() { return this.minScore; } @@ -92,6 +94,7 @@ private final Float minScore() { * Optional - The maximum distance allowed between the vector and each of the returned search results. * @return The maximum distance allowed between the vector and each ofthe returned search results. */ + @Nullable private final Float maxDistance() { return this.maxDistance; } @@ -111,8 +114,6 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { super.serializeInternal(generator, mapper); - // TODO: Implement the rest of the serialization. - generator.writeKey("vector"); generator.writeStartArray(); for (float value : this.vector) { @@ -183,7 +184,7 @@ public Builder vector(@Nullable float[] vector) { } /** - * Required - The number of neighbors the search of each graph will return. + * Optional - The number of neighbors to return. * * @param k The number of neighbors to return. * @return This builder. diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java index 9984f912d0..08f1453a56 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java @@ -26,7 +26,12 @@ public class NeuralQuery extends QueryBase implements QueryVariant { private final String field; private final String queryText; private final String queryImage; - private final int k; + @Nullable + private final Integer k; + @Nullable + private final Float minScore; + @Nullable + private final Float maxDistance; @Nullable private final String modelId; @Nullable @@ -41,7 +46,9 @@ private NeuralQuery(NeuralQuery.Builder builder) { } this.queryText = builder.queryText; this.queryImage = builder.queryImage; - this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k"); + this.k = builder.k; + this.minScore = builder.minScore; + this.maxDistance = builder.maxDistance; this.modelId = builder.modelId; this.filter = builder.filter; } @@ -90,17 +97,34 @@ public final String queryImage() { } /** - * Required - The number of neighbors to return. + * Optional - The number of neighbors to return. * * @return The number of neighbors to return. */ - public final int k() { + @Nullable + public final Integer k() { return this.k; } /** - * Builder for {@link NeuralQuery}. + * Optional - The minimum score threshold for the search results + * + * @return The minimum score threshold for the search results + */ + @Nullable + public final Float minScore() { + return this.minScore; + } + + /** + * Optional - The maximum distance threshold for the search results + * + * @return The maximum distance threshold for the search results */ + @Nullable + public final Float maxDistance() { + return this.maxDistance; + } /** * Optional - The model_id field if the default model for the index or field is set. @@ -141,7 +165,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { generator.write("model_id", this.modelId); } - generator.write("k", this.k); + if (this.k != null) { + generator.write("k", this.k); + } + + if (this.minScore != null) { + generator.write("min_score", this.minScore); + } + + if (this.maxDistance != null) { + generator.write("max_distance", this.maxDistance); + } if (this.filter != null) { generator.writeKey("filter"); @@ -152,7 +186,14 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } public Builder toBuilder() { - return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter); + return toBuilder(new Builder()).field(field) + .queryText(queryText) + .queryImage(queryImage) + .k(k) + .minScore(minScore) + .maxDistance(maxDistance) + .modelId(modelId) + .filter(filter); } /** @@ -162,8 +203,13 @@ public static class Builder extends QueryBase.AbstractBuilder sendTemplateRequest(String index, String title, boolean suggs, boolean aggs)