From 60fe69456c0dc06b4ef9b9974caa9e719d4ee3ee Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 7 Nov 2024 19:30:14 +0800 Subject: [PATCH] [bp][spark] Make xgboost spark support large model size (#10984) --------- Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 34 ++++++++++++------- .../test_with_spark/test_spark_local.py | 20 +++++++++++ 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 6700aeed8675..591144cd1c27 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -597,6 +597,9 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]: ) +_MODEL_CHUNK_SIZE = 4096 * 1024 + + class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): _input_kwargs: Dict[str, Any] @@ -1091,25 +1094,27 @@ def _train_booster( context.barrier() if context.partitionId() == 0: - yield pd.DataFrame( - data={ - "config": [booster.save_config()], - "booster": [booster.save_raw("json").decode("utf-8")], - } - ) + config = booster.save_config() + yield pd.DataFrame({"data": [config]}) + booster_json = booster.save_raw("json").decode("utf-8") + + for offset in range(0, len(booster_json), _MODEL_CHUNK_SIZE): + booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE] + yield pd.DataFrame({"data": [booster_chunk]}) def _run_job() -> Tuple[str, str]: rdd = ( dataset.mapInPandas( _train_booster, # type: ignore - schema="config string, booster string", + schema="data string", ) .rdd.barrier() .mapPartitions(lambda x: x) ) rdd_with_resource = self._try_stage_level_scheduling(rdd) - ret = rdd_with_resource.collect()[0] - return ret[0], ret[1] + ret = rdd_with_resource.collect() + data = [v[0] for v in ret] + return data[0], "".join(data[1:]) get_logger(_LOG_TAG).info( "Running xgboost-%s on %s workers with" @@ -1690,7 +1695,12 @@ def saveImpl(self, path: str) -> None: _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) model_save_path = os.path.join(path, "model") booster = xgb_model.get_booster().save_raw("json").decode("utf-8") - _get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile( + booster_chunks = [] + + for offset in range(0, len(booster), _MODEL_CHUNK_SIZE): + booster_chunks.append(booster[offset : offset + _MODEL_CHUNK_SIZE]) + + _get_spark_session().sparkContext.parallelize(booster_chunks, 1).saveAsTextFile( model_save_path ) @@ -1721,8 +1731,8 @@ def load(self, path: str) -> "_SparkXGBModel": ) model_load_path = os.path.join(path, "model") - ser_xgb_model = ( - _get_spark_session().sparkContext.textFile(model_load_path).collect()[0] + ser_xgb_model = "".join( + _get_spark_session().sparkContext.textFile(model_load_path).collect() ) def create_xgb_model() -> "XGBModel": diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index feb7b18bc035..f80ae0c670da 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -867,6 +867,26 @@ def test_regressor_model_pipeline_save_load(self, reg_data: RegData) -> None: ) assert_model_compatible(model.stages[0], tmpdir) + def test_with_small_model_chunk_size(self, reg_data: RegData, monkeypatch) -> None: + import xgboost.spark.core + + monkeypatch.setattr(xgboost.spark.core, "_MODEL_CHUNK_SIZE", 4) + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + regressor = SparkXGBRegressor(**reg_data.reg_params) + model = regressor.fit(reg_data.reg_df_train) + model.save(path) + loaded_model = SparkXGBRegressorModel.load(path) + assert model.uid == loaded_model.uid + for k, v in reg_data.reg_params.items(): + assert loaded_model.getOrDefault(k) == v + + pred_result = loaded_model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None: clf = SparkXGBClassifier(device="cuda", tree_method="exact") with pytest.raises(ValueError, match="not supported for distributed"):