From e9246b9ea7626138833a42ee0db8b3a735d96b8e Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 19 Apr 2020 13:15:46 -0700 Subject: [PATCH] Added learning_rate option for LightGBM - #2 --- CHANGELOG.md | 4 ++++ lib/eps/base_estimator.rb | 4 ++-- lib/eps/lightgbm.rb | 3 ++- test/lightgbm_test.rb | 10 ++++++++++ 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a71625..6fe8c26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.3.5 (unreleased) + +- Added `learning_rate` option for LightGBM + ## 0.3.4 (2020-04-05) - Added `predict_probability` for classification diff --git a/lib/eps/base_estimator.rb b/lib/eps/base_estimator.rb index a57bf15..cb6e2e2 100644 --- a/lib/eps/base_estimator.rb +++ b/lib/eps/base_estimator.rb @@ -83,7 +83,7 @@ def _predict(data, probabilities) singular ? predictions.first : predictions end - def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil) + def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, text_features: nil, **options) data, @target = prep_data(data, y, target, weight) @target_type = Utils.column_type(data.label, @target) @@ -175,7 +175,7 @@ def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: n raise "No data in validation set" if validation_set && validation_set.empty? @validation_set = validation_set - @evaluator = _train(verbose: verbose, early_stopping: early_stopping) + @evaluator = _train(**options) # reset pmml @pmml = nil diff --git a/lib/eps/lightgbm.rb b/lib/eps/lightgbm.rb index 6b359ef..289eec9 100644 --- a/lib/eps/lightgbm.rb +++ b/lib/eps/lightgbm.rb @@ -17,7 +17,7 @@ def _summary(extended: false) str end - def _train(verbose: nil, early_stopping: nil) + def _train(verbose: nil, early_stopping: nil, learning_rate: 0.1) train_set = @train_set validation_set = @validation_set.dup summary_label = train_set.label @@ -66,6 +66,7 @@ def _train(verbose: nil, early_stopping: nil) params[:min_data_in_bin] = 1 params[:min_data_in_leaf] = 1 end + params[:learning_rate] = learning_rate # create datasets categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last) diff --git a/test/lightgbm_test.rb b/test/lightgbm_test.rb index d5ccb27..02aa541 100644 --- a/test/lightgbm_test.rb +++ b/test/lightgbm_test.rb @@ -218,6 +218,16 @@ def test_text_features_classification assert_equal ["ham", "spam"], model.predict(test_data) end + def test_learning_rate + data = mpg_data + model = Eps::LightGBM.new(data, target: :hwy, split: false, learning_rate: 1) + assert model.summary + + expected = [30.80980036, 34.39919293, 17.99841545, 17.298401, 28.4685196, 29.31558087, 27.79557906, 18.44068633, 24.64178236, 29.31558087] + predictions = model.predict(data.first(10)) + assert_elements_in_delta expected, predictions + end + private def model