Skip to content

Commit

Permalink
Fix default take value of train and predict methods
Browse files Browse the repository at this point in the history
  • Loading branch information
yonromai committed Oct 13, 2023
1 parent bbe83ba commit 3c6cc0f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions nxontology_ml/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def predict(
training_set: tuple[np.ndarray, np.ndarray],
nxo: NXOntology[str],
include_feature_values: bool = True,
take: int | None = 0,
take: int | None = None,
) -> pd.DataFrame:
target_nodes: list[str] = list(get_disease_nodes(take=take, nxo=nxo))
assert len(target_nodes) > 0, "No disease node found"
Expand Down Expand Up @@ -96,8 +96,8 @@ def train_predict(
nxo: NXOntology[str] | None = None,
training_set: tuple[np.ndarray, np.ndarray] | None = None,
include_feature_values: bool = True,
train_take: int | None = 0,
predict_take: int | None = 0,
train_take: int | None = None,
predict_take: int | None = None,
) -> pd.DataFrame:
"""
Run both model training and prediction tasks.
Expand Down
2 changes: 1 addition & 1 deletion nxontology_ml/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train_model(
conf: ModelConfig = DEFAULT_MODEL_CONFIG,
nxo: NXOntology[str] | None = None,
training_set: tuple[np.ndarray, np.ndarray] | None = None,
take: int | None = 0,
take: int | None = None,
) -> tuple[Pipeline, CatBoostClassifier]:
nxo = nxo or get_efo_otar_slim()
nxo.freeze()
Expand Down

0 comments on commit 3c6cc0f

Please sign in to comment.