From 3debc82f11c3bd0460c5b689665e59fb9bb2cdd1 Mon Sep 17 00:00:00 2001 From: Maheedhar Reddy Chappidi Date: Thu, 17 Aug 2023 13:02:28 -0700 Subject: [PATCH] Auto-detection of accelerator_type for aws_accelerators trn1_inf (#37998) This change is to support auto-detection of AWS accelerators and configuring appropriate environment variables to designate the neuron_core per task/actor. Related REP #33707 Signed-off-by: e428265 --- .../doc_code/neuron_core_accelerator.py | 28 +++ .../ray-core/tasks/using-ray-with-gpus.rst | 24 +- python/ray/_private/accelerator.py | 111 +++++++++ python/ray/_private/ray_constants.py | 20 ++ python/ray/_private/ray_option_utils.py | 31 +++ python/ray/_private/resource_spec.py | 25 +- python/ray/_private/utils.py | 154 ++++++++++-- python/ray/_private/worker.py | 93 ++++--- python/ray/_raylet.pyx | 5 +- .../autoscaler/_private/aws/node_provider.py | 21 ++ python/ray/runtime_context.py | 24 +- python/ray/tests/BUILD | 1 + python/ray/tests/test_accelerator.py | 99 ++++++++ python/ray/tests/test_advanced_2.py | 232 ++++++++++++++++++ python/ray/tests/test_advanced_8.py | 21 +- python/ray/tests/test_autoscaler_yaml.py | 22 ++ python/ray/util/accelerators/__init__.py | 2 + python/ray/util/accelerators/accelerators.py | 1 + src/ray/common/ray_config_def.h | 6 +- src/ray/common/test/scheduling_ids_test.cc | 3 +- 20 files changed, 839 insertions(+), 84 deletions(-) create mode 100644 doc/source/ray-core/doc_code/neuron_core_accelerator.py create mode 100644 python/ray/_private/accelerator.py create mode 100644 python/ray/tests/test_accelerator.py diff --git a/doc/source/ray-core/doc_code/neuron_core_accelerator.py b/doc/source/ray-core/doc_code/neuron_core_accelerator.py new file mode 100644 index 000000000000..4128102d6b0e --- /dev/null +++ b/doc/source/ray-core/doc_code/neuron_core_accelerator.py @@ -0,0 +1,28 @@ +# __neuron_core_accelerator_start__ +import ray +import os +from ray.util.accelerators import AWS_NEURON_CORE + +# On trn1.2xlarge instance, there will be 2 neuron cores. +ray.init(resources={"neuron_cores": 2}) + + +@ray.remote(resources={"neuron_cores": 1}) +class NeuronCoreActor: + def info(self): + ids = ray.get_runtime_context().get_resource_ids() + print("neuron_core_ids: {}".format(ids["neuron_cores"])) + print(f"NEURON_RT_VISIBLE_CORES: {os.environ['NEURON_RT_VISIBLE_CORES']}") + + +@ray.remote(resources={"neuron_cores": 1}, accelerator_type=AWS_NEURON_CORE) +def use_neuron_core_task(): + ids = ray.get_runtime_context().get_resource_ids() + print("neuron_core_ids: {}".format(ids["neuron_cores"])) + print(f"NEURON_RT_VISIBLE_CORES: {os.environ['NEURON_RT_VISIBLE_CORES']}") + + +neuron_core_actor = NeuronCoreActor.remote() +ray.get(neuron_core_actor.info.remote()) +ray.get(use_neuron_core_task.remote()) +# __neuron_core_accelerator_end__ diff --git a/doc/source/ray-core/tasks/using-ray-with-gpus.rst b/doc/source/ray-core/tasks/using-ray-with-gpus.rst index 32305549070c..59430e1a8445 100644 --- a/doc/source/ray-core/tasks/using-ray-with-gpus.rst +++ b/doc/source/ray-core/tasks/using-ray-with-gpus.rst @@ -123,4 +123,26 @@ This also lets the multi-node-type autoscaler know that there is demand for that :start-after: __accelerator_type_start__ :end-before: __accelerator_type_end__ -See ``ray.util.accelerators`` for available accelerator types. Current automatically detected accelerator types include Nvidia GPUs. +See ``ray.util.accelerators`` for available accelerator types. Current automatically detected accelerator types include: + + - Nvidia GPUs + - AWS Neuron Cores + +AWS Neuron Core Accelerator (Experimental) +------------------------------------------ + +Similar to Nvidia GPUs, Ray auto-detects `AWS Neuron Cores`_ by default. +The user can specify `resources={"neuron_cores": some_number}` on +task or actor resource requirements to assign the Neuron Core(s). + +.. _`AWS Neuron Cores` : https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/model-architecture-fit.html + +.. note:: + + Ray supports a heterogeneous cluster of GPUs and Neuron Cores but doesn't allow specifying resources requirements of + ``num_gpus`` and ``neuron_cores`` together for a task or actor. + +.. literalinclude:: ../doc_code/neuron_core_accelerator.py + :language: python + :start-after: __neuron_core_accelerator_start__ + :end-before: __neuron_core_accelerator_end__ diff --git a/python/ray/_private/accelerator.py b/python/ray/_private/accelerator.py new file mode 100644 index 000000000000..2a98e7048b41 --- /dev/null +++ b/python/ray/_private/accelerator.py @@ -0,0 +1,111 @@ +import json +import os +import subprocess +import sys +from typing import Optional + + +def update_resources_with_accelerator_type(resources: dict): + """Update the resources dictionary with the accelerator type and custom + resources. + + Currently, we support AWS NeuronCore (neuron_cores / + accelerator_type:aws-neuron-core) detection and configuration. + + Args: + resources: Resources dictionary to be updated with + accelerator type and custom resources. + """ + _detect_and_configure_aws_neuron_core(resources) + + +def _detect_and_configure_aws_neuron_core(resources: dict): + """Configuration and auto-detection of AWS NeuronCore accelerator type + and number of NeuronCore (neuron_cores). + + If the number of NeuronCore is not specified in the resources, this + function will try to detect the number of NeuronCore. + + If the number of NeuronCore is specified in the resources, this + function will check if the number of NeuronCore is greater than the + number of visible NeuronCore and raise an error if it is true. + + If the number of NeuronCore is greater than the number of visible + NeuronCore, this function will raise an error. + + Lastly, update accelerator_type and neuron_cores in resources. + + Args: + resources: Resources dictionary to be updated with + NeuronCore accelerator type and custom resources(neuron_cores). + + Raises: + ValueError: If the number of NeuronCore is greater than the number of + visible NeuronCore. + """ + import ray._private.ray_constants as ray_constants + import ray._private.utils as utils + + # AWS NeuronCore detection and configuration + # 1. Check if the user specified neuron_cores in resources + neuron_cores = resources.get(ray_constants.NEURON_CORES, None) + # 2. Check if the user specified NEURON_RT_VISIBLE_CORES + neuron_core_ids = utils.get_aws_neuron_core_visible_ids() + if ( + neuron_cores is not None + and neuron_core_ids is not None + and neuron_cores > len(neuron_core_ids) + ): + raise ValueError( + f"Attempting to start raylet with {neuron_cores} " + f"neuron cores, but NEURON_RT_VISIBLE_CORES contains " + f"{neuron_core_ids}." + ) + # 3. Auto-detect neuron_cores if not specified in resources + if neuron_cores is None: + neuron_cores = _autodetect_aws_neuron_cores() + # Don't use more neuron cores than allowed by NEURON_RT_VISIBLE_CORES. + if neuron_cores is not None and neuron_core_ids is not None: + neuron_cores = min(neuron_cores, len(neuron_core_ids)) + if neuron_cores is not None: + # 4. Update accelerator_type and neuron_cores with + # number of neuron cores detected or configured. + resources.update( + { + ray_constants.NEURON_CORES: neuron_cores, + utils.get_neuron_core_constraint_name(): neuron_cores, + } + ) + + +def _autodetect_aws_neuron_cores() -> Optional[int]: + """ + Attempt to detect the number of Neuron cores on this machine. + + Returns: + The number of Neuron cores if any were detected, otherwise None. + """ + result = None + if sys.platform.startswith("linux") and os.path.isdir("/opt/aws/neuron/bin/"): + result = _get_neuron_core_count() + return result + + +def _get_neuron_core_count() -> int: + """Get the number of Neuron cores on a machine based on neuron_path. + + Returns: + The number of Neuron cores on this machine (Default to 0). + """ + neuron_path = "/opt/aws/neuron/bin/" + nc_count: int = 0 + result = subprocess.run( + [os.path.join(neuron_path, "neuron-ls"), "--json-output"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.returncode == 0 and result.stdout: + json_out = json.loads(result.stdout) + for neuron_device in json_out: + nc_count += neuron_device.get("nc_count", 0) + return nc_count diff --git a/python/ray/_private/ray_constants.py b/python/ray/_private/ray_constants.py index 429935b6e3a7..ae605634c050 100644 --- a/python/ray/_private/ray_constants.py +++ b/python/ray/_private/ray_constants.py @@ -395,7 +395,27 @@ def env_set_by_user(key): LANGUAGE_WORKER_TYPES = ["python", "java", "cpp"] +# Accelerator constants +NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR = ( + "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES" +) NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" +CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES" +NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES" +NEURON_CORES = "neuron_cores" +GPU = "GPU" +# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html#aws-inf2-arch +# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#aws-trn1-arch +# Subject to removal after the information is available via public API +AWS_NEURON_INSTANCE_MAP = { + "trn1.2xlarge": 2, + "trn1.32xlarge": 32, + "trn1n.32xlarge": 32, + "inf2.xlarge": 2, + "inf2.8xlarge": 2, + "inf2.24xlarge": 12, + "inf2.48xlarge": 24, +} RAY_WORKER_NICENESS = "RAY_worker_niceness" # Default max_retries option in @ray.remote for non-actor diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index d9ec6fcabe22..4d15fa7f9711 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -6,6 +6,7 @@ import ray from ray._private import ray_constants from ray._private.utils import get_ray_doc_version +from ray.util.accelerators import accelerators from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import ( NodeAffinitySchedulingStrategy, @@ -107,6 +108,34 @@ def _validate_resources(resources: Optional[Dict[str, float]]) -> Optional[str]: return None +def _validate_neuron_core_accelerator(options: Dict[str, Any]): + """Validate options for NeuronCore accelerator/ neuron_cores and GPU, + supports only one or the other (Either NeuronCore or GPU). + + Args: + options: The options to be validated. + + Raises: + ValueError: If the options are invalid. + """ + num_gpus = options.get("num_gpus", None) + if num_gpus is not None and num_gpus > 0: + resources = options["resources"] if "resources" in options else None + accelerator_type_value: str = options.get("accelerator_type", "") + if resources is not None: + neuron_cores: int = resources.get(ray_constants.NEURON_CORES, 0) + if neuron_cores > 0: + raise ValueError( + "'num_gpus' cannot be used together with " + "neuron_cores/accelerator_type:aws-neuron-core." + ) + elif accelerator_type_value == accelerators.AWS_NEURON_CORE: + raise ValueError( + "'num_gpus' cannot be used together with " + "neuron_cores/accelerator_type:aws-neuron-core." + ) + + _common_options = { "accelerator_type": Option((str, type(None))), "memory": _resource_option("memory"), @@ -291,6 +320,7 @@ def validate_task_options(options: Dict[str, Any], in_options: bool): if in_options and "max_calls" in options: raise ValueError("Setting 'max_calls' is not supported in '.options()'.") _check_deprecate_placement_group(options) + _validate_neuron_core_accelerator(options) def validate_actor_options(options: Dict[str, Any], in_options: bool): @@ -335,6 +365,7 @@ def validate_actor_options(options: Dict[str, Any], in_options: bool): ) _check_deprecate_placement_group(options) + _validate_neuron_core_accelerator(options) def update_options( diff --git a/python/ray/_private/resource_spec.py b/python/ray/_private/resource_spec.py index d916a7db43b6..4cf5b01a0338 100644 --- a/python/ray/_private/resource_spec.py +++ b/python/ray/_private/resource_spec.py @@ -7,9 +7,12 @@ from collections import namedtuple from typing import Optional +import ray._private.accelerator as accelerator + import ray import ray._private.ray_constants as ray_constants + try: import GPUtil except ImportError: @@ -198,6 +201,8 @@ def resolve(self, is_head: bool, node_ip_address: Optional[str] = None): except Exception: logger.exception("Could not parse gpu information.") + accelerator.update_resources_with_accelerator_type(resources) + # Choose a default object store size. system_memory = ray._private.utils.get_system_memory() avail_memory = ray._private.utils.estimate_available_memory() @@ -305,12 +310,9 @@ def _get_gpu_types_gputil(): if len(gpu_list) > 0: gpu_list_names = [gpu.name for gpu in gpu_list] info_str = gpu_list_names.pop() - pretty_name = _pretty_gpu_name(info_str) + pretty_name = _pretty_nvidia_gpu_name(info_str) if pretty_name: - constraint_name = ( - f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}" f"{pretty_name}" - ) - return {constraint_name: 1} + return {ray._private.utils.get_constraint_name(pretty_name): 1} return {} @@ -336,11 +338,8 @@ def _constraints_from_gpu_info(info_str: str): if k.strip() == "Model": full_model_name = v.strip() break - pretty_name = _pretty_gpu_name(full_model_name) - if pretty_name: - constraint_name = f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}" f"{pretty_name}" - return {constraint_name: 1} - return {} + pretty_name = _pretty_nvidia_gpu_name(full_model_name) + return {ray._private.utils.get_constraint_name(pretty_name): 1} def _get_gpu_info_string(): @@ -364,11 +363,11 @@ def _get_gpu_info_string(): # TODO(Alex): This pattern may not work for non NVIDIA Tesla GPUs (which have # the form "Tesla V100-SXM2-16GB" or "Tesla K80"). -GPU_NAME_PATTERN = re.compile(r"\w+\s+([A-Z0-9]+)") +NVIDIA_GPU_NAME_PATTERN = re.compile(r"\w+\s+([A-Z0-9]+)") -def _pretty_gpu_name(name): +def _pretty_nvidia_gpu_name(name): if name is None: return None - match = GPU_NAME_PATTERN.match(name) + match = NVIDIA_GPU_NAME_PATTERN.match(name) return match.group(1) if match else None diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index fb767df4b672..4d90dc4610c0 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -33,6 +33,7 @@ Union, Coroutine, List, + Mapping, ) # Import psutil after ray so the packaged version is used. @@ -48,7 +49,6 @@ if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv - pwd = None if sys.platform != "win32": import pwd @@ -64,7 +64,6 @@ win32_job = None win32_AssignProcessToJobObject = None - ENV_DISABLE_DOCKER_CPU_WARNING = "RAY_DISABLE_DOCKER_CPU_WARNING" in os.environ _PYARROW_VERSION = None @@ -72,7 +71,6 @@ _CALLED_FREQ = defaultdict(lambda: 0) _CALLED_FREQ_LOCK = threading.Lock() - PLACEMENT_GROUP_INDEXED_BUNDLED_RESOURCE_PATTERN = re.compile( r"(.+)_group_(\d+)_([0-9a-zA-Z]+)" ) @@ -278,30 +276,82 @@ def compute_driver_id_from_job(job_id): return ray.WorkerID(driver_id_str) -def get_cuda_visible_devices(): - """Get the device IDs in the CUDA_VISIBLE_DEVICES environment variable. +def get_gpu_and_accelerator_runtime_ids() -> Mapping[str, Optional[List[str]]]: + """ + Get the device IDs of GPUs (CUDA), accelerators(NeuronCore) using + (CUDA_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES) environment variables. + + Returns: + A dictionary with keys: + - ray_constants.GPU: The list of device IDs of GPUs. + - ray_constants.NEURON_CORES: The list of device IDs of + accelerators. + If either of the environment variables is not set, returns None for + corresponding key. + """ + return { + ray_constants.GPU: get_cuda_visible_devices(), + ray_constants.NEURON_CORES: get_aws_neuron_core_visible_ids(), + } + + +def get_cuda_visible_devices() -> Optional[List[str]]: + """ + Get the device IDs using CUDA_VISIBLE_DEVICES environment variable. Returns: - devices (List[str]): If CUDA_VISIBLE_DEVICES is set, returns a - list of strings representing the IDs of the visible GPUs. + devices (List[str]): If environment variable is set, returns a + list of strings representing the IDs of the visible devices. If it is not set or is set to NoDevFiles, returns empty list. """ - gpu_ids_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) + return _get_visible_ids(env_var=ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR) + + +def get_aws_neuron_core_visible_ids() -> Optional[List[str]]: + """ + Get the device IDs using NEURON_RT_VISIBLE_CORES environment variable. - if gpu_ids_str is None: + Returns: + devices (List[str]): If environment variable is set, returns a + list of strings representing the IDs of the visible devices. + If it is not set or is set to NoDevFiles, returns empty list. + """ + return _get_visible_ids(env_var=ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR) + + +def _get_visible_ids(env_var: str) -> Optional[List[str]]: + """Get the device IDs from defined environment variable. + Args: + env_var: Environment variable (e.g., CUDA_VISIBLE_DEVICES, + NEURON_RT_VISIBLE_CORES) to set based on the accelerator runtime. + + Returns: + devices (List[str]): If environment variable is set, returns a + list of strings representing the IDs of the visible devices or cores. + If it is not set or is set to NoDevFiles, returns empty list. + """ + if env_var not in ( + ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR, + ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR, + ): + raise ValueError(f"Invalid environment variable {env_var} to get visible IDs.") + visible_ids_str = os.environ.get(env_var, None) + + if visible_ids_str is None: return None - if gpu_ids_str == "": + if visible_ids_str == "": return [] - if gpu_ids_str == "NoDevFiles": + if visible_ids_str == "NoDevFiles": return [] - # GPU identifiers are given as strings representing integers or UUIDs. - return list(gpu_ids_str.split(",")) + # Identifiers are given as strings representing integers or UUIDs. + return list(visible_ids_str.split(",")) last_set_gpu_ids = None +last_set_neuron_core_ids = None def set_omp_num_threads_if_unset() -> bool: @@ -344,24 +394,92 @@ def set_omp_num_threads_if_unset() -> bool: return True -def set_cuda_visible_devices(gpu_ids): +def set_cuda_visible_devices(gpu_ids: List[str]): """Set the CUDA_VISIBLE_DEVICES environment variable. Args: gpu_ids (List[str]): List of strings representing GPU IDs. """ - if os.environ.get(ray_constants.NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR): return - global last_set_gpu_ids if last_set_gpu_ids == gpu_ids: return # optimization: already set - - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids]) + _set_visible_ids(gpu_ids, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR) last_set_gpu_ids = gpu_ids +def set_gpu_and_accelerator_runtime_ids() -> None: + """Set (CUDA_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES, ..) environment variables + based on the accelerator runtime. + + Raises: + ValueError: If the environment variable is set to a different + environment variable. + """ + ids = ray.get_runtime_context().get_resource_ids() + set_cuda_visible_devices(ids[ray_constants.GPU]) + set_aws_neuron_core_visible_ids(ids[ray_constants.NEURON_CORES]) + + +def set_aws_neuron_core_visible_ids(neuron_core_ids: List[str]) -> None: + """Set the NEURON_RT_VISIBLE_CORES environment variable based on + given neuron_core_ids. + + Args: + neuron_core_ids (List[str]): List of int representing core IDs. + """ + if os.environ.get(ray_constants.NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR): + return + global last_set_neuron_core_ids + if last_set_neuron_core_ids == neuron_core_ids: + return # optimization: already set + _set_visible_ids(neuron_core_ids, ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR) + last_set_neuron_core_ids = neuron_core_ids + + +def _set_visible_ids(visible_ids: List[str], env_var: str): + """Set the environment variable (e.g., CUDA_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES) + passed based on accelerator runtime and will raise an error if the function uses + different environment variable. + + Args: + visible_ids (List[str]): List of strings representing GPU IDs or NeuronCore IDs. + env_var: Environment variable to set based on accelerator runtime. + + """ + if env_var not in ( + ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR, + ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR, + ): + raise ValueError(f"Invalid environment variable {env_var} to set visible IDs.") + os.environ[env_var] = ",".join([str(i) for i in visible_ids]) + + +def get_neuron_core_constraint_name(): + """Get the name of the constraint that represents the AWS Neuron core accelerator. + + Returns: + (str) The constraint name. + """ + import ray.util.accelerators.accelerators as accelerators + + return get_constraint_name(accelerators.AWS_NEURON_CORE) + + +def get_constraint_name(pretty_name: str): + """Get the name of the constraint that represents the given resource. + + Args: + pretty_name: The name of the resource. + + Returns: + (str) The constraint name. + """ + constraint_name = f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}" f"{pretty_name}" + return constraint_name + + def resources_from_ray_options(options_dict: Dict[str, Any]) -> Dict[str, Any]: """Determine a task's resource requirements. diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 32cf18a575db..d8389d3ae36e 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -426,8 +426,10 @@ def __init__(self): self.mode = None self.actors = {} # When the worker is constructed. Record the original value of the - # CUDA_VISIBLE_DEVICES environment variable. - self.original_gpu_ids = ray._private.utils.get_cuda_visible_devices() + # (CUDA_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES, ..) environment variables. + self.original_gpu_and_accelerator_runtime_ids = ( + ray._private.utils.get_gpu_and_accelerator_runtime_ids() + ) # A dictionary that maps from driver id to SerializationContext # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} @@ -832,6 +834,54 @@ def print_logs(self): # Close the pubsub client to avoid leaking file descriptors. subscriber.close() + def get_resource_ids_for_resource( + self, resource_name: str, resource_regex: str + ) -> Union[List[str], List[int]]: + """Get the resource IDs that are assigned to the given resource. + + Args: + resource_name: The name of the resource. + resource_regex: The regex of the resource. + + Returns: + (List[str]) The IDs that are assigned to the given resource pre-configured. + (List[int]) The IDs that are assigned to the given resource. + + """ + resource_ids = self.core_worker.resource_ids() + assigned_ids = set() + # Handle both normal and placement group GPU, accelerator resources. + # Note: We should only get the GPU, accelerator ids from the placement + # group resource that does not contain the bundle index! + import re + + for resource, assignment in resource_ids.items(): + if resource == resource_name or re.match(resource_regex, resource): + for resource_id, _ in assignment: + assigned_ids.add(resource_id) + + # If the user had already set the environment variables + # (CUDA_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES, ..) then respect that + # in the sense that only IDs that appear in (CUDA_VISIBLE_DEVICES, + # NEURON_RT_VISIBLE_CORES, ..) should be returned. + if ( + self.original_gpu_and_accelerator_runtime_ids.get(resource_name, None) + is not None + ): + runtime_ids = self.original_gpu_and_accelerator_runtime_ids[resource_name] + assigned_ids = [str(runtime_ids[i]) for i in assigned_ids] + # Give all accelerator ids local_mode. + if self.mode == LOCAL_MODE: + if resource_name == ray_constants.GPU: + max_runtime_ids = self.node.get_resource_spec().num_gpus + else: + max_runtime_ids = self.node.get_resource_spec().resources.get( + resource_name, None + ) + if max_runtime_ids: + assigned_ids = runtime_ids[:max_runtime_ids] + return list(assigned_ids) + @PublicAPI @client_mode_hook @@ -848,42 +898,9 @@ def get_gpu_ids(): """ worker = global_worker worker.check_connected() - - if worker.mode != WORKER_MODE: - if log_once("worker_get_gpu_ids_empty_from_driver"): - logger.warning( - "`ray.get_gpu_ids()` will always return the empty list when " - "called from the driver. This is because Ray does not manage " - "GPU allocations to the driver process." - ) - - # TODO(ilr) Handle inserting resources in local mode - all_resource_ids = global_worker.core_worker.resource_ids() - assigned_ids = set() - for resource, assignment in all_resource_ids.items(): - # Handle both normal and placement group GPU resources. - # Note: We should only get the GPU ids from the placement - # group resource that does not contain the bundle index! - import re - - if resource == "GPU" or re.match(r"^GPU_group_[0-9A-Za-z]+$", resource): - for resource_id, _ in assignment: - assigned_ids.add(resource_id) - - assigned_ids = list(assigned_ids) - # If the user had already set CUDA_VISIBLE_DEVICES, then respect that (in - # the sense that only GPU IDs that appear in CUDA_VISIBLE_DEVICES should be - # returned). - if global_worker.original_gpu_ids is not None: - assigned_ids = [ - global_worker.original_gpu_ids[gpu_id] for gpu_id in assigned_ids - ] - # Give all GPUs in local_mode. - if global_worker.mode == LOCAL_MODE: - max_gpus = global_worker.node.get_resource_spec().num_gpus - assigned_ids = global_worker.original_gpu_ids[:max_gpus] - - return assigned_ids + return worker.get_resource_ids_for_resource( + ray_constants.GPU, f"^{ray_constants.GPU}_group_[0-9A-Za-z]+$" + ) @Deprecated( diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a96ef61e2c12..43ada4130442 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1819,8 +1819,9 @@ cdef execute_task_with_cancellation_handler( task_name = name.decode("utf-8") title = f"ray::{task_name}" - # Automatically restrict the GPUs available to this task. - ray._private.utils.set_cuda_visible_devices(ray.get_gpu_ids()) + # Automatically restrict the GPUs (CUDA), neuron_core accelerator + # runtime_ids to restrict availability to this task. + ray._private.utils.set_gpu_and_accelerator_runtime_ids() # Automatically configure OMP_NUM_THREADS to the assigned CPU number. # It will be unset after the task execution if it was overwridden here. diff --git a/python/ray/autoscaler/_private/aws/node_provider.py b/python/ray/autoscaler/_private/aws/node_provider.py index 631018c12512..e633e3e592ef 100644 --- a/python/ray/autoscaler/_private/aws/node_provider.py +++ b/python/ray/autoscaler/_private/aws/node_provider.py @@ -10,6 +10,7 @@ from boto3.resources.base import ServiceResource import ray._private.ray_constants as ray_constants +from ray._private.utils import get_neuron_core_constraint_name from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import ( CLOUDWATCH_AGENT_INSTALLED_AMI_TAG, CLOUDWATCH_AGENT_INSTALLED_TAG, @@ -648,6 +649,26 @@ def fillout_available_node_types_resources( autodetected_resources.update( {"GPU": gpus[0]["Count"], f"accelerator_type:{gpu_name}": 1} ) + # TODO: AWS SDK (public API) doesn't yet expose the NeuronCore + # information. It will be available (work-in-progress) + # as xxAcceleratorInfo in InstanceTypeInfo. + # https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_InstanceTypeInfo.html + # See https://github.com/ray-project/ray/issues/38473 + if ( + instance_type.lower() + in ray_constants.AWS_NEURON_INSTANCE_MAP.keys() + and gpus is None + ): + neuron_cores = ray_constants.AWS_NEURON_INSTANCE_MAP.get( + instance_type.lower() + ) + autodetected_resources.update( + { + ray_constants.NEURON_CORES: neuron_cores, + get_neuron_core_constraint_name(): neuron_cores, + } + ) + autodetected_resources.update( available_node_types[node_type].get("resources", {}) ) diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index 1ec4a3851196..ae6a1a14fc50 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -1,7 +1,8 @@ import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import ray._private.worker +from ray._private import ray_constants from ray._private.client_mode_hook import client_mode_hook from ray._private.utils import pasre_pg_formatted_resources_to_original from ray.runtime_env import RuntimeEnv @@ -381,6 +382,27 @@ def _get_actor_call_stats(self): worker.check_connected() return worker.core_worker.get_actor_call_stats() + def get_resource_ids(self) -> Dict[str, List[str]]: + """ + Get the current worker's GPU and accelerator ids. + + Returns: + A dictionary keyed by the resource name. The values are list + of ids `{'GPU': ['0', '1'], 'neuron_cores': ['0', '1']}`. + """ + worker = self.worker + worker.check_connected() + ids_dict: Dict[str, List[str]] = {} + for name in [ray_constants.GPU, ray_constants.NEURON_CORES]: + resource_ids = worker.get_resource_ids_for_resource( + name, f"^{name}_group_[0-9A-Za-z]+$" + ) + # Convert resource_ids to strings as they can be user-configured + # or system-generated. + resource_ids = [str(i) for i in resource_ids] + ids_dict[name] = resource_ids + return ids_dict + _runtime_context = None diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index d67186a928d8..b08b720cd3d7 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -123,6 +123,7 @@ py_test_module_list( files = [ "test_actor_bounded_threads.py", "test_autoscaler_fake_scaledown.py", + "test_accelerator.py", "test_log_dedup.py", "test_logging.py", "test_memory_scheduling.py", diff --git a/python/ray/tests/test_accelerator.py b/python/ray/tests/test_accelerator.py new file mode 100644 index 000000000000..33a5f4f4d2ce --- /dev/null +++ b/python/ray/tests/test_accelerator.py @@ -0,0 +1,99 @@ +import mock +import pytest + +import ray._private.accelerator as accelerator +import ray._private.utils as utils +import ray._private.ray_constants as ray_constants + + +def test_configured_aws_neuron_core(): + resources = {"CPU": 1, "neuron_cores": 4} + accelerator.update_resources_with_accelerator_type(resources) + assert resources.get(utils.get_neuron_core_constraint_name()) == 4 + assert resources.get(ray_constants.NEURON_CORES) == 4 + + +@mock.patch( + "ray._private.utils.get_aws_neuron_core_visible_ids", return_value=[0, 1, 2] +) +def test_aws_neuron_core_with_more_user_configured(mock_get_nc_ids): + resources = {"CPU": 1, "neuron_cores": 4} + with pytest.raises(ValueError): + accelerator.update_resources_with_accelerator_type(resources) + assert mock_get_nc_ids.called + + +@mock.patch("ray._private.accelerator._autodetect_aws_neuron_cores", return_value=2) +def test_auto_detect_aws_neuron_core(mock_autodetect_aws_neuron_cores): + resources = {"CPU": 1} + accelerator.update_resources_with_accelerator_type(resources) + assert mock_autodetect_aws_neuron_cores.called + assert resources.get(utils.get_neuron_core_constraint_name()) == 2 + assert resources.get(ray_constants.NEURON_CORES) == 2 + + +@mock.patch( + "ray._private.utils.get_aws_neuron_core_visible_ids", return_value=[0, 1, 2] +) +@mock.patch("ray._private.accelerator._autodetect_aws_neuron_cores", return_value=4) +def test_auto_detect_nc_with_more_user_configured( + mock_get_nc_ids, mock_autodetect_aws_neuron_cores +): + resources = {"CPU": 1} + accelerator.update_resources_with_accelerator_type(resources) + assert mock_get_nc_ids.called + assert mock_autodetect_aws_neuron_cores.called + assert resources.get(utils.get_neuron_core_constraint_name()) == 3 + assert resources.get(ray_constants.NEURON_CORES) == 3 + + +@mock.patch("subprocess.run") +def test_get_neuron_core_count_single_device(mock_subprocess): + mock_subprocess.return_value.returncode = 0 + mock_subprocess.return_value.stdout = ( + b'[{"neuron_device":0,"bdf":"00:1e.0",' + b'"connected_to":null,"nc_count":2,' + b'"memory_size":34359738368,"neuron_processes":[]}]' + ) + assert accelerator._get_neuron_core_count() == 2 + assert mock_subprocess.called + + +@mock.patch("subprocess.run") +def test_get_neuron_core_count_multiple_devices(mock_subprocess): + mock_subprocess.return_value.returncode = 0 + mock_subprocess.return_value.stdout = ( + b'[{"neuron_device":0,"bdf":"00:1e.0",' + b'"connected_to":null,"nc_count":2,' + b'"memory_size":34359738368,"neuron_processes":[]},' + b'{"neuron_device":1,"bdf":"00:1f.0","connected_to":null,' + b'"nc_count":2,"memory_size":34359738368,"neuron_processes":[]}]' + ) + assert accelerator._get_neuron_core_count() == 4 + assert mock_subprocess.called + + +@mock.patch("subprocess.run") +def test_get_neuron_core_count_failure_with_error(mock_subprocess): + mock_subprocess.return_value.returncode = 1 + mock_subprocess.return_value.stderr = b"AccessDenied" + assert accelerator._get_neuron_core_count() == 0 + assert mock_subprocess.called + + +@mock.patch("subprocess.run") +def test_get_neuron_core_count_failure_with_empty_results(mock_subprocess): + mock_subprocess.return_value.returncode = 0 + mock_subprocess.return_value.stdout = b"[{}]" + assert accelerator._get_neuron_core_count() == 0 + assert mock_subprocess.called + + +if __name__ == "__main__": + import sys + import os + + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_advanced_2.py b/python/ray/tests/test_advanced_2.py index 8bb2838c61e6..9be2733a2649 100644 --- a/python/ray/tests/test_advanced_2.py +++ b/python/ray/tests/test_advanced_2.py @@ -10,6 +10,9 @@ import ray import ray.cluster_utils from ray._private.test_utils import RayTestTimeoutException, wait_for_condition +from ray.util.placement_group import placement_group +from ray.util.accelerators import AWS_NEURON_CORE +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy logger = logging.getLogger(__name__) @@ -21,6 +24,12 @@ def test_gpu_ids(shutdown_only): def get_gpu_ids(num_gpus_per_worker): gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == num_gpus_per_worker + neuron_core_ids = ray.get_runtime_context().get_resource_ids()["neuron_cores"] + gpu_ids_from_runtime_context = ray.get_runtime_context().get_resource_ids()[ + "GPU" + ] + assert len(gpu_ids) == len(gpu_ids_from_runtime_context) + assert len(neuron_core_ids) == 0 assert os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( [str(i) for i in gpu_ids] # noqa ) @@ -477,6 +486,229 @@ def f(): ray.get(results) +def test_neuron_core_ids(shutdown_only): + num_nc = 3 + accelerator_type = AWS_NEURON_CORE + ray.init(num_cpus=num_nc, resources={"neuron_cores": num_nc}) + + def get_neuron_core_ids(neuron_cores_per_worker): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()["neuron_cores"] + gpu_ids = ray.get_gpu_ids() + assert len(neuron_core_ids) == neuron_cores_per_worker + assert len(gpu_ids) == 0 + cores = os.environ.get("NEURON_RT_VISIBLE_CORES") + if cores is not None: + assert cores == ",".join([str(i) for i in neuron_core_ids]) # noqa + for neuron_core_id in neuron_core_ids: + assert neuron_core_id in [str(i) for i in range(num_nc)] + return neuron_core_ids + + f0 = ray.remote(resources={"neuron_cores": 0})(lambda: get_neuron_core_ids(0)) + f1 = ray.remote(resources={"neuron_cores": 1})(lambda: get_neuron_core_ids(1)) + f2 = ray.remote(resources={"neuron_cores": 2})(lambda: get_neuron_core_ids(2)) + + # Wait for all workers to start up. + @ray.remote + def g(): + time.sleep(0.2) + return os.getpid() + + start_time = time.time() + while True: + num_workers_started = len(set(ray.get([g.remote() for _ in range(num_nc)]))) + if num_workers_started == num_nc: + break + if time.time() > start_time + 10: + raise RayTestTimeoutException( + "Timed out while waiting for workers to start up." + ) + + list_of_ids = ray.get([f0.remote() for _ in range(10)]) + assert list_of_ids == 10 * [[]] + ray.get([f1.remote() for _ in range(10)]) + ray.get([f2.remote() for _ in range(10)]) + + # Test that actors have NEURON_RT_VISIBLE_CORES set properly. + + @ray.remote + class Actor0: + def __init__(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == 0 + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] # noqa + ) + # Set self.x to make sure that we got here. + self.x = 0 + + def test(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == 0 + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] # noqa + ) + return self.x + + @ray.remote(resources={"neuron_cores": 1}) + class Actor1: + def __init__(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == 1 + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] # noqa + ) + # Set self.x to make sure that we got here. + self.x = 1 + + def test(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == 1 + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] + ) + return self.x + + @ray.remote(resources={"neuron_cores": 2}, accelerator_type=accelerator_type) + class Actor2: + def __init__(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == 2 + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] + ) + # Set self.x to make sure that we got here. + self.x = 2 + + def test(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == 2 + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] + ) + return self.x + + a0 = Actor0.remote() + assert ray.get(a0.test.remote()) == 0 + + a1 = Actor1.remote() + assert ray.get(a1.test.remote()) == 1 + + a2 = Actor2.remote() + assert ray.get(a2.test.remote()) == 2 + + +def test_neuron_core_with_placement_group(shutdown_only): + neuron_cores = 2 + ray.init(num_cpus=1, resources={"neuron_cores": neuron_cores}) + + @ray.remote(resources={"neuron_cores": neuron_cores}) + class NeuronCoreActor: + def __init__(self): + pass + + def ready(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + assert len(neuron_core_ids) == neuron_cores + assert os.environ["NEURON_RT_VISIBLE_CORES"] == ",".join( + [str(i) for i in neuron_core_ids] # noqa + ) + + # Reserve a placement group of 1 bundle that reserves 1 CPU and 2 NeuronCore. + pg = placement_group([{"CPU": 1, "neuron_cores": neuron_cores}]) + + # Wait until placement group is created. + ray.get(pg.ready(), timeout=10) + + actor = NeuronCoreActor.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + ) + ).remote() + + ray.get(actor.ready.remote(), timeout=10) + + +def test_gpu_and_neuron_cores(shutdown_only): + num_gpus = 2 + num_nc = 2 + nc_accelerator_type = AWS_NEURON_CORE + ray.init(num_cpus=2, num_gpus=num_gpus, resources={"neuron_cores": num_nc}) + + def get_gpu_ids(num_gpus_per_worker): + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == num_gpus_per_worker + assert os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( + [str(i) for i in gpu_ids] # noqa + ) + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + gpu_ids_from_runtime_context = ray.get_runtime_context().get_resource_ids()[ + "GPU" + ] + for gpu_id in gpu_ids_from_runtime_context: + assert gpu_id in [str(i) for i in range(num_gpus)] + return len(gpu_ids) + + def get_neuron_core_ids(neuron_cores_per_worker): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()["neuron_cores"] + assert len(neuron_core_ids) == neuron_cores_per_worker + cores = os.environ.get("NEURON_RT_VISIBLE_CORES") + if cores is not None: + assert cores == ",".join([str(i) for i in neuron_core_ids]) # noqa + for neuron_core_id in neuron_core_ids: + assert neuron_core_id in [str(i) for i in range(num_nc)] + return len(neuron_core_ids) + + gpu_f = ray.remote(num_gpus=2)(lambda: get_gpu_ids(2)) + assert ray.get(gpu_f.remote()) == 2 + nc_f = ray.remote(resources={"neuron_cores": 2})(lambda: get_neuron_core_ids(2)) + assert ray.get(nc_f.remote()) == 2 + + with pytest.raises(ValueError): + ray.remote(resources={"neuron_cores": 2}, num_gpus=1)( + lambda: get_neuron_core_ids(2) + ) + + with pytest.raises(ValueError): + ray.remote(accelerator_type=nc_accelerator_type, num_gpus=1)( + lambda: get_neuron_core_ids(2) + ) + + with pytest.raises(ValueError): + + @ray.remote(resources={"neuron_cores": 2}, num_gpus=2) + class IncorrectNeuronCoreActorWithGPU: + def test(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + return len(neuron_core_ids) + + with pytest.raises(ValueError): + + @ray.remote(accelerator_type=nc_accelerator_type, num_gpus=2) + class IncorrectNeuronCoreAcceleratorWithGPU: + def test(self): + neuron_core_ids = ray.get_runtime_context().get_resource_ids()[ + "neuron_cores" + ] + return len(neuron_core_ids) + + # TODO: 5 retry attempts may be too little for Travis and we may need to # increase it if this test begins to be flaky on Travis. def test_zero_capacity_deletion_semantics(shutdown_only): diff --git a/python/ray/tests/test_advanced_8.py b/python/ray/tests/test_advanced_8.py index aaabf0367b97..526f082902ee 100644 --- a/python/ray/tests/test_advanced_8.py +++ b/python/ray/tests/test_advanced_8.py @@ -242,14 +242,17 @@ def test_gpu_info_parsing(): assert resource_spec._constraints_from_gpu_info(None) == {} -def test_accelerator_type_api(shutdown_only): - v100 = ray.util.accelerators.NVIDIA_TESLA_V100 - resource_name = f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}{v100}" +@pytest.mark.parametrize( + "accelerator_type", + [ray.util.accelerators.NVIDIA_TESLA_V100, ray.util.accelerators.AWS_NEURON_CORE], +) +def test_accelerator_type_api(accelerator_type, shutdown_only): + resource_name = f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}{accelerator_type}" ray.init(num_cpus=4, resources={resource_name: 1}) quantity = 1 - @ray.remote(accelerator_type=v100) + @ray.remote(accelerator_type=accelerator_type) def decorated_func(quantity): wait_for_condition(lambda: ray.available_resources()[resource_name] < quantity) return True @@ -261,10 +264,12 @@ def via_options_func(quantity): return True assert ray.get( - ray.remote(via_options_func).options(accelerator_type=v100).remote(quantity) + ray.remote(via_options_func) + .options(accelerator_type=accelerator_type) + .remote(quantity) ) - @ray.remote(accelerator_type=v100) + @ray.remote(accelerator_type=accelerator_type) class DecoratedActor: def __init__(self): pass @@ -286,7 +291,9 @@ def initialized(self): wait_for_condition(lambda: ray.available_resources()[resource_name] < quantity) quantity = ray.available_resources()[resource_name] - with_options = ray.remote(ActorWithOptions).options(accelerator_type=v100).remote() + with_options = ( + ray.remote(ActorWithOptions).options(accelerator_type=accelerator_type).remote() + ) ray.get(with_options.initialized.remote()) wait_for_condition(lambda: ray.available_resources()[resource_name] < quantity) diff --git a/python/ray/tests/test_autoscaler_yaml.py b/python/ray/tests/test_autoscaler_yaml.py index 089acaf2ef95..e6ea2ca04d84 100644 --- a/python/ray/tests/test_autoscaler_yaml.py +++ b/python/ray/tests/test_autoscaler_yaml.py @@ -148,6 +148,13 @@ def testValidateDefaultConfigAWSMultiNodeTypes(self): "cpu_4_ondemand": new_config["available_node_types"]["cpu_4_ondemand"], "cpu_16_spot": new_config["available_node_types"]["cpu_16_spot"], "gpu_8_ondemand": new_config["available_node_types"]["gpu_8_ondemand"], + "neuron_core_inf_1_ondemand": { + "node_config": { + "InstanceType": "inf2.xlarge", + "ImageId": "latest_dlami", + }, + "max_workers": 2, + }, } orig_new_config = copy.deepcopy(new_config) expected_available_node_types = orig_new_config["available_node_types"] @@ -164,8 +171,15 @@ def testValidateDefaultConfigAWSMultiNodeTypes(self): "GPU": 4, "accelerator_type:V100": 1, } + expected_available_node_types["neuron_core_inf_1_ondemand"]["resources"] = { + "CPU": 4, + "memory": 12025908428, + "neuron_cores": 2, + "accelerator_type:aws-neuron-core": 2, + } expected_available_node_types["cpu_16_spot"]["min_workers"] = 0 expected_available_node_types["gpu_8_ondemand"]["min_workers"] = 0 + expected_available_node_types["neuron_core_inf_1_ondemand"]["min_workers"] = 0 boto3_dict = { "InstanceTypes": [ @@ -185,6 +199,14 @@ def testValidateDefaultConfigAWSMultiNodeTypes(self): "MemoryInfo": {"SizeInMiB": 249856}, "GpuInfo": {"Gpus": [{"Name": "V100", "Count": 4}]}, }, + { + "InstanceType": "inf2.xlarge", + "VCpuInfo": {"DefaultVCpus": 4}, + "MemoryInfo": {"SizeInMiB": 16384}, + "AcceleratorInfo": { + "Accelerators": [{"Name": "Inferentia", "Count": 1}] + }, + }, ] } describe_instance_types_mock = Mock() diff --git a/python/ray/util/accelerators/__init__.py b/python/ray/util/accelerators/__init__.py index 6d125fd6503a..c30f5936e0df 100644 --- a/python/ray/util/accelerators/__init__.py +++ b/python/ray/util/accelerators/__init__.py @@ -6,6 +6,7 @@ NVIDIA_TESLA_K80, NVIDIA_TESLA_A100, NVIDIA_TESLA_A10G, + AWS_NEURON_CORE, ) __all__ = [ @@ -16,4 +17,5 @@ "NVIDIA_TESLA_K80", "NVIDIA_TESLA_A100", "NVIDIA_TESLA_A10G", + "AWS_NEURON_CORE", ] diff --git a/python/ray/util/accelerators/accelerators.py b/python/ray/util/accelerators/accelerators.py index a8ffd1dc97f4..291b2222b157 100644 --- a/python/ray/util/accelerators/accelerators.py +++ b/python/ray/util/accelerators/accelerators.py @@ -5,3 +5,4 @@ NVIDIA_TESLA_K80 = "K80" NVIDIA_TESLA_A100 = "A100" NVIDIA_TESLA_A10G = "A10G" +AWS_NEURON_CORE = "aws-neuron-core" diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 56a8f11ab453..20bd6347d31d 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -688,9 +688,9 @@ RAY_CONFIG(uint32_t, RAY_CONFIG(std::string, predefined_unit_instance_resources, "GPU") /// The scheduler will treat these custom resource types as unit_instance. -/// Default custom_unit_instance_resources is empty. -/// When set it to "FPGA", we will treat FPGA as unit_instance. -RAY_CONFIG(std::string, custom_unit_instance_resources, "") +/// Default custom_unit_instance_resources is "neuron_cores". +/// When set it to "neuron_cores,FPGA", we will also treat FPGA as unit_instance. +RAY_CONFIG(std::string, custom_unit_instance_resources, "neuron_cores") // Maximum size of the batches when broadcasting resources to raylet. RAY_CONFIG(uint64_t, resource_broadcast_batch_size, 512) diff --git a/src/ray/common/test/scheduling_ids_test.cc b/src/ray/common/test/scheduling_ids_test.cc index be99abc6d199..d85aa3b08bf7 100644 --- a/src/ray/common/test/scheduling_ids_test.cc +++ b/src/ray/common/test/scheduling_ids_test.cc @@ -55,12 +55,13 @@ TEST_F(SchedulingIDsTest, UnitInstanceResourceTest) { R"( { "predefined_unit_instance_resources": "CPU,GPU", - "custom_unit_instance_resources": "custom1" + "custom_unit_instance_resources": "neuron_cores,custom1" } )"); ASSERT_TRUE(ResourceID::CPU().IsUnitInstanceResource()); ASSERT_TRUE(ResourceID::GPU().IsUnitInstanceResource()); ASSERT_TRUE(ResourceID("custom1").IsUnitInstanceResource()); + ASSERT_TRUE(ResourceID("neuron_cores").IsUnitInstanceResource()); ASSERT_FALSE(ResourceID::Memory().IsUnitInstanceResource()); ASSERT_FALSE(ResourceID("custom2").IsUnitInstanceResource());