Skip to content

Commit

Permalink
Auto-detection of accelerator_type for aws_accelerators trn1_inf (ray…
Browse files Browse the repository at this point in the history
…-project#37998)

This change is to support auto-detection of AWS accelerators and configuring appropriate environment variables to designate the neuron_core per task/actor.

Related
REP
ray-project#33707

Signed-off-by: e428265 <[email protected]>
  • Loading branch information
chappidim authored and arvind-chandra committed Aug 31, 2023
1 parent de409c6 commit 3debc82
Show file tree
Hide file tree
Showing 20 changed files with 839 additions and 84 deletions.
28 changes: 28 additions & 0 deletions doc/source/ray-core/doc_code/neuron_core_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# __neuron_core_accelerator_start__
import ray
import os
from ray.util.accelerators import AWS_NEURON_CORE

# On trn1.2xlarge instance, there will be 2 neuron cores.
ray.init(resources={"neuron_cores": 2})


@ray.remote(resources={"neuron_cores": 1})
class NeuronCoreActor:
def info(self):
ids = ray.get_runtime_context().get_resource_ids()
print("neuron_core_ids: {}".format(ids["neuron_cores"]))
print(f"NEURON_RT_VISIBLE_CORES: {os.environ['NEURON_RT_VISIBLE_CORES']}")


@ray.remote(resources={"neuron_cores": 1}, accelerator_type=AWS_NEURON_CORE)
def use_neuron_core_task():
ids = ray.get_runtime_context().get_resource_ids()
print("neuron_core_ids: {}".format(ids["neuron_cores"]))
print(f"NEURON_RT_VISIBLE_CORES: {os.environ['NEURON_RT_VISIBLE_CORES']}")


neuron_core_actor = NeuronCoreActor.remote()
ray.get(neuron_core_actor.info.remote())
ray.get(use_neuron_core_task.remote())
# __neuron_core_accelerator_end__
24 changes: 23 additions & 1 deletion doc/source/ray-core/tasks/using-ray-with-gpus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,26 @@ This also lets the multi-node-type autoscaler know that there is demand for that
:start-after: __accelerator_type_start__
:end-before: __accelerator_type_end__

See ``ray.util.accelerators`` for available accelerator types. Current automatically detected accelerator types include Nvidia GPUs.
See ``ray.util.accelerators`` for available accelerator types. Current automatically detected accelerator types include:

- Nvidia GPUs
- AWS Neuron Cores

AWS Neuron Core Accelerator (Experimental)
------------------------------------------

Similar to Nvidia GPUs, Ray auto-detects `AWS Neuron Cores`_ by default.
The user can specify `resources={"neuron_cores": some_number}` on
task or actor resource requirements to assign the Neuron Core(s).

.. _`AWS Neuron Cores` : https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/model-architecture-fit.html
.. note::

Ray supports a heterogeneous cluster of GPUs and Neuron Cores but doesn't allow specifying resources requirements of
``num_gpus`` and ``neuron_cores`` together for a task or actor.

.. literalinclude:: ../doc_code/neuron_core_accelerator.py
:language: python
:start-after: __neuron_core_accelerator_start__
:end-before: __neuron_core_accelerator_end__
111 changes: 111 additions & 0 deletions python/ray/_private/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
import os
import subprocess
import sys
from typing import Optional


def update_resources_with_accelerator_type(resources: dict):
"""Update the resources dictionary with the accelerator type and custom
resources.
Currently, we support AWS NeuronCore (neuron_cores /
accelerator_type:aws-neuron-core) detection and configuration.
Args:
resources: Resources dictionary to be updated with
accelerator type and custom resources.
"""
_detect_and_configure_aws_neuron_core(resources)


def _detect_and_configure_aws_neuron_core(resources: dict):
"""Configuration and auto-detection of AWS NeuronCore accelerator type
and number of NeuronCore (neuron_cores).
If the number of NeuronCore is not specified in the resources, this
function will try to detect the number of NeuronCore.
If the number of NeuronCore is specified in the resources, this
function will check if the number of NeuronCore is greater than the
number of visible NeuronCore and raise an error if it is true.
If the number of NeuronCore is greater than the number of visible
NeuronCore, this function will raise an error.
Lastly, update accelerator_type and neuron_cores in resources.
Args:
resources: Resources dictionary to be updated with
NeuronCore accelerator type and custom resources(neuron_cores).
Raises:
ValueError: If the number of NeuronCore is greater than the number of
visible NeuronCore.
"""
import ray._private.ray_constants as ray_constants
import ray._private.utils as utils

# AWS NeuronCore detection and configuration
# 1. Check if the user specified neuron_cores in resources
neuron_cores = resources.get(ray_constants.NEURON_CORES, None)
# 2. Check if the user specified NEURON_RT_VISIBLE_CORES
neuron_core_ids = utils.get_aws_neuron_core_visible_ids()
if (
neuron_cores is not None
and neuron_core_ids is not None
and neuron_cores > len(neuron_core_ids)
):
raise ValueError(
f"Attempting to start raylet with {neuron_cores} "
f"neuron cores, but NEURON_RT_VISIBLE_CORES contains "
f"{neuron_core_ids}."
)
# 3. Auto-detect neuron_cores if not specified in resources
if neuron_cores is None:
neuron_cores = _autodetect_aws_neuron_cores()
# Don't use more neuron cores than allowed by NEURON_RT_VISIBLE_CORES.
if neuron_cores is not None and neuron_core_ids is not None:
neuron_cores = min(neuron_cores, len(neuron_core_ids))
if neuron_cores is not None:
# 4. Update accelerator_type and neuron_cores with
# number of neuron cores detected or configured.
resources.update(
{
ray_constants.NEURON_CORES: neuron_cores,
utils.get_neuron_core_constraint_name(): neuron_cores,
}
)


def _autodetect_aws_neuron_cores() -> Optional[int]:
"""
Attempt to detect the number of Neuron cores on this machine.
Returns:
The number of Neuron cores if any were detected, otherwise None.
"""
result = None
if sys.platform.startswith("linux") and os.path.isdir("/opt/aws/neuron/bin/"):
result = _get_neuron_core_count()
return result


def _get_neuron_core_count() -> int:
"""Get the number of Neuron cores on a machine based on neuron_path.
Returns:
The number of Neuron cores on this machine (Default to 0).
"""
neuron_path = "/opt/aws/neuron/bin/"
nc_count: int = 0
result = subprocess.run(
[os.path.join(neuron_path, "neuron-ls"), "--json-output"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.returncode == 0 and result.stdout:
json_out = json.loads(result.stdout)
for neuron_device in json_out:
nc_count += neuron_device.get("nc_count", 0)
return nc_count
20 changes: 20 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,27 @@ def env_set_by_user(key):

LANGUAGE_WORKER_TYPES = ["python", "java", "cpp"]

# Accelerator constants
NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR = (
"RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES"
)
NOSET_CUDA_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"
CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES"
NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES"
NEURON_CORES = "neuron_cores"
GPU = "GPU"
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html#aws-inf2-arch
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#aws-trn1-arch
# Subject to removal after the information is available via public API
AWS_NEURON_INSTANCE_MAP = {
"trn1.2xlarge": 2,
"trn1.32xlarge": 32,
"trn1n.32xlarge": 32,
"inf2.xlarge": 2,
"inf2.8xlarge": 2,
"inf2.24xlarge": 12,
"inf2.48xlarge": 24,
}
RAY_WORKER_NICENESS = "RAY_worker_niceness"

# Default max_retries option in @ray.remote for non-actor
Expand Down
31 changes: 31 additions & 0 deletions python/ray/_private/ray_option_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ray
from ray._private import ray_constants
from ray._private.utils import get_ray_doc_version
from ray.util.accelerators import accelerators
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
Expand Down Expand Up @@ -107,6 +108,34 @@ def _validate_resources(resources: Optional[Dict[str, float]]) -> Optional[str]:
return None


def _validate_neuron_core_accelerator(options: Dict[str, Any]):
"""Validate options for NeuronCore accelerator/ neuron_cores and GPU,
supports only one or the other (Either NeuronCore or GPU).
Args:
options: The options to be validated.
Raises:
ValueError: If the options are invalid.
"""
num_gpus = options.get("num_gpus", None)
if num_gpus is not None and num_gpus > 0:
resources = options["resources"] if "resources" in options else None
accelerator_type_value: str = options.get("accelerator_type", "")
if resources is not None:
neuron_cores: int = resources.get(ray_constants.NEURON_CORES, 0)
if neuron_cores > 0:
raise ValueError(
"'num_gpus' cannot be used together with "
"neuron_cores/accelerator_type:aws-neuron-core."
)
elif accelerator_type_value == accelerators.AWS_NEURON_CORE:
raise ValueError(
"'num_gpus' cannot be used together with "
"neuron_cores/accelerator_type:aws-neuron-core."
)


_common_options = {
"accelerator_type": Option((str, type(None))),
"memory": _resource_option("memory"),
Expand Down Expand Up @@ -291,6 +320,7 @@ def validate_task_options(options: Dict[str, Any], in_options: bool):
if in_options and "max_calls" in options:
raise ValueError("Setting 'max_calls' is not supported in '.options()'.")
_check_deprecate_placement_group(options)
_validate_neuron_core_accelerator(options)


def validate_actor_options(options: Dict[str, Any], in_options: bool):
Expand Down Expand Up @@ -335,6 +365,7 @@ def validate_actor_options(options: Dict[str, Any], in_options: bool):
)

_check_deprecate_placement_group(options)
_validate_neuron_core_accelerator(options)


def update_options(
Expand Down
25 changes: 12 additions & 13 deletions python/ray/_private/resource_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from collections import namedtuple
from typing import Optional

import ray._private.accelerator as accelerator

import ray
import ray._private.ray_constants as ray_constants


try:
import GPUtil
except ImportError:
Expand Down Expand Up @@ -198,6 +201,8 @@ def resolve(self, is_head: bool, node_ip_address: Optional[str] = None):
except Exception:
logger.exception("Could not parse gpu information.")

accelerator.update_resources_with_accelerator_type(resources)

# Choose a default object store size.
system_memory = ray._private.utils.get_system_memory()
avail_memory = ray._private.utils.estimate_available_memory()
Expand Down Expand Up @@ -305,12 +310,9 @@ def _get_gpu_types_gputil():
if len(gpu_list) > 0:
gpu_list_names = [gpu.name for gpu in gpu_list]
info_str = gpu_list_names.pop()
pretty_name = _pretty_gpu_name(info_str)
pretty_name = _pretty_nvidia_gpu_name(info_str)
if pretty_name:
constraint_name = (
f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}" f"{pretty_name}"
)
return {constraint_name: 1}
return {ray._private.utils.get_constraint_name(pretty_name): 1}
return {}


Expand All @@ -336,11 +338,8 @@ def _constraints_from_gpu_info(info_str: str):
if k.strip() == "Model":
full_model_name = v.strip()
break
pretty_name = _pretty_gpu_name(full_model_name)
if pretty_name:
constraint_name = f"{ray_constants.RESOURCE_CONSTRAINT_PREFIX}" f"{pretty_name}"
return {constraint_name: 1}
return {}
pretty_name = _pretty_nvidia_gpu_name(full_model_name)
return {ray._private.utils.get_constraint_name(pretty_name): 1}


def _get_gpu_info_string():
Expand All @@ -364,11 +363,11 @@ def _get_gpu_info_string():

# TODO(Alex): This pattern may not work for non NVIDIA Tesla GPUs (which have
# the form "Tesla V100-SXM2-16GB" or "Tesla K80").
GPU_NAME_PATTERN = re.compile(r"\w+\s+([A-Z0-9]+)")
NVIDIA_GPU_NAME_PATTERN = re.compile(r"\w+\s+([A-Z0-9]+)")


def _pretty_gpu_name(name):
def _pretty_nvidia_gpu_name(name):
if name is None:
return None
match = GPU_NAME_PATTERN.match(name)
match = NVIDIA_GPU_NAME_PATTERN.match(name)
return match.group(1) if match else None
Loading

0 comments on commit 3debc82

Please sign in to comment.