Skip to content

Commit

Permalink
chore: pre-commit formatting (#91)
Browse files Browse the repository at this point in the history
* chore: formatting

* chore: formatting

* chore: remove other hooks

* Update poetry lock

---------

Co-authored-by: Nirant Kasliwal <[email protected]>
  • Loading branch information
Anush008 and NirantK authored Jan 16, 2024
1 parent b01f882 commit 9b63427
Show file tree
Hide file tree
Showing 19 changed files with 733 additions and 789 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
name: ci
name: ci
on:
push:
branches:
- master
- master
- main
permissions:
contents: write
Expand All @@ -14,7 +14,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ env.cache_id }}
Expand Down
19 changes: 8 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi, jupyter ]
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The default embedding supports "query" and "passage" prefixes for the input text

## 🚀 Installation

To install the FastEmbed library, pip works:
To install the FastEmbed library, pip works:

```bash
pip install fastembed
Expand All @@ -36,8 +36,8 @@ documents: List[str] = [
"passage: This is an example passage.",
"fastembed is supported by and maintained by Qdrant." # You can leave out the prefix but it's recommended
]
embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embeddings: List[np.ndarray] = list(embedding_model.embed(documents)) # Note the list() call - this is a generator
embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embeddings: List[np.ndarray] = list(embedding_model.embed(documents)) # Note the list() call - this is a generator
```

## Usage with Qdrant
Expand All @@ -48,7 +48,7 @@ Installation with Qdrant Client in Python:
pip install qdrant-client[fastembed]
```

Might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh.
Might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh.

```python
from qdrant_client import QdrantClient
Expand Down
9 changes: 6 additions & 3 deletions docs/examples/FastEmbed_vs_HF_Comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
"outputs": [],
"source": [
"import time\n",
"from pathlib import Path\n",
"from typing import Any, Callable, List, Tuple\n",
"from typing import Callable, List, Tuple\n",
"\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"from fastembed.embedding import DefaultEmbedding\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -116,6 +114,7 @@
" sentence_embeddings = F.normalize(sentence_embeddings)\n",
" return sentence_embeddings\n",
"\n",
"\n",
"hf = HF(model_id=\"BAAI/bge-small-en-v1.5\")\n",
"hf.embed(documents).shape"
]
Expand Down Expand Up @@ -165,6 +164,8 @@
],
"source": [
"import types\n",
"\n",
"\n",
"def calculate_time_stats(embed_func: Callable, documents: list, k: int) -> Tuple[float, float, float]:\n",
" times = []\n",
" for _ in range(k):\n",
Expand All @@ -181,6 +182,7 @@
" # Returning mean, max, and min time for the call\n",
" return (sum(times) / k, max(times), min(times))\n",
"\n",
"\n",
"hf_stats = calculate_time_stats(hf.embed, documents, k=2)\n",
"print(f\"Huggingface Transformers (Average, Max, Min): {hf_stats}\")\n",
"fst_stats = calculate_time_stats(embedding_model.embed, documents, k=2)\n",
Expand Down Expand Up @@ -289,6 +291,7 @@
" \"\"\"\n",
" return F.cosine_similarity(embeddings1, embeddings2).mean().item()\n",
"\n",
"\n",
"calculate_cosine_similarity(hf.embed(documents), Tensor(list(embedding_model.embed(documents))))"
]
},
Expand Down
16 changes: 12 additions & 4 deletions docs/examples/Supported_Models.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -141,12 +151,10 @@
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from fastembed.embedding import Embedding\n",
"import pandas as pd\n",
"pd.set_option('display.max_colwidth', None)\n",
"\n",
"pd.set_option(\"display.max_colwidth\", None)\n",
"pd.DataFrame(Embedding.list_supported_models())"
]
}
Expand Down
14 changes: 2 additions & 12 deletions docs/examples/Usage_With_Qdrant.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
"outputs": [],
"source": [
"from typing import List\n",
"import numpy as np\n",
"from fastembed.embedding import FlagEmbedding as Embedding\n",
"from qdrant_client import QdrantClient"
]
},
Expand Down Expand Up @@ -170,12 +168,7 @@
"ids = [42, 2]\n",
"\n",
"# Use the new add method\n",
"client.add(\n",
" collection_name=\"demo_collection\",\n",
" documents=docs,\n",
" metadata=metadata,\n",
" ids=ids\n",
")"
"client.add(collection_name=\"demo_collection\", documents=docs, metadata=metadata, ids=ids)"
]
},
{
Expand All @@ -199,10 +192,7 @@
}
],
"source": [
"search_result = client.query(\n",
" collection_name=\"demo_collection\",\n",
" query_text=[\"This is a query document\"]\n",
")\n",
"search_result = client.query(collection_name=\"demo_collection\", query_text=[\"This is a query document\"])\n",
"print(search_result)"
]
},
Expand Down
3 changes: 2 additions & 1 deletion docs/experimental/Binary Quantization from Scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from tqdm import tqdm"
]
},
Expand Down Expand Up @@ -305,9 +304,11 @@
"sampling_rate = [1, 2, 3, 5]\n",
"results = []\n",
"\n",
"\n",
"def mean_accuracy(number_of_samples, limit, sampling_rate):\n",
" return np.mean([accuracy(i, limit=limit, oversampling=sampling_rate) for i in range(number_of_samples)])\n",
"\n",
"\n",
"for i in tqdm(sampling_rate):\n",
" for j in tqdm(limits):\n",
" result = {\"sampling_rate\": i, \"limit\": j, \"recall\": mean_accuracy(number_of_samples, j, i)}\n",
Expand Down
50 changes: 25 additions & 25 deletions docs/experimental/Binary Quantization with Qdrant.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"import uuid\n",
"from qdrant_client import QdrantClient\n",
"from qdrant_client.http import models\n",
"from qdrant_client.http.models import PointStruct"
"from qdrant_client.http import models"
]
},
{
Expand Down Expand Up @@ -71,6 +69,7 @@
],
"source": [
"import datasets\n",
"\n",
"dataset = datasets.load_dataset(\"KShivendu/dbpedia-entities-openai-1M\", split=\"train[0:100000]\")"
]
},
Expand Down Expand Up @@ -133,10 +132,8 @@
}
],
"source": [
"from qdrant_client import QdrantClient\n",
"\n",
"# client = QdrantClient(\n",
"# url=\"https://2aaa9439-b209-4ba6-8beb-d0b61dbd9388.us-east-1-0.aws.cloud.qdrant.io:6333\", \n",
"# url=\"https://2aaa9439-b209-4ba6-8beb-d0b61dbd9388.us-east-1-0.aws.cloud.qdrant.io:6333\",\n",
"# api_key=\"FCF8_ADVuSRrtNGeg_rBJvAMJecEDgQhzuXMZGW8F7OzvaC9wYOPeQ\",\n",
"# prefer_grpc=True\n",
"# )\n",
Expand Down Expand Up @@ -175,12 +172,10 @@
"bs = 10000\n",
"for i in range(0, len(dataset), bs):\n",
" client.upload_collection(\n",
" collection_name=collection_name, \n",
" ids=range(i, i+bs),\n",
" vectors=dataset[i:i+bs][\"openai\"],\n",
" payload=[\n",
" {\"text\": x} for x in dataset[i:i+bs][\"text\"]\n",
" ],\n",
" collection_name=collection_name,\n",
" ids=range(i, i + bs),\n",
" vectors=dataset[i : i + bs][\"openai\"],\n",
" payload=[{\"text\": x} for x in dataset[i : i + bs][\"text\"]],\n",
" parallel=10,\n",
" )"
]
Expand All @@ -203,10 +198,7 @@
],
"source": [
"client.update_collection(\n",
" collection_name=f\"{collection_name}\",\n",
" optimizer_config=models.OptimizersConfigDiff(\n",
" indexing_threshold=20000\n",
" )\n",
" collection_name=f\"{collection_name}\", optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000)\n",
")"
]
},
Expand Down Expand Up @@ -289,6 +281,7 @@
"source": [
"import random\n",
"from random import randint\n",
"\n",
"random.seed(37)\n",
"\n",
"query_indices = [randint(0, len(dataset)) for _ in range(100)]\n",
Expand All @@ -304,7 +297,10 @@
"source": [
"## Add Gaussian noise to any vector\n",
"import numpy as np\n",
"\n",
"np.random.seed(37)\n",
"\n",
"\n",
"def add_noise(vector, noise=0.05):\n",
" return vector + noise * np.random.randn(*vector.shape)"
]
Expand Down Expand Up @@ -959,6 +955,8 @@
],
"source": [
"import time\n",
"\n",
"\n",
"def correct(results, text):\n",
" result_texts = [x.payload[\"text\"] for x in results]\n",
" return text in result_texts\n",
Expand All @@ -977,7 +975,7 @@
" rescore=rescore,\n",
" oversampling=oversampling,\n",
" )\n",
" )\n",
" ),\n",
" )\n",
" correct_results += correct(results, text)\n",
" return correct_results\n",
Expand All @@ -996,14 +994,16 @@
" start = time.time()\n",
" correct_results = count_correct(query_dataset, limit=limit, oversampling=oversampling, rescore=rescore)\n",
" end = time.time()\n",
" results.append({\n",
" \"limit\": limit,\n",
" \"oversampling\": oversampling,\n",
" \"rescore\": rescore,\n",
" \"correct\": correct_results,\n",
" \"total queries\": len(query_dataset[\"text\"]),\n",
" \"time\": end - start,\n",
" })\n",
" results.append(\n",
" {\n",
" \"limit\": limit,\n",
" \"oversampling\": oversampling,\n",
" \"rescore\": rescore,\n",
" \"correct\": correct_results,\n",
" \"total queries\": len(query_dataset[\"text\"]),\n",
" \"time\": end - start,\n",
" }\n",
" )\n",
"\n",
"results_df = pd.DataFrame(results)\n",
"results_df"
Expand Down
10 changes: 5 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The default embedding supports "query" and "passage" prefixes for the input text

## 🚀 Installation

To install the FastEmbed library, pip works:
To install the FastEmbed library, pip works:

```bash
pip install fastembed
Expand All @@ -32,8 +32,8 @@ documents: List[str] = [
"passage: This is an example passage.",
"fastembed is supported by and maintained by Qdrant." # You can leave out the prefix but it's recommended
]
embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embeddings: List[np.ndarray] = embedding_model.embed(documents) # If you use
embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embeddings: List[np.ndarray] = embedding_model.embed(documents) # If you use
```

## Usage with Qdrant
Expand All @@ -44,7 +44,7 @@ Installation with Qdrant Client in Python:
pip install qdrant-client[fastembed]
```

Might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh.
Might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh.

```python
from qdrant_client import QdrantClient
Expand Down Expand Up @@ -73,4 +73,4 @@ search_result = client.query(
query_text="This is a query document"
)
print(search_result)
```
```
4 changes: 2 additions & 2 deletions docs/overrides/main.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<a href="{{ page.nb_url }}" title="Download Notebook" class="md-content__button md-icon jp-DownloadNB">
{% include ".icons/material/download.svg" %}
</a>
{% endif %}
{% endif %}

{{ super() }}

Expand All @@ -24,4 +24,4 @@
href="https://cloud.qdrant.io?utm_source=twitter&utm_medium=website&utm_campaign=fastembed">Qdrant Cloud</a> to
get started with vector search!
</div>
{% endblock %}
{% endblock %}
Loading

0 comments on commit 9b63427

Please sign in to comment.