Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Sep 4, 2023
1 parent d11e55a commit 1e97ae8
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,17 +1256,19 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
dev_ordinal = -1
if is_local:
if run_on_gpu and is_cupy_available():
total_gpus = cupy.cuda.runtime.getDeviceCount()
import cupy as cp # pylint: disable=import-error

total_gpus = cp.cuda.runtime.getDeviceCount()
if total_gpus > 0:
partition_id = context.partitionId()
from pyspark import TaskContext
partition_id = TaskContext.get().partitionId()
# For transform local mode, default the gpu_id to (partition id) % gpus.
dev_ordinal = partition_id % total_gpus
elif run_on_gpu:
from pyspark import TaskContext
dev_ordinal = _get_gpu_id(TaskContext.get())

if dev_ordinal >= 0:
print("------------------- ", TaskContext.get().partitionId())
device = "cuda:" + str(dev_ordinal)
get_logger("XGBoost-PySpark").info(
"Do the inference with device: %s", device
Expand All @@ -1278,7 +1280,7 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
)

def to_gpu(data: ArrayLike) -> ArrayLike:
"""Move the data to gpu"""
"""Move the data to gpu if possible"""
if dev_ordinal >= 0:
import cudf # pylint: disable=import-error
import cupy as cp # pylint: disable=import-error
Expand Down

0 comments on commit 1e97ae8

Please sign in to comment.