-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: get started with advanced worker
- Loading branch information
Showing
14 changed files
with
705 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
* | ||
!.gitignore | ||
!.datashare |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,227 @@ | ||
# Implement your own Datashare worker | ||
# Advanced Datashare worker | ||
|
||
In this section we'll augment the worker template app (translation and classification) with | ||
[vector store](https://en.wikipedia.org/wiki/Vector_database) to allow us to perform semantic similarity searches | ||
between queries and Datashare docs. | ||
|
||
Make sure you've followed the [basic worker example](worker-basic.md) to understand the basics ! | ||
|
||
## Clone the template repository | ||
|
||
Start by cloning the [template repository](https://github.com/ICIJ/datashare-python): | ||
Start over and clone the [template repository](https://github.com/ICIJ/datashare-python) once again: | ||
|
||
<!-- termynal --> | ||
```console | ||
$ git clone [email protected]:ICIJ/datashare-python.git | ||
---> 100% | ||
``` | ||
|
||
## Install dependencies | ||
## Install extra dependencies | ||
|
||
We'll use [LanceDB](https://lancedb.github.io/lancedb/) to implement our vector store, we need to add it as well as | ||
the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) to our dependencies: | ||
|
||
Install [`uv`](https://docs.astral.sh/uv/getting-started/installation/) and install dependencies: | ||
<!-- termynal --> | ||
```console | ||
$ curl -LsSf https://astral.sh/uv/install.sh | sh | ||
$ uv sync --frozen --group dev | ||
$ uv add lancedb sentence-transformers | ||
---> 100% | ||
``` | ||
|
||
!!! note | ||
In a production setup, since elasticsearch implements its [own vector database](https://www.elastic.co/elasticsearch/vector-database) | ||
it might have been convenient to use it. For this examples, we're using LanceDB as it's embedded and doesn't require | ||
any deployment update. | ||
|
||
## Embedding Datashare documents | ||
|
||
For the demo purpose, we'll split the task of embedding docs into two tasks: | ||
|
||
- the `create_vectorization_tasks` which scans the index, get IDs of Datashare docs and batch them and create `vectorize_docs` tasks | ||
- the `vectorize_docs` tasks (triggered by the `create_vectorization_tasks` task) receives docs IDs, | ||
fetch the doc contents from the index and add them to vector database | ||
|
||
!!! note | ||
We could have performed vectorization in a single task, having first task splitting a large tasks into batches/chunks | ||
is a commonly used pattern to distribute heavy workloads across workers (learn more in the | ||
[task workflow guide](../../guides/task-workflows.md)). | ||
|
||
|
||
### The `create_vectorization_tasks` task | ||
|
||
The `create_vectorization_tasks` is defined in the `tasks/vectorize.py` file as following: | ||
```python title="tasks/vectorize.py" | ||
--8<-- | ||
vectorize.py:create_vectorization_tasks | ||
--8<-- | ||
``` | ||
|
||
|
||
The function starts by creating a schema for our vector DB table using the convenient | ||
[LanceDB embedding function](https://lancedb.github.io/lancedb/embeddings/embedding_functions/) feature, | ||
which will automatically create the record `vector field from the provided source field (`content` in our case) using | ||
our HuggingFace embedding model: | ||
```python title="tasks/vectorize.py" hl_lines="2 6 7" | ||
--8<-- | ||
vectorize.py:embedding-schema | ||
--8<-- | ||
``` | ||
|
||
We then (re)-create a vector table using the **DB connection provided by dependency injection** (see the next section to learn more): | ||
```python title="tasks/vectorize.py" hl_lines="4" | ||
--8<-- | ||
vectorize.py:create-table | ||
--8<-- | ||
``` | ||
|
||
Next `create_vectorization_tasks` queries the index matching all documents: | ||
```python title="tasks/vectorize.py" | ||
--8<-- | ||
vectorize.py:query-docs | ||
--8<-- | ||
``` | ||
and scroll through results pages creating batches of `batch_size`: | ||
```python title="tasks/vectorize.py" | ||
--8<-- | ||
vectorize.py:retrieve-docs | ||
--8<-- | ||
``` | ||
|
||
Finally, for each batch, it spawns a vectorization task using the datashare task client and returns the list of created tasks: | ||
```python title="tasks/vectorize.py" hl_lines="5 6 7 8 10" | ||
--8<-- | ||
vectorize.py:batch-vectorization | ||
--8<-- | ||
``` | ||
|
||
### The `lifespan_vector_db` dependency injection | ||
|
||
In order to avoid to re-create a DB connection each time the worker processes a task, we leverage | ||
[dependency injection](../../guides/dependency-injection.md) in order to create the connection at start up and | ||
retrieve it inside our function. | ||
|
||
This pattern is already used for the elasticsearch client and the datashare task client, to use it for the vector DB | ||
connection, we'll need to update the | ||
[dependencies.py](https://github.com/ICIJ/datashare-python/blob/main/ml_worker/tasks/dependencies.py) file. | ||
|
||
First we need to implement the dependency setup function: | ||
```python title="dependencies.py" hl_lines="10 11" | ||
--8<-- | ||
vector_db_dependencies.py:setup | ||
--8<-- | ||
``` | ||
|
||
The function creates a connection to the vector DB located on the filesystem and stores the connection to a | ||
global variable. | ||
|
||
We then have to implement a function to make this global available to the rest of the codebase: | ||
```python title="dependencies.py" hl_lines="4" | ||
--8<-- | ||
vector_db_dependencies.py:provide | ||
--8<-- | ||
``` | ||
We also need to make sure the connection is properly exited when the worker stops by implementing the dependency tear down. | ||
We just call the `:::python AsyncConnection.__aexit__` methode: | ||
```python title="dependencies.py" hl_lines="2" | ||
--8<-- | ||
vector_db_dependencies.py:teardown | ||
--8<-- | ||
``` | ||
|
||
Read the [dependency injection guide](../../guides/dependency-injection.md) to learn more ! | ||
|
||
|
||
### The `vectorize_docs` task | ||
|
||
Next we implement the `vectorize_docs` as following: | ||
|
||
```python title="tasks/vectorize.py" | ||
--8<-- | ||
vectorize.py:vectorize_docs | ||
--8<-- | ||
``` | ||
|
||
The task function starts by retriving the batch document contents, querying the index by doc IDs: | ||
```python title="tasks/vectorize.py" hl_lines="1-4" | ||
--8<-- | ||
vectorize.py:retrieve-doc-content | ||
--8<-- | ||
``` | ||
|
||
Finally, we add each doc content to the vector DB table, because we created table using a schema and the | ||
[embedding function](https://lancedb.github.io/lancedb/embeddings/embedding_functions/) feature, the embedding vector | ||
will be automatically created from the `content` source field: | ||
```python title="tasks/vectorize.py" hl_lines="5-7" | ||
--8<-- | ||
vectorize.py:vectorization | ||
--8<-- | ||
``` | ||
|
||
|
||
## Semantic similarity search | ||
|
||
Now that we've built a vector store from Datashare's docs, we need to query it. Let's create a `find_most_similar` | ||
task which find the most similar docs for a provided set of queries. | ||
|
||
The task function starts by loading the embedding model and vectorizes the input queries: | ||
|
||
```python title="tasks/vectorize.py" hl_lines="13-14" | ||
--8<-- | ||
vectorize.py:find_most_similar | ||
--8<-- | ||
``` | ||
|
||
it then performs an [hybrid search](https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/), using both the | ||
input query vector and its text: | ||
|
||
```python title="tasks/vectorize.py" hl_lines="4-11" | ||
--8<-- | ||
vectorize.py:hybrid-search | ||
--8<-- | ||
``` | ||
|
||
## Registering the new tasks | ||
|
||
In order to turn our function into a Datashare [task](../../learn/concepts-basic.md#tasks), we have to register it into the | ||
`:::python app` [async app](../../learn/concepts-basic.md#app) variable of the | ||
[app.py](https://github.com/ICIJ/datashare-python/blob/main/ml_worker/app.py) file, using the `:::python @task` decorator: | ||
|
||
```python title="app.py" hl_lines="16 17 18 19 20 25 32 37" | ||
--8<-- | ||
vectorize_app.py:vectorize-app | ||
--8<-- | ||
``` | ||
|
||
## Testing | ||
|
||
Finally, we implement some tests in the `tests/tasks/test_vectorize.py` file: | ||
|
||
```python title="tests/tasks/test_vectorize.py" | ||
--8<-- | ||
test_vectorize.py:test-vectorize | ||
--8<-- | ||
``` | ||
|
||
We can then run the tests after starting test services using the `ml-worker` Docker Compose wrapper: | ||
<!-- termynal --> | ||
```console | ||
$ ./ml-worker up -d postgresql redis elasticsearch rabbitmq datashare_web | ||
$ uv run --frozen pytest ml_worker/tests/tasks/test_vectorize.py | ||
===== test session starts ===== | ||
collected 3 items | ||
|
||
ml_worker/tests/tasks/test_vectorize.py ... [100%] | ||
|
||
====== 3 passed in 6.87s ====== | ||
.... | ||
``` | ||
|
||
## Summary | ||
|
||
We've successfully added a vector store to Datashare ! | ||
|
||
Rather than copy-pasting the above code blocks, you can replace/update your codebase with the following files: | ||
- [ml_worker/tasks/vectorize.py](https://github.com/ICIJ/datashare-python/blob/main/docs/src/vectorize.py) | ||
- [ml_worker/tasks/dependencies](https://github.com/ICIJ/datashare-python/blob/main/docs/src/vector_db_dependencies.py) | ||
- [ml_worker/app.py](https://github.com/ICIJ/datashare-python/blob/main/docs/src/vectorize_app.py) | ||
- [ml_worker/tests/tasks/test_vectorize.py](https://github.com/ICIJ/datashare-python/blob/main/docs/src/test_vectorize.py) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# pylint: disable=redefined-outer-name | ||
# --8<-- [start:test-vectorize] | ||
from pathlib import Path | ||
from typing import List | ||
|
||
import pytest | ||
from icij_common.es import ESClient | ||
from lancedb import AsyncConnection as LanceDBConnection, connect_async | ||
|
||
from ml_worker.objects import Document | ||
from ml_worker.tasks.vectorize import ( | ||
create_vectorization_tasks, | ||
find_most_similar, | ||
make_record_schema, | ||
recreate_vector_table, | ||
vectorize_docs, | ||
) | ||
from ml_worker.tests.conftest import TEST_PROJECT | ||
from ml_worker.utils import DSTaskClient | ||
|
||
|
||
@pytest.fixture | ||
async def test_vector_db(tmpdir) -> LanceDBConnection: | ||
db = await connect_async(Path(tmpdir) / "test_vectors.db") | ||
return db | ||
|
||
|
||
@pytest.mark.integration | ||
async def test_create_vectorization_tasks( | ||
populate_es: List[Document], # pylint: disable=unused-argument | ||
test_es_client: ESClient, | ||
test_task_client: DSTaskClient, | ||
test_vector_db: LanceDBConnection, | ||
): | ||
# When | ||
task_ids = await create_vectorization_tasks( | ||
project=TEST_PROJECT, | ||
es_client=test_es_client, | ||
task_client=test_task_client, | ||
vector_db=test_vector_db, | ||
batch_size=2, | ||
) | ||
# Then | ||
assert len(task_ids) == 2 | ||
|
||
|
||
@pytest.mark.integration | ||
async def test_vectorize_docs( | ||
populate_es: List[Document], # pylint: disable=unused-argument | ||
test_es_client: ESClient, | ||
test_vector_db: LanceDBConnection, | ||
): | ||
# Given | ||
model = "BAAI/bge-small-en-v1.5" | ||
docs = ["doc-0", "doc-3"] | ||
schema = make_record_schema(model) | ||
await recreate_vector_table(test_vector_db, schema) | ||
|
||
# When | ||
n_vectorized = await vectorize_docs( | ||
docs, | ||
TEST_PROJECT, | ||
es_client=test_es_client, | ||
vector_db=test_vector_db, | ||
) | ||
# Then | ||
assert n_vectorized == 2 | ||
table = await test_vector_db.open_table("ds_docs") | ||
records = await table.query().to_list() | ||
assert len(records) == 2 | ||
doc_ids = sorted(d["doc_id"] for d in records) | ||
assert doc_ids == ["doc-0", "doc-3"] | ||
assert all("vector" in r for r in records) | ||
|
||
|
||
@pytest.mark.integration | ||
async def test_find_most_similar(test_vector_db: LanceDBConnection): | ||
# Given | ||
model = "BAAI/bge-small-en-v1.5" | ||
schema = make_record_schema(model) | ||
table = await recreate_vector_table(test_vector_db, schema) | ||
docs = [ | ||
{"doc_id": "novel", "content": "I'm a doc about novels"}, | ||
{"doc_id": "monkey", "content": "I'm speaking about monkeys"}, | ||
] | ||
await table.add(docs) | ||
queries = ["doc about books", "doc speaking about animal"] | ||
|
||
# When | ||
n_similar = 1 | ||
most_similar = await find_most_similar( | ||
queries, model, vector_db=test_vector_db, n_similar=n_similar | ||
) | ||
# Then | ||
assert len(most_similar) == 2 | ||
similar_ids = [s["doc_id"] for s in most_similar] | ||
assert similar_ids == ["novel", "monkey"] | ||
assert all("distance" in s for s in most_similar) | ||
|
||
|
||
# --8<-- [end:test-vectorize] |
Oops, something went wrong.