Skip to content

Commit

Permalink
Switch to vector bytes field to send vector when using gRPC
Browse files Browse the repository at this point in the history
  • Loading branch information
antas-marcin committed Dec 6, 2023
1 parent 339405e commit 9de56e3
Show file tree
Hide file tree
Showing 15 changed files with 283 additions and 20 deletions.
5 changes: 4 additions & 1 deletion src/main/java/io/weaviate/client/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.weaviate.client.base.http.impl.CommonsHttpClientImpl;
import io.weaviate.client.base.util.DbVersionProvider;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.base.util.GrpcVersionSupport;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.backup.Backup;
import io.weaviate.client.v1.batch.Batch;
Expand All @@ -22,6 +23,7 @@ public class WeaviateClient {
private final Config config;
private final DbVersionProvider dbVersionProvider;
private final DbVersionSupport dbVersionSupport;
private final GrpcVersionSupport grpcVersionSupport;
private final HttpClient httpClient;
private final AccessTokenProvider tokenProvider;

Expand All @@ -38,6 +40,7 @@ public WeaviateClient(Config config, HttpClient httpClient, AccessTokenProvider
this.httpClient = httpClient;
dbVersionProvider = initDbVersionProvider();
dbVersionSupport = new DbVersionSupport(dbVersionProvider);
grpcVersionSupport = new GrpcVersionSupport(dbVersionProvider);
this.tokenProvider = tokenProvider;
}

Expand All @@ -56,7 +59,7 @@ public Data data() {

public Batch batch() {
dbVersionProvider.refresh();
return new Batch(httpClient, config, dbVersionSupport, tokenProvider, data());
return new Batch(httpClient, config, dbVersionSupport, grpcVersionSupport, tokenProvider, data());
}

public Backup backup() {
Expand Down
26 changes: 26 additions & 0 deletions src/main/java/io/weaviate/client/base/util/GrpcVersionSupport.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.weaviate.client.base.util;

import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;

@RequiredArgsConstructor
public class GrpcVersionSupport {

private final DbVersionProvider provider;

public boolean supportsVectorBytesField() {
String[] versionNumbers = StringUtils.split(provider.getVersion(), ".");
if (versionNumbers != null && versionNumbers.length >= 2) {
int major = Integer.parseInt(versionNumbers[0]);
int minor = Integer.parseInt(versionNumbers[1]);
if (major == 1 && minor == 22 && versionNumbers.length == 3) {
String patch = versionNumbers[2];
if (!patch.contains("rc") && Integer.parseInt(patch) >= 6) {
return true;
}
}
return (major == 1 && minor >= 23) || major >= 2;
}
return false;
}
}
9 changes: 6 additions & 3 deletions src/main/java/io/weaviate/client/v1/batch/Batch.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.BeaconPath;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.base.util.GrpcVersionSupport;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.api.ObjectsBatcher;
Expand All @@ -20,14 +21,16 @@ public class Batch {
private final BeaconPath beaconPath;
private final ObjectsPath objectsPath;
private final ReferencesPath referencesPath;
private final GrpcVersionSupport grpcVersionSupport;
private final Data data;

public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport,
public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport, GrpcVersionSupport grpcVersionSupport,
AccessTokenProvider tokenProvider, Data data) {
this.config = config;
this.httpClient = httpClient;
this.tokenProvider = tokenProvider;
this.beaconPath = new BeaconPath(dbVersionSupport);
this.grpcVersionSupport = grpcVersionSupport;
this.objectsPath = new ObjectsPath();
this.referencesPath = new ReferencesPath();
this.data = data;
Expand All @@ -38,7 +41,7 @@ public ObjectsBatcher objectsBatcher() {
}

public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
return ObjectsBatcher.create(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig);
return ObjectsBatcher.create(httpClient, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig);
}

public ObjectsBatcher objectsAutoBatcher() {
Expand All @@ -64,7 +67,7 @@ public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatc

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig,
ObjectsBatcher.AutoBatchConfig autoBatchConfig) {
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, autoBatchConfig);
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, autoBatchConfig);
}

public ObjectsBatchDeleter objectsBatchDeleter() {
Expand Down
16 changes: 10 additions & 6 deletions src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.Assert;
import io.weaviate.client.base.util.GrpcVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
Expand Down Expand Up @@ -69,18 +70,20 @@ public class ObjectsBatcher extends BaseClient<ObjectGetResponse[]>
private final List<CompletableFuture<Result<ObjectGetResponse[]>>> undoneFutures;
private final boolean useGRPC;
private final AccessTokenProvider tokenProvider;
private final GrpcVersionSupport grpcVersionSupport;
private final Config config;


private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
AccessTokenProvider tokenProvider,
AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
super(httpClient, config);
this.config = config;
this.useGRPC = config.useGRPC();
this.tokenProvider = tokenProvider;
this.data = data;
this.objectsPath = objectsPath;
this.grpcVersionSupport = grpcVersionSupport;
this.objects = new ArrayList<>();
this.batchRetriesConfig = batchRetriesConfig;

Expand All @@ -100,18 +103,18 @@ private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsP
}

public static ObjectsBatcher create(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
AccessTokenProvider tokenProvider,
AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport,
BatchRetriesConfig batchRetriesConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, null);
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, null);
}

public static ObjectsBatcher createAuto(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
AccessTokenProvider tokenProvider,
AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, autoBatchConfig);
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, autoBatchConfig);
}


Expand Down Expand Up @@ -290,8 +293,9 @@ private Result<ObjectGetResponse[]> internalRun(List<WeaviateObject> batch) {
}

private Result<ObjectGetResponse[]> internalGrpcRun(List<WeaviateObject> batch) {
BatchObjectConverter batchObjectConverter = new BatchObjectConverter(grpcVersionSupport);
List<WeaviateProtoBatch.BatchObject> batchObjects = batch.stream()
.map(BatchObjectConverter::toBatchObject)
.map(batchObjectConverter::toBatchObject)
.collect(Collectors.toList());
WeaviateProtoBatch.BatchObjectsRequest.Builder batchObjectsRequestBuilder = WeaviateProtoBatch.BatchObjectsRequest.newBuilder();
batchObjectsRequestBuilder.addAllObjects(batchObjects);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package io.weaviate.client.v1.batch.grpc;

import com.google.protobuf.ByteString;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import io.weaviate.client.base.util.CrossReference;
import io.weaviate.client.base.util.GrpcVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.data.model.WeaviateObject;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -14,12 +18,17 @@
import java.util.stream.Collectors;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import lombok.experimental.FieldDefaults;

@RequiredArgsConstructor
public class BatchObjectConverter {

public static WeaviateProtoBatch.BatchObject toBatchObject(WeaviateObject obj) {
protected static final int BYTES_PER_FLOAT = Float.SIZE / 8;
private final GrpcVersionSupport grpcVersionSupport;

public WeaviateProtoBatch.BatchObject toBatchObject(WeaviateObject obj) {
WeaviateProtoBatch.BatchObject.Builder builder = WeaviateProtoBatch.BatchObject.newBuilder();
if (obj.getId() != null) {
builder.setUuid(obj.getId());
Expand All @@ -28,7 +37,13 @@ public static WeaviateProtoBatch.BatchObject toBatchObject(WeaviateObject obj) {
builder.setCollection(obj.getClassName());
}
if (obj.getVector() != null) {
builder.addAllVector(Arrays.asList(obj.getVector()));
if (grpcVersionSupport.supportsVectorBytesField()) {
ByteBuffer buffer = ByteBuffer.allocate(obj.getVector().length * BYTES_PER_FLOAT).order(ByteOrder.LITTLE_ENDIAN);
Arrays.stream(obj.getVector()).forEach(buffer::putFloat);
builder.setVectorBytes(ByteString.copyFrom(buffer.array()));
} else {
builder.addAllVector(Arrays.asList(obj.getVector()));
}
}
if (obj.getTenant() != null) {
builder.setTenant(obj.getTenant());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package io.weaviate.client.base.util;

import com.jparams.junit4.JParamsTestRunner;
import com.jparams.junit4.data.DataMethod;
import static org.assertj.core.api.Assertions.assertThat;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

@RunWith(JParamsTestRunner.class)
public class GrpcVersionSupportTest {

private AutoCloseable openedMocks;
@InjectMocks
private GrpcVersionSupport grpcVersionProvider;
@Mock
private DbVersionProvider dbVersionProviderMock;

@Before
public void setUp() {
openedMocks = MockitoAnnotations.openMocks(this);
}

@After
public void tearDown() throws Exception {
openedMocks.close();
}

@Test
@DataMethod(source = GrpcVersionSupportTest.class, method = "provideNotSupported")
public void shouldNotSupportVectorBytes(String dbVersion) {
Mockito.when(dbVersionProviderMock.getVersion()).thenReturn(dbVersion);

assertThat(grpcVersionProvider.supportsVectorBytesField()).isFalse();
}

public static Object[][] provideNotSupported() {
return new Object[][]{
{"0.11"},
{"1.13.9"},
{"1.22.0-rc.0"},
{"1.22.4"},
{"1.22.5"},
};
}

@Test
@DataMethod(source = GrpcVersionSupportTest.class, method = "provideSupported")
public void shouldSupportVectorBytes(String dbVersion) {
Mockito.when(dbVersionProviderMock.getVersion()).thenReturn(dbVersion);

assertThat(grpcVersionProvider.supportsVectorBytesField()).isTrue();
}

public static Object[][] provideSupported() {
return new Object[][]{
{"1.22.6"},
{"1.23.0-rc.0"},
{"1.23.10"},
{"1.30.1"},
{"1.31"},
{"2.31"},
{"10.11.12"},
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.weaviate.integration.client;

import java.util.HashMap;
import java.util.Map;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;

public class WeaviateContainer {

public static class DockerContainer {
private final GenericContainer<?> container;

private DockerContainer(GenericContainer<?> container) {
this.container = container;
}

public void start() {
container.start();
}

public Integer getMappedPort(int originalPort) {
return container.getMappedPort(originalPort);
}

public void stop() {
container.stop();
}
}

public static DockerContainer create(String image) {
Map<String, String> env = new HashMap<>();
env.put("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true");
env.put("QUERY_DEFAULTS_LIMIT", "20");
env.put("PERSISTENCE_DATA_PATH", "./data");
env.put("DEFAULT_VECTORIZER_MODULE", "none");
GenericContainer<?> weaviate = new GenericContainer<>(image)
.withEnv(env)
.withExposedPorts(8080, 50051)
.waitingFor(Wait.forListeningPorts(8080, 50051));
return new DockerContainer(weaviate);
}
}
Loading

0 comments on commit 9de56e3

Please sign in to comment.