Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Radial Search #1166

Merged
merged 8 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ This section is for maintaining a changelog for all breaking changes for the cli
## [Unreleased 2.x]

### Added
- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166))

### Dependencies
- Bumps `org.junit:junit-bom` from 5.10.3 to 5.11.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
public class KnnQuery extends QueryBase implements QueryVariant {
private final String field;
private final float[] vector;
private final int k;
@Nullable
private final Integer k;
@Nullable
private final Float minScore;
@Nullable
private final Float maxDistance;
@Nullable
private final Query filter;

Expand All @@ -32,7 +37,9 @@ private KnnQuery(Builder builder) {

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.k = builder.k;
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.filter = builder.filter;
}

Expand Down Expand Up @@ -66,13 +73,29 @@ public final float[] vector() {
}

/**
* Required - The number of neighbors the search of each graph will return.
* Optional - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
public final int k() {
public final Integer k() {
return this.k;
}

/**
* Optional - The minimum score allowed for the returned search results.
* @return The minimum score allowed for the returned search results.
*/
private final Float minScore() {
return this.minScore;
}

/**
* Optional - The maximum distance allowed between the vector and each of the returned search results.
* @return The maximum distance allowed between the vector and each ofthe returned search results.
*/
private final Float maxDistance() {
return this.maxDistance;
}

/**
* Optional - A query to filter the results of the query.
* @return The filter query.
Expand All @@ -97,7 +120,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}
generator.writeEnd();

generator.write("k", this.k);
if (this.k != null) {
generator.write("k", this.k);
}

if (this.minScore != null) {
generator.write("min_score", this.minScore);
}

if (this.maxDistance != null) {
generator.write("max_distance", this.maxDistance);
}

if (this.filter != null) {
generator.writeKey("filter");
Expand All @@ -108,7 +141,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).vector(vector).k(k).filter(filter);
return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter);
}

/**
Expand All @@ -122,6 +155,10 @@ public static class Builder extends QueryBase.AbstractBuilder<Builder> implement
@Nullable
private Integer k;
@Nullable
private Float minScore;
@Nullable
private Float maxDistance;
@Nullable
private Query filter;

/**
Expand Down Expand Up @@ -156,6 +193,28 @@ public Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - The minimum score allowed for the returned search results.
*
* @param minScore The minimum score allowed for the returned search results.
* @return This builder.
*/
public Builder minScore(@Nullable Float minScore) {
this.minScore = minScore;
return this;
}

/**
* Optional - The maximum distance allowed between the vector and each of the returned search results.
*
* @param maxDistance The maximum distance allowed between the vector and each ofthe returned search results.
* @return This builder.
*/
public Builder maxDistance(@Nullable Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
Expand Down Expand Up @@ -201,6 +260,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer<Builder> op)
b.vector(vector);
}, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector");
op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(Builder::field, JsonpDeserializer.stringDeserializer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
public class KnnQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).build();
KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).minScore(0.0f).maxDistance(1.0f).build();
KnnQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ public void testHybridQuery() {
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
}

@Test
Expand All @@ -304,6 +304,6 @@ public void testHybridQueryFromJson() {
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
}
}
Loading