From cb2782b155ff67dc1e586f36a27c5d032070c801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 20 Aug 2021 12:40:08 +0200 Subject: [PATCH] Improves lightgbm conversion speed (#491) * improves lightgbm conversion speed --- .azure-pipelines/linux-conda-CI.yml | 1 + onnxmltools/__init__.py | 2 +- onnxmltools/convert/lightgbm/_parse.py | 45 ++- .../lightgbm/operator_converters/LightGbm.py | 295 +++++++++++++++--- .../test_LightGbmTreeEnsembleConverters.py | 24 +- 5 files changed, 310 insertions(+), 57 deletions(-) diff --git a/.azure-pipelines/linux-conda-CI.yml b/.azure-pipelines/linux-conda-CI.yml index 8a4c0a53d..1df477490 100644 --- a/.azure-pipelines/linux-conda-CI.yml +++ b/.azure-pipelines/linux-conda-CI.yml @@ -97,6 +97,7 @@ jobs: displayName: 'Install dependencies' - script: | + pip install flake8 python -m flake8 ./onnxmltools displayName: 'run flake8 check' diff --git a/onnxmltools/__init__.py b/onnxmltools/__init__.py index e6441c043..c7f618b4b 100644 --- a/onnxmltools/__init__.py +++ b/onnxmltools/__init__.py @@ -5,7 +5,7 @@ This framework converts any machine learned model into onnx format which is a common language to describe any machine learned model. """ -__version__ = "1.8.0" +__version__ = "1.9.0" __author__ = "Microsoft" __producer__ = "OnnxMLTools" __producer_version__ = __version__ diff --git a/onnxmltools/convert/lightgbm/_parse.py b/onnxmltools/convert/lightgbm/_parse.py index d711abe0c..fd68e6f15 100644 --- a/onnxmltools/convert/lightgbm/_parse.py +++ b/onnxmltools/convert/lightgbm/_parse.py @@ -21,28 +21,49 @@ class WrappedBooster: def __init__(self, booster): self.booster_ = booster - _model_dict = self.booster_.dump_model() - self.classes_ = self._generate_classes(_model_dict) - self.n_features_ = len(_model_dict['feature_names']) - if (_model_dict['objective'].startswith('binary') or - _model_dict['objective'].startswith('multiclass')): + self.n_features_ = self.booster_.feature_name() + self.objective_ = self.get_objective() + if self.objective_.startswith('binary'): self.operator_name = 'LgbmClassifier' - elif _model_dict['objective'].startswith(('regression', 'poisson', 'gamma')): + self.classes_ = self._generate_classes(booster) + elif self.objective_.startswith('multiclass'): + self.operator_name = 'LgbmClassifier' + self.classes_ = self._generate_classes(booster) + elif self.objective_.startswith('regression'): self.operator_name = 'LgbmRegressor' else: - # Other objectives are not supported. - raise ValueError("Unsupported LightGbm objective: '{}'.".format(_model_dict['objective'])) - if _model_dict.get('average_output', False): + raise NotImplementedError( + 'Unsupported LightGbm objective: %r.' % self.objective_) + average_output = self.booster_.attr('average_output') + if average_output: self.boosting_type = 'rf' else: # Other than random forest, other boosting types do not affect later conversion. # Here `gbdt` is chosen for no reason. self.boosting_type = 'gbdt' - def _generate_classes(self, model_dict): - if model_dict['num_class'] == 1: + @staticmethod + def _generate_classes(booster): + if isinstance(booster, dict): + num_class = booster['num_class'] + else: + num_class = booster.attr('num_class') + if num_class is None: + dp = booster.dump_model(num_iteration=1) + num_class = dp['num_class'] + if num_class == 1: return numpy.asarray([0, 1]) - return numpy.arange(model_dict['num_class']) + return numpy.arange(num_class) + + def get_objective(self): + "Returns the objective." + if hasattr(self, 'objective_') and self.objective_ is not None: + return self.objective_ + objective = self.booster_.attr('objective') + if objective is not None: + return objective + dp = self.booster_.dump_model(num_iteration=1) + return dp['objective'] def _get_lightgbm_operator_name(model): diff --git a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py index aff27d5b1..03150c82e 100644 --- a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py +++ b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py @@ -2,8 +2,10 @@ import copy import numbers +from collections import deque, Counter +import ctypes +import json import numpy as np -from collections import Counter from ...common._apply_operation import ( apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip) from ...common._registration import register_converter @@ -11,6 +13,14 @@ from ....proto import onnx_proto +def has_tqdm(): + try: + from tqdm import tqdm # noqa + return True + except ImportError: + return False + + def _translate_split_criterion(criterion): # If the criterion is true, LightGBM use the left child. # Otherwise, right child is selected. @@ -216,13 +226,152 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool, float(node['leaf_value']) * learning_rate) +def dump_booster_model(self, num_iteration=None, start_iteration=0, + importance_type='split', verbose=0): + """ + Dumps Booster to JSON format. + + Parameters + ---------- + self: booster + num_iteration : int or None, optional (default=None) + Index of the iteration that should be dumped. + If None, if the best iteration exists, it is dumped; otherwise, + all iterations are dumped. + If <= 0, all iterations are dumped. + start_iteration : int, optional (default=0) + Start index of the iteration that should be dumped. + importance_type : string, optional (default="split") + What type of feature importance should be dumped. + If "split", result contains numbers of times the feature is used in a model. + If "gain", result contains total gains of splits which use the feature. + verbose: dispays progress (usefull for big trees) + + Returns + ------- + json_repr : dict + JSON format of Booster. + + .. note:: + This function is inspired from + the *lightgbm* (`dump_model + `_. + It creates intermediate structure to speed up the conversion + into ONNX of such model. The function overwrites the + `json.load` to fastly extract nodes. + """ + if getattr(self, 'is_mock', False): + return self.dump_model(), None + from lightgbm.basic import ( + _LIB, FEATURE_IMPORTANCE_TYPE_MAPPER, _safe_call, + json_default_with_numpy) + if num_iteration is None: + num_iteration = self.best_iteration + importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type] + buffer_len = 1 << 20 + tmp_out_len = ctypes.c_int64(0) + string_buffer = ctypes.create_string_buffer(buffer_len) + ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) + if verbose >= 2: + print("[dump_booster_model] call CAPI: LGBM_BoosterDumpModel") + _safe_call(_LIB.LGBM_BoosterDumpModel( + self.handle, + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + ctypes.c_int(importance_type_int), + ctypes.c_int64(buffer_len), + ctypes.byref(tmp_out_len), + ptr_string_buffer)) + actual_len = tmp_out_len.value + # if buffer length is not long enough, reallocate a buffer + if actual_len > buffer_len: + string_buffer = ctypes.create_string_buffer(actual_len) + ptr_string_buffer = ctypes.c_char_p( + *[ctypes.addressof(string_buffer)]) + _safe_call(_LIB.LGBM_BoosterDumpModel( + self.handle, + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + ctypes.c_int(importance_type_int), + ctypes.c_int64(actual_len), + ctypes.byref(tmp_out_len), + ptr_string_buffer)) + + class Hook(json.JSONDecoder): + """ + Keep track of the progress, stores a copy of all objects with + a decision into a different container in order to walk through + all nodes in a much faster way than going through the architecture. + """ + def __init__(self, *args, info=None, n_trees=None, verbose=0, + **kwargs): + json.JSONDecoder.__init__( + self, object_hook=self.hook, *args, **kwargs) + self.nodes = [] + self.buffer = [] + self.info = info + self.n_trees = n_trees + self.verbose = verbose + self.stored = 0 + if verbose >= 2 and n_trees is not None and has_tqdm(): + from tqdm import tqdm + self.loop = tqdm(total=n_trees) + self.loop.set_description("dump_booster") + else: + self.loop = None + + def hook(self, obj): + """ + Hook called everytime a JSON object is created. + Keep track of the progress, stores a copy of all objects with + a decision into a different container. + """ + # Every obj goes through this function from the leaves to the root. + if 'tree_info' in obj: + self.info['decision_nodes'] = self.nodes + if self.n_trees is not None and len(self.nodes) != self.n_trees: + raise RuntimeError( + "Unexpected number of trees %d (expecting %d)." % ( + len(self.nodes), self.n_trees)) + self.nodes = [] + if self.loop is not None: + self.loop.close() + if 'tree_structure' in obj: + self.nodes.append(self.buffer) + if self.loop is not None: + self.loop.update(len(self.nodes)) + if len(self.nodes) % 10 == 0: + self.loop.set_description( + "dump_booster: %d/%d trees, %d nodes" % ( + len(self.nodes), self.n_trees, self.stored)) + self.buffer = [] + if "decision_type" in obj: + self.buffer.append(obj) + self.stored += 1 + return obj + + if verbose >= 2: + print("[dump_booster_model] to_json") + info = {} + ret = json.loads(string_buffer.value.decode('utf-8'), cls=Hook, + info=info, n_trees=self.num_trees(), verbose=verbose) + ret['pandas_categorical'] = json.loads( + json.dumps(self.pandas_categorical, + default=json_default_with_numpy)) + if verbose >= 2: + print("[dump_booster_model] end.") + return ret, info + + def convert_lightgbm(scope, operator, container): """ Converters for *lightgbm*. """ + verbose = getattr(container, 'verbose', 0) gbm_model = operator.raw_operator - gbm_text = gbm_model.booster_.dump_model() - modify_tree_for_rule_in_set(gbm_text, use_float=True) + gbm_text, info = dump_booster_model(gbm_model.booster_, verbose=verbose) + modify_tree_for_rule_in_set(gbm_text, use_float=True, verbose=verbose, info=info) attrs = get_default_tree_classifier_attribute_pairs() attrs['name'] = operator.full_name @@ -417,32 +566,49 @@ def convert_lightgbm(scope, operator, container): name=scope.get_unique_operator_name('Identity')) -def modify_tree_for_rule_in_set(gbm, use_float=False): +def modify_tree_for_rule_in_set(gbm, use_float=False, verbose=0, count=0, # pylint: disable=R1710 + info=None): """ LightGBM produces sometimes a tree with a node set to use rule ``==`` to a set of values (= in set), the values are separated by ``||``. This function unfold theses nodes. + + :param gbm: a tree coming from lightgbm dump + :param use_float: use float otherwise int first + then float if it does not work + :param verbose: verbosity, use *tqdm* to show progress + :param count: number of nodes already changed (origin) before this call + :param info: addition information to speed up this search + :return: number of changed nodes (include *count*) """ if 'tree_info' in gbm: - for tree in gbm['tree_info']: - modify_tree_for_rule_in_set(tree, use_float=use_float) - return + if info is not None: + dec_nodes = info['decision_nodes'] + else: + dec_nodes = None + if verbose >= 2 and has_tqdm(): + from tqdm import tqdm + loop = tqdm(gbm['tree_info']) + for i, tree in enumerate(loop): + loop.set_description("rules tree %d c=%d" % (i, count)) + count = modify_tree_for_rule_in_set( + tree, use_float=use_float, count=count, + info=None if dec_nodes is None else dec_nodes[i]) + else: + for i, tree in enumerate(gbm['tree_info']): + count = modify_tree_for_rule_in_set( + tree, use_float=use_float, count=count, + info=None if dec_nodes is None else dec_nodes[i]) + return count if 'tree_structure' in gbm: - modify_tree_for_rule_in_set(gbm['tree_structure'], use_float=use_float) - return + return modify_tree_for_rule_in_set( + gbm['tree_structure'], use_float=use_float, count=count, + info=info) if 'decision_type' not in gbm: - return - - def recursive_call(this): - if 'left_child' in this: - modify_tree_for_rule_in_set( - this['left_child'], use_float=use_float) - if 'right_child' in this: - modify_tree_for_rule_in_set( - this['right_child'], use_float=use_float) + return count def str2number(val): if use_float: @@ -450,29 +616,88 @@ def str2number(val): else: try: return int(val) - except ValueError: + except ValueError: # pragma: no cover return float(val) - dec = gbm['decision_type'] - if dec != '==': - return recursive_call(gbm) + if info is None: + + def recursive_call(this, c): + if 'left_child' in this: + c = process_node(this['left_child'], count=c) + if 'right_child' in this: + c = process_node(this['right_child'], count=c) + return c + + def process_node(node, count): + if 'decision_type' not in node: + return count + if node['decision_type'] != '==': + return recursive_call(node, count) + th = node['threshold'] + if not isinstance(th, str): + return recursive_call(node, count) + pos = th.find('||') + if pos == -1: + return recursive_call(node, count) + th1 = str2number(th[:pos]) + + def doit(): + rest = th[pos + 2:] + if '||' not in rest: + rest = str2number(rest) + + node['threshold'] = th1 + new_node = node.copy() + node['right_child'] = new_node + new_node['threshold'] = rest + + doit() + return recursive_call(node, count + 1) + + return process_node(gbm, count) + + # when info is used + + def split_node(node, th, pos): + th1 = str2number(th[:pos]) + + rest = th[pos + 2:] + if '||' not in rest: + rest = str2number(rest) + app = False + else: + app = True + + node['threshold'] = th1 + new_node = node.copy() + node['right_child'] = new_node + new_node['threshold'] = rest + return new_node, app + + stack = deque(info) + while len(stack) > 0: + node = stack.pop() + + if 'decision_type' not in node: + continue # leave + + if node['decision_type'] != '==': + continue - th = gbm['threshold'] - if not isinstance(th, str) or '||' not in th: - return recursive_call(gbm) + th = node['threshold'] + if not isinstance(th, str): + continue - pos = th.index('||') - th1 = str2number(th[:pos]) + pos = th.find('||') + if pos == -1: + continue - rest = th[pos + 2:] - if '||' not in rest: - rest = str2number(rest) + new_node, app = split_node(node, th, pos) + count += 1 + if app: + stack.append(new_node) - gbm['threshold'] = th1 - new_node = gbm.copy() - gbm['right_child'] = new_node - new_node['threshold'] = rest - return recursive_call(gbm) + return count def convert_lgbm_zipmap(scope, operator, container): diff --git a/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py b/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py index e5dc40547..bcbfb1912 100644 --- a/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py +++ b/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py @@ -6,6 +6,7 @@ import lightgbm import numpy from numpy.testing import assert_almost_equal +from onnx.defs import onnx_opset_version from lightgbm import LGBMClassifier, LGBMRegressor import onnxruntime from onnxmltools.convert.common.utils import hummingbird_installed @@ -16,6 +17,8 @@ from onnxmltools.utils import dump_single_regression from onnxmltools.utils.tests_helper import convert_model +TARGET_OPSET = min(13, onnx_opset_version()) + class TestLightGbmTreeEnsembleModels(unittest.TestCase): @@ -31,7 +34,8 @@ def test_lightgbm_classifier_zipmap(self): model = LGBMClassifier(n_estimators=3, min_child_samples=1) model.fit(X, y) onx = convert_model( - model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))]) + model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET) assert "zipmap" in str(onx).lower() def test_lightgbm_classifier_nozipmap(self): @@ -42,7 +46,7 @@ def test_lightgbm_classifier_nozipmap(self): model.fit(X, y) onx = convert_model( model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))], - zipmap=False) + zipmap=False, target_opset=TARGET_OPSET) assert "zipmap" not in str(onx).lower() onxs = onx[0].SerializeToString() try: @@ -99,7 +103,8 @@ def test_lightgbm_booster_classifier(self): 'n_estimators': 3, 'min_child_samples': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based classifier', - [('input', FloatTensorType([None, 2]))]) + [('input', FloatTensorType([None, 2]))], + target_opset=TARGET_OPSET) dump_data_and_model(X, model, model_onnx, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", basename=prefix + "BoosterBin" + model.__class__.__name__) @@ -114,7 +119,7 @@ def test_lightgbm_booster_classifier_nozipmap(self): data) model_onnx, prefix = convert_model(model, 'tree-based classifier', [('input', FloatTensorType([None, 2]))], - zipmap=False) + zipmap=False, target_opset=TARGET_OPSET) assert "zipmap" not in str(model_onnx).lower() dump_data_and_model(X, model, model_onnx, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", @@ -129,7 +134,8 @@ def test_lightgbm_booster_classifier_zipmap(self): 'n_estimators': 3, 'min_child_samples': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based classifier', - [('input', FloatTensorType([None, 2]))]) + [('input', FloatTensorType([None, 2]))], + target_opset=TARGET_OPSET) assert "zipmap" in str(model_onnx).lower() dump_data_and_model(X, model, model_onnx, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", @@ -246,13 +252,13 @@ def test_lightgbm_booster_regressor(self): def _test_lgbm(self, X, model, extra_config={}): # Create ONNX-ML model onnx_ml_model = convert_model( - model, 'lgbm-onnxml', [("input", FloatTensorType([X.shape[0], X.shape[1]]))] - )[0] + model, 'lgbm-onnxml', [("input", FloatTensorType([X.shape[0], X.shape[1]]))], + target_opset=TARGET_OPSET)[0] # Create ONNX model onnx_model = convert_model( - model, 'lgbm-onnx', [("input", FloatTensorType([X.shape[0], X.shape[1]]))], without_onnx_ml=True - )[0] + model, 'lgbm-onnx', [("input", FloatTensorType([X.shape[0], X.shape[1]]))], without_onnx_ml=True, + target_opset=TARGET_OPSET)[0] try: from onnxruntime import InferenceSession