From fffeaf89b668a4ec8df0786dcf75da3295f65dda Mon Sep 17 00:00:00 2001 From: Romain Yon Date: Fri, 13 Oct 2023 14:04:06 -0400 Subject: [PATCH] Fix default take value of train and predict methods --- nxontology_ml/model/predict.py | 6 +++--- nxontology_ml/model/train.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nxontology_ml/model/predict.py b/nxontology_ml/model/predict.py index 1365221..0840db1 100644 --- a/nxontology_ml/model/predict.py +++ b/nxontology_ml/model/predict.py @@ -49,7 +49,7 @@ def predict( training_set: tuple[np.ndarray, np.ndarray], nxo: NXOntology[str], include_feature_values: bool = True, - take: int | None = 0, + take: int | None = None, ) -> pd.DataFrame: target_nodes: list[str] = list(get_disease_nodes(take=take, nxo=nxo)) assert len(target_nodes) > 0, "No disease node found" @@ -96,8 +96,8 @@ def train_predict( nxo: NXOntology[str] | None = None, training_set: tuple[np.ndarray, np.ndarray] | None = None, include_feature_values: bool = True, - train_take: int | None = 0, - predict_take: int | None = 0, + train_take: int | None = None, + predict_take: int | None = None, ) -> pd.DataFrame: """ Run both model training and prediction tasks. diff --git a/nxontology_ml/model/train.py b/nxontology_ml/model/train.py index fa573c3..c45b851 100644 --- a/nxontology_ml/model/train.py +++ b/nxontology_ml/model/train.py @@ -24,7 +24,7 @@ def train_model( conf: ModelConfig = DEFAULT_MODEL_CONFIG, nxo: NXOntology[str] | None = None, training_set: tuple[np.ndarray, np.ndarray] | None = None, - take: int | None = 0, + take: int | None = None, ) -> tuple[Pipeline, CatBoostClassifier]: nxo = nxo or get_efo_otar_slim() nxo.freeze()