-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathingest.py
151 lines (121 loc) · 4.71 KB
/
ingest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
from shutil import rmtree
import click
import chromadb
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
from datasets import load_dataset
from const import (
DATABASE_PATH,
VECTOR_SEARCH_COLLECTION_NAME,
VECTOR_SEARCH_SENTENCE_TRANSFORMER_MODEL,
VECTOR_SEARCH_DISTANCE_FUNCTION,
VECTOR_SEARCH_NORMALIZE_EMBEDDINGS,
VECTOR_SEARCH_CHUNK_PREFIX,
)
@click.command()
@click.option(
"--sentence-transformer-device",
default="cpu",
help="On what device should Sentence Transformer run?",
)
def ingest(sentence_transformer_device: str) -> bool:
"""
Generates embeddings for the COLD French Law Dataset and stores them into a ChromaDB collection.
May require Hugging Face authentication via `huggingface-cli login`.
"""
#
# Clear up existing vector store, if any
#
rmtree(DATABASE_PATH, ignore_errors=True)
os.makedirs(DATABASE_PATH, exist_ok=True)
#
# Initialize vector store, embedding model and dataset access
#
chroma_client = chromadb.PersistentClient(
path=DATABASE_PATH,
settings=chromadb.Settings(anonymized_telemetry=False),
)
chroma_collection = chroma_client.create_collection(
name=VECTOR_SEARCH_COLLECTION_NAME,
metadata={"hnsw:space": VECTOR_SEARCH_DISTANCE_FUNCTION},
)
embedding_model = SentenceTransformer(
VECTOR_SEARCH_SENTENCE_TRANSFORMER_MODEL,
device=sentence_transformer_device,
)
text_splitter = SentenceTransformersTokenTextSplitter(
model_name=VECTOR_SEARCH_SENTENCE_TRANSFORMER_MODEL,
tokens_per_chunk=embedding_model[0].max_seq_length - 4,
chunk_overlap=25,
) # Note: The text splitter adjusts its cut-off based on the models' max_seq_length
dataset = load_dataset(
"harvard-lil/cold-french-law",
data_files="cold-french-law.csv",
split="train",
)
dataset_total = len(dataset)
embeddings_total = 0
dataset_i = 0
click.echo(f"{dataset_total} entries to ingest.")
# For each entry in the dataset:
# - Split text into chunks of X tokens
# - Generate embeddings and associated metadata
# - Add to vector store
for entry in dataset:
dataset_i += 1
text = entry_to_text(entry) # Generate text for current entry
text_chunks = text_splitter.split_text(text) # Split text into chunks
# Add VECTOR_SEARCH_CHUNK_PREFIX to every chunk
for i in range(0, len(text_chunks)):
text_chunks[i] = VECTOR_SEARCH_CHUNK_PREFIX + text_chunks[i]
# Status update
status = f"{dataset_i}/{dataset_total} | "
status += f"{entry['article_identifier']} was split into {len(text_chunks)} chunks."
click.echo(status)
# Generate embeddings and meta data for each chunk
embeddings = embedding_model.encode(
text_chunks,
normalize_embeddings=VECTOR_SEARCH_NORMALIZE_EMBEDDINGS,
)
documents = []
metadatas = []
ids = []
for i in range(0, len(text_chunks)):
documents.append(entry["article_identifier"])
ids.append(f"{entry['article_identifier']}-{i+1}")
metadata = {
"article_identifier": entry["article_identifier"],
"texte_nature": entry["texte_nature"] if entry["texte_nature"] else "",
"texte_titre": entry["texte_titre"] if entry["texte_titre"] else "",
"texte_ministere": entry["texte_ministere"] if entry["texte_ministere"] else "",
"text_chunk": text_chunks[i][len(VECTOR_SEARCH_CHUNK_PREFIX) :], # noqa
}
metadatas.append(metadata)
embeddings = embeddings.tolist()
embeddings_total += len(embeddings)
# Store embeddings and metadata
chroma_collection.add(
documents=documents,
embeddings=embeddings,
metadatas=metadatas,
ids=ids,
)
click.echo(f"Done - {embeddings_total} embeddings for {dataset_total} documents.")
return True
def entry_to_text(entry: dict) -> str:
"""
Generates an "embeddable" text version of a `harvard-lil/cold-french-law` dataset record.
"""
output = ""
# Pick a "title" based on texte_nature
if entry["texte_nature"] == "CODE":
output = f"Article {entry['article_num']} du {entry['texte_titre_court']}. "
else:
output = f"{entry['texte_titre']}. "
# Remove line-breaks to increase N tokens per embedding
text = entry["article_contenu_text"] if entry["article_contenu_text"] else ""
output += text.replace("\n", " ")
return output
if __name__ == "__main__":
ingest()