Skip to content

Commit

Permalink
feat: sparse, multiple vectors support (#16)
Browse files Browse the repository at this point in the history
* feat: sparse, multiple vectors support

* refactor: Qdrant options

* chore: formatting

* chore: removed redundant validator

* test: multiple dense, multiple sparse

* docs: Updated README.md
  • Loading branch information
Anush008 authored Mar 2, 2024
1 parent 7579916 commit 7ea10b8
Show file tree
Hide file tree
Showing 14 changed files with 757 additions and 210 deletions.
187 changes: 152 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,16 @@ This will build and store the fat JAR in the `target` directory by default.

For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark).

```xml
<dependency>
<groupId>io.qdrant</groupId>
<artifactId>spark</artifactId>
<version>2.0.1</version>
</dependency>
```

## Usage 📝

### Creating a Spark session (Single-node) with Qdrant support 🌟
### Creating a Spark session (Single-node) with Qdrant support

```python
from pyspark.sql import SparkSession

spark = SparkSession.builder.config(
"spark.jars",
"spark-2.0.1.jar", # specify the downloaded JAR file
"spark-2.1.0.jar", # specify the downloaded JAR file
)
.master("local[*]")
.appName("qdrant")
Expand All @@ -52,30 +44,150 @@ spark = SparkSession.builder.config(

### Loading data 📊

To load data into Qdrant, a collection has to be created beforehand with the appropriate vector dimensions and configurations.
> [!IMPORTANT]
> Before loading the data using this connector, a collection has to be [created](https://qdrant.tech/documentation/concepts/collections/#create-a-collection) in advance with the appropriate vector dimensions and configurations.
The connector supports ingesting multiple named/unnamed, dense/sparse vectors.

<details>
<summary><b>Unnamed/Default vector</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", <QDRANT_GRPC_URL>)
.option("collection_name", <QDRANT_COLLECTION_NAME>)
.option("embedding_field", <EMBEDDING_FIELD_NAME>) # Expected to be a field of type ArrayType(FloatType)
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

</details>

<details>
<summary><b>Named vector</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", <QDRANT_GRPC_URL>)
.option("collection_name", <QDRANT_COLLECTION_NAME>)
.option("embedding_field", <EMBEDDING_FIELD_NAME>) # Expected to be a field of type ArrayType(FloatType)
.option("vector_name", <VECTOR_NAME>)
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

> #### NOTE
>
> The `embedding_field` and `vector_name` options are maintained for backward compatibility. It is recommended to use `vector_fields` and `vector_names` for named vectors as shown below.
</details>

<details>
<summary><b>Multiple named vectors</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", "<QDRANT_GRPC_URL>")
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
.option("vector_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
.option("vector_names", "<VECTOR_NAME>,<ANOTHER_VECTOR_NAME>")
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

</details>

<details>
<summary><b>Sparse vectors</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", "<QDRANT_GRPC_URL>")
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
.option("sparse_vector_value_fields", "<COLUMN_NAME>")
.option("sparse_vector_index_fields", "<COLUMN_NAME>")
.option("sparse_vector_names", "<SPARSE_VECTOR_NAME>")
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

</details>

<details>
<summary><b>Multiple sparse vectors</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", "<QDRANT_GRPC_URL>")
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
.option("sparse_vector_value_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
.option("sparse_vector_index_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
.option("sparse_vector_names", "<SPARSE_VECTOR_NAME>,<ANOTHER_SPARSE_VECTOR_NAME>")
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

</details>

<details>
<summary><b>Combination of named dense and sparse vectors</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", "<QDRANT_GRPC_URL>")
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
.option("vector_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
.option("vector_names", "<VECTOR_NAME>,<ANOTHER_VECTOR_NAME>")
.option("sparse_vector_value_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
.option("sparse_vector_index_fields", "<COLUMN_NAME>,<ANOTHER_COLUMN_NAME>")
.option("sparse_vector_names", "<SPARSE_VECTOR_NAME>,<ANOTHER_SPARSE_VECTOR_NAME>")
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

</details>

<details>
<summary><b>No vectors - Entire dataframe is stored as payload</b></summary>

```python
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", <QDRANT_GRPC_URL>)
.option("collection_name", <QDRANT_COLLECTION_NAME>)
.option("embedding_field", <EMBEDDING_FIELD_NAME>) # Expected to be a field of type ArrayType(FloatType)
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
<pyspark.sql.DataFrame>
.write
.format("io.qdrant.spark.Qdrant")
.option("qdrant_url", "<QDRANT_GRPC_URL>")
.option("collection_name", "<QDRANT_COLLECTION_NAME>")
.option("schema", <pyspark.sql.DataFrame>.schema.json())
.mode("append")
.save()
```

- By default, UUIDs are generated for each row. If you need to use custom IDs, you can do so by setting the `id_field` option.
- An API key can be set using the `api_key` option to make authenticated requests.
</details>

## Databricks

You can use the `qdrant-spark` connector as a library in Databricks to ingest data into Qdrant.
You can use the connector as a library in Databricks to ingest data into Qdrant.

- Go to the `Libraries` section in your cluster dashboard.
- Select `Install New` to open the library installation modal.
- Search for `io.qdrant:spark:2.0.1` in the Maven packages and click `Install`.
- Search for `io.qdrant:spark:2.1.0` in the Maven packages and click `Install`.

<img width="1064" alt="Screenshot 2024-01-05 at 17 20 01 (1)" src="https://github.com/qdrant/qdrant-spark/assets/46051506/d95773e0-c5c6-4ff2-bf50-8055bb08fd1b">

Expand All @@ -85,17 +197,22 @@ Qdrant supports all the Spark data types. The appropriate types are mapped based

## Options and Spark types 🛠️

| Option | Description | DataType | Required |
| :---------------- | :------------------------------------------------------------------------ | :--------------------- | :------- |
| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: <http://localhost:6334> | `StringType` ||
| `collection_name` | Name of the collection to write data into | `StringType` ||
| `embedding_field` | Name of the field holding the embeddings | `ArrayType(FloatType)` ||
| `schema` | JSON string of the dataframe schema | `StringType` ||
| `id_field` | Name of the field holding the point IDs. Default: Generates a random UUId | `StringType` ||
| `batch_size` | Max size of the upload batch. Default: 100 | `IntType` ||
| `retries` | Number of upload retries. Default: 3 | `IntType` ||
| `api_key` | Qdrant API key to be sent in the header. Default: null | `StringType` ||
| `vector_name` | Name of the vector in the collection. Default: null | `StringType` ||
| Option | Description | Column DataType | Required |
| :--------------------------- | :------------------------------------------------------------------ | :---------------------------- | :------- |
| `qdrant_url` | GRPC URL of the Qdrant instance. Eg: <http://localhost:6334> | - ||
| `collection_name` | Name of the collection to write data into | - ||
| `schema` | JSON string of the dataframe schema | - ||
| `embedding_field` | Name of the column holding the embeddings | `ArrayType(FloatType)` ||
| `id_field` | Name of the column holding the point IDs. Default: Random UUID | `StringType` or `IntegerType` ||
| `batch_size` | Max size of the upload batch. Default: 64 | - ||
| `retries` | Number of upload retries. Default: 3 | - ||
| `api_key` | Qdrant API key for authentication | - ||
| `vector_name` | Name of the vector in the collection. | - ||
| `vector_fields` | Comma-separated names of columns holding the vectors. | `ArrayType(FloatType)` ||
| `vector_names` | Comma-separated names of vectors in the collection. | - ||
| `sparse_vector_index_fields` | Comma-separated names of columns holding the sparse vector indices. | `ArrayType(IntegerType)` ||
| `sparse_vector_value_fields` | Comma-separated names of columns holding the sparse vector values. | `ArrayType(FloatType)` ||
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||

## LICENSE 📜

Expand Down
34 changes: 1 addition & 33 deletions src/main/java/io/qdrant/spark/Qdrant.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.qdrant.spark;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableProvider;
Expand All @@ -17,8 +15,7 @@
*/
public class Qdrant implements TableProvider, DataSourceRegister {

private final String[] requiredFields =
new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"};
private final String[] requiredFields = new String[] {"schema", "collection_name", "qdrant_url"};

/**
* Returns the short name of the data source.
Expand All @@ -44,11 +41,9 @@ public StructType inferSchema(CaseInsensitiveStringMap options) {
}
}
StructType schema = (StructType) StructType.fromJson(options.get("schema"));
validateOptions(options, schema);

return schema;
}
;

/**
* Returns a table for the data source based on the provided schema, partitioning, and properties.
Expand All @@ -64,31 +59,4 @@ public Table getTable(
QdrantOptions options = new QdrantOptions(properties);
return new QdrantCluster(options, schema);
}

/**
* Checks if the required options are present in the provided options and chekcs if the specified
* id_field and embedding_field are present in the provided schema.
*
* @param options The options to check.
* @param schema The schema to check.
*/
void validateOptions(CaseInsensitiveStringMap options, StructType schema) {

List<String> fieldNames = Arrays.asList(schema.fieldNames());

if (options.containsKey("id_field")) {
String idField = options.get("id_field").toString();

if (!fieldNames.contains(idField)) {
throw new IllegalArgumentException("Specified 'id_field' is not present in the schema");
}
}

String embeddingField = options.get("embedding_field").toString();

if (!fieldNames.contains(embeddingField)) {
throw new IllegalArgumentException(
"Specified 'embedding_field' is not present in the schema");
}
}
}
56 changes: 10 additions & 46 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
package io.qdrant.spark;

import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.VectorsFactory.namedVectors;
import static io.qdrant.client.VectorsFactory.vectors;
import static io.qdrant.spark.QdrantValueFactory.value;

import io.qdrant.client.grpc.JsonWithInt.Value;
import io.qdrant.client.grpc.Points.PointId;
import io.qdrant.client.grpc.Points.PointStruct;
import io.qdrant.client.grpc.Points.Vectors;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -40,7 +31,7 @@ public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {

private final ArrayList<PointStruct> points = new ArrayList<>();

public QdrantDataWriter(QdrantOptions options, StructType schema) throws Exception {
public QdrantDataWriter(QdrantOptions options, StructType schema) {
this.options = options;
this.schema = schema;
this.qdrantUrl = options.qdrantUrl;
Expand All @@ -50,44 +41,17 @@ public QdrantDataWriter(QdrantOptions options, StructType schema) throws Excepti
@Override
public void write(InternalRow record) {
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
Map<String, Value> payload = new HashMap<>();

if (this.options.idField == null) {
pointBuilder.setId(id(UUID.randomUUID()));
}
for (StructField field : this.schema.fields()) {
int fieldIndex = this.schema.fieldIndex(field.name());
if (this.options.idField != null && field.name().equals(this.options.idField)) {

DataType dataType = field.dataType();
switch (dataType.typeName()) {
case "string":
pointBuilder.setId(id(UUID.fromString(record.getString(fieldIndex))));
break;
PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options);
pointBuilder.setId(pointId);

case "integer":
case "long":
pointBuilder.setId(id(record.getInt(fieldIndex)));
break;

default:
throw new IllegalArgumentException("Point ID should be of type string or integer");
}

} else if (field.name().equals(this.options.embeddingField)) {
float[] embeddings = record.getArray(fieldIndex).toFloatArray();
if (options.vectorName != null) {
pointBuilder.setVectors(
namedVectors(Collections.singletonMap(options.vectorName, vector(embeddings))));
} else {
pointBuilder.setVectors(vectors(embeddings));
}
} else {
payload.put(field.name(), value(record, field, fieldIndex));
}
}
Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options);
pointBuilder.setVectors(vectors);

Map<String, Value> payload =
QdrantPayloadHandler.preparePayload(record, this.schema, this.options);
pointBuilder.putAllPayload(payload);

this.points.add(pointBuilder.build());

if (this.points.size() >= this.options.batchSize) {
Expand Down
Loading

0 comments on commit 7ea10b8

Please sign in to comment.