Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add Support for Hybrid Query Type #857

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased 2.x]
### Added
- Add support for Hybrid query type ([#850](https://github.com/opensearch-project/opensearch-java/pull/850))

### Dependencies
- Bumps `org.ajoberstar.grgit:grgit-gradle` from 5.2.0 to 5.2.2
Expand Down
19 changes: 19 additions & 0 deletions guides/search.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ for (int i = 0; i < searchResponse.hits().hits().size(); i++) {
}
```

### Search documents using a hybrid query
```java
Query searchQuery = Query.of(
h -> h.hybrid(
q -> q.queries(Arrays.asList(
new MatchQuery.Builder().field("text").query(FieldValue.of("Text for document 2")).build().toQuery(),
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding").queryText("Hi world").modelId("bQ1J8ooBpBj3wT4HVUsb").k(100).build().toQuery()
)
)
)
);
SearchRequest searchRequest = new SearchRequest.Builder().query(searchQuery).build();
SearchResponse<IndexData> searchResponse = client.search(searchRequest, IndexData.class);
for (var hit : searchResponse.hits().hits()) {
LOGGER.info("Found {} with score {}", hit.source(), hit.score());
}
```

### Search documents using suggesters

[AppData](../samples/src/main/java/org/opensearch/client/samples/util/AppData.java) refers to the sample data class used in the below samples.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.util.List;
import java.util.function.Function;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

public class HybridQuery extends QueryBase implements QueryVariant {
private final List<Query> queries;

private HybridQuery(HybridQuery.Builder builder) {
super(builder);
this.queries = ApiTypeHelper.unmodifiable(builder.queries);
}

public static HybridQuery of(Function<HybridQuery.Builder, ObjectBuilder<HybridQuery>> fn) {
return fn.apply(new HybridQuery.Builder()).build();
}

/**
* Required - list of search queries.
*
* @return list of queries provided under hybrid clause.
*/
public final List<Query> queries() {
return this.queries;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
super.serializeInternal(generator, mapper);
generator.writeKey("queries");
generator.writeStartArray();
for (Query item0 : this.queries) {
item0.serialize(generator, mapper);
}
generator.writeEnd();
}

@Override
public Query.Kind _queryKind() {
return Query.Kind.Hybrid;
}

public HybridQuery.Builder toBuilder() {
return new HybridQuery.Builder().queries(queries);
}

public static class Builder extends QueryBase.AbstractBuilder<HybridQuery.Builder> implements ObjectBuilder<HybridQuery> {
private List<Query> queries;

/**
* API name: {@code hybrid}
* <p>
* Adds all elements of <code>list</code> to <code>hybrid</code>.
*/
public final HybridQuery.Builder queries(List<Query> list) {
this.queries = _listAddAll(this.queries, list);
return this;
}

/**
* API name: {@code hybrid}
* <p>
* Adds one or more values to <code>hybrid</code>.
*/
public final HybridQuery.Builder queries(Query value, Query... values) {
this.queries = _listAdd(this.queries, value, values);
return this;
}

/**
* API name: {@code hybrid}
* <p>
* Adds a value to <code>hybrid</code> using a builder lambda.
*/
public final HybridQuery.Builder queries(Function<Query.Builder, ObjectBuilder<Query>> fn) {
return queries(fn.apply(new Query.Builder()).build());
}

@Override
protected Builder self() {
return this;
}

@Override
public HybridQuery build() {
_checkSingleUse();
return new HybridQuery(this);
}
}

public static final JsonpDeserializer<HybridQuery> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
HybridQuery.Builder::new,
HybridQuery::setupHybridQueryDeserializer
);

protected static void setupHybridQueryDeserializer(ObjectDeserializer<HybridQuery.Builder> op) {
setupQueryBaseDeserializer(op);
op.add(HybridQuery.Builder::queries, JsonpDeserializer.arrayDeserializer(Query._DESERIALIZER), "queries");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ public enum Kind implements JsonEnum {

Neural("neural"),

Hybrid("hybrid"),

ParentId("parent_id"),

Percolate("percolate"),
Expand Down Expand Up @@ -725,6 +727,23 @@ public NeuralQuery neural() {
return TaggedUnionUtils.get(this, Kind.Neural);
}

/**
* Is this variant instance of kind {@code hybrid}?
*/
public boolean isHybrid() {
return _kind == Kind.Hybrid;
}

/**
* Get the {@code hybrid} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code hybrid} kind.
*/
public HybridQuery hybrid() {
return TaggedUnionUtils.get(this, Kind.Hybrid);
}

/**
* Is this variant instance of kind {@code parent_id}?
*/
Expand Down Expand Up @@ -1510,6 +1529,16 @@ public ObjectBuilder<Query> neural(Function<NeuralQuery.Builder, ObjectBuilder<N
return this.neural(fn.apply(new NeuralQuery.Builder()).build());
}

public ObjectBuilder<Query> hybrid(HybridQuery v) {
this._kind = Kind.Hybrid;
this._value = v;
return this;
}

public ObjectBuilder<Query> hybrid(Function<HybridQuery.Builder, ObjectBuilder<HybridQuery>> fn) {
return this.hybrid(fn.apply(new HybridQuery.Builder()).build());
}

public ObjectBuilder<Query> parentId(ParentIdQuery v) {
this._kind = Kind.ParentId;
this._value = v;
Expand Down Expand Up @@ -1818,6 +1847,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match");
op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested");
op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural");
op.add(Builder::hybrid, HybridQuery._DESERIALIZER, "hybrid");
op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id");
op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate");
op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ public static NeuralQuery.Builder neural() {
return new NeuralQuery.Builder();
}

/**
* Creates a builder for the {@link HybridQuery nested} {@code Query} variant.
*/
public static HybridQuery.Builder hybrid() {
return new HybridQuery.Builder();
}

/**
* Creates a builder for the {@link ParentIdQuery parent_id} {@code Query}
* variant.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.opensearch.client.opensearch._types.query_dsl;

import java.util.Arrays;
import org.junit.Test;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch.model.ModelTestCase;

public class HybridQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
HybridQuery origin = new HybridQuery.Builder().queries(
Arrays.asList(
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
.build()
.toQuery(),
new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery()
)
).build();
HybridQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@

package org.opensearch.client.opensearch.model;

import java.util.Arrays;
import org.junit.Test;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.mapping.Property;
import org.opensearch.client.opensearch._types.mapping.TypeMapping;
import org.opensearch.client.opensearch._types.query_dsl.KnnQuery;
import org.opensearch.client.opensearch._types.query_dsl.NeuralQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.QueryBuilders;
import org.opensearch.client.opensearch._types.query_dsl.TermQuery;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.indices.GetMappingResponse;

Expand Down Expand Up @@ -243,4 +248,57 @@ public void testNeuralQueryFromJson() {
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}

@Test
public void testHybridQuery() {

Query query = Query.of(
h -> h.hybrid(
q -> q.queries(
Arrays.asList(
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
.build()
.toQuery(),
new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery()
)
)
)
);
SearchRequest searchRequest = SearchRequest.of(s -> s.query(query));
assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field());
assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
}

@Test
public void testHybridQueryFromJson() {

String json = "{\"query\""
+ ":{\"hybrid\":{\"queries\":[{\"term\":{\"passage_text\":\"Foo bar\"}},"
+ "{\"neural\":{\"passage_embedding\":{\"query_text\":\"Hi world\",\"model_id\":\"bQ1J8ooBpBj3wT4HVUsb\",\"k\":100}}},"
+ "{\"knn\":{\"passage_embedding\":{\"vector\":[0.01,0.02],\"k\":2}}}]}},\"size\":10"
+ "}";

SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper);

assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field());
assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
}
}
Loading
Loading