diff --git a/lightgbm_ray/examples/simple_tune.py b/lightgbm_ray/examples/simple_tune.py index b5d5ced..1eee3a1 100644 --- a/lightgbm_ray/examples/simple_tune.py +++ b/lightgbm_ray/examples/simple_tune.py @@ -70,7 +70,7 @@ def main(cpus_per_actor, num_actors, num_samples): # Load the best model checkpoint. best_bst = lightgbm_ray.tune.load_model( - os.path.join(analysis.best_logdir, "tuned.lgbm") + os.path.join(analysis.best_trial.local_path, "tuned.lgbm") ) best_bst.save_model("best_model.lgbm") diff --git a/lightgbm_ray/main.py b/lightgbm_ray/main.py index 5d24725..3aa64d8 100644 --- a/lightgbm_ray/main.py +++ b/lightgbm_ray/main.py @@ -253,7 +253,7 @@ def _save_internal_checkpoint_callback() -> Callable: def _callback(env: CallbackEnv) -> None: if not is_rank_0: return - if ( + if this.checkpoint_frequency > 0 and ( env.iteration == env.end_iteration - 1 or env.iteration % this.checkpoint_frequency == 0 ): diff --git a/lightgbm_ray/tests/test_tune.py b/lightgbm_ray/tests/test_tune.py index 2f44068..9e54622 100644 --- a/lightgbm_ray/tests/test_tune.py +++ b/lightgbm_ray/tests/test_tune.py @@ -144,8 +144,13 @@ def testReplaceTuneCheckpoints(self): replaced = in_dict["callbacks"][0] self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback)) - self.assertSequenceEqual(replaced._report._metrics, ["met"]) - self.assertEqual(replaced._checkpoint._filename, "test") + + if getattr(replaced, "_report", None): + self.assertSequenceEqual(replaced._report._metrics, ["met"]) + self.assertEqual(replaced._checkpoint._filename, "test") + else: + self.assertSequenceEqual(replaced._metrics, ["met"]) + self.assertEqual(replaced._filename, "test") def testEndToEndCheckpointing(self): ray.init(num_cpus=4) diff --git a/lightgbm_ray/tune.py b/lightgbm_ray/tune.py index f4b4df2..0ee6e56 100644 --- a/lightgbm_ray/tune.py +++ b/lightgbm_ray/tune.py @@ -5,12 +5,13 @@ import ray from lightgbm.basic import Booster from lightgbm.callback import CallbackEnv +from ray.train._internal.session import get_session from ray.util.annotations import PublicAPI from xgboost_ray.session import put_queue from xgboost_ray.util import force_on_current_node try: - from ray import tune + from ray import train, tune from ray.tune import is_session_enabled from ray.tune.integration.lightgbm import ( TuneReportCallback as OrigTuneReportCallback, @@ -49,49 +50,68 @@ def is_rank_0(self, val: bool): if TUNE_INSTALLED: - - class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback): - def __call__(self, env: CallbackEnv) -> None: - if not self.is_rank_0: - return - eval_result = self._get_eval_result(env) - report_dict = self._get_report_dict(eval_result) - put_queue(lambda: tune.report(**report_dict)) - - class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback): - def __call__(self, env: CallbackEnv) -> None: - if not self.is_rank_0: - return - put_queue( - lambda: self._create_checkpoint( - env.model, env.iteration, self._filename, self._frequency + if not hasattr(train, "report"): + + class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback): + def __call__(self, env: CallbackEnv) -> None: + if not self.is_rank_0: + return + eval_result = self._get_eval_result(env) + report_dict = self._get_report_dict(eval_result) + put_queue(lambda: tune.report(**report_dict)) + + class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback): + def __call__(self, env: CallbackEnv) -> None: + if not self.is_rank_0: + return + put_queue( + lambda: self._create_checkpoint( + env.model, env.iteration, self._filename, self._frequency + ) ) - ) - - class TuneReportCheckpointCallback( - _TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback - ): - _checkpoint_callback_cls = _TuneCheckpointCallback - _report_callback_cls = TuneReportCallback - - @property - def is_rank_0(self) -> bool: - try: - return self._is_rank_0 - except AttributeError: - return False - - @is_rank_0.setter - def is_rank_0(self, val: bool): - self._is_rank_0 = val - if hasattr(self, "_checkpoint"): - self._checkpoint.is_rank_0 = val - if hasattr(self, "_report"): - self._report.is_rank_0 = val + + class TuneReportCheckpointCallback( + _TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback + ): + _checkpoint_callback_cls = _TuneCheckpointCallback + _report_callback_cls = TuneReportCallback + + @property + def is_rank_0(self) -> bool: + try: + return self._is_rank_0 + except AttributeError: + return False + + @is_rank_0.setter + def is_rank_0(self, val: bool): + self._is_rank_0 = val + if hasattr(self, "_checkpoint"): + self._checkpoint.is_rank_0 = val + if hasattr(self, "_report"): + self._report.is_rank_0 = val + + else: + + class TuneReportCheckpointCallback( + _TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback + ): + def __call__(self, env: CallbackEnv): + if self.is_rank_0: + put_queue( + lambda: super(TuneReportCheckpointCallback, self).__call__( + env=env + ) + ) + + class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback): + def __call__(self, env: CallbackEnv): + if self.is_rank_0: + put_queue(lambda: super(TuneReportCallback, self).__call__(env=env)) def _try_add_tune_callback(kwargs: Dict): - if TUNE_INSTALLED and is_session_enabled(): + if TUNE_INSTALLED and is_session_enabled() or get_session(): callbacks = kwargs.get("callbacks", []) or [] new_callbacks = [] has_tune_callback = False @@ -117,10 +137,19 @@ def _try_add_tune_callback(kwargs: Dict): ) has_tune_callback = True elif isinstance(cb, OrigTuneReportCheckpointCallback): + if getattr(cb, "_report", None): + orig_metrics = cb._report._metrics + orig_filename = cb._checkpoint._filename + orig_frequency = cb._checkpoint._frequency + else: + orig_metrics = cb._metrics + orig_filename = cb._filename + orig_frequency = cb._frequency + replace_cb = TuneReportCheckpointCallback( - metrics=cb._report._metrics, - filename=cb._checkpoint._filename, - frequency=cb._checkpoint._frequency, + metrics=orig_metrics, + filename=orig_filename, + frequency=orig_frequency, ) new_callbacks.append(replace_cb) logging.warning(