Skip to content

Commit

Permalink
[python-package] Allow to pass early stopping min delta in params
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Jan 14, 2024
1 parent ef2a49c commit c8ae768
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def train(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
first_metric_only=first_metric_only,
min_delta=params.get("early_stopping_min_delta", 0.0),
verbose=_choose_param_value(
main_param_name="verbosity",
params=params,
Expand Down Expand Up @@ -737,6 +738,7 @@ def cv(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
first_metric_only=first_metric_only,
min_delta=params.get("early_stopping_min_delta", 0.0),
verbose=_choose_param_value(
main_param_name="verbosity",
params=params,
Expand Down
14 changes: 10 additions & 4 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,8 +981,11 @@ def train_fn():
assert bst.best_iteration == 0


@pytest.mark.parametrize('first_metric_only', [True, False])
def test_early_stopping_via_global_params(first_metric_only):
@pytest.mark.parametrize(
('first_metric_only', 'early_stopping_min_delta'),
[(True, 0.0), (True, 1e3), (False, 0.0)]
)
def test_early_stopping_via_global_params(first_metric_only, early_stopping_min_delta):
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
Expand All @@ -991,7 +994,8 @@ def test_early_stopping_via_global_params(first_metric_only):
'metric': 'None',
'verbose': -1,
'early_stopping_round': 2,
'first_metric_only': first_metric_only
'first_metric_only': first_metric_only,
'early_stopping_min_delta': early_stopping_min_delta,
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
Expand All @@ -1002,8 +1006,10 @@ def test_early_stopping_via_global_params(first_metric_only):
feval=[decreasing_metric, constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name)
if first_metric_only:
if first_metric_only and early_stopping_min_delta == 0:
assert gbm.best_iteration == num_trees
elif first_metric_only:
assert gbm.best_iteration == 2
else:
assert gbm.best_iteration == 1
assert valid_set_name in gbm.best_score
Expand Down

0 comments on commit c8ae768

Please sign in to comment.