diff --git a/nxontology_ml/text_embeddings/text_embeddings_transformer.py b/nxontology_ml/text_embeddings/text_embeddings_transformer.py index f788272..065da94 100644 --- a/nxontology_ml/text_embeddings/text_embeddings_transformer.py +++ b/nxontology_ml/text_embeddings/text_embeddings_transformer.py @@ -1,3 +1,6 @@ +import os +from concurrent.futures import ThreadPoolExecutor + import numpy as np import pandas as pd from sklearn.base import TransformerMixin @@ -29,11 +32,13 @@ def __init__( lda: LDA | None, pca: PCA | None, embedding_model: AutoModelEmbeddings, + max_workers: int | None = None, ): self._enabled = enabled self._lda = lda self._pca = pca self._embedding_model = embedding_model + self._max_workers = max_workers or os.cpu_count() or 1 def fit( self, @@ -78,15 +83,18 @@ def transform(self, X: NodeFeatures, copy: bool | None = None) -> NodeFeatures: return X def _nodes_to_vec(self, X: NodeFeatures) -> np.ndarray: - embedded_nodes: list[np.array] = [] - for node in tqdm( - X.nodes, - desc="Fetching node embeddings", - delay=5, - ): - embedded_nodes.append(self._embedding_model.embed_node(node)) - - return np.array(embedded_nodes) + with ThreadPoolExecutor(max_workers=self._max_workers) as executor: + parallel_it = executor.map(self._embedding_model.embed_node, X.nodes) + return np.array( + list( + tqdm( + parallel_it, + desc="Fetching node embeddings", + total=len(X.nodes), + delay=5, + ) + ) + ) @classmethod def from_config(