Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support System Allocated Memory (SAM) #701

Draft
wants to merge 16 commits into
base: branch-24.10
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/site/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ nav_order: 6

The following configurations can be supplied as Spark properties.

| Property name | Default | Meaning |
| :-------------- | :------ | :------- |
| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory|
| Property name | Default | Meaning |
|:----------------------------------|:--------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| spark.rapids.ml.uvm.enabled | false | if set to true, enables [CUDA unified virtual memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) (aka managed memory) during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory |
| spark.rapids.ml.sam.enabled | false | if set to true, enables System Allocated Memory (SAM) on [HMM](https://developer.nvidia.com/blog/simplifying-gpu-application-development-with-heterogeneous-memory-management/) or [ATS](https://developer.nvidia.com/blog/nvidia-grace-hopper-superchip-architecture-in-depth/) systems during estimator.fit() operations to allow processing of larger datasets than would fit in GPU memory |
| spark.rapids.ml.sam.headroom | None | when using System Allocated Memory (SAM) and GPU memory is oversubscribed, we may need to reserve some GPU memory as headroom to allow other CUDA calls to function without running out memory. Set a size appropriate for your application |
| spark.executorEnv.CUPY_ENABLE_SAM | 0 | if set to 1, enables System Allocated Memory (SAM) for CuPy operations. This enables CuPy to work with SAM, and also avoid unnecessary memory copying |
| spark.driverEnv.CUPY_ENABLE_SAM | 0 | if set to 1, enables System Allocated Memory (SAM) for CuPy operations. This enables CuPy to work with SAM, and also avoid unnecessary memory copying. This is needed when running Spark in local mode. |

22 changes: 22 additions & 0 deletions python/benchmark/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,28 @@ def run(self) -> None:
with WithSparkSession(
self._args.spark_confs, shutdown=(not self._args.no_shutdown)
) as spark:
cuda_system_mem_enabled = (
spark.conf.get("spark.rapids.ml.sam.enabled", "false")
== "true"
)
if cuda_system_mem_enabled:
def configure_numpy_allocator():
import cupy._core.numpy_allocator as ac
import numpy_allocator
import ctypes
lib = ctypes.CDLL(ac.__file__)

class my_allocator(metaclass=numpy_allocator.type):
_calloc_ = ctypes.addressof(lib._calloc)
_malloc_ = ctypes.addressof(lib._malloc)
_realloc_ = ctypes.addressof(lib._realloc)
_free_ = ctypes.addressof(lib._free)

my_allocator.__enter__() # change the allocator globally
print("Custom numpy allocator setup on all executors")
return [True]
spark.sparkContext.runJob(spark.sparkContext.parallelize([1]), lambda _: configure_numpy_allocator())

for _ in range(self._args.num_runs):
train_df, features_col, label_col = self.input_dataframe(
spark, *self._args.train_path
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy_allocator
pyspark>=3.2.1,<3.5
scikit-learn>=1.2.1
5 changes: 4 additions & 1 deletion python/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ cat <<EOF
--spark_confs spark.python.worker.reuse=true \
--spark_confs spark.master=local[$local_threads] \
--spark_confs spark.driver.memory=128g \
--spark_confs spark.rapids.ml.uvm.enabled=true
--spark_confs spark.rapids.ml.uvm.enabled=false \
--spark_confs spark.rapids.ml.sam.enabled=true \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does sam work on non-GH machines? Will it fall back to uvm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only when HMM is supported:

$ nvidia-smi -q | grep Addressing
    Addressing Mode                       : HMM

Copy link
Collaborator

@wbo4958 wbo4958 Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if HMM is not supported but we've enabled SAM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invoking the RMM system mr would cause a crash. I guess we should figure out what value to default to that's the most convenient to us.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.
Does our ci machine support RMM? Nightly ci executes run_benchmark.sh with a A100 40G GPU I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work as long as we install the open source driver. These are the requirements:

  • NVIDIA CUDA 12.2 with the open-source r535_00 driver or newer.
  • A sufficiently recent Linux kernel: 6.1.24+, 6.2.11+, or 6.3+.
  • A GPU with one of the following supported architectures: NVIDIA Turing, NVIDIA Ampere, NVIDIA Ada Lovelace, NVIDIA Hopper, or newer.
  • A 64-bit x86 CPU.

--spark_confs spark.rapids.ml.sam.headroom=1g \
--spark_confs spark.executorEnv.CUPY_ENABLE_SAM=1
lijinf2 marked this conversation as resolved.
Show resolved Hide resolved
EOF
)

Expand Down
8 changes: 7 additions & 1 deletion python/src/spark_rapids_ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,12 @@ def _get_cuml_fit_func(
_get_spark_session().conf.get("spark.rapids.ml.uvm.enabled", "false")
== "true"
)
cuda_system_mem_enabled = (
_get_spark_session().conf.get("spark.rapids.ml.sam.enabled", "false")
== "true"
)
if cuda_managed_mem_enabled and cuda_system_mem_enabled:
raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.")

idCol_bc = self.idCols_
raw_data_bc = self.raw_data_
Expand All @@ -957,7 +963,7 @@ def _cuml_fit(

# experiments indicate it is faster to convert to numpy array and then to cupy array than directly
# invoking cupy array on the list
if cuda_managed_mem_enabled:
if cuda_managed_mem_enabled or cuda_system_mem_enabled:
features = cp.array(features)

inputs.append(features)
Expand Down
62 changes: 53 additions & 9 deletions python/src/spark_rapids_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import threading
import time
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from typing import (
Expand Down Expand Up @@ -68,6 +69,7 @@
Row,
StructType,
)
from pyspark.util import _parse_memory
from scipy.sparse import csr_matrix

from .common.cuml_context import CumlContext
Expand Down Expand Up @@ -664,6 +666,19 @@ def _call_cuml_fit_func(
if cuda_managed_mem_enabled:
get_logger(cls).info("CUDA managed memory enabled.")

cuda_system_mem_enabled = (
_get_spark_session().conf.get("spark.rapids.ml.sam.enabled", "false")
== "true"
)
if cuda_managed_mem_enabled and cuda_system_mem_enabled:
raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.")
if cuda_system_mem_enabled:
get_logger(cls).info("CUDA system allocated memory enabled.")
cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None)
if cuda_system_mem_enabled and cuda_system_mem_headroom is not None:
cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20
get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.")

# parameters passed to subclass
params: Dict[str, Any] = {
param_alias.cuml_init: self.cuml_params,
Expand Down Expand Up @@ -709,13 +724,19 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:

if cuda_managed_mem_enabled:
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator

rmm.reinitialize(
managed_memory=True,
devices=_CumlCommon._get_gpu_device(context, is_local),
)
cp.cuda.set_allocator(rmm_cupy_allocator)
# cupy allocator is set to rmm in cudf
if cuda_system_mem_enabled:
import rmm
if cuda_system_mem_headroom is None:
mr = rmm.mr.SystemMemoryResource()
else:
mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom)
rmm.mr.set_current_device_resource(mr)
# cupy allocator is set to rmm in cudf

_CumlCommon._initialize_cuml_logging(cuml_verbose)

Expand All @@ -724,6 +745,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
logger.info("Loading data into python worker memory")
inputs = []
sizes = []
start_time = time.time()

for pdf in pdf_iter:
sizes.append(pdf.shape[0])
Expand All @@ -738,7 +760,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:

# experiments indicate it is faster to convert to numpy array and then to cupy array than directly
# invoking cupy array on the list
if cuda_managed_mem_enabled and use_sparse_array is False:
if (cuda_managed_mem_enabled or cuda_system_mem_enabled) and use_sparse_array is False:
features = cp.array(features)

label = pdf[alias.label] if alias.label in pdf.columns else None
Expand All @@ -747,7 +769,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
)
inputs.append((features, label, row_number))

if cuda_managed_mem_enabled and use_sparse_array is True:
if (cuda_managed_mem_enabled or cuda_system_mem_enabled) and use_sparse_array is True:
concated_nnz = sum(triplet[0].nnz for triplet in inputs) # type: ignore
if concated_nnz > np.iinfo(np.int32).max:
logger.warn(
Expand All @@ -764,6 +786,8 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
"A python worker received no data. Please increase amount of data or use fewer workers."
)

logger.info(f"Data loaded into python worker memory in {time.time() - start_time:.3f} seconds")

logger.info("Initializing cuml context")
with CumlContext(
partition_id, num_workers, context, enable_nccl, require_ucx
Expand All @@ -774,12 +798,14 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
params[param_alias.loop] = cc._loop

logger.info("Invoking cuml fit")
start_time = time.time()

# call the cuml fit function
# *note*: cuml_fit_func may delete components of inputs to free
# memory. do not rely on inputs after this call.
result = cuml_fit_func(inputs, params)
logger.info("Cuml fit complete")
logger.info(f"Cuml fit took {time.time() - start_time:.3f} seconds")

if partially_collect == True:
if enable_nccl:
Expand Down Expand Up @@ -1364,6 +1390,19 @@ def _transform_evaluate_internal(
if cuda_managed_mem_enabled:
get_logger(self.__class__).info("CUDA managed memory enabled.")

cuda_system_mem_enabled = (
_get_spark_session().conf.get("spark.rapids.ml.sam.enabled", "false")
== "true"
)
if cuda_managed_mem_enabled and cuda_system_mem_enabled:
raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.")
if cuda_system_mem_enabled:
get_logger(self.__class__).info("CUDA system allocated memory enabled.")
cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None)
if cuda_system_mem_enabled and cuda_system_mem_headroom is not None:
cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20
get_logger(self.__class__).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.")

def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
from pyspark import TaskContext

Expand All @@ -1372,17 +1411,22 @@ def _transform_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
_CumlCommon._set_gpu_device(context, is_local, True)

if cuda_managed_mem_enabled:
import cupy as cp
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator

rmm.reinitialize(
managed_memory=True,
devices=_CumlCommon._get_gpu_device(
context, is_local, is_transform=True
),
)
cp.cuda.set_allocator(rmm_cupy_allocator)
# cupy allocator is set to rmm in cudf
if cuda_system_mem_enabled:
import rmm
if cuda_system_mem_headroom is None:
mr = rmm.mr.SystemMemoryResource()
else:
mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom)
rmm.mr.set_current_device_resource(mr)
# cupy allocator is set to rmm in cudf

# Construct the cuml counterpart object
cuml_instance = construct_cuml_object_func()
Expand Down
28 changes: 24 additions & 4 deletions python/src/spark_rapids_ml/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
StructField,
StructType,
)
from pyspark.util import _parse_memory

from spark_rapids_ml.core import FitInputType, _CumlModel

Expand Down Expand Up @@ -992,6 +993,19 @@ def _call_cuml_fit_func_dataframe(
if cuda_managed_mem_enabled:
get_logger(cls).info("CUDA managed memory enabled.")

cuda_system_mem_enabled = (
_get_spark_session().conf.get("spark.rapids.ml.sam.enabled", "false")
== "true"
)
if cuda_managed_mem_enabled and cuda_system_mem_enabled:
raise ValueError("Both CUDA managed memory and system allocated memory cannot be enabled at the same time.")
if cuda_system_mem_enabled:
get_logger(cls).info("CUDA system allocated memory enabled.")
cuda_system_mem_headroom = _get_spark_session().conf.get("spark.rapids.ml.sam.headroom", None)
if cuda_system_mem_enabled and cuda_system_mem_headroom is not None:
cuda_system_mem_headroom = _parse_memory(cuda_system_mem_headroom) << 20
get_logger(cls).info(f"CUDA system allocated memory headroom set to {cuda_system_mem_headroom}.")

# parameters passed to subclass
params: Dict[str, Any] = {
param_alias.cuml_init: self.cuml_params,
Expand All @@ -1017,10 +1031,16 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:

if cuda_managed_mem_enabled:
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator

rmm.reinitialize(managed_memory=True)
cp.cuda.set_allocator(rmm_cupy_allocator)
# cupy allocator is set to rmm in cudf
if cuda_system_mem_enabled:
import rmm
if cuda_system_mem_headroom is None:
mr = rmm.mr.SystemMemoryResource()
else:
mr = rmm.mr.SamHeadroomMemoryResource(headroom=cuda_system_mem_headroom)
rmm.mr.set_current_device_resource(mr)
# cupy allocator is set to rmm in cudf

_CumlCommon._initialize_cuml_logging(cuml_verbose)

Expand All @@ -1042,7 +1062,7 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
features = np.array(list(pdf[alias.data]), order=array_order)
# experiments indicate it is faster to convert to numpy array and then to cupy array than directly
# invoking cupy array on the list
if cuda_managed_mem_enabled:
if cuda_managed_mem_enabled or cuda_system_mem_enabled:
features = cp.array(features)

label = pdf[alias.label] if alias.label in pdf.columns else None
Expand Down