Skip to content

Commit

Permalink
Rename model + Add SPLADE notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
NirantK committed Mar 20, 2024
1 parent 256b226 commit c907b3b
Show file tree
Hide file tree
Showing 2 changed files with 393 additions and 1 deletion.
392 changes: 392 additions & 0 deletions docs/examples/SPLADE_with_FastEmbed.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,392 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction to SPLADE with FastEmbed\n",
"\n",
"In this notebook, we will explore how to generate Sparse Vectors -- in particular a variant of the [SPLADE](https://arxiv.org/abs/2107.05720).\n",
"\n",
"> 💡 The original [naver/SPLADE](https://github.com/naver/splade) models were licensed CC BY-NC-SA 4.0 -- Not for Commercial Use. This [SPLADE++](https://huggingface.co/prithivida/Splade_PP_en_v1) model is Apache License and hence, licensed for commercial use. \n",
"\n",
"## Outline:\n",
"1. [What is SPALDE?](#What-is-SPLADE?)\n",
"2. [Setting up the environment](#Setting-up-the-environment)\n",
"3. Generating SPLADE vectors with FastEmbed\n",
"4. Understanding SPLADE vectors\n",
"5. Applications of SPLADE vectors\n",
"\n",
"## What is SPLADE?\n",
"\n",
"SPLADE was a novel method for _learning_ sparse vectors for text representation. This particular model beats BM25 -- the underlying approach for Elastic/Lucene family of implementations. Thus making it highly effective for tasks such as information retrieval, document classification, and more. \n",
"\n",
"The key advantage of SPLADE is its ability to generate sparse vectors, which are more efficient and interpretable than dense vectors. This makes SPLADE a powerful tool for handling large-scale text data.\n",
"\n",
"## Setting up the environment\n",
"\n",
"This notebook uses few dependencies, which are installed below: "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# !pip install -q fastembed"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get started! 🚀"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding, SparseEmbedding\n",
"from typing import List"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> You can find the list of all supported Sparse Embedding models by calling this API: `SparseTextEmbedding.list_supported_models()`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'model': 'prithvida/Splade_PP_en_v1',\n",
" 'vocab_size': 30522,\n",
" 'description': 'Independent Implementation of SPLADE++ Model for English',\n",
" 'size_in_GB': 0.532,\n",
" 'sources': {'hf': 'Qdrant/SPLADE_PP_en_v1'}}]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SparseTextEmbedding.list_supported_models()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4308dd64349499b9107d723a035ce73",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 9 files: 0%| | 0/9 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_name = \"prithvida/Splade_PP_en_v1\"\n",
"# This triggers the model download\n",
"model = SparseTextEmbedding(model_name=model_name)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"documents: List[str] = [\n",
" \"Chandrayaan-3 is India's third lunar mission\",\n",
" \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n",
" \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n",
" \"Chandrayaan-3 will be launched by the Indian Space Research Organisation (ISRO)\",\n",
" \"The estimated cost of the mission is around $35 million\",\n",
" \"It will carry instruments to study the lunar surface and atmosphere\",\n",
" \"Chandrayaan-3 landed on the Moon's surface on 23rd August 2023\",\n",
" \"It consists of a lander named Vikram and a rover named Pragyan similar to Chandrayaan-2. Its propulsion module would act like an orbiter.\",\n",
" \"The propulsion module carries the lander and rover configuration until the spacecraft is in a 100-kilometre (62 mi) lunar orbit\",\n",
" \"The mission used GSLV Mk III rocket for its launch\",\n",
" \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n",
" \"Chandrayaan-3 was launched earlier in the year 2023\",\n",
"]\n",
"sparse_embeddings_list: List[SparseEmbedding] = list(\n",
" model.embed(documents, batch_size=6)\n",
") # batch_size is optional, notice the generator"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SparseEmbedding(values=array([0.05297208, 0.01963477, 0.36459631, 1.38508618, 0.71776593,\n",
" 0.12667948, 0.46230844, 0.446771 , 0.26897505, 1.01519883,\n",
" 1.5655334 , 0.29412213, 1.53102326, 0.59785569, 1.1001817 ,\n",
" 0.02079751, 0.09955651, 0.44249091, 0.09747757, 1.53519952,\n",
" 1.36765671, 0.15740395, 0.49882549, 0.38629025, 0.76612782,\n",
" 1.25805044, 0.39058095, 0.27236196, 0.45152301, 0.48262018,\n",
" 0.26085234, 1.35912788, 0.70710695, 1.71639752]), indices=array([ 1010, 1011, 1016, 1017, 2001, 2018, 2034, 2093, 2117,\n",
" 2319, 2353, 2509, 2634, 2686, 2796, 2817, 2922, 2959,\n",
" 3003, 3148, 3260, 3390, 3462, 3523, 3822, 4231, 4316,\n",
" 4774, 5590, 5871, 6416, 11926, 12076, 16469]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index = 0\n",
"sparse_embeddings_list[index]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The previous output is a SparseEmbedding object for the first document in our list.\n",
"\n",
"It contains two arrays: values and indices. \n",
"- The 'values' array represents the weights of the features (tokens) in the document.\n",
"- The 'indices' array represents the indices of these features in the model's vocabulary.\n",
"\n",
"Each pair of corresponding values and indices represents a token and its weight in the document."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Token at index 1010 has weight 0.05297207832336426\n",
"Token at index 1011 has weight 0.01963476650416851\n",
"Token at index 1016 has weight 0.36459630727767944\n",
"Token at index 1017 has weight 1.385086178779602\n",
"Token at index 2001 has weight 0.7177659273147583\n"
]
}
],
"source": [
"# Let's print the first 5 features and their weights for better understanding.\n",
"for i in range(5):\n",
" print(f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is still a little abstract, so let's use the tokenizer vocab to make sense of these indices."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3e13d98f38244741aa8d5473991a30aa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/1.38k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "67128c35e1504a99b8ec61b0f572a90c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "85522823dd444b218c00e4e77c9d5d51",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/712k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4f890e772918465dba9177f0e74478f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/695 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{',': 0.05297207832336426, '-': 0.01963476650416851, '2': 0.36459630727767944, '3': 1.385086178779602, 'was': 0.7177659273147583, 'had': 0.12667948007583618, 'first': 0.46230843663215637, 'three': 0.4467709958553314, 'second': 0.26897504925727844, '##an': 1.015198826789856, 'third': 1.5655333995819092, '##3': 0.2941221296787262, 'india': 1.5310232639312744, 'space': 0.5978556871414185, 'indian': 1.1001816987991333, 'study': 0.02079751156270504, 'largest': 0.09955651313066483, 'fourth': 0.44249090552330017, 'leader': 0.09747757017612457, '##ya': 1.535199522972107, 'mission': 1.3676567077636719, 'launched': 0.15740394592285156, 'flight': 0.4988254904747009, 'iii': 0.3862902522087097, '3rd': 0.7661278247833252, 'moon': 1.2580504417419434, 'vehicle': 0.390580952167511, 'planet': 0.27236196398735046, 'expedition': 0.4515230059623718, 'satellite': 0.4826201796531677, 'missions': 0.2608523368835449, 'lunar': 1.3591278791427612, 'spacecraft': 0.7071069478988647, 'chandra': 1.7163975238800049}\n"
]
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"chandra\": 1.7163975238800049,\n",
" \"third\": 1.5655333995819092,\n",
" \"##ya\": 1.535199522972107,\n",
" \"india\": 1.5310232639312744,\n",
" \"3\": 1.385086178779602,\n",
" \"mission\": 1.3676567077636719,\n",
" \"lunar\": 1.3591278791427612,\n",
" \"moon\": 1.2580504417419434,\n",
" \"indian\": 1.1001816987991333,\n",
" \"##an\": 1.015198826789856,\n",
" \"3rd\": 0.7661278247833252,\n",
" \"was\": 0.7177659273147583,\n",
" \"spacecraft\": 0.7071069478988647,\n",
" \"space\": 0.5978556871414185,\n",
" \"flight\": 0.4988254904747009,\n",
" \"satellite\": 0.4826201796531677,\n",
" \"first\": 0.46230843663215637,\n",
" \"expedition\": 0.4515230059623718,\n",
" \"three\": 0.4467709958553314,\n",
" \"fourth\": 0.44249090552330017,\n",
" \"vehicle\": 0.390580952167511,\n",
" \"iii\": 0.3862902522087097,\n",
" \"2\": 0.36459630727767944,\n",
" \"##3\": 0.2941221296787262,\n",
" \"planet\": 0.27236196398735046,\n",
" \"second\": 0.26897504925727844,\n",
" \"missions\": 0.2608523368835449,\n",
" \"launched\": 0.15740394592285156,\n",
" \"had\": 0.12667948007583618,\n",
" \"largest\": 0.09955651313066483,\n",
" \"leader\": 0.09747757017612457,\n",
" \",\": 0.05297207832336426,\n",
" \"study\": 0.02079751156270504,\n",
" \"-\": 0.01963476650416851\n",
"}\n"
]
}
],
"source": [
"import json\n",
"\n",
"\n",
"def get_tokens_and_weights(sparse_embedding, tokenizer):\n",
" token_weight_dict = {}\n",
" for i in range(len(sparse_embedding.indices)):\n",
" token = tokenizer.decode([sparse_embedding.indices[i]])\n",
" weight = sparse_embedding.values[i]\n",
" token_weight_dict[token] = weight\n",
"\n",
" # Sort the dictionary by weights\n",
" token_weight_dict = dict(sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True))\n",
" return token_weight_dict\n",
"\n",
"\n",
"# Test the function with the first SparseEmbedding\n",
"print(json.dumps(get_tokens_and_weights(sparse_embeddings_list[index], tokenizer), indent=4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Observations\n",
"\n",
"1. The relative order of importance is quite useful. The most important tokens in the sentence have the highest weights.\n",
"1. **Term Expansion**: The model can expand the terms in the document. This means that the model can generate weights for tokens that are not present in the document but are related to the tokens in the document. This is a powerful feature that allows the model to capture the context of the document. Here, you'll see that the model has added the token '3' from 'third' and 'moon' from 'lunar' to the sparse vector.\n",
"\n",
"## Design Choices\n",
"\n",
"1. The weights are not normalized. This means that the sum of the weights is not 1 or 100. This is a common practice in sparse embeddings, as it allows the model to capture the importance of each token in the document.\n",
"1. Tokens are included in the sparse vector only if they are present in the model's vocabulary. This means that the model will not generate a weight for tokens that it has not seen during training."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fst",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit c907b3b

Please sign in to comment.