Skip to content

Commit

Permalink
Improves lightgbm conversion speed (#491)
Browse files Browse the repository at this point in the history
* improves lightgbm conversion speed
  • Loading branch information
xadupre authored Aug 20, 2021
1 parent 3d81a0a commit cb2782b
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 57 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ jobs:
displayName: 'Install dependencies'
- script: |
pip install flake8
python -m flake8 ./onnxmltools
displayName: 'run flake8 check'
Expand Down
2 changes: 1 addition & 1 deletion onnxmltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
This framework converts any machine learned model into onnx format
which is a common language to describe any machine learned model.
"""
__version__ = "1.8.0"
__version__ = "1.9.0"
__author__ = "Microsoft"
__producer__ = "OnnxMLTools"
__producer_version__ = __version__
Expand Down
45 changes: 33 additions & 12 deletions onnxmltools/convert/lightgbm/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,49 @@ class WrappedBooster:

def __init__(self, booster):
self.booster_ = booster
_model_dict = self.booster_.dump_model()
self.classes_ = self._generate_classes(_model_dict)
self.n_features_ = len(_model_dict['feature_names'])
if (_model_dict['objective'].startswith('binary') or
_model_dict['objective'].startswith('multiclass')):
self.n_features_ = self.booster_.feature_name()
self.objective_ = self.get_objective()
if self.objective_.startswith('binary'):
self.operator_name = 'LgbmClassifier'
elif _model_dict['objective'].startswith(('regression', 'poisson', 'gamma')):
self.classes_ = self._generate_classes(booster)
elif self.objective_.startswith('multiclass'):
self.operator_name = 'LgbmClassifier'
self.classes_ = self._generate_classes(booster)
elif self.objective_.startswith('regression'):
self.operator_name = 'LgbmRegressor'
else:
# Other objectives are not supported.
raise ValueError("Unsupported LightGbm objective: '{}'.".format(_model_dict['objective']))
if _model_dict.get('average_output', False):
raise NotImplementedError(
'Unsupported LightGbm objective: %r.' % self.objective_)
average_output = self.booster_.attr('average_output')
if average_output:
self.boosting_type = 'rf'
else:
# Other than random forest, other boosting types do not affect later conversion.
# Here `gbdt` is chosen for no reason.
self.boosting_type = 'gbdt'

def _generate_classes(self, model_dict):
if model_dict['num_class'] == 1:
@staticmethod
def _generate_classes(booster):
if isinstance(booster, dict):
num_class = booster['num_class']
else:
num_class = booster.attr('num_class')
if num_class is None:
dp = booster.dump_model(num_iteration=1)
num_class = dp['num_class']
if num_class == 1:
return numpy.asarray([0, 1])
return numpy.arange(model_dict['num_class'])
return numpy.arange(num_class)

def get_objective(self):
"Returns the objective."
if hasattr(self, 'objective_') and self.objective_ is not None:
return self.objective_
objective = self.booster_.attr('objective')
if objective is not None:
return objective
dp = self.booster_.dump_model(num_iteration=1)
return dp['objective']


def _get_lightgbm_operator_name(model):
Expand Down
Loading

0 comments on commit cb2782b

Please sign in to comment.