From f251454f9a438755891e2f281cf0f0c405e01ae1 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Tue, 12 Nov 2024 11:38:40 +0800 Subject: [PATCH] [pyspark] Support collective.Conf in spark --- python-package/xgboost/spark/core.py | 79 ++++++++++--------- python-package/xgboost/spark/estimator.py | 10 +++ python-package/xgboost/spark/utils.py | 7 +- .../test_with_spark/test_spark_local.py | 36 +++++++-- 4 files changed, 88 insertions(+), 44 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 3d5618d5d8f4..43f9f8e88fc0 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -8,6 +8,7 @@ import logging import os from collections import namedtuple +from dataclasses import asdict from typing import ( Any, Callable, @@ -67,6 +68,7 @@ from xgboost.training import train as worker_train from .._typing import ArrayLike +from ..collective import Config from .data import ( _read_csr_matrix_from_unwrapped_spark_vec, alias, @@ -123,8 +125,7 @@ "pred_contrib_col", "use_gpu", "launch_tracker_on_driver", - "tracker_host_ip", - "tracker_port", + "tracker", ] _non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"] @@ -257,21 +258,20 @@ class _SparkXGBParams( "launched on the driver side; otherwise, it will be launched on the executor side.", TypeConverters.toBoolean, ) - tracker_host_ip = Param( + tracker = Param( Params._dummy(), - "tracker_host_ip", - "A string variable. The tracker host IP address. To set tracker host ip, you need to " - "enable launch_tracker_on_driver to be true first", - TypeConverters.toString, - ) - tracker_port = Param( - Params._dummy(), - "tracker_port", - "A string variable. The port number tracker listens on. To set tracker host port, you need " + "tracker", + "xgboost.collective.Config. The communicator configuration, you need " "to enable launch_tracker_on_driver first", - TypeConverters.toInt, + TypeConverters.identity, ) + def set_tracker(self, value: Config) -> "_SparkXGBParams": + """Set communicator configuration""" + assert isinstance(value, Config) + self.set(self.tracker, value) + return self + def set_device(self, value: str) -> "_SparkXGBParams": """Set device, optional value: cpu, cuda, gpu""" _check_distributed_params({"device": value}) @@ -621,7 +621,6 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]: ("enable_sparse_data_optim", "has_validation_col", "features_cols_names"), ) - _MODEL_CHUNK_SIZE = 4096 * 1024 @@ -1030,30 +1029,26 @@ def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]: launch_tracker_on_driver = self.getOrDefault(self.launch_tracker_on_driver) rabit_args = {} if launch_tracker_on_driver: - tracker_host_ip: Optional[str] = None - if self.isDefined(self.tracker_host_ip): - tracker_host_ip = self.getOrDefault(self.tracker_host_ip) - else: - tracker_host_ip = ( + tracker = Config() + if self.isDefined(self.tracker): + tracker = self.getOrDefault(self.tracker) + assert isinstance(tracker, Config) + + if tracker.tracker_host_ip is None: + tracker.tracker_host_ip = ( _get_spark_session().sparkContext.getConf().get("spark.driver.host") ) - assert tracker_host_ip is not None - tracker_port = 0 - if self.isDefined(self.tracker_port): - tracker_port = self.getOrDefault(self.tracker_port) - num_workers = self.getOrDefault(self.num_workers) - rabit_args.update( - _get_rabit_args(tracker_host_ip, num_workers, tracker_port) - ) + rabit_args.update(_get_rabit_args(tracker, num_workers)) else: - if self.isDefined(self.tracker_host_ip) or self.isDefined( - self.tracker_port - ): - raise ValueError( - "You must enable launch_tracker_on_driver to use " - "tracker_host_ip and tracker_port" - ) + if self.isDefined(self.tracker): + tracker = self.getOrDefault(self.tracker) + assert isinstance(tracker, Config) + if tracker.tracker_host_ip is not None: + raise ValueError( + f"You must enable launch_tracker_on_driver to use " + f"tracker host: {tracker.tracker_host_ip}" + ) return launch_tracker_on_driver, rabit_args def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": @@ -1075,6 +1070,9 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": num_workers = self.getOrDefault(self.num_workers) launch_tracker_on_driver, rabit_args = self._get_tracker_args() + tracker: Optional[Config] = ( + self.getOrDefault(self.tracker) if self.isSet(self.tracker) else None + ) log_level = get_logger_level(_LOG_TAG) @@ -1114,11 +1112,12 @@ def _train_booster( if use_qdm and (booster_params.get("max_bin", None) is not None): dmatrix_kwargs["max_bin"] = booster_params["max_bin"] - _rabit_args = rabit_args if context.partitionId() == 0: if not launch_tracker_on_driver: - _rabit_args = _get_rabit_args(_get_host_ip(context), num_workers) + _tracker = tracker if tracker is not None else Config() + _tracker.tracker_host_ip = _get_host_ip(context) + _rabit_args = _get_rabit_args(_tracker, num_workers) get_logger(_LOG_TAG, log_level).info(msg) worker_message: Dict[str, Any] = { @@ -1629,7 +1628,7 @@ def saveMetadata( xgboost.spark._SparkXGBModel. """ instance._validate_params() - skipParams = ["callbacks", "xgb_model"] + skipParams = ["callbacks", "xgb_model", "tracker"] jsonParams = {} for p, v in instance._paramMap.items(): # pylint: disable=protected-access if p.name not in skipParams: @@ -1650,6 +1649,10 @@ def saveMetadata( init_booster = instance.getOrDefault("xgb_model") if init_booster is not None: extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH + + tracker_conf: Config = instance.getOrDefault("tracker") + if tracker_conf is not None: + extraMetadata["tracker_conf"] = asdict(tracker_conf) DefaultParamsWriter.saveMetadata( instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams ) @@ -1691,6 +1694,8 @@ def loadMetadataAndInstance( f"Fails to load the callbacks param due to {e}. Please set the " "callbacks param manually for the loaded estimator." ) + if "tracker_conf" in metadata: + pyspark_xgb.set_tracker(Config(**metadata["tracker_conf"])) if "init_booster" in metadata: load_path = os.path.join(path, metadata["init_booster"]) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 8a4840846ac2..b198f821ef7d 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -14,6 +14,7 @@ from xgboost import XGBClassifier, XGBRanker, XGBRegressor +from ..collective import Config from .core import ( # type: ignore _ClassificationModel, _SparkXGBEstimator, @@ -164,6 +165,8 @@ class SparkXGBRegressor(_SparkXGBEstimator): launch_tracker_on_driver: Boolean value to indicate whether the tracker should be launched on the driver side or the executor side. + tracker: + The communicator configuration. See :py:class:`~xgboost.collective.Config` kwargs: A dictionary of xgboost parameters, please refer to @@ -219,6 +222,7 @@ def __init__( # pylint:disable=too-many-arguments repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, launch_tracker_on_driver: bool = True, + tracker: Optional[Config] = None, **kwargs: Any, ) -> None: super().__init__() @@ -348,6 +352,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction launch_tracker_on_driver: Boolean value to indicate whether the tracker should be launched on the driver side or the executor side. + tracker: + The communicator configuration. See :py:class:`~xgboost.collective.Config` kwargs: A dictionary of xgboost parameters, please refer to @@ -403,6 +409,7 @@ def __init__( # pylint:disable=too-many-arguments repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, launch_tracker_on_driver: bool = True, + tracker: Optional[Config] = None, **kwargs: Any, ) -> None: super().__init__() @@ -535,6 +542,8 @@ class SparkXGBRanker(_SparkXGBEstimator): launch_tracker_on_driver: Boolean value to indicate whether the tracker should be launched on the driver side or the executor side. + tracker: + The communicator configuration. See :py:class:`~xgboost.collective.Config` kwargs: A dictionary of xgboost parameters, please refer to @@ -596,6 +605,7 @@ def __init__( # pylint:disable=too-many-arguments repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, launch_tracker_on_driver: bool = True, + tracker: Optional[Config] = None, **kwargs: Any, ) -> None: super().__init__() diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index a8a2314272a6..c96ec284abe3 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -15,6 +15,7 @@ from pyspark.sql.session import SparkSession from ..collective import CommunicatorContext as CCtx +from ..collective import Config from ..collective import _Args as CollArgs from ..collective import _ArgVals as CollArgsVals from ..core import Booster @@ -66,9 +67,11 @@ def _start_tracker(host: str, n_workers: int, port: int = 0) -> CollArgs: return args -def _get_rabit_args(host: str, n_workers: int, port: int = 0) -> CollArgs: +def _get_rabit_args(conf: Config, n_workers: int) -> CollArgs: """Get rabit context arguments to send to each worker.""" - env = _start_tracker(host, n_workers, port) + assert conf.tracker_host_ip is not None + port = 0 if conf.tracker_port is None else conf.tracker_port + env = _start_tracker(conf.tracker_host_ip, n_workers, port) return env diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 8d64dc205ef0..bc3a3e02631a 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -12,6 +12,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.collective import Config from xgboost.spark.data import pred_contribs pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())] @@ -1650,16 +1651,14 @@ def test_unsupported_params(self): def test_tracker(self): classifier = SparkXGBClassifier( launch_tracker_on_driver=True, - tracker_host_ip="192.168.1.32", - tracker_port=59981, + tracker=Config(tracker_host_ip="192.168.1.32", tracker_port=59981), ) with pytest.raises(Exception, match="Failed to bind socket"): classifier._get_tracker_args() classifier = SparkXGBClassifier( launch_tracker_on_driver=False, - tracker_host_ip="127.0.0.1", - tracker_port=58892, + tracker=Config(tracker_host_ip="127.0.0.1", tracker_port=58892), ) with pytest.raises( ValueError, match="You must enable launch_tracker_on_driver" @@ -1668,13 +1667,40 @@ def test_tracker(self): classifier = SparkXGBClassifier( launch_tracker_on_driver=True, - tracker_host_ip="127.0.0.1", + tracker=Config(tracker_host_ip="127.0.0.1", tracker_port=58893), num_workers=2, ) launch_tracker_on_driver, rabit_envs = classifier._get_tracker_args() assert launch_tracker_on_driver is True assert rabit_envs["n_workers"] == 2 assert rabit_envs["dmlc_tracker_uri"] == "127.0.0.1" + assert rabit_envs["dmlc_tracker_port"] == 58893 + + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + classifier = SparkXGBClassifier( + launch_tracker_on_driver=True, + tracker=Config(tracker_host_ip="127.0.0.1", tracker_port=58894), + num_workers=1, + n_estimators=1, + ) + + def check_tracker(tracker: Config) -> None: + assert tracker.tracker_host_ip == "127.0.0.1" + assert tracker.tracker_port == 58894 + + check_tracker(classifier.getOrDefault(classifier.tracker)) + classifier.write().overwrite().save(path) + + loaded_classifier = SparkXGBClassifier.load(path) + check_tracker(loaded_classifier.getOrDefault(classifier.tracker)) + + model = classifier.fit(self.cls_df_sparse_train) + check_tracker(model.getOrDefault(classifier.tracker)) + + model.write().overwrite().save(path) + loaded_model = SparkXGBClassifierModel.load(path) + check_tracker(loaded_model.getOrDefault(classifier.tracker)) LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1"))