Skip to content

Commit

Permalink
Add version check on client init
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Dec 20, 2024
1 parent 593a397 commit cb829b3
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 11 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ dependencies {

testImplementation "io.grpc:grpc-testing:${grpcVersion}"
testImplementation "org.junit.jupiter:junit-jupiter-api:${jUnitVersion}"
testImplementation "org.junit.jupiter:junit-jupiter-params:${jUnitVersion}"
testImplementation "org.mockito:mockito-core:3.4.0"
testImplementation "org.slf4j:slf4j-nop:${slf4jVersion}"
testImplementation "org.testcontainers:qdrant:${testcontainersVersion}"
Expand Down
77 changes: 66 additions & 11 deletions src/main/java/io/qdrant/client/QdrantGrpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.qdrant.client.grpc.CollectionsGrpc;
import io.qdrant.client.grpc.*;
import io.qdrant.client.grpc.CollectionsGrpc.CollectionsFutureStub;
import io.qdrant.client.grpc.PointsGrpc;
import io.qdrant.client.grpc.PointsGrpc.PointsFutureStub;
import io.qdrant.client.grpc.QdrantGrpc;
import io.qdrant.client.grpc.QdrantGrpc.QdrantFutureStub;
import io.qdrant.client.grpc.SnapshotsGrpc;
import io.qdrant.client.grpc.SnapshotsGrpc.SnapshotsFutureStub;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -45,7 +42,7 @@ public class QdrantGrpcClient implements AutoCloseable {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(ManagedChannel channel) {
return new Builder(channel, false);
return new Builder(channel, false, true);
}

/**
Expand All @@ -56,7 +53,21 @@ public static Builder newBuilder(ManagedChannel channel) {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(ManagedChannel channel, boolean shutdownChannelOnClose) {
return new Builder(channel, shutdownChannelOnClose);
return new Builder(channel, shutdownChannelOnClose, true);
}

/**
* Creates a new builder to build a client.
*
* @param channel The channel for communication.
* @param shutdownChannelOnClose Whether the channel is shutdown on client close.
* @param checkCompatibility Whether to check compatibility between client's and server's
* versions.
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(
ManagedChannel channel, boolean shutdownChannelOnClose, boolean checkCompatibility) {
return new Builder(channel, shutdownChannelOnClose, checkCompatibility);
}

/**
Expand All @@ -66,7 +77,7 @@ public static Builder newBuilder(ManagedChannel channel, boolean shutdownChannel
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(String host) {
return new Builder(host, 6334, true);
return new Builder(host, 6334, true, true);
}

/**
Expand All @@ -77,7 +88,7 @@ public static Builder newBuilder(String host) {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(String host, int port) {
return new Builder(host, port, true);
return new Builder(host, port, true, true);
}

/**
Expand All @@ -90,7 +101,23 @@ public static Builder newBuilder(String host, int port) {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(String host, int port, boolean useTransportLayerSecurity) {
return new Builder(host, port, useTransportLayerSecurity);
return new Builder(host, port, useTransportLayerSecurity, true);
}

/**
* Creates a new builder to build a client.
*
* @param host The host to connect to.
* @param port The port to connect to.
* @param useTransportLayerSecurity Whether the client uses Transport Layer Security (TLS) to
* secure communications. Running without TLS should only be used for testing purposes.
* @param checkCompatibility Whether to check compatibility between client's and server's
* versions.
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(
String host, int port, boolean useTransportLayerSecurity, boolean checkCompatibility) {
return new Builder(host, port, useTransportLayerSecurity, checkCompatibility);
}

/**
Expand Down Expand Up @@ -168,17 +195,24 @@ public static class Builder {
@Nullable private CallCredentials callCredentials;
@Nullable private Duration timeout;

Builder(ManagedChannel channel, boolean shutdownChannelOnClose) {
Builder(ManagedChannel channel, boolean shutdownChannelOnClose, boolean checkCompatibility) {
this.channel = channel;
this.shutdownChannelOnClose = shutdownChannelOnClose;
String clientVersion = Builder.class.getPackage().getImplementationVersion();
if (checkCompatibility) {
checkVersionsCompatibility(clientVersion);
}
}

Builder(String host, int port, boolean useTransportLayerSecurity) {
Builder(String host, int port, boolean useTransportLayerSecurity, boolean checkCompatibility) {
String clientVersion = Builder.class.getPackage().getImplementationVersion();
String javaVersion = System.getProperty("java.version");
String userAgent = "java-client/" + clientVersion + " java/" + javaVersion;
this.channel = createChannel(host, port, useTransportLayerSecurity, userAgent);
this.shutdownChannelOnClose = true;
if (checkCompatibility) {
checkVersionsCompatibility(clientVersion);
}
}

/**
Expand Down Expand Up @@ -238,5 +272,26 @@ private static ManagedChannel createChannel(

return channelBuilder.build();
}

private void checkVersionsCompatibility(String clientVersion) {
try {
String serverVersion =
QdrantGrpc.newBlockingStub(this.channel)
.healthCheck(QdrantOuterClass.HealthCheckRequest.getDefaultInstance())
.getVersion();
if (!VersionsCompatibilityChecker.isCompatible(clientVersion, serverVersion)) {
System.out.println(
"Qdrant client version "
+ clientVersion
+ " is incompatible with server version "
+ serverVersion
+ ". Major versions should match and minor version difference must not exceed 1. "
+ "Set check_version=False to skip version check.");
}
} catch (Exception e) {
System.out.println(
"Failed to obtain server version. Unable to check client-server compatibility. Set checkCompatibility=False to skip version check.");
}
}
}
}
96 changes: 96 additions & 0 deletions src/main/java/io/qdrant/client/VersionsCompatibilityChecker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.qdrant.client;

import java.util.ArrayList;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class Version {
private final int major;
private final int minor;
private final String rest;

public Version(int major, int minor, String rest) {
this.major = major;
this.minor = minor;
this.rest = rest;
}

public int getMajor() {
return major;
}

public int getMinor() {
return minor;
}

public String getRest() {
return rest;
}
}

/** Utility class to check compatibility between server's and client's versions. */
public class VersionsCompatibilityChecker {
private static final Logger logger = LoggerFactory.getLogger(VersionsCompatibilityChecker.class);

/** Default constructor. */
public VersionsCompatibilityChecker() {}

private static Version parseVersion(String version) throws IllegalArgumentException {
if (version.isEmpty()) {
throw new IllegalArgumentException("Version is None");
}

try {
String[] parts = version.split("\\.");
int major = parts.length > 0 ? Integer.parseInt(parts[0]) : 0;
int minor = parts.length > 1 ? Integer.parseInt(parts[1]) : 0;
String rest =
parts.length > 2
? String.join(".", new ArrayList<>(Arrays.asList(parts).subList(2, parts.length)))
: "";

return new Version(major, minor, rest);
} catch (Exception e) {
throw new IllegalArgumentException(
"Unable to parse version, expected format: x.y.z, found: " + version, e);
}
}

/**
* Compares server's and client's versions.
*
* @param clientVersion The client's version.
* @param serverVersion The server's version.
* @return True if the versions are compatible, false otherwise.
*/
public static boolean isCompatible(String clientVersion, String serverVersion) {
if (clientVersion.isEmpty()) {
logger.warn("Unable to compare with client version {}", clientVersion);
return false;
}

if (serverVersion.isEmpty()) {
logger.warn("Unable to compare with server version {}", serverVersion);
return false;
}

if (clientVersion.equals(serverVersion)) {
return true;
}

try {
Version parsedServerVersion = parseVersion(serverVersion);
Version parsedClientVersion = parseVersion(clientVersion);

int majorDiff = Math.abs(parsedServerVersion.getMajor() - parsedClientVersion.getMajor());
if (majorDiff >= 1) {
return false;
}
return Math.abs(parsedServerVersion.getMinor() - parsedClientVersion.getMinor()) <= 1;
} catch (IllegalArgumentException e) {
logger.warn("Unable to compare versions: {}", e.getMessage());
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.qdrant.client;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

public class VersionsCompatibilityCheckerTest {
private static Stream<Object[]> validVersionProvider() {
return Stream.of(
new Object[] {"1.2.3", 1, 2, "3"},
new Object[] {"1.2.3-alpha", 1, 2, "3-alpha"},
new Object[] {"1.2", 1, 2, ""},
new Object[] {"1", 1, 0, ""},
new Object[] {"1.", 1, 0, ""});
}

@ParameterizedTest
@MethodSource("validVersionProvider")
public void testParseVersion_validVersion(
String versionStr, int expectedMajor, int expectedMinor, String expectedRest)
throws Exception {
Method method =
VersionsCompatibilityChecker.class.getDeclaredMethod("parseVersion", String.class);
method.setAccessible(true);
Version version = (Version) method.invoke(null, versionStr);
assertEquals(expectedMajor, version.getMajor());
assertEquals(expectedMinor, version.getMinor());
assertEquals(expectedRest, version.getRest());
}

private static Stream<String> invalidVersionProvider() {
return Stream.of("v1.12.0", "", ".1", ".1.", "1.null.1", "null.0.1", null);
}

@ParameterizedTest
@MethodSource("invalidVersionProvider")
public void testParseVersion_invalidVersion(String versionStr) throws Exception {
Method method =
VersionsCompatibilityChecker.class.getDeclaredMethod("parseVersion", String.class);
method.setAccessible(true);
assertThrows(
InvocationTargetException.class,
() -> method.invoke(null, versionStr));
}

private static Stream<Object[]> versionCompatibilityProvider() {
return Stream.of(
new Object[] {"1.9.3.dev0", "2.8.1.dev12-something", false},
new Object[] {"1.9", "2.8", false},
new Object[] {"1", "2", false},
new Object[] {"1.9.0", "2.9.0", false},
new Object[] {"1.1.0", "1.2.9", true},
new Object[] {"1.2.7", "1.1.8.dev0", true},
new Object[] {"1.2.1", "1.2.29", true},
new Object[] {"1.2.0", "1.2.0", true},
new Object[] {"1.2.0", "1.4.0", false},
new Object[] {"1.4.0", "1.2.0", false},
new Object[] {"1.9.0", "3.7.0", false},
new Object[] {"3.0.0", "1.0.0", false},
new Object[] {"", "1.0.0", false},
new Object[] {"1.0.0", "", false},
new Object[] {"", "", false});
}

@ParameterizedTest
@MethodSource("versionCompatibilityProvider")
public void testIsCompatible(String clientVersion, String serverVersion, boolean expected) {
assertEquals(expected, VersionsCompatibilityChecker.isCompatible(clientVersion, serverVersion));
}
}

0 comments on commit cb829b3

Please sign in to comment.