Skip to content

Commit

Permalink
Added learning_rate option for LightGBM - #2
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Apr 19, 2020
1 parent e2a1e22 commit e9246b9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.3.5 (unreleased)

- Added `learning_rate` option for LightGBM

## 0.3.4 (2020-04-05)

- Added `predict_probability` for classification
Expand Down
4 changes: 2 additions & 2 deletions lib/eps/base_estimator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _predict(data, probabilities)
singular ? predictions.first : predictions
end

def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil)
def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, text_features: nil, **options)
data, @target = prep_data(data, y, target, weight)
@target_type = Utils.column_type(data.label, @target)

Expand Down Expand Up @@ -175,7 +175,7 @@ def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: n
raise "No data in validation set" if validation_set && validation_set.empty?

@validation_set = validation_set
@evaluator = _train(verbose: verbose, early_stopping: early_stopping)
@evaluator = _train(**options)

# reset pmml
@pmml = nil
Expand Down
3 changes: 2 additions & 1 deletion lib/eps/lightgbm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _summary(extended: false)
str
end

def _train(verbose: nil, early_stopping: nil)
def _train(verbose: nil, early_stopping: nil, learning_rate: 0.1)

This comment has been minimized.

Copy link
@HansHauge

HansHauge Apr 19, 2020

Wow, thank you! 👍

train_set = @train_set
validation_set = @validation_set.dup
summary_label = train_set.label
Expand Down Expand Up @@ -66,6 +66,7 @@ def _train(verbose: nil, early_stopping: nil)
params[:min_data_in_bin] = 1
params[:min_data_in_leaf] = 1
end
params[:learning_rate] = learning_rate

# create datasets
categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last)
Expand Down
10 changes: 10 additions & 0 deletions test/lightgbm_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ def test_text_features_classification
assert_equal ["ham", "spam"], model.predict(test_data)
end

def test_learning_rate
data = mpg_data
model = Eps::LightGBM.new(data, target: :hwy, split: false, learning_rate: 1)
assert model.summary

expected = [30.80980036, 34.39919293, 17.99841545, 17.298401, 28.4685196, 29.31558087, 27.79557906, 18.44068633, 24.64178236, 29.31558087]
predictions = model.predict(data.first(10))
assert_elements_in_delta expected, predictions
end

private

def model
Expand Down

0 comments on commit e9246b9

Please sign in to comment.