From 7ea10b831e9c002f6535b463c7460c7fa6881851 Mon Sep 17 00:00:00 2001 From: Anush Date: Sat, 2 Mar 2024 20:15:13 +0530 Subject: [PATCH] feat: sparse, multiple vectors support (#16) * feat: sparse, multiple vectors support * refactor: Qdrant options * chore: formatting * chore: removed redundant validator * test: multiple dense, multiple sparse * docs: Updated README.md --- README.md | 187 ++++++++++--- src/main/java/io/qdrant/spark/Qdrant.java | 34 +-- .../io/qdrant/spark/QdrantDataWriter.java | 56 +--- src/main/java/io/qdrant/spark/QdrantGrpc.java | 14 +- .../java/io/qdrant/spark/QdrantOptions.java | 100 +++++-- .../io/qdrant/spark/QdrantPayloadHandler.java | 28 ++ .../io/qdrant/spark/QdrantPointIdHandler.java | 33 +++ .../io/qdrant/spark/QdrantVectorHandler.java | 75 ++++++ src/test/java/io/qdrant/spark/TestQdrant.java | 55 +--- .../java/io/qdrant/spark/TestQdrantGrpc.java | 2 +- .../io/qdrant/spark/TestQdrantOptions.java | 4 +- src/test/python/conftest.py | 5 + src/test/python/test_qdrant_ingest.py | 122 ++++++++- src/test/python/users.json | 252 ++++++++++++++++++ 14 files changed, 757 insertions(+), 210 deletions(-) create mode 100644 src/main/java/io/qdrant/spark/QdrantPayloadHandler.java create mode 100644 src/main/java/io/qdrant/spark/QdrantPointIdHandler.java create mode 100644 src/main/java/io/qdrant/spark/QdrantVectorHandler.java diff --git a/README.md b/README.md index 645ae95..36221b2 100644 --- a/README.md +++ b/README.md @@ -26,24 +26,16 @@ This will build and store the fat JAR in the `target` directory by default. For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark). -```xml - - io.qdrant - spark - 2.0.1 - -``` - ## Usage 📝 -### Creating a Spark session (Single-node) with Qdrant support 🌟 +### Creating a Spark session (Single-node) with Qdrant support ```python from pyspark.sql import SparkSession spark = SparkSession.builder.config( "spark.jars", - "spark-2.0.1.jar", # specify the downloaded JAR file + "spark-2.1.0.jar", # specify the downloaded JAR file ) .master("local[*]") .appName("qdrant") @@ -52,30 +44,150 @@ spark = SparkSession.builder.config( ### Loading data 📊 -To load data into Qdrant, a collection has to be created beforehand with the appropriate vector dimensions and configurations. +> [!IMPORTANT] +> Before loading the data using this connector, a collection has to be [created](https://qdrant.tech/documentation/concepts/collections/#create-a-collection) in advance with the appropriate vector dimensions and configurations. + +The connector supports ingesting multiple named/unnamed, dense/sparse vectors. + +
+ Unnamed/Default vector + +```python + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", ) + .option("collection_name", ) + .option("embedding_field", ) # Expected to be a field of type ArrayType(FloatType) + .option("schema", .schema.json()) + .mode("append") + .save() +``` + +
+ +
+ Named vector + +```python + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", ) + .option("collection_name", ) + .option("embedding_field", ) # Expected to be a field of type ArrayType(FloatType) + .option("vector_name", ) + .option("schema", .schema.json()) + .mode("append") + .save() +``` + +> #### NOTE +> +> The `embedding_field` and `vector_name` options are maintained for backward compatibility. It is recommended to use `vector_fields` and `vector_names` for named vectors as shown below. + +
+ +
+ Multiple named vectors + +```python + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", "") + .option("collection_name", "") + .option("vector_fields", ",") + .option("vector_names", ",") + .option("schema", .schema.json()) + .mode("append") + .save() +``` + +
+ +
+ Sparse vectors + +```python + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", "") + .option("collection_name", "") + .option("sparse_vector_value_fields", "") + .option("sparse_vector_index_fields", "") + .option("sparse_vector_names", "") + .option("schema", .schema.json()) + .mode("append") + .save() +``` + +
+ +
+ Multiple sparse vectors + +```python + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", "") + .option("collection_name", "") + .option("sparse_vector_value_fields", ",") + .option("sparse_vector_index_fields", ",") + .option("sparse_vector_names", ",") + .option("schema", .schema.json()) + .mode("append") + .save() +``` + +
+ +
+ Combination of named dense and sparse vectors + +```python + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", "") + .option("collection_name", "") + .option("vector_fields", ",") + .option("vector_names", ",") + .option("sparse_vector_value_fields", ",") + .option("sparse_vector_index_fields", ",") + .option("sparse_vector_names", ",") + .option("schema", .schema.json()) + .mode("append") + .save() +``` + +
+ +
+ No vectors - Entire dataframe is stored as payload ```python - - .write - .format("io.qdrant.spark.Qdrant") - .option("qdrant_url", ) - .option("collection_name", ) - .option("embedding_field", ) # Expected to be a field of type ArrayType(FloatType) - .option("schema", .schema.json()) - .mode("append") - .save() + + .write + .format("io.qdrant.spark.Qdrant") + .option("qdrant_url", "") + .option("collection_name", "") + .option("schema", .schema.json()) + .mode("append") + .save() ``` -- By default, UUIDs are generated for each row. If you need to use custom IDs, you can do so by setting the `id_field` option. -- An API key can be set using the `api_key` option to make authenticated requests. +
## Databricks -You can use the `qdrant-spark` connector as a library in Databricks to ingest data into Qdrant. +You can use the connector as a library in Databricks to ingest data into Qdrant. - Go to the `Libraries` section in your cluster dashboard. - Select `Install New` to open the library installation modal. -- Search for `io.qdrant:spark:2.0.1` in the Maven packages and click `Install`. +- Search for `io.qdrant:spark:2.1.0` in the Maven packages and click `Install`. Screenshot 2024-01-05 at 17 20 01 (1) @@ -85,17 +197,22 @@ Qdrant supports all the Spark data types. The appropriate types are mapped based ## Options and Spark types 🛠️ -| Option | Description | DataType | Required | -| :---------------- | :------------------------------------------------------------------------ | :--------------------- | :------- | -| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: | `StringType` | ✅ | -| `collection_name` | Name of the collection to write data into | `StringType` | ✅ | -| `embedding_field` | Name of the field holding the embeddings | `ArrayType(FloatType)` | ✅ | -| `schema` | JSON string of the dataframe schema | `StringType` | ✅ | -| `id_field` | Name of the field holding the point IDs. Default: Generates a random UUId | `StringType` | ❌ | -| `batch_size` | Max size of the upload batch. Default: 100 | `IntType` | ❌ | -| `retries` | Number of upload retries. Default: 3 | `IntType` | ❌ | -| `api_key` | Qdrant API key to be sent in the header. Default: null | `StringType` | ❌ | -| `vector_name` | Name of the vector in the collection. Default: null | `StringType` | ❌ | +| Option | Description | Column DataType | Required | +| :--------------------------- | :------------------------------------------------------------------ | :---------------------------- | :------- | +| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: | - | ✅ | +| `collection_name` | Name of the collection to write data into | - | ✅ | +| `schema` | JSON string of the dataframe schema | - | ✅ | +| `embedding_field` | Name of the column holding the embeddings | `ArrayType(FloatType)` | ❌ | +| `id_field` | Name of the column holding the point IDs. Default: Random UUID | `StringType` or `IntegerType` | ❌ | +| `batch_size` | Max size of the upload batch. Default: 64 | - | ❌ | +| `retries` | Number of upload retries. Default: 3 | - | ❌ | +| `api_key` | Qdrant API key for authentication | - | ❌ | +| `vector_name` | Name of the vector in the collection. | - | ❌ | +| `vector_fields` | Comma-separated names of columns holding the vectors. | `ArrayType(FloatType)` | ❌ | +| `vector_names` | Comma-separated names of vectors in the collection. | - | ❌ | +| `sparse_vector_index_fields` | Comma-separated names of columns holding the sparse vector indices. | `ArrayType(IntegerType)` | ❌ | +| `sparse_vector_value_fields` | Comma-separated names of columns holding the sparse vector values. | `ArrayType(FloatType)` | ❌ | +| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - | ❌ | ## LICENSE 📜 diff --git a/src/main/java/io/qdrant/spark/Qdrant.java b/src/main/java/io/qdrant/spark/Qdrant.java index 0428016..f77e39a 100644 --- a/src/main/java/io/qdrant/spark/Qdrant.java +++ b/src/main/java/io/qdrant/spark/Qdrant.java @@ -1,7 +1,5 @@ package io.qdrant.spark; -import java.util.Arrays; -import java.util.List; import java.util.Map; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; @@ -17,8 +15,7 @@ */ public class Qdrant implements TableProvider, DataSourceRegister { - private final String[] requiredFields = - new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"}; + private final String[] requiredFields = new String[] {"schema", "collection_name", "qdrant_url"}; /** * Returns the short name of the data source. @@ -44,11 +41,9 @@ public StructType inferSchema(CaseInsensitiveStringMap options) { } } StructType schema = (StructType) StructType.fromJson(options.get("schema")); - validateOptions(options, schema); return schema; } - ; /** * Returns a table for the data source based on the provided schema, partitioning, and properties. @@ -64,31 +59,4 @@ public Table getTable( QdrantOptions options = new QdrantOptions(properties); return new QdrantCluster(options, schema); } - - /** - * Checks if the required options are present in the provided options and chekcs if the specified - * id_field and embedding_field are present in the provided schema. - * - * @param options The options to check. - * @param schema The schema to check. - */ - void validateOptions(CaseInsensitiveStringMap options, StructType schema) { - - List fieldNames = Arrays.asList(schema.fieldNames()); - - if (options.containsKey("id_field")) { - String idField = options.get("id_field").toString(); - - if (!fieldNames.contains(idField)) { - throw new IllegalArgumentException("Specified 'id_field' is not present in the schema"); - } - } - - String embeddingField = options.get("embedding_field").toString(); - - if (!fieldNames.contains(embeddingField)) { - throw new IllegalArgumentException( - "Specified 'embedding_field' is not present in the schema"); - } - } } diff --git a/src/main/java/io/qdrant/spark/QdrantDataWriter.java b/src/main/java/io/qdrant/spark/QdrantDataWriter.java index 6848547..f54bdce 100644 --- a/src/main/java/io/qdrant/spark/QdrantDataWriter.java +++ b/src/main/java/io/qdrant/spark/QdrantDataWriter.java @@ -1,25 +1,16 @@ package io.qdrant.spark; -import static io.qdrant.client.PointIdFactory.id; -import static io.qdrant.client.VectorFactory.vector; -import static io.qdrant.client.VectorsFactory.namedVectors; -import static io.qdrant.client.VectorsFactory.vectors; -import static io.qdrant.spark.QdrantValueFactory.value; - import io.qdrant.client.grpc.JsonWithInt.Value; +import io.qdrant.client.grpc.Points.PointId; import io.qdrant.client.grpc.Points.PointStruct; +import io.qdrant.client.grpc.Points.Vectors; import java.io.Serializable; import java.net.URL; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; import java.util.Map; -import java.util.UUID; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.WriterCommitMessage; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,7 +31,7 @@ public class QdrantDataWriter implements DataWriter, Serializable { private final ArrayList points = new ArrayList<>(); - public QdrantDataWriter(QdrantOptions options, StructType schema) throws Exception { + public QdrantDataWriter(QdrantOptions options, StructType schema) { this.options = options; this.schema = schema; this.qdrantUrl = options.qdrantUrl; @@ -50,44 +41,17 @@ public QdrantDataWriter(QdrantOptions options, StructType schema) throws Excepti @Override public void write(InternalRow record) { PointStruct.Builder pointBuilder = PointStruct.newBuilder(); - Map payload = new HashMap<>(); - - if (this.options.idField == null) { - pointBuilder.setId(id(UUID.randomUUID())); - } - for (StructField field : this.schema.fields()) { - int fieldIndex = this.schema.fieldIndex(field.name()); - if (this.options.idField != null && field.name().equals(this.options.idField)) { - DataType dataType = field.dataType(); - switch (dataType.typeName()) { - case "string": - pointBuilder.setId(id(UUID.fromString(record.getString(fieldIndex)))); - break; + PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options); + pointBuilder.setId(pointId); - case "integer": - case "long": - pointBuilder.setId(id(record.getInt(fieldIndex))); - break; - - default: - throw new IllegalArgumentException("Point ID should be of type string or integer"); - } - - } else if (field.name().equals(this.options.embeddingField)) { - float[] embeddings = record.getArray(fieldIndex).toFloatArray(); - if (options.vectorName != null) { - pointBuilder.setVectors( - namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings)))); - } else { - pointBuilder.setVectors(vectors(embeddings)); - } - } else { - payload.put(field.name(), value(record, field, fieldIndex)); - } - } + Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options); + pointBuilder.setVectors(vectors); + Map payload = + QdrantPayloadHandler.preparePayload(record, this.schema, this.options); pointBuilder.putAllPayload(payload); + this.points.add(pointBuilder.build()); if (this.points.size() >= this.options.batchSize) { diff --git a/src/main/java/io/qdrant/spark/QdrantGrpc.java b/src/main/java/io/qdrant/spark/QdrantGrpc.java index 7627291..4447b1b 100644 --- a/src/main/java/io/qdrant/spark/QdrantGrpc.java +++ b/src/main/java/io/qdrant/spark/QdrantGrpc.java @@ -8,7 +8,6 @@ import java.net.URL; import java.util.List; import java.util.concurrent.ExecutionException; -import javax.annotation.Nullable; /** A class that provides methods to interact with Qdrant REST API. */ public class QdrantGrpc implements Serializable { @@ -21,20 +20,15 @@ public class QdrantGrpc implements Serializable { * @param apiKey The API key to authenticate with Qdrant. * @throws MalformedURLException If the URL is invalid. */ - public QdrantGrpc(URL url, @Nullable String apiKey) throws MalformedURLException { + public QdrantGrpc(URL url, String apiKey) throws MalformedURLException { String host = url.getHost(); int port = url.getPort() == -1 ? 6334 : url.getPort(); boolean useTls = url.getProtocol().equalsIgnoreCase("https"); - QdrantGrpcClient.Builder qdrantGrpcClientBuilder = - QdrantGrpcClient.newBuilder(host, port, useTls); - - if (apiKey != null) { - qdrantGrpcClientBuilder.withApiKey(apiKey); - } - - this.client = new QdrantClient(qdrantGrpcClientBuilder.build()); + this.client = + new QdrantClient( + QdrantGrpcClient.newBuilder(host, port, useTls).withApiKey(apiKey).build()); } /** diff --git a/src/main/java/io/qdrant/spark/QdrantOptions.java b/src/main/java/io/qdrant/spark/QdrantOptions.java index 5334c0c..e38b100 100644 --- a/src/main/java/io/qdrant/spark/QdrantOptions.java +++ b/src/main/java/io/qdrant/spark/QdrantOptions.java @@ -1,38 +1,86 @@ package io.qdrant.spark; import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Objects; -/** This class represents the options for connecting to a Qdrant instance. */ public class QdrantOptions implements Serializable { - public String qdrantUrl; - public String apiKey; - public String collectionName; - public String embeddingField; - public String idField; - public String vectorName; - public int batchSize = 100; - public int retries = 3; - - /** - * Constructor for QdrantOptions. - * - * @param options A map of options for connecting to a Qdrant instance. - */ + private static final int DEFAULT_BATCH_SIZE = 64; + private static final int DEFAULT_RETRIES = 3; + + public final String qdrantUrl; + public final String apiKey; + public final String collectionName; + public final String idField; + public final int batchSize; + public final int retries; + public final String embeddingField; + public final String vectorName; + public final String[] sparseVectorValueFields; + public final String[] sparseVectorIndexFields; + public final String[] sparseVectorNames; + public final String[] vectorFields; + public final String[] vectorNames; + public final List payloadFieldsToSkip; + public QdrantOptions(Map options) { - this.qdrantUrl = options.get("qdrant_url"); - this.collectionName = options.get("collection_name"); - this.embeddingField = options.get("embedding_field"); - this.idField = options.get("id_field"); - this.apiKey = options.get("api_key"); - this.vectorName = options.get("vector_name"); - - if (options.containsKey("batch_size")) { - this.batchSize = Integer.parseInt(options.get("batch_size")); + Objects.requireNonNull(options); + + qdrantUrl = options.get("qdrant_url"); + collectionName = options.get("collection_name"); + batchSize = + Integer.parseInt(options.getOrDefault("batch_size", String.valueOf(DEFAULT_BATCH_SIZE))); + retries = Integer.parseInt(options.getOrDefault("retries", String.valueOf(DEFAULT_RETRIES))); + idField = options.getOrDefault("id_field", ""); + apiKey = options.getOrDefault("api_key", ""); + embeddingField = options.getOrDefault("embedding_field", ""); + vectorName = options.getOrDefault("vector_name", ""); + + sparseVectorValueFields = parseArray(options.get("sparse_vector_value_fields")); + sparseVectorIndexFields = parseArray(options.get("sparse_vector_index_fields")); + sparseVectorNames = parseArray(options.get("sparse_vector_names")); + vectorFields = parseArray(options.get("vector_fields")); + vectorNames = parseArray(options.get("vector_names")); + + validateSparseVectorFields(); + validateVectorFields(); + + payloadFieldsToSkip = new ArrayList<>(); + payloadFieldsToSkip.add(idField); + payloadFieldsToSkip.add(embeddingField); + payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorValueFields)); + payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorIndexFields)); + payloadFieldsToSkip.addAll(Arrays.asList(sparseVectorNames)); + payloadFieldsToSkip.addAll(Arrays.asList(vectorFields)); + payloadFieldsToSkip.addAll(Arrays.asList(vectorNames)); + } + + private String[] parseArray(String input) { + if (input != null) { + String[] parts = input.split(","); + for (int i = 0; i < parts.length; i++) { + parts[i] = parts[i].trim(); + } + return parts; + } else { + return new String[0]; + } + } + + private void validateSparseVectorFields() { + if (sparseVectorValueFields.length != sparseVectorIndexFields.length + || sparseVectorValueFields.length != sparseVectorNames.length) { + throw new IllegalArgumentException( + "Sparse vector value fields, index fields, and names should have the same length"); } + } - if (options.containsKey("retries")) { - this.retries = Integer.parseInt(options.get("retries")); + private void validateVectorFields() { + if (vectorFields.length != vectorNames.length) { + throw new IllegalArgumentException("Vector fields and names should have the same length"); } } } diff --git a/src/main/java/io/qdrant/spark/QdrantPayloadHandler.java b/src/main/java/io/qdrant/spark/QdrantPayloadHandler.java new file mode 100644 index 0000000..95e3d28 --- /dev/null +++ b/src/main/java/io/qdrant/spark/QdrantPayloadHandler.java @@ -0,0 +1,28 @@ +package io.qdrant.spark; + +import static io.qdrant.spark.QdrantValueFactory.value; + +import io.qdrant.client.grpc.JsonWithInt.Value; +import java.util.HashMap; +import java.util.Map; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class QdrantPayloadHandler { + static Map preparePayload( + InternalRow record, StructType schema, QdrantOptions options) { + + Map payload = new HashMap<>(); + for (StructField field : schema.fields()) { + + if (options.payloadFieldsToSkip.contains(field.name())) { + continue; + } + int fieldIndex = schema.fieldIndex(field.name()); + payload.put(field.name(), value(record, field, fieldIndex)); + } + + return payload; + } +} diff --git a/src/main/java/io/qdrant/spark/QdrantPointIdHandler.java b/src/main/java/io/qdrant/spark/QdrantPointIdHandler.java new file mode 100644 index 0000000..d90d97a --- /dev/null +++ b/src/main/java/io/qdrant/spark/QdrantPointIdHandler.java @@ -0,0 +1,33 @@ +package io.qdrant.spark; + +import static io.qdrant.client.PointIdFactory.id; + +import io.qdrant.client.grpc.Points.PointId; +import java.util.UUID; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; + +public class QdrantPointIdHandler { + static PointId preparePointId(InternalRow record, StructType schema, QdrantOptions options) { + String idField = options.idField; + + if (idField.isEmpty()) { + return id(UUID.randomUUID()); + } + + int idFieldIndex = schema.fieldIndex(idField); + DataType idFieldType = schema.fields()[idFieldIndex].dataType(); + switch (idFieldType.typeName()) { + case "string": + return id(UUID.fromString(record.getString(idFieldIndex))); + + case "integer": + case "long": + return id(record.getInt(idFieldIndex)); + + default: + throw new IllegalArgumentException("Point ID should be of type string or integer"); + } + } +} diff --git a/src/main/java/io/qdrant/spark/QdrantVectorHandler.java b/src/main/java/io/qdrant/spark/QdrantVectorHandler.java new file mode 100644 index 0000000..c894043 --- /dev/null +++ b/src/main/java/io/qdrant/spark/QdrantVectorHandler.java @@ -0,0 +1,75 @@ +package io.qdrant.spark; + +import static io.qdrant.client.VectorFactory.vector; +import static io.qdrant.client.VectorsFactory.namedVectors; + +import com.google.common.primitives.Floats; +import com.google.common.primitives.Ints; +import io.qdrant.client.grpc.Points.Vector; +import io.qdrant.client.grpc.Points.Vectors; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +public class QdrantVectorHandler { + static Vectors prepareVectors(InternalRow record, StructType schema, QdrantOptions options) { + + Vectors.Builder vectorsBuilder = Vectors.newBuilder(); + Vectors sparseVectors = prepareSparseVectors(record, schema, options); + Vectors denseVectors = prepareDenseVectors(record, schema, options); + + vectorsBuilder.mergeFrom(sparseVectors).mergeFrom(denseVectors); + + if (options.embeddingField.isEmpty()) { + return vectorsBuilder.build(); + } + + int vectorFieldIndex = schema.fieldIndex(options.embeddingField); + float[] embeddings = record.getArray(vectorFieldIndex).toFloatArray(); + + // The vector name defaults to "" + return vectorsBuilder + .mergeFrom(namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings)))) + .build(); + } + + private static Vectors prepareSparseVectors( + InternalRow record, StructType schema, QdrantOptions options) { + Map sparseVectors = new HashMap<>(); + + for (int i = 0; i < options.sparseVectorNames.length; i++) { + String sparseVectorName = options.sparseVectorNames[i]; + String sparseVectorValueField = options.sparseVectorValueFields[i]; + String sparseVectorIndexField = options.sparseVectorIndexFields[i]; + int sparseVectorValueFieldIndex = schema.fieldIndex(sparseVectorValueField); + int sparseVectorIndexFieldIndex = schema.fieldIndex(sparseVectorIndexField); + List sparseVectorValues = + Floats.asList(record.getArray(sparseVectorValueFieldIndex).toFloatArray()); + List sparseVectorIndices = + Ints.asList(record.getArray(sparseVectorIndexFieldIndex).toIntArray()); + + sparseVectors.put(sparseVectorName, vector(sparseVectorValues, sparseVectorIndices)); + } + + return namedVectors(sparseVectors); + } + + private static Vectors prepareDenseVectors( + InternalRow record, StructType schema, QdrantOptions options) { + Map denseVectors = new HashMap<>(); + + for (int i = 0; i < options.vectorNames.length; i++) { + String vectorName = options.vectorNames[i]; + String vectorField = options.vectorFields[i]; + int vectorFieldIndex = schema.fieldIndex(vectorField); + float[] vectorValues = record.getArray(vectorFieldIndex).toFloatArray(); + + denseVectors.put(vectorName, vector(vectorValues)); + } + + return namedVectors(denseVectors); + } +} diff --git a/src/test/java/io/qdrant/spark/TestQdrant.java b/src/test/java/io/qdrant/spark/TestQdrant.java index ce2e1c8..949dca6 100644 --- a/src/test/java/io/qdrant/spark/TestQdrant.java +++ b/src/test/java/io/qdrant/spark/TestQdrant.java @@ -57,47 +57,16 @@ public void testGetTable() { Assert.assertTrue(qdrant.getTable(schema, null, dataSourceOptions) instanceof QdrantCluster); } - @Test() - public void testCheckRequiredOptions() { - Qdrant qdrant = new Qdrant(); - StructType schema = - new StructType() - .add("id", DataTypes.StringType) - .add("embedding", DataTypes.createArrayType(DataTypes.FloatType)); - Map options = new HashMap<>(); - options.put("schema", schema.json()); - options.put("collection_name", "test_collection"); - options.put("embedding_field", "embedding"); - options.put("qdrant_url", "http://localhost:6334"); - CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options); - qdrant.validateOptions(dataSourceOptions, schema); - } - - @Test(expected = IllegalArgumentException.class) - public void testCheckRequiredOptionsMissingIdField() { - Qdrant qdrant = new Qdrant(); - StructType schema = - new StructType().add("embedding", DataTypes.createArrayType(DataTypes.FloatType)); - Map options = new HashMap<>(); - options.put("schema", schema.json()); - options.put("collection_name", "test_collection"); - options.put("embedding_field", "embedding"); - options.put("qdrant_url", "http://localhost:6334"); - options.put("id_field", "id"); - CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options); - qdrant.validateOptions(dataSourceOptions, schema); - } - - @Test(expected = IllegalArgumentException.class) - public void testCheckRequiredOptionsMissingEmbeddingField() { - Qdrant qdrant = new Qdrant(); - StructType schema = new StructType().add("id", DataTypes.StringType); - Map options = new HashMap<>(); - options.put("schema", schema.json()); - options.put("collection_name", "test_collection"); - options.put("qdrant_url", "http://localhost:6334"); - options.put("id_field", "id"); - CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options); - qdrant.inferSchema(dataSourceOptions); - } + // @Test(expected = IllegalArgumentException.class) + // public void testCheckRequiredOptionsMissingEmbeddingField() { + // Qdrant qdrant = new Qdrant(); + // StructType schema = new StructType().add("id", DataTypes.StringType); + // Map options = new HashMap<>(); + // options.put("schema", schema.json()); + // options.put("collection_name", "test_collection"); + // options.put("qdrant_url", "http://localhost:6334"); + // options.put("id_field", "id"); + // CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options); + // qdrant.inferSchema(dataSourceOptions); + // } } diff --git a/src/test/java/io/qdrant/spark/TestQdrantGrpc.java b/src/test/java/io/qdrant/spark/TestQdrantGrpc.java index abde5df..865bbad 100644 --- a/src/test/java/io/qdrant/spark/TestQdrantGrpc.java +++ b/src/test/java/io/qdrant/spark/TestQdrantGrpc.java @@ -48,7 +48,7 @@ public void setup() throws InterruptedException, ExecutionException { @Test public void testUploadBatch() throws Exception { String qdrantUrl = String.join("", "http://", qdrant.getGrpcHostAddress()); - QdrantGrpc qdrantGrpc = new QdrantGrpc(new URL(qdrantUrl), null); + QdrantGrpc qdrantGrpc = new QdrantGrpc(new URL(qdrantUrl), ""); List points = new ArrayList<>(); diff --git a/src/test/java/io/qdrant/spark/TestQdrantOptions.java b/src/test/java/io/qdrant/spark/TestQdrantOptions.java index 2aaa2c7..ecf2cb7 100644 --- a/src/test/java/io/qdrant/spark/TestQdrantOptions.java +++ b/src/test/java/io/qdrant/spark/TestQdrantOptions.java @@ -26,7 +26,7 @@ public void testQdrantOptions() { assertEquals("my-id-field", qdrantOptions.idField); // Test default values - assertEquals(100, qdrantOptions.batchSize); - assertEquals(3, qdrantOptions.retries); + assertEquals(qdrantOptions.batchSize, 64); + assertEquals(qdrantOptions.retries, 3); } } diff --git a/src/test/python/conftest.py b/src/test/python/conftest.py index fcc5728..f71918f 100644 --- a/src/test/python/conftest.py +++ b/src/test/python/conftest.py @@ -83,9 +83,14 @@ def qdrant() -> Qdrant: size=QDRANT_EMBEDDING_DIM, distance=QDRANT_DISTANCE, ), + "another_dense": models.VectorParams( + size=QDRANT_EMBEDDING_DIM, + distance=QDRANT_DISTANCE, + ), }, sparse_vectors_config={ "sparse": models.SparseVectorParams(), + "another_sparse": models.SparseVectorParams(), }, ) diff --git a/src/test/python/test_qdrant_ingest.py b/src/test/python/test_qdrant_ingest.py index a1af5c3..9175c9a 100644 --- a/src/test/python/test_qdrant_ingest.py +++ b/src/test/python/test_qdrant_ingest.py @@ -6,6 +6,7 @@ input_file_path = Path(__file__).with_name("users.json") + def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession): df = spark_session.read.option("multiline", "true").json(str(input_file_path)) df.write.format("io.qdrant.spark.Qdrant").option( @@ -13,7 +14,7 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession): qdrant.url, ).option("collection_name", qdrant.collection_name).option( "embedding_field", "dense_vector" - ).option("schema", df.schema.json()).mode("append").save() + ).mode("append").option("schema", df.schema.json()).save() qdrant.client.count(qdrant.collection_name) == df.count() @@ -25,27 +26,120 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession): qdrant.url, ).option("collection_name", qdrant.collection_name).option( "embedding_field", "dense_vector" - ).option("schema", df.schema.json()).option("vector_name", "dense").mode( + ).option("vector_name", "dense").option("schema", df.schema.json()).mode( + "append" + ).save() + + qdrant.client.count(qdrant.collection_name) == df.count() + + +def test_upsert_multiple_named_dense_vectors( + qdrant: Qdrant, spark_session: SparkSession +): + df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "vector_fields", "dense_vector,dense_vector" + ).option("vector_names", "dense,another_dense").option( + "schema", df.schema.json() + ).mode("append").save() + + qdrant.client.count(qdrant.collection_name) == df.count() + + +def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession): + df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "sparse_vector_value_fields", "sparse_values" + ).option("sparse_vector_index_fields", "sparse_indices").option( + "sparse_vector_names", "sparse" + ).option("schema", df.schema.json()).mode("append").save() + + qdrant.client.count(qdrant.collection_name) == df.count() + + +def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession): + df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "sparse_vector_value_fields", "sparse_values,sparse_values" + ).option("sparse_vector_index_fields", "sparse_indices,sparse_indices").option( + "sparse_vector_names", "sparse,another_sparse" + ).option("schema", df.schema.json()).mode("append").save() + + qdrant.client.count(qdrant.collection_name) == df.count() + + +def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkSession): + df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "vector_fields", "dense_vector" + ).option("vector_names", "dense").option( + "sparse_vector_value_fields", "sparse_values" + ).option("sparse_vector_index_fields", "sparse_indices").option( + "sparse_vector_names", "sparse" + ).option("schema", df.schema.json()).mode("append").save() + + qdrant.client.count(qdrant.collection_name) == df.count() + + +def test_upsert_sparse_unnamed_dense_vectors( + qdrant: Qdrant, spark_session: SparkSession +): + df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "embedding_field", "dense_vector" + ).option("sparse_vector_value_fields", "sparse_values").option( + "sparse_vector_index_fields", "sparse_indices" + ).option("sparse_vector_names", "sparse").option("schema", df.schema.json()).mode( "append" ).save() qdrant.client.count(qdrant.collection_name) == df.count() -def test_missing_field(qdrant: Qdrant, spark_session: SparkSession): +def test_upsert_multiple_sparse_dense_vectors( + qdrant: Qdrant, spark_session: SparkSession +): df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "vector_fields", "dense_vector,dense_vector" + ).option("vector_names", "dense,another_dense").option( + "sparse_vector_value_fields", "sparse_values,sparse_values" + ).option("sparse_vector_index_fields", "sparse_indices,sparse_indices").option( + "sparse_vector_names", "sparse,another_sparse" + ).option("schema", df.schema.json()).mode("append").save() + + qdrant.client.count(qdrant.collection_name) == df.count() + - with pytest.raises(IllegalArgumentException) as e: - df.write.format("io.qdrant.spark.Qdrant").option( - "qdrant_url", - qdrant.url, - ).option("collection_name", qdrant.collection_name).option( - "embedding_field", "missing_field" - ).option("schema", df.schema.json()).option("vector_name", "dense").mode( - "append" - ).save() - - assert "Specified 'embedding_field' is not present in the schema" in str(e) +# Test an upsert without vectors. All the dataframe fields will be treated as payload +def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession): + df = spark_session.read.option("multiline", "true").json(str(input_file_path)) + df.write.format("io.qdrant.spark.Qdrant").option( + "qdrant_url", + qdrant.url, + ).option("collection_name", qdrant.collection_name).option( + "schema", df.schema.json() + ).mode("append").save() + + qdrant.client.count(qdrant.collection_name) == df.count() def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession): diff --git a/src/test/python/users.json b/src/test/python/users.json index ce4fa74..3c733d3 100644 --- a/src/test/python/users.json +++ b/src/test/python/users.json @@ -71,6 +71,20 @@ 0.644, 0.123, 0.123 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -145,6 +159,20 @@ 0.5345644, 0.53426123, 0.412323 + ], + "sparse_indices": [ + 32, + 5632, + 23423, + 432432, + 123 + ], + "sparse_values": [ + 0.7967, + 0.53, + 0.34, + 0.644, + 0.756765 ] }, { @@ -219,6 +247,20 @@ 0.4214113, 0.12321, 0.6546534 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -293,6 +335,20 @@ 0.2341, 0.4535, 0.25435 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -367,6 +423,20 @@ 0.423, 0.4231, 0.443214 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -441,6 +511,20 @@ 0.534213, 0.543, 0.543 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -515,6 +599,20 @@ 0.5345644, 0.53426123, 0.412323 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -589,6 +687,20 @@ 0.5345644, 0.53426123, 0.412323 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -663,6 +775,20 @@ 0.5662346, 0.74537, 0.14532154 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -737,6 +863,20 @@ 0.5345644, 0.53426123, 0.412323 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -811,6 +951,20 @@ 0.5662346, 0.74537, 0.14532154 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -884,6 +1038,20 @@ 0.5662346, 0.74537, 0.14532154 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -958,6 +1126,20 @@ 0.5662346, 0.74537, 0.14532154 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -1032,6 +1214,20 @@ 0.5662346, 0.74537, 0.14532154 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -1106,6 +1302,20 @@ 0.5435, 0.7547, 0.6435643 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -1180,6 +1390,20 @@ 0.5662346, 0.74537, 0.14532154 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -1254,6 +1478,20 @@ 0.2431134, 0.423, 0.4324 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] }, { @@ -1328,6 +1566,20 @@ 0.4231421, 0.4231412, 0.43214 + ], + "sparse_indices": [ + 32, + 455, + 523, + 5324, + 42332 + ], + "sparse_values": [ + 0.23, + 0.3212, + 0.1233, + 0.644, + 0.123 ] } ] \ No newline at end of file