diff --git a/python/ray/train/_internal/backend_executor.py b/python/ray/train/_internal/backend_executor.py index 586777c22eaf9..eb5344f417c2d 100644 --- a/python/ray/train/_internal/backend_executor.py +++ b/python/ray/train/_internal/backend_executor.py @@ -1,9 +1,11 @@ import logging import os from collections import defaultdict +from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Any import ray +import ray._private.ray_constants as ray_constants from ray.data import Dataset from ray._private.ray_constants import env_integer from ray.air.config import CheckpointConfig @@ -27,6 +29,7 @@ TRAIN_ENABLE_WORKER_SPREAD_ENV, TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, DISABLE_LAZY_CHECKPOINTING_ENV, + ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, ) from ray.util.placement_group import get_current_placement_group, remove_placement_group @@ -43,6 +46,25 @@ class TrainingWorkerError(Exception): """Raised if a worker fails during training.""" +@dataclass +class ResourceConfig: + """ + Resource configuration for resource_ids to share between workers. + + Args: + resource_name: The name of the resource to configure + (Example: "neuron_cores" or "gpu"). + resource_enable_sharing_env_var: The environment variable to + check if the resource should be shared. + share_resource_ids_env_var: The environment variable to configure for + sharing the resources with other workers. + """ + + resource_name: str + resource_enable_sharing_env_var: str + share_resource_ids_env_var: str + + class BackendExecutor: """Main execution class for training backends. @@ -101,6 +123,13 @@ def __init__( self._checkpoint_upload_from_workers = ( checkpoint_config and checkpoint_config._checkpoint_upload_from_workers ) + self._resource_configs = [ + ResourceConfig( + ray_constants.NEURON_CORES, + ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, + ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR, + ) + ] def start( self, @@ -153,6 +182,16 @@ def start( if self._num_gpus_per_worker > 0 and share_cuda_visible_devices_enabled: self._share_cuda_visible_devices() + elif self._additional_resources_per_worker: + for resource_config in self._resource_configs: + if self._is_share_resources_enabled( + resource_config.resource_name, + resource_config.resource_enable_sharing_env_var, + ): + self._share_resource_ids( + resource_config.resource_name, + resource_config.share_resource_ids_env_var, + ) self._backend.on_start(self.worker_group, self._backend_config) except RayActorError as exc: logger.exception(str(exc)) @@ -245,32 +284,82 @@ def _share_cuda_visible_devices(self): - Worker2: "0,1" """ + self._share_resource_ids( + ray_constants.GPU, ray_constants.CUDA_VISIBLE_DEVICES_ENV_VAR + ) - node_ids_and_gpu_ids = [ - (w.metadata.node_id, w.metadata.gpu_ids) for w in self.worker_group.workers - ] + def _share_resource_ids(self, resource: str, env_var: str): + """Sets the given env_var on all workers. + + For each worker, the cores/devices are visible to all the + workers on that worker's node.This allows workers on the + same node to communicate with one another. + Example: + + Setup: + - Node1: + - Worker1: {0, 1} + - Worker2: {2, 3} + - Node2: + - Worker3: {0, 1} + + NEURON_RT_VISIBLE_CORES/TPU_VISIBLE_CHIPS/...: + - Worker1: "0,1,2,3" + - Worker2: "0,1,2,3" + - Worker2: "0,1" + + Args: + resource: The name of the resource/accelerator. + env_var: The name of the environment variable to set. + """ + node_ids_and_resource_ids = [ + ( + w.metadata.node_id, + w.metadata.resource_ids[resource], + ) + for w in self.worker_group.workers + ] node_id_to_worker_id = defaultdict(set) - node_id_to_gpu_ids = defaultdict(set) + node_id_to_resource_ids = defaultdict(set) - for worker_id, (node_id, gpu_ids) in enumerate(node_ids_and_gpu_ids): + for worker_id, (node_id, resource_ids) in enumerate(node_ids_and_resource_ids): node_id_to_worker_id[node_id].add(worker_id) - node_id_to_gpu_ids[node_id].update(gpu_ids) + node_id_to_resource_ids[node_id].update(resource_ids) futures = [] - for node_id, gpu_ids in node_id_to_gpu_ids.items(): - gpu_ids = sorted(gpu_ids) - all_gpu_ids = ",".join(gpu_ids) + for node_id, resource_ids in node_id_to_resource_ids.items(): + resource_ids = sorted(resource_ids) + all_resource_ids = ",".join(resource_ids) - def set_gpu_ids(): - os.environ["CUDA_VISIBLE_DEVICES"] = all_gpu_ids + def set_resource_ids(): + os.environ[env_var] = all_resource_ids for worker_id in node_id_to_worker_id[node_id]: futures.append( - self.worker_group.execute_single_async(worker_id, set_gpu_ids) + self.worker_group.execute_single_async(worker_id, set_resource_ids) ) ray.get(futures) + def _is_share_resources_enabled(self, resource_name: str, enable_sharing_env: str): + """Whether to share resource IDs on all workers + based on enable_sharing_env. + + This will return true if resources are requested and greater than 0. + Also, user can disable by configuring the `enable_sharing_env` to "0". + + Args: + resource_name: The name of the resource/accelerator. + enable_sharing_env: The name of the environment variable + to check. + """ + has_resource_requested = ( + self._additional_resources_per_worker.get(resource_name, 0) > 0 + ) + return has_resource_requested and ray_constants.env_bool( + enable_sharing_env, True + ) + def _create_rank_world_size_mappings(self) -> List[Dict]: """Create rank and world size mappings for workers. There are three maps returned: diff --git a/python/ray/train/_internal/worker_group.py b/python/ray/train/_internal/worker_group.py index 59ba775b5630d..229f5a474daaa 100644 --- a/python/ray/train/_internal/worker_group.py +++ b/python/ray/train/_internal/worker_group.py @@ -44,14 +44,15 @@ class WorkerMetadata: node_id: ID of the node this worker is on. node_ip: IP address of the node this worker is on. hostname: Hostname that this worker is on. - gpu_ids: List of CUDA IDs available to this worker. + resource_ids: Map of accelerator resources + ("GPU", "neuron_cores", ..) to their IDs. pid: Process ID of this worker. """ node_id: str node_ip: str hostname: str - gpu_ids: Optional[List[str]] + resource_ids: Dict[str, List[str]] pid: int @@ -86,14 +87,14 @@ def construct_metadata() -> WorkerMetadata: node_id = ray.get_runtime_context().get_node_id() node_ip = ray.util.get_node_ip_address() hostname = socket.gethostname() - gpu_ids = [str(gpu_id) for gpu_id in ray.get_gpu_ids()] + resource_ids = ray.get_runtime_context().get_resource_ids() pid = os.getpid() return WorkerMetadata( node_id=node_id, node_ip=node_ip, hostname=hostname, - gpu_ids=gpu_ids, + resource_ids=resource_ids, pid=pid, ) diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 11bf5964a24c6..f9eedec1cbd83 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -67,6 +67,12 @@ def _get_defaults_results_dir() -> str: # Backend.share_cuda_visible_devices. 1 for True, 0 for False. ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_CUDA_VISIBLE_DEVICES" +# Integer value which if set will not share neuron-core accelerator visible cores +# across workers. 1 for True (default), 0 for False. +ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV = ( + "TRAIN_ENABLE_SHARE_NEURON_CORES_ACCELERATOR" +) + # Integer value which indicates the number of seconds to wait when creating # the worker placement group before timing out. TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV = "TRAIN_PLACEMENT_GROUP_TIMEOUT_S" @@ -85,6 +91,7 @@ def _get_defaults_results_dir() -> str: TRAIN_ENV_VARS = { ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, + ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, TRAIN_ENABLE_WORKER_SPREAD_ENV, RAY_AIR_NEW_PERSISTENCE_MODE, diff --git a/python/ray/train/tests/conftest.py b/python/ray/train/tests/conftest.py index 5c0338ea3a510..302306f4fc6d9 100644 --- a/python/ray/train/tests/conftest.py +++ b/python/ray/train/tests/conftest.py @@ -55,6 +55,20 @@ def ray_2_node_2_gpu(): cluster.shutdown() +@pytest.fixture +def ray_2_node_2_neuron_cores(): + cluster = Cluster() + for _ in range(2): + cluster.add_node(num_cpus=4, resources={"neuron_cores": 2}) + + ray.init(address=cluster.address) + + yield + + ray.shutdown() + cluster.shutdown() + + @pytest.fixture def ray_start_2_cpus(): address_info = ray.init(num_cpus=2) diff --git a/python/ray/train/tests/test_backend.py b/python/ray/train/tests/test_backend.py index c3a44cddedb30..eafbccab34698 100644 --- a/python/ray/train/tests/test_backend.py +++ b/python/ray/train/tests/test_backend.py @@ -7,6 +7,7 @@ import time import ray +import ray._private.ray_constants as ray_constants from ray import train from ray.air._internal.util import StartTraceback @@ -25,6 +26,7 @@ from ray.train.constants import ( ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, TRAIN_ENABLE_WORKER_SPREAD_ENV, + ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, ) from ray.train.tensorflow import TensorflowConfig from ray.train.torch import TorchConfig @@ -97,7 +99,7 @@ def mock_add_workers(self, num_workers): node_id=0, node_ip=str(i % 2), hostname=0, - gpu_ids=[0], + resource_ids={"GPU": ["0"]}, pid=0, ) worker.metadata = metadata @@ -307,12 +309,6 @@ def check_process_group(): def test_cuda_visible_devices(ray_2_node_2_gpu, worker_results): config = TestConfig() - if worker_results[0] != len(worker_results[1]): - raise ValueError( - "Invalid test parameter. Length of expected result should " - "match number of workers." - ) - def get_resources(): cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] # Sort the cuda visible devices to have exact match with expected result. @@ -419,6 +415,80 @@ def get_resources(): assert results == expected_results +@pytest.mark.parametrize( + "worker_results", + [ + (1, [[0]]), + (2, [[0, 1]] * 2), + (3, [[0]] + [[0, 1]] * 2), + (4, [[0, 1]] * 4), + ], +) +def test_neuron_core_accelerator_ids(ray_2_node_2_neuron_cores, worker_results): + config = TestConfig() + + def get_resources(): + neuron_resource_ids = os.environ[ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR] + # Sort the runtime ids to have exact match with expected result. + sorted_devices = [ + int(device) for device in sorted(neuron_resource_ids.split(",")) + ] + return sorted_devices + + num_workers, expected_results = worker_results + # sharing enabled by default + os.environ.pop(ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, None) + e = BackendExecutor( + config, + num_workers=num_workers, + num_cpus_per_worker=0, + additional_resources_per_worker={"neuron_cores": 1}, + ) + e.start() + _start_training(e, get_resources) + results = e.finish_training() + results.sort() + assert results == expected_results + + +@pytest.mark.parametrize( + "worker_results", + [ + (1, [[0]]), + (2, [[0]] + [[1]]), + (3, [[0]] * 2 + [[1]]), + (4, [[0]] * 2 + [[1]] * 2), + ], +) +def test_neuron_core_accelerator_ids_sharing_disabled( + ray_2_node_2_neuron_cores, worker_results +): + config = TestConfig() + + def get_resources(): + neuron_resource_ids = os.environ[ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR] + # Sort the runtime ids to have exact match with expected result. + sorted_devices = [ + int(device) for device in sorted(neuron_resource_ids.split(",")) + ] + return sorted_devices + + num_workers, expected_results = worker_results + + os.environ[ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV] = "0" + e = BackendExecutor( + config, + num_workers=num_workers, + num_cpus_per_worker=0, + additional_resources_per_worker={"neuron_cores": 1}, + ) + e.start() + _start_training(e, get_resources) + results = e.finish_training() + results.sort() + assert results == expected_results + + def get_node_id_set(): node_id_set = set() for actor_info in ray._private.state.actors().values(): diff --git a/python/ray/train/tests/test_worker_group.py b/python/ray/train/tests/test_worker_group.py index 37ae87e72f1de..1c5445977953b 100644 --- a/python/ray/train/tests/test_worker_group.py +++ b/python/ray/train/tests/test_worker_group.py @@ -4,6 +4,7 @@ import ray from ray.train._internal.worker_group import WorkerGroup, Worker, WorkerMetadata +import ray._private.ray_constants as ray_constants @pytest.fixture @@ -14,6 +15,22 @@ def ray_start_2_cpus(): ray.shutdown() +@pytest.fixture +def ray_start_2_cpus_and_gpus(): + address_info = ray.init(num_cpus=2, num_gpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_2_cpus_and_neuron_core_accelerator(): + address_info = ray.init(num_cpus=2, resources={ray_constants.NEURON_CORES: 2}) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + def test_worker_creation(ray_start_2_cpus): assert ray.available_resources()["CPU"] == 2 wg = WorkerGroup(num_workers=2) @@ -59,6 +76,42 @@ def test_worker_restart(ray_start_2_cpus): wg.execute(lambda: 1) +def test_worker_with_gpu_ids(ray_start_2_cpus_and_gpus): + num_gpus = 2 + wg = WorkerGroup(num_workers=2, num_gpus_per_worker=1) + assert len(wg.workers) == 2 + time.sleep(1) + assert ray_constants.GPU not in ray.available_resources() + wg.execute(lambda: 1) + assert len(wg.workers) == 2 + for w in wg.workers: + resource_ids = w.metadata.resource_ids + gpu_ids = resource_ids[ray_constants.GPU] + for gpu_id in gpu_ids: + assert gpu_id in [str(i) for i in range(num_gpus)] + assert len(resource_ids[ray_constants.NEURON_CORES]) == 0 + + +def test_worker_with_neuron_core_accelerator_ids( + ray_start_2_cpus_and_neuron_core_accelerator, +): + num_nc = 2 + wg = WorkerGroup( + num_workers=2, additional_resources_per_worker={ray_constants.NEURON_CORES: 1} + ) + assert len(wg.workers) == 2 + time.sleep(1) + assert ray_constants.NEURON_CORES not in ray.available_resources() + wg.execute(lambda: 1) + assert len(wg.workers) == 2 + for w in wg.workers: + resource_ids = w.metadata.resource_ids + assert len(resource_ids[ray_constants.GPU]) == 0 + neuron_core_ids = resource_ids[ray_constants.NEURON_CORES] + for neuron_core_id in neuron_core_ids: + assert neuron_core_id in [str(i) for i in range(num_nc)] + + def test_execute_async(ray_start_2_cpus): wg = WorkerGroup(num_workers=2) futures = wg.execute_async(lambda: 1) @@ -91,7 +144,7 @@ def create_worker_group(ips): node_id="dummy", node_ip=ip, hostname="dummy", - gpu_ids=None, + resource_ids=None, pid=0, ), )