From cf604670a04d3dc2c828f18bb6822f3d12a500ad Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Mon, 5 Feb 2024 13:22:44 +0000 Subject: [PATCH] add tests for quantized training with categorical features --- tests/python_package_test/test_engine.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e355e5ab074a..3047565555e8 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -309,7 +309,8 @@ def test_missing_value_handle_none(): assert evals_result['valid_0']['auc'][-1] == pytest.approx(ret) -def test_categorical_handle(): +@pytest.mark.parametrize('use_quantized_grad', [True, False]) +def test_categorical_handle(use_quantized_grad): x = [0, 1, 2, 3, 4, 5, 6, 7] y = [0, 1, 0, 1, 0, 1, 0, 1] @@ -332,7 +333,8 @@ def test_categorical_handle(): 'cat_l2': 0, 'max_cat_to_onehot': 1, 'zero_as_missing': True, - 'categorical_column': 0 + 'categorical_column': 0, + 'use_quantized_grad': use_quantized_grad, } evals_result = {} gbm = lgb.train( @@ -349,7 +351,8 @@ def test_categorical_handle(): assert evals_result['valid_0']['auc'][-1] == pytest.approx(ret) -def test_categorical_handle_na(): +@pytest.mark.parametrize('use_quantized_grad', [True, False]) +def test_categorical_handle_na(use_quantized_grad): x = [0, np.nan, 0, np.nan, 0, np.nan] y = [0, 1, 0, 1, 0, 1] @@ -372,7 +375,8 @@ def test_categorical_handle_na(): 'cat_l2': 0, 'max_cat_to_onehot': 1, 'zero_as_missing': False, - 'categorical_column': 0 + 'categorical_column': 0, + 'use_quantized_grad': use_quantized_grad, } evals_result = {} gbm = lgb.train( @@ -389,7 +393,8 @@ def test_categorical_handle_na(): assert evals_result['valid_0']['auc'][-1] == pytest.approx(ret) -def test_categorical_non_zero_inputs(): +@pytest.mark.parametrize('use_quantized_grad', [True, False]) +def test_categorical_non_zero_inputs(use_quantized_grad): x = [1, 1, 1, 1, 1, 1, 2, 2] y = [1, 1, 1, 1, 1, 1, 0, 0] @@ -412,7 +417,8 @@ def test_categorical_non_zero_inputs(): 'cat_l2': 0, 'max_cat_to_onehot': 1, 'zero_as_missing': False, - 'categorical_column': 0 + 'categorical_column': 0, + 'use_quantized_grad': use_quantized_grad, } evals_result = {} gbm = lgb.train(