Skip to content

Commit

Permalink
use model_path_txt fixture for lgb.model
Browse files Browse the repository at this point in the history
  • Loading branch information
nicklamiller committed Jul 4, 2024
1 parent c180b76 commit 262bfa8
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def categorize(continuous_x):
return np.digitize(continuous_x, bins=np.arange(0, 1, 0.01))


@pytest.fixture
def model_path_txt(tmp_path):
return str(tmp_path / "lgb.model")


def test_binary():
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down Expand Up @@ -1258,14 +1263,13 @@ def test_cv():
np.testing.assert_allclose(cv_res_lambda["valid ndcg@3-mean"], cv_res_lambda_obj["valid ndcg@3-mean"])


def test_cv_works_with_init_model(tmp_path):
def test_cv_works_with_init_model(model_path_txt):
X, y = make_synthetic_regression()
params = {"objective": "regression", "verbose": -1}
num_train_rounds = 2
lgb_train = lgb.Dataset(X, y, free_raw_data=False)
bst = lgb.train(params=params, train_set=lgb_train, num_boost_round=num_train_rounds)
preds_raw = bst.predict(X, raw_score=True)
model_path_txt = str(tmp_path / "lgb.model")
bst.save_model(model_path_txt)

num_cv_rounds = 5
Expand Down Expand Up @@ -1349,7 +1353,7 @@ def test_cvbooster():
assert ret < 0.15


def test_cvbooster_save_load(tmp_path):
def test_cvbooster_save_load(model_path_txt):
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
Expand All @@ -1372,8 +1376,6 @@ def test_cvbooster_save_load(tmp_path):
preds = cvbooster.predict(X_test)
best_iteration = cvbooster.best_iteration

model_path_txt = str(tmp_path / "lgb.model")

cvbooster.save_model(model_path_txt)
model_string = cvbooster.model_to_string()
del cvbooster
Expand Down Expand Up @@ -1432,7 +1434,7 @@ def test_feature_name():
assert feature_names == gbm.feature_name()


def test_feature_name_with_non_ascii(rng, tmp_path):
def test_feature_name_with_non_ascii(rng, model_path_txt):
X_train = rng.normal(size=(100, 4))
y_train = rng.normal(size=(100,))
# This has non-ascii strings.
Expand All @@ -1442,7 +1444,6 @@ def test_feature_name_with_non_ascii(rng, tmp_path):

gbm = lgb.train(params, lgb_train, num_boost_round=5)
assert feature_names == gbm.feature_name()
model_path_txt = str(tmp_path / "lgb.model")
gbm.save_model(model_path_txt)

gbm2 = lgb.Booster(model_file=model_path_txt)
Expand Down Expand Up @@ -1498,7 +1499,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
np.testing.assert_allclose(preds, orig_preds)


def test_save_load_copy_pickle(tmp_path):
def test_save_load_copy_pickle(model_path_txt, tmp_path):
def train_and_predict(init_model=None, return_model=False):
X, y = make_synthetic_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand All @@ -1510,7 +1511,6 @@ def train_and_predict(init_model=None, return_model=False):
gbm = train_and_predict(return_model=True)
ret_origin = train_and_predict(init_model=gbm)
other_ret = []
model_path_txt = str(tmp_path / "lgb.model")
gbm.save_model(model_path_txt)
with open(model_path_txt) as f: # check all params are logged into model file correctly
assert f.read().find("[num_iterations: 10]") != -1
Expand Down

0 comments on commit 262bfa8

Please sign in to comment.