diff --git a/nxontology_ml/model/predict.py b/nxontology_ml/model/predict.py index 6fb4254..1365221 100644 --- a/nxontology_ml/model/predict.py +++ b/nxontology_ml/model/predict.py @@ -52,6 +52,7 @@ def predict( take: int | None = 0, ) -> pd.DataFrame: target_nodes: list[str] = list(get_disease_nodes(take=take, nxo=nxo)) + assert len(target_nodes) > 0, "No disease node found" target_features = feature_pipeline.transform(target_nodes) target_labels = model.predict(target_features) target_probas = model.predict_proba(target_features) diff --git a/nxontology_ml/sklearn_transformer.py b/nxontology_ml/sklearn_transformer.py index 1095d80..1cd0308 100644 --- a/nxontology_ml/sklearn_transformer.py +++ b/nxontology_ml/sklearn_transformer.py @@ -8,6 +8,7 @@ from nxontology.node import NodeInfo from pandas.core.dtypes.base import ExtensionDtype from sklearn.base import TransformerMixin +from tqdm import tqdm @dataclass @@ -73,16 +74,32 @@ def transform(self, X: NodeFeatures, copy: bool | None = None) -> NodeFeatures: return X if self._num_features_fn: assert self._num_features_names + + vecs: list[np.array] = [] + for node in tqdm( + X.nodes, + desc=f"{self.__class__.__name__}: Computing num features", + delay=5, + ): + vecs.append(self._num_features_fn(node)) + new_features = pd.DataFrame( - data=[self._num_features_fn(node) for node in X.nodes], + data=vecs, columns=self._num_features_names, dtype=self._num_feature_dtype, ) X.num_features = pd.concat([X.num_features, new_features], axis=1) if self._cat_features_fn: assert self._cat_features_names + cat_vecs: list[np.array] = [] + for node in tqdm( + X.nodes, + desc=f"{self.__class__.__name__}: Computing cat features", + delay=5, + ): + cat_vecs.append(self._cat_features_fn(node)) new_features = pd.DataFrame( - data=[self._cat_features_fn(node) for node in X.nodes], + data=cat_vecs, columns=self._cat_features_names, ) X.cat_features = pd.concat([X.cat_features, new_features], axis=1) diff --git a/nxontology_ml/text_embeddings/text_embeddings_transformer.py b/nxontology_ml/text_embeddings/text_embeddings_transformer.py index ada34b9..f788272 100644 --- a/nxontology_ml/text_embeddings/text_embeddings_transformer.py +++ b/nxontology_ml/text_embeddings/text_embeddings_transformer.py @@ -3,6 +3,7 @@ from sklearn.base import TransformerMixin from sklearn.decomposition import PCA from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA +from tqdm import tqdm from nxontology_ml.sklearn_transformer import ( NodeFeatures, @@ -77,7 +78,15 @@ def transform(self, X: NodeFeatures, copy: bool | None = None) -> NodeFeatures: return X def _nodes_to_vec(self, X: NodeFeatures) -> np.ndarray: - return np.array([self._embedding_model.embed_node(node) for node in X.nodes]) + embedded_nodes: list[np.array] = [] + for node in tqdm( + X.nodes, + desc="Fetching node embeddings", + delay=5, + ): + embedded_nodes.append(self._embedding_model.embed_node(node)) + + return np.array(embedded_nodes) @classmethod def from_config(