Skip to content

Commit

Permalink
Use multithreading to fetch text embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
yonromai committed Oct 12, 2023
1 parent 9eeb752 commit b6f441d
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions nxontology_ml/text_embeddings/text_embeddings_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -29,11 +32,13 @@ def __init__(
lda: LDA | None,
pca: PCA | None,
embedding_model: AutoModelEmbeddings,
max_workers: int | None = None,
):
self._enabled = enabled
self._lda = lda
self._pca = pca
self._embedding_model = embedding_model
self._max_workers = max_workers or os.cpu_count() or 1

def fit(
self,
Expand Down Expand Up @@ -78,15 +83,18 @@ def transform(self, X: NodeFeatures, copy: bool | None = None) -> NodeFeatures:
return X

def _nodes_to_vec(self, X: NodeFeatures) -> np.ndarray:
embedded_nodes: list[np.array] = []
for node in tqdm(
X.nodes,
desc="Fetching node embeddings",
delay=5,
):
embedded_nodes.append(self._embedding_model.embed_node(node))

return np.array(embedded_nodes)
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
parallel_it = executor.map(self._embedding_model.embed_node, X.nodes)
return np.array(
list(
tqdm(
parallel_it,
desc="Fetching node embeddings",
total=len(X.nodes),
delay=5,
)
)
)

@classmethod
def from_config(
Expand Down

0 comments on commit b6f441d

Please sign in to comment.