Skip to content

Commit

Permalink
[pyspark] Support collective.Conf in spark
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Nov 18, 2024
1 parent 1766d48 commit f251454
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 44 deletions.
79 changes: 42 additions & 37 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import os
from collections import namedtuple
from dataclasses import asdict
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -67,6 +68,7 @@
from xgboost.training import train as worker_train

from .._typing import ArrayLike
from ..collective import Config
from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
alias,
Expand Down Expand Up @@ -123,8 +125,7 @@
"pred_contrib_col",
"use_gpu",
"launch_tracker_on_driver",
"tracker_host_ip",
"tracker_port",
"tracker",
]

_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
Expand Down Expand Up @@ -257,21 +258,20 @@ class _SparkXGBParams(
"launched on the driver side; otherwise, it will be launched on the executor side.",
TypeConverters.toBoolean,
)
tracker_host_ip = Param(
tracker = Param(
Params._dummy(),
"tracker_host_ip",
"A string variable. The tracker host IP address. To set tracker host ip, you need to "
"enable launch_tracker_on_driver to be true first",
TypeConverters.toString,
)
tracker_port = Param(
Params._dummy(),
"tracker_port",
"A string variable. The port number tracker listens on. To set tracker host port, you need "
"tracker",
"xgboost.collective.Config. The communicator configuration, you need "
"to enable launch_tracker_on_driver first",
TypeConverters.toInt,
TypeConverters.identity,
)

def set_tracker(self, value: Config) -> "_SparkXGBParams":
"""Set communicator configuration"""
assert isinstance(value, Config)
self.set(self.tracker, value)
return self

def set_device(self, value: str) -> "_SparkXGBParams":
"""Set device, optional value: cpu, cuda, gpu"""
_check_distributed_params({"device": value})
Expand Down Expand Up @@ -621,7 +621,6 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
("enable_sparse_data_optim", "has_validation_col", "features_cols_names"),
)


_MODEL_CHUNK_SIZE = 4096 * 1024


Expand Down Expand Up @@ -1030,30 +1029,26 @@ def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]:
launch_tracker_on_driver = self.getOrDefault(self.launch_tracker_on_driver)
rabit_args = {}
if launch_tracker_on_driver:
tracker_host_ip: Optional[str] = None
if self.isDefined(self.tracker_host_ip):
tracker_host_ip = self.getOrDefault(self.tracker_host_ip)
else:
tracker_host_ip = (
tracker = Config()
if self.isDefined(self.tracker):
tracker = self.getOrDefault(self.tracker)
assert isinstance(tracker, Config)

if tracker.tracker_host_ip is None:
tracker.tracker_host_ip = (
_get_spark_session().sparkContext.getConf().get("spark.driver.host")
)
assert tracker_host_ip is not None
tracker_port = 0
if self.isDefined(self.tracker_port):
tracker_port = self.getOrDefault(self.tracker_port)

num_workers = self.getOrDefault(self.num_workers)
rabit_args.update(
_get_rabit_args(tracker_host_ip, num_workers, tracker_port)
)
rabit_args.update(_get_rabit_args(tracker, num_workers))
else:
if self.isDefined(self.tracker_host_ip) or self.isDefined(
self.tracker_port
):
raise ValueError(
"You must enable launch_tracker_on_driver to use "
"tracker_host_ip and tracker_port"
)
if self.isDefined(self.tracker):
tracker = self.getOrDefault(self.tracker)
assert isinstance(tracker, Config)
if tracker.tracker_host_ip is not None:
raise ValueError(
f"You must enable launch_tracker_on_driver to use "
f"tracker host: {tracker.tracker_host_ip}"
)
return launch_tracker_on_driver, rabit_args

def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
Expand All @@ -1075,6 +1070,9 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
num_workers = self.getOrDefault(self.num_workers)

launch_tracker_on_driver, rabit_args = self._get_tracker_args()
tracker: Optional[Config] = (
self.getOrDefault(self.tracker) if self.isSet(self.tracker) else None
)

log_level = get_logger_level(_LOG_TAG)

Expand Down Expand Up @@ -1114,11 +1112,12 @@ def _train_booster(

if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

_rabit_args = rabit_args
if context.partitionId() == 0:
if not launch_tracker_on_driver:
_rabit_args = _get_rabit_args(_get_host_ip(context), num_workers)
_tracker = tracker if tracker is not None else Config()
_tracker.tracker_host_ip = _get_host_ip(context)
_rabit_args = _get_rabit_args(_tracker, num_workers)
get_logger(_LOG_TAG, log_level).info(msg)

worker_message: Dict[str, Any] = {
Expand Down Expand Up @@ -1629,7 +1628,7 @@ def saveMetadata(
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
skipParams = ["callbacks", "xgb_model", "tracker"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
Expand All @@ -1650,6 +1649,10 @@ def saveMetadata(
init_booster = instance.getOrDefault("xgb_model")
if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH

tracker_conf: Config = instance.getOrDefault("tracker")
if tracker_conf is not None:
extraMetadata["tracker_conf"] = asdict(tracker_conf)
DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
Expand Down Expand Up @@ -1691,6 +1694,8 @@ def loadMetadataAndInstance(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)
if "tracker_conf" in metadata:
pyspark_xgb.set_tracker(Config(**metadata["tracker_conf"]))

if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
Expand Down
10 changes: 10 additions & 0 deletions python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from xgboost import XGBClassifier, XGBRanker, XGBRegressor

from ..collective import Config
from .core import ( # type: ignore
_ClassificationModel,
_SparkXGBEstimator,
Expand Down Expand Up @@ -164,6 +165,8 @@ class SparkXGBRegressor(_SparkXGBEstimator):
launch_tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
tracker:
The communicator configuration. See :py:class:`~xgboost.collective.Config`
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -219,6 +222,7 @@ def __init__( # pylint:disable=too-many-arguments
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
launch_tracker_on_driver: bool = True,
tracker: Optional[Config] = None,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down Expand Up @@ -348,6 +352,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
launch_tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
tracker:
The communicator configuration. See :py:class:`~xgboost.collective.Config`
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -403,6 +409,7 @@ def __init__( # pylint:disable=too-many-arguments
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
launch_tracker_on_driver: bool = True,
tracker: Optional[Config] = None,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down Expand Up @@ -535,6 +542,8 @@ class SparkXGBRanker(_SparkXGBEstimator):
launch_tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
tracker:
The communicator configuration. See :py:class:`~xgboost.collective.Config`
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -596,6 +605,7 @@ def __init__( # pylint:disable=too-many-arguments
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
launch_tracker_on_driver: bool = True,
tracker: Optional[Config] = None,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down
7 changes: 5 additions & 2 deletions python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyspark.sql.session import SparkSession

from ..collective import CommunicatorContext as CCtx
from ..collective import Config
from ..collective import _Args as CollArgs
from ..collective import _ArgVals as CollArgsVals
from ..core import Booster
Expand Down Expand Up @@ -66,9 +67,11 @@ def _start_tracker(host: str, n_workers: int, port: int = 0) -> CollArgs:
return args


def _get_rabit_args(host: str, n_workers: int, port: int = 0) -> CollArgs:
def _get_rabit_args(conf: Config, n_workers: int) -> CollArgs:
"""Get rabit context arguments to send to each worker."""
env = _start_tracker(host, n_workers, port)
assert conf.tracker_host_ip is not None
port = 0 if conf.tracker_port is None else conf.tracker_port
env = _start_tracker(conf.tracker_host_ip, n_workers, port)
return env


Expand Down
36 changes: 31 additions & 5 deletions tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import xgboost as xgb
from xgboost import testing as tm
from xgboost.collective import Config
from xgboost.spark.data import pred_contribs

pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())]
Expand Down Expand Up @@ -1650,16 +1651,14 @@ def test_unsupported_params(self):
def test_tracker(self):
classifier = SparkXGBClassifier(
launch_tracker_on_driver=True,
tracker_host_ip="192.168.1.32",
tracker_port=59981,
tracker=Config(tracker_host_ip="192.168.1.32", tracker_port=59981),
)
with pytest.raises(Exception, match="Failed to bind socket"):
classifier._get_tracker_args()

classifier = SparkXGBClassifier(
launch_tracker_on_driver=False,
tracker_host_ip="127.0.0.1",
tracker_port=58892,
tracker=Config(tracker_host_ip="127.0.0.1", tracker_port=58892),
)
with pytest.raises(
ValueError, match="You must enable launch_tracker_on_driver"
Expand All @@ -1668,13 +1667,40 @@ def test_tracker(self):

classifier = SparkXGBClassifier(
launch_tracker_on_driver=True,
tracker_host_ip="127.0.0.1",
tracker=Config(tracker_host_ip="127.0.0.1", tracker_port=58893),
num_workers=2,
)
launch_tracker_on_driver, rabit_envs = classifier._get_tracker_args()
assert launch_tracker_on_driver is True
assert rabit_envs["n_workers"] == 2
assert rabit_envs["dmlc_tracker_uri"] == "127.0.0.1"
assert rabit_envs["dmlc_tracker_port"] == 58893

with tempfile.TemporaryDirectory() as tmpdir:
path = "file:" + tmpdir
classifier = SparkXGBClassifier(
launch_tracker_on_driver=True,
tracker=Config(tracker_host_ip="127.0.0.1", tracker_port=58894),
num_workers=1,
n_estimators=1,
)

def check_tracker(tracker: Config) -> None:
assert tracker.tracker_host_ip == "127.0.0.1"
assert tracker.tracker_port == 58894

check_tracker(classifier.getOrDefault(classifier.tracker))
classifier.write().overwrite().save(path)

loaded_classifier = SparkXGBClassifier.load(path)
check_tracker(loaded_classifier.getOrDefault(classifier.tracker))

model = classifier.fit(self.cls_df_sparse_train)
check_tracker(model.getOrDefault(classifier.tracker))

model.write().overwrite().save(path)
loaded_model = SparkXGBClassifierModel.load(path)
check_tracker(loaded_model.getOrDefault(classifier.tracker))


LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1"))
Expand Down

0 comments on commit f251454

Please sign in to comment.