From 9221c9694b8818e04fedb66c57b0285b0c83266b Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 5 Aug 2024 12:37:22 -0700 Subject: [PATCH 01/13] support system allocated memory Signed-off-by: Rong Ou --- docs/site/configuration.md | 8 ++-- python/run_benchmark.sh | 5 ++- python/src/spark_rapids_ml/clustering.py | 9 ++++- python/src/spark_rapids_ml/core.py | 47 +++++++++++++++++++++++- python/src/spark_rapids_ml/umap.py | 24 +++++++++++- 5 files changed, 86 insertions(+), 7 deletions(-) diff --git a/docs/site/configuration.md b/docs/site/configuration.md index 7926d609..5abeb841 100644 --- a/docs/site/configuration.md +++ b/docs/site/configuration.md @@ -6,7 +6,9 @@ nav_order: 6 The following configurations can be supplied as Spark properties. -| Property name | Default | Meaning | -| :-------------- | :------ | :------- | -| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory| +| Property name | Default | Meaning | +|:-------------------------------|:--------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | +| spark.rapids.ml.sam.enabled | false | if set to true, enables System Allocated Memory (SAM) on [HMM](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/) or [ATS](https://developer.nvidia.com/blog/nvidia-grace-hopper-superchip-architecture-in-depth/) systems during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | +| spark.rapids.ml.sam.headroom | None | when using System Allocated Memory (SAM) and GPU memory is oversubscribed, we may need to reserve some GPU memory as headroom to allow other CUDA calls to function without running out memory | diff --git a/python/run_benchmark.sh b/python/run_benchmark.sh index 518728cf..9f1db683 100755 --- a/python/run_benchmark.sh +++ b/python/run_benchmark.sh @@ -98,7 +98,10 @@ cat < pd.DataFrame: devices=_CumlCommon._get_gpu_device(context, is_local), ) cp.cuda.set_allocator(rmm_cupy_allocator) + if cuda_system_mem_enabled: + import rmm + from rmm.allocators.cupy import rmm_cupy_allocator + + if cuda_system_mem_headroom is None: + mr = rmm.mr.SystemMemoryResource() + else: + mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) + rmm.mr.set_current_device_resource(mr) + cp.cuda.set_allocator(rmm_cupy_allocator) _CumlCommon._initialize_cuml_logging(cuml_verbose) @@ -738,7 +760,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: # experiments indicate it is faster to convert to numpy array and then to cupy array than directly # invoking cupy array on the list - if cuda_managed_mem_enabled: + if cuda_managed_mem_enabled or cuda_system_mem_enabled: features = ( cp.array(features) if use_sparse_array is False @@ -1356,6 +1378,18 @@ def _transform_evaluate_internal( if cuda_managed_mem_enabled: get_logger(self.__class__).info("CUDA managed memory enabled.") + cuda_system_mem_enabled = ( + _get_spark_session().conf.get("spark.rapids.ml.sam.enabled", "false") + == "true" + ) + if cuda_managed_mem_enabled and cuda_system_mem_enabled: + raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.") + if cuda_system_mem_enabled: + get_logger(self.__class__).info("CUDA system allocated memory enabled.") + cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) + if cuda_system_mem_headroom is not None: + get_logger(self.__class__).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") + def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: from pyspark import TaskContext @@ -1375,6 +1409,17 @@ def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ), ) cp.cuda.set_allocator(rmm_cupy_allocator) + if cuda_system_mem_enabled: + import cupy as cp + import rmm + from rmm.allocators.cupy import rmm_cupy_allocator + + if cuda_system_mem_headroom is None: + mr = rmm.mr.SystemMemoryResource() + else: + mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) + rmm.mr.set_current_device_resource(mr) + cp.cuda.set_allocator(rmm_cupy_allocator) # Construct the cuml counterpart object cuml_instance = construct_cuml_object_func() diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index ed9fabf1..68b27114 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -992,6 +992,18 @@ def _call_cuml_fit_func_dataframe( if cuda_managed_mem_enabled: get_logger(cls).info("CUDA managed memory enabled.") + cuda_system_mem_enabled = ( + _get_spark_session().conf.get("spark.rapids.ml.sam.enabled", "false") + == "true" + ) + if cuda_managed_mem_enabled and cuda_system_mem_enabled: + raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.") + if cuda_system_mem_enabled: + get_logger(cls).info("CUDA system allocated memory enabled.") + cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) + if cuda_system_mem_headroom is not None: + get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") + # parameters passed to subclass params: Dict[str, Any] = { param_alias.cuml_init: self.cuml_params, @@ -1021,6 +1033,16 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: rmm.reinitialize(managed_memory=True) cp.cuda.set_allocator(rmm_cupy_allocator) + if cuda_system_mem_enabled: + import rmm + from rmm.allocators.cupy import rmm_cupy_allocator + + if cuda_system_mem_headroom is None: + mr = rmm.mr.SystemMemoryResource() + else: + mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) + rmm.mr.set_current_device_resource(mr) + cp.cuda.set_allocator(rmm_cupy_allocator) _CumlCommon._initialize_cuml_logging(cuml_verbose) @@ -1042,7 +1064,7 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: features = np.array(list(pdf[alias.data]), order=array_order) # experiments indicate it is faster to convert to numpy array and then to cupy array than directly # invoking cupy array on the list - if cuda_managed_mem_enabled: + if cuda_managed_mem_enabled or cuda_system_mem_enabled: features = cp.array(features) label = pdf[alias.label] if alias.label in pdf.columns else None From 590fb38968187d0a1404a42ed9bdeb4fc8ee4966 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 5 Aug 2024 12:52:42 -0700 Subject: [PATCH 02/13] add timing info Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/core.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 3a7f473a..346984de 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -16,6 +16,7 @@ import json import os import threading +import time from abc import ABCMeta, abstractmethod from collections import namedtuple from typing import ( @@ -746,6 +747,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: logger.info("Loading data into python worker memory") inputs = [] sizes = [] + start_time = time.time() for pdf in pdf_iter: sizes.append(pdf.shape[0]) @@ -778,6 +780,8 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: "A python worker received no data. Please increase amount of data or use fewer workers." ) + logger.info(f"Data loaded into python worker memory in {time.time() - start_time} seconds") + logger.info("Initializing cuml context") with CumlContext( partition_id, num_workers, context, enable_nccl, require_ucx @@ -788,12 +792,14 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: params[param_alias.loop] = cc._loop logger.info("Invoking cuml fit") + start_time = time.time() # call the cuml fit function # *note*: cuml_fit_func may delete components of inputs to free # memory. do not rely on inputs after this call. result = cuml_fit_func(inputs, params) logger.info("Cuml fit complete") + logger.info(f"Cuml fit took {time.time() - start_time} seconds") if partially_collect == True: if enable_nccl: From 51d9248cc75c6d39681c2cd2c0841d0ed33e6583 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 5 Aug 2024 14:36:50 -0700 Subject: [PATCH 03/13] parse headroom size into integer Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/clustering.py | 3 +++ python/src/spark_rapids_ml/core.py | 3 +++ python/src/spark_rapids_ml/umap.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/python/src/spark_rapids_ml/clustering.py b/python/src/spark_rapids_ml/clustering.py index 0c6fa966..b446cf73 100644 --- a/python/src/spark_rapids_ml/clustering.py +++ b/python/src/spark_rapids_ml/clustering.py @@ -39,6 +39,7 @@ StructField, StructType, ) +from pyspark.util import _parse_memory from .core import ( CumlT, @@ -939,6 +940,8 @@ def _get_cuml_fit_func( if cuda_managed_mem_enabled and cuda_system_mem_enabled: raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) + if cuda_system_mem_headroom is not None: + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) idCol_bc = self.idCols_ raw_data_bc = self.raw_data_ diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 346984de..47eb9cde 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -69,6 +69,7 @@ Row, StructType, ) +from pyspark.util import _parse_memory from scipy.sparse import csr_matrix from .common.cuml_context import CumlContext @@ -675,6 +676,7 @@ def _call_cuml_fit_func( get_logger(cls).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") # parameters passed to subclass @@ -1394,6 +1396,7 @@ def _transform_evaluate_internal( get_logger(self.__class__).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) get_logger(self.__class__).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index 68b27114..b9819d9f 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -54,6 +54,7 @@ StructField, StructType, ) +from pyspark.util import _parse_memory from spark_rapids_ml.core import FitInputType, _CumlModel @@ -1002,6 +1003,7 @@ def _call_cuml_fit_func_dataframe( get_logger(cls).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") # parameters passed to subclass From ca5ed0862523334b73836ce2a49b5615db7bc96d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 5 Aug 2024 14:45:45 -0700 Subject: [PATCH 04/13] document cupy sam env var Signed-off-by: Rong Ou --- docs/site/configuration.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/site/configuration.md b/docs/site/configuration.md index 5abeb841..c3ddbd30 100644 --- a/docs/site/configuration.md +++ b/docs/site/configuration.md @@ -6,9 +6,10 @@ nav_order: 6 The following configurations can be supplied as Spark properties. -| Property name | Default | Meaning | -|:-------------------------------|:--------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | -| spark.rapids.ml.sam.enabled | false | if set to true, enables System Allocated Memory (SAM) on [HMM](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/) or [ATS](https://developer.nvidia.com/blog/nvidia-grace-hopper-superchip-architecture-in-depth/) systems during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | -| spark.rapids.ml.sam.headroom | None | when using System Allocated Memory (SAM) and GPU memory is oversubscribed, we may need to reserve some GPU memory as headroom to allow other CUDA calls to function without running out memory | +| Property name | Default | Meaning | +|:----------------------------------|:--------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | +| spark.rapids.ml.sam.enabled | false | if set to true, enables System Allocated Memory (SAM) on [HMM](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/) or [ATS](https://developer.nvidia.com/blog/nvidia-grace-hopper-superchip-architecture-in-depth/) systems during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | +| spark.rapids.ml.sam.headroom | None | when using System Allocated Memory (SAM) and GPU memory is oversubscribed, we may need to reserve some GPU memory as headroom to allow other CUDA calls to function without running out memory. Set a size appropriate for your application | +| spark.executorEnv.CUPY_ENABLE_SAM | 0 | if set to 1, enables System Allocated Memory (SAM) for CuPy operations. This enabled CuPy to work with SAM, and also avoid unnecessary memory coping | From d2641babad1a24a7209302e8ed8c41b0e9f9d816 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 5 Aug 2024 14:51:19 -0700 Subject: [PATCH 05/13] convert headroom from MiB to bytes Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/clustering.py | 2 +- python/src/spark_rapids_ml/core.py | 4 ++-- python/src/spark_rapids_ml/umap.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/src/spark_rapids_ml/clustering.py b/python/src/spark_rapids_ml/clustering.py index b446cf73..c9f12c8c 100644 --- a/python/src/spark_rapids_ml/clustering.py +++ b/python/src/spark_rapids_ml/clustering.py @@ -941,7 +941,7 @@ def _get_cuml_fit_func( raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: - cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 idCol_bc = self.idCols_ raw_data_bc = self.raw_data_ diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 47eb9cde..847831f3 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -676,7 +676,7 @@ def _call_cuml_fit_func( get_logger(cls).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: - cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") # parameters passed to subclass @@ -1396,7 +1396,7 @@ def _transform_evaluate_internal( get_logger(self.__class__).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: - cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 get_logger(self.__class__).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index b9819d9f..7ff57adc 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1003,7 +1003,7 @@ def _call_cuml_fit_func_dataframe( get_logger(cls).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) if cuda_system_mem_headroom is not None: - cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) + cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") # parameters passed to subclass From da5b3bbffe9f0d352f7c12fea43667d14d43e7af Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 6 Aug 2024 10:51:48 -0700 Subject: [PATCH 06/13] address review feedback Signed-off-by: Rong Ou --- docs/site/configuration.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/site/configuration.md b/docs/site/configuration.md index c3ddbd30..dafd3cfe 100644 --- a/docs/site/configuration.md +++ b/docs/site/configuration.md @@ -6,10 +6,11 @@ nav_order: 6 The following configurations can be supplied as Spark properties. -| Property name | Default | Meaning | -|:----------------------------------|:--------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | -| spark.rapids.ml.sam.enabled | false | if set to true, enables System Allocated Memory (SAM) on [HMM](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/) or [ATS](https://developer.nvidia.com/blog/nvidia-grace-hopper-superchip-architecture-in-depth/) systems during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | -| spark.rapids.ml.sam.headroom | None | when using System Allocated Memory (SAM) and GPU memory is oversubscribed, we may need to reserve some GPU memory as headroom to allow other CUDA calls to function without running out memory. Set a size appropriate for your application | -| spark.executorEnv.CUPY_ENABLE_SAM | 0 | if set to 1, enables System Allocated Memory (SAM) for CuPy operations. This enabled CuPy to work with SAM, and also avoid unnecessary memory coping | +| Property name | Default | Meaning | +|:----------------------------------|:--------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | +| spark.rapids.ml.sam.enabled | false | if set to true, enables System Allocated Memory (SAM) on [HMM](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/) or [ATS](https://developer.nvidia.com/blog/nvidia-grace-hopper-superchip-architecture-in-depth/) systems during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory | +| spark.rapids.ml.sam.headroom | None | when using System Allocated Memory (SAM) and GPU memory is oversubscribed, we may need to reserve some GPU memory as headroom to allow other CUDA calls to function without running out memory. Set a size appropriate for your application | +| spark.executorEnv.CUPY_ENABLE_SAM | 0 | if set to 1, enables System Allocated Memory (SAM) for CuPy operations. This enables CuPy to work with SAM, and also avoid unnecessary memory copying | +| spark.driverEnv.CUPY_ENABLE_SAM | 0 | if set to 1, enables System Allocated Memory (SAM) for CuPy operations. This enables CuPy to work with SAM, and also avoid unnecessary memory copying. This is needed when running Spark in local mode. | From d423f2a545c811edf5d09973043886c1775bd5b3 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 6 Aug 2024 16:29:03 -0700 Subject: [PATCH 07/13] use system memory in rmm cupy allocator Signed-off-by: Rong Ou --- python/run_benchmark.sh | 2 +- python/src/spark_rapids_ml/common/rmm_cupy.py | 39 +++++++++++++++++++ python/src/spark_rapids_ml/core.py | 9 +++-- python/src/spark_rapids_ml/umap.py | 4 +- 4 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 python/src/spark_rapids_ml/common/rmm_cupy.py diff --git a/python/run_benchmark.sh b/python/run_benchmark.sh index 9f1db683..81da2924 100755 --- a/python/run_benchmark.sh +++ b/python/run_benchmark.sh @@ -100,7 +100,7 @@ cat <>> from rmm.allocators.cupy import rmm_cupy_allocator + >>> import cupy + >>> cupy.cuda.set_allocator(rmm_cupy_system_allocator) + """ + stream = Stream(obj=cupy.cuda.get_current_stream()) + buf = librmm.device_buffer.DeviceBuffer(size=nbytes, stream=stream) + mem = cupy.cuda.SystemMemory.from_external( + ptr=buf.ptr, size=buf.size, owner=buf + ) + ptr = cupy.cuda.memory.MemoryPointer(mem, 0) + + return ptr diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 847831f3..501a2120 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -732,15 +732,16 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) cp.cuda.set_allocator(rmm_cupy_allocator) if cuda_system_mem_enabled: + from .common.rmm_cupy import rmm_cupy_system_allocator import rmm - from rmm.allocators.cupy import rmm_cupy_allocator if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() else: mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) rmm.mr.set_current_device_resource(mr) - cp.cuda.set_allocator(rmm_cupy_allocator) + + cp.cuda.set_allocator(rmm_cupy_system_allocator) _CumlCommon._initialize_cuml_logging(cuml_verbose) @@ -1419,16 +1420,16 @@ def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) cp.cuda.set_allocator(rmm_cupy_allocator) if cuda_system_mem_enabled: + from .common.rmm_cupy import rmm_cupy_system_allocator import cupy as cp import rmm - from rmm.allocators.cupy import rmm_cupy_allocator if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() else: mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) rmm.mr.set_current_device_resource(mr) - cp.cuda.set_allocator(rmm_cupy_allocator) + cp.cuda.set_allocator(rmm_cupy_system_allocator) # Construct the cuml counterpart object cuml_instance = construct_cuml_object_func() diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index 7ff57adc..4b449c6b 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1037,14 +1037,14 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: cp.cuda.set_allocator(rmm_cupy_allocator) if cuda_system_mem_enabled: import rmm - from rmm.allocators.cupy import rmm_cupy_allocator + from .common.rmm_cupy import rmm_cupy_system_allocator if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() else: mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) rmm.mr.set_current_device_resource(mr) - cp.cuda.set_allocator(rmm_cupy_allocator) + cp.cuda.set_allocator(rmm_cupy_system_allocator) _CumlCommon._initialize_cuml_logging(cuml_verbose) From 31faca45d0da016faa80f00d0e705329737df4c1 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 7 Aug 2024 14:36:32 -0700 Subject: [PATCH 08/13] switch back to rmm_cupy_allocator Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/common/rmm_cupy.py | 39 ------------------- python/src/spark_rapids_ml/core.py | 23 +++-------- python/src/spark_rapids_ml/umap.py | 8 +--- 3 files changed, 8 insertions(+), 62 deletions(-) delete mode 100644 python/src/spark_rapids_ml/common/rmm_cupy.py diff --git a/python/src/spark_rapids_ml/common/rmm_cupy.py b/python/src/spark_rapids_ml/common/rmm_cupy.py deleted file mode 100644 index cd9029ec..00000000 --- a/python/src/spark_rapids_ml/common/rmm_cupy.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Copyright (c) 2024, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import cupy -from rmm import _lib as librmm -from rmm._cuda.stream import Stream - - -# TODO(rongou): move this into RMM. -def rmm_cupy_system_allocator(nbytes): - """ - A CuPy allocator that makes use of RMM system memory resource. - - Examples - -------- - >>> from rmm.allocators.cupy import rmm_cupy_allocator - >>> import cupy - >>> cupy.cuda.set_allocator(rmm_cupy_system_allocator) - """ - stream = Stream(obj=cupy.cuda.get_current_stream()) - buf = librmm.device_buffer.DeviceBuffer(size=nbytes, stream=stream) - mem = cupy.cuda.SystemMemory.from_external( - ptr=buf.ptr, size=buf.size, owner=buf - ) - ptr = cupy.cuda.memory.MemoryPointer(mem, 0) - - return ptr diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index c1c25e6c..7b06837b 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -724,24 +724,19 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: if cuda_managed_mem_enabled: import rmm - from rmm.allocators.cupy import rmm_cupy_allocator - rmm.reinitialize( managed_memory=True, devices=_CumlCommon._get_gpu_device(context, is_local), ) - cp.cuda.set_allocator(rmm_cupy_allocator) + # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: - from .common.rmm_cupy import rmm_cupy_system_allocator import rmm - if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() else: mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) rmm.mr.set_current_device_resource(mr) - - cp.cuda.set_allocator(rmm_cupy_system_allocator) + # cupy allocator is set to rmm in cudf _CumlCommon._initialize_cuml_logging(cuml_verbose) @@ -791,7 +786,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: "A python worker received no data. Please increase amount of data or use fewer workers." ) - logger.info(f"Data loaded into python worker memory in {time.time() - start_time} seconds") + logger.info(f"Data loaded into python worker memory in {time.time() - start_time:.3f} seconds") logger.info("Initializing cuml context") with CumlContext( @@ -810,7 +805,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: # memory. do not rely on inputs after this call. result = cuml_fit_func(inputs, params) logger.info("Cuml fit complete") - logger.info(f"Cuml fit took {time.time() - start_time} seconds") + logger.info(f"Cuml fit took {time.time() - start_time:.3f} seconds") if partially_collect == True: if enable_nccl: @@ -1416,28 +1411,22 @@ def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: _CumlCommon._set_gpu_device(context, is_local, True) if cuda_managed_mem_enabled: - import cupy as cp import rmm - from rmm.allocators.cupy import rmm_cupy_allocator - rmm.reinitialize( managed_memory=True, devices=_CumlCommon._get_gpu_device( context, is_local, is_transform=True ), ) - cp.cuda.set_allocator(rmm_cupy_allocator) + # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: - from .common.rmm_cupy import rmm_cupy_system_allocator - import cupy as cp import rmm - if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() else: mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) rmm.mr.set_current_device_resource(mr) - cp.cuda.set_allocator(rmm_cupy_system_allocator) + # cupy allocator is set to rmm in cudf # Construct the cuml counterpart object cuml_instance = construct_cuml_object_func() diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index 4b449c6b..a9b494c6 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1031,20 +1031,16 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: if cuda_managed_mem_enabled: import rmm - from rmm.allocators.cupy import rmm_cupy_allocator - rmm.reinitialize(managed_memory=True) - cp.cuda.set_allocator(rmm_cupy_allocator) + # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: import rmm - from .common.rmm_cupy import rmm_cupy_system_allocator - if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() else: mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom) rmm.mr.set_current_device_resource(mr) - cp.cuda.set_allocator(rmm_cupy_system_allocator) + # cupy allocator is set to rmm in cudf _CumlCommon._initialize_cuml_logging(cuml_verbose) From 53eace8958474a648944669f5bae66f041f27c0d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 8 Aug 2024 16:09:51 -0700 Subject: [PATCH 09/13] fix sparse Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/clustering.py | 3 --- python/src/spark_rapids_ml/core.py | 6 +++--- python/src/spark_rapids_ml/umap.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/src/spark_rapids_ml/clustering.py b/python/src/spark_rapids_ml/clustering.py index c9f12c8c..f14d35dd 100644 --- a/python/src/spark_rapids_ml/clustering.py +++ b/python/src/spark_rapids_ml/clustering.py @@ -939,9 +939,6 @@ def _get_cuml_fit_func( ) if cuda_managed_mem_enabled and cuda_system_mem_enabled: raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.") - cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) - if cuda_system_mem_headroom is not None: - cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 idCol_bc = self.idCols_ raw_data_bc = self.raw_data_ diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 7b06837b..58dab72a 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -675,7 +675,7 @@ def _call_cuml_fit_func( if cuda_system_mem_enabled: get_logger(cls).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) - if cuda_system_mem_headroom is not None: + if cuda_system_mem_enabled and cuda_system_mem_headroom is not None: cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") @@ -769,7 +769,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) inputs.append((features, label, row_number)) - if cuda_managed_mem_enabled and use_sparse_array is True: + if (cuda_managed_mem_enabled or cuda_system_mem_enabled) and use_sparse_array is True: concated_nnz = sum(triplet[0].nnz for triplet in inputs) # type: ignore if concated_nnz > np.iinfo(np.int32).max: logger.warn( @@ -1399,7 +1399,7 @@ def _transform_evaluate_internal( if cuda_system_mem_enabled: get_logger(self.__class__).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) - if cuda_system_mem_headroom is not None: + if cuda_system_mem_enabled and cuda_system_mem_headroom is not None: cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 get_logger(self.__class__).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index a9b494c6..49fcf28b 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1002,7 +1002,7 @@ def _call_cuml_fit_func_dataframe( if cuda_system_mem_enabled: get_logger(cls).info("CUDA system allocated memory enabled.") cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None) - if cuda_system_mem_headroom is not None: + if cuda_system_mem_enabled and cuda_system_mem_headroom is not None: cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20 get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.") From ac3ec87dcfee816418409f2a45b84d96e425f65e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 12 Aug 2024 18:14:08 -0700 Subject: [PATCH 10/13] use custom numpy allocator Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/core.py | 24 ++++++++++++++++++++++++ python/src/spark_rapids_ml/umap.py | 12 ++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 58dab72a..a55cca4e 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -730,6 +730,18 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: + import cupy._core.numpy_allocator as ac + import numpy_allocator + import ctypes + lib = ctypes.CDLL(ac.__file__) + + class my_allocator(metaclass=numpy_allocator.type): + _calloc_ = ctypes.addressof(lib._calloc) + _malloc_ = ctypes.addressof(lib._malloc) + _realloc_ = ctypes.addressof(lib._realloc) + _free_ = ctypes.addressof(lib._free) + my_allocator.__enter__() # change the allocator globally + import rmm if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() @@ -1420,6 +1432,18 @@ def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: + import cupy._core.numpy_allocator as ac + import numpy_allocator + import ctypes + lib = ctypes.CDLL(ac.__file__) + + class my_allocator(metaclass=numpy_allocator.type): + _calloc_ = ctypes.addressof(lib._calloc) + _malloc_ = ctypes.addressof(lib._malloc) + _realloc_ = ctypes.addressof(lib._realloc) + _free_ = ctypes.addressof(lib._free) + my_allocator.__enter__() # change the allocator globally + import rmm if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index 49fcf28b..b97ba992 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1034,6 +1034,18 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: rmm.reinitialize(managed_memory=True) # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: + import cupy._core.numpy_allocator as ac + import numpy_allocator + import ctypes + lib = ctypes.CDLL(ac.__file__) + + class my_allocator(metaclass=numpy_allocator.type): + _calloc_ = ctypes.addressof(lib._calloc) + _malloc_ = ctypes.addressof(lib._malloc) + _realloc_ = ctypes.addressof(lib._realloc) + _free_ = ctypes.addressof(lib._free) + my_allocator.__enter__() # change the allocator globally + import rmm if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() From b50ed60c1905405dd253fc3f5a2ea840dc9ada9e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 14 Aug 2024 15:45:17 -0700 Subject: [PATCH 11/13] configure custom numpy allocator correctly Signed-off-by: Rong Ou --- python/benchmark/benchmark/base.py | 22 ++++++++++++++++++++++ python/src/spark_rapids_ml/core.py | 24 ------------------------ python/src/spark_rapids_ml/umap.py | 12 ------------ 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/python/benchmark/benchmark/base.py b/python/benchmark/benchmark/base.py index 9af91252..78b04fd0 100644 --- a/python/benchmark/benchmark/base.py +++ b/python/benchmark/benchmark/base.py @@ -227,6 +227,28 @@ def run(self) -> None: with WithSparkSession( self._args.spark_confs, shutdown=(not self._args.no_shutdown) ) as spark: + cuda_system_mem_enabled = ( + spark.conf.get("spark.rapids.ml.sam.enabled", "false") + == "true" + ) + if cuda_system_mem_enabled: + def configure_numpy_allocator(): + import cupy._core.numpy_allocator as ac + import numpy_allocator + import ctypes + lib = ctypes.CDLL(ac.__file__) + + class my_allocator(metaclass=numpy_allocator.type): + _calloc_ = ctypes.addressof(lib._calloc) + _malloc_ = ctypes.addressof(lib._malloc) + _realloc_ = ctypes.addressof(lib._realloc) + _free_ = ctypes.addressof(lib._free) + + my_allocator.__enter__() # change the allocator globally + print("Custom numpy allocator setup on all executors") + return [True] + spark.sparkContext.runJob(spark.sparkContext.parallelize([1]), lambda _: configure_numpy_allocator()) + for _ in range(self._args.num_runs): train_df, features_col, label_col = self.input_dataframe( spark, *self._args.train_path diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index a55cca4e..58dab72a 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -730,18 +730,6 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: - import cupy._core.numpy_allocator as ac - import numpy_allocator - import ctypes - lib = ctypes.CDLL(ac.__file__) - - class my_allocator(metaclass=numpy_allocator.type): - _calloc_ = ctypes.addressof(lib._calloc) - _malloc_ = ctypes.addressof(lib._malloc) - _realloc_ = ctypes.addressof(lib._realloc) - _free_ = ctypes.addressof(lib._free) - my_allocator.__enter__() # change the allocator globally - import rmm if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() @@ -1432,18 +1420,6 @@ def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: ) # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: - import cupy._core.numpy_allocator as ac - import numpy_allocator - import ctypes - lib = ctypes.CDLL(ac.__file__) - - class my_allocator(metaclass=numpy_allocator.type): - _calloc_ = ctypes.addressof(lib._calloc) - _malloc_ = ctypes.addressof(lib._malloc) - _realloc_ = ctypes.addressof(lib._realloc) - _free_ = ctypes.addressof(lib._free) - my_allocator.__enter__() # change the allocator globally - import rmm if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index b97ba992..49fcf28b 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1034,18 +1034,6 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: rmm.reinitialize(managed_memory=True) # cupy allocator is set to rmm in cudf if cuda_system_mem_enabled: - import cupy._core.numpy_allocator as ac - import numpy_allocator - import ctypes - lib = ctypes.CDLL(ac.__file__) - - class my_allocator(metaclass=numpy_allocator.type): - _calloc_ = ctypes.addressof(lib._calloc) - _malloc_ = ctypes.addressof(lib._malloc) - _realloc_ = ctypes.addressof(lib._realloc) - _free_ = ctypes.addressof(lib._free) - my_allocator.__enter__() # change the allocator globally - import rmm if cuda_system_mem_headroom is None: mr = rmm.mr.SystemMemoryResource() From 49bf63bebc230c7431da017235e12d4eac0e53dc Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 16 Aug 2024 10:21:43 -0700 Subject: [PATCH 12/13] add numpy_allcoator as requirement Signed-off-by: Rong Ou --- python/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/requirements.txt b/python/requirements.txt index 21d77116..17839abb 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,2 +1,3 @@ +numpy_allocator pyspark>=3.2.1,<3.5 scikit-learn>=1.2.1 From b2bb98b263a38fcc4225e4e6c87f19c6c54f30b1 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 20 Aug 2024 14:09:57 -0700 Subject: [PATCH 13/13] remove unused import Signed-off-by: Rong Ou --- python/src/spark_rapids_ml/clustering.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/src/spark_rapids_ml/clustering.py b/python/src/spark_rapids_ml/clustering.py index f14d35dd..4ac3a9e1 100644 --- a/python/src/spark_rapids_ml/clustering.py +++ b/python/src/spark_rapids_ml/clustering.py @@ -39,7 +39,6 @@ StructField, StructType, ) -from pyspark.util import _parse_memory from .core import ( CumlT,