Skip to content

Commit

Permalink
add tests for quantized training with categorical features
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Feb 5, 2024
1 parent b07caf2 commit cf60467
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def test_missing_value_handle_none():
assert evals_result['valid_0']['auc'][-1] == pytest.approx(ret)


def test_categorical_handle():
@pytest.mark.parametrize('use_quantized_grad', [True, False])
def test_categorical_handle(use_quantized_grad):
x = [0, 1, 2, 3, 4, 5, 6, 7]
y = [0, 1, 0, 1, 0, 1, 0, 1]

Expand All @@ -332,7 +333,8 @@ def test_categorical_handle():
'cat_l2': 0,
'max_cat_to_onehot': 1,
'zero_as_missing': True,
'categorical_column': 0
'categorical_column': 0,
'use_quantized_grad': use_quantized_grad,
}
evals_result = {}
gbm = lgb.train(
Expand All @@ -349,7 +351,8 @@ def test_categorical_handle():
assert evals_result['valid_0']['auc'][-1] == pytest.approx(ret)


def test_categorical_handle_na():
@pytest.mark.parametrize('use_quantized_grad', [True, False])
def test_categorical_handle_na(use_quantized_grad):
x = [0, np.nan, 0, np.nan, 0, np.nan]
y = [0, 1, 0, 1, 0, 1]

Expand All @@ -372,7 +375,8 @@ def test_categorical_handle_na():
'cat_l2': 0,
'max_cat_to_onehot': 1,
'zero_as_missing': False,
'categorical_column': 0
'categorical_column': 0,
'use_quantized_grad': use_quantized_grad,
}
evals_result = {}
gbm = lgb.train(
Expand All @@ -389,7 +393,8 @@ def test_categorical_handle_na():
assert evals_result['valid_0']['auc'][-1] == pytest.approx(ret)


def test_categorical_non_zero_inputs():
@pytest.mark.parametrize('use_quantized_grad', [True, False])
def test_categorical_non_zero_inputs(use_quantized_grad):
x = [1, 1, 1, 1, 1, 1, 2, 2]
y = [1, 1, 1, 1, 1, 1, 0, 0]

Expand All @@ -412,7 +417,8 @@ def test_categorical_non_zero_inputs():
'cat_l2': 0,
'max_cat_to_onehot': 1,
'zero_as_missing': False,
'categorical_column': 0
'categorical_column': 0,
'use_quantized_grad': use_quantized_grad,
}
evals_result = {}
gbm = lgb.train(
Expand Down

0 comments on commit cf60467

Please sign in to comment.