-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[train] Add TorchAwsNeuronXLABackend and XLAConfig (#39130)
This change adds new TorchXLA config and NeuronBackend with XLA setup as a Ray Train backend. --------- Signed-off-by: maheedhar reddy chappidi <[email protected]> Signed-off-by: woshiyyya <[email protected]> Signed-off-by: Yunxuan Xiao <[email protected]> Co-authored-by: woshiyyya <[email protected]> Co-authored-by: Yunxuan Xiao <[email protected]> Co-authored-by: Scott Perry <[email protected]>
- Loading branch information
1 parent
676343a
commit 0115c3b
Showing
4 changed files
with
177 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ray.train.torch.xla.config import TorchXLAConfig | ||
|
||
__all__ = [ | ||
"TorchXLAConfig", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import logging | ||
import os | ||
import re | ||
import shutil | ||
import uuid | ||
from dataclasses import dataclass | ||
|
||
import ray | ||
from ray.train._internal.utils import get_address_and_port | ||
from ray.train._internal.worker_group import WorkerGroup | ||
from ray.train.backend import Backend | ||
from ray.train.torch import TorchConfig | ||
from ray.util import PublicAPI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
@dataclass | ||
class TorchXLAConfig(TorchConfig): | ||
""" | ||
Configuration for torch XLA setup. | ||
See https://pytorch.org/xla/release/1.13/index.html for more info. | ||
Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend) | ||
is supported with xrt runtime. | ||
""" | ||
|
||
neuron_parallel_compile: bool = False | ||
|
||
@property | ||
def backend_cls(self): | ||
return _TorchAwsNeuronXLABackend | ||
|
||
|
||
def _kill_xrt_server(): | ||
import subprocess | ||
|
||
subprocess.call(["pkill", "-f", "xrt_run_server"]) | ||
|
||
|
||
def _set_xla_env_vars(): | ||
# https://pytorch.org/docs/1.13/elastic/run.html#environment-variables | ||
context = ray.train.get_context() | ||
|
||
os.environ["LOCAL_RANK"] = str(context.get_local_rank()) | ||
os.environ["RANK"] = str(context.get_world_rank()) | ||
os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size()) | ||
os.environ["WORLD_SIZE"] = str(context.get_world_size()) | ||
os.environ["GROUP_RANK"] = str(context.get_node_rank()) | ||
os.environ["GROUP_WORLD_SIZE"] = str( | ||
context.get_world_size() / context.get_local_world_size() | ||
) | ||
os.environ["ROLE_RANK"] = str(context.get_world_rank()) | ||
os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank()) | ||
os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size()) | ||
|
||
# EFA and XLA setup | ||
# https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c | ||
# https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa | ||
os.environ["FI_PROVIDER"] = "efa" | ||
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" | ||
os.environ["FI_EFA_FORK_SAFE"] = "1" | ||
os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1" | ||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" | ||
|
||
|
||
def _setup_xla_torch_process_group(): | ||
try: | ||
import torch.distributed as dist | ||
import torch_xla.core.xla_model as xm # noqa F401 | ||
import torch_xla.distributed.xla_backend # noqa F401 | ||
|
||
dist.init_process_group("xla") | ||
except ImportError: | ||
raise ImportError("torch_xla must be installed to use torch_xla backend.") | ||
|
||
|
||
# The following env vars enable Neuron graph extraction for parallel compilation | ||
# Note: model outputs are invalid and should be ignored while these env vars are set | ||
def _set_neuron_parallel_compile_env_vars(): | ||
os.environ["NEURON_PARALLEL_COMPILE"] = "1" | ||
os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1" | ||
os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1" | ||
|
||
|
||
# Compile previously extracted Neuron graphs | ||
def _neuron_compile_extracted_graphs(): | ||
try: | ||
from libneuronxla.neuron_cc_cache import CacheUrl | ||
from libneuronxla.neuron_parallel_compile import parallel_compile | ||
except ImportError: | ||
raise ImportError( | ||
"libneuronxla must be installed to use Neuron parallel compilation." | ||
) | ||
|
||
# Only 1 worker per node should run parallel_compile() | ||
if os.environ.get("LOCAL_RANK") == "0": | ||
logger.info("Compiling extracted graphs on local rank0 worker") | ||
|
||
parallel_compile_workdir = ( | ||
f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/" | ||
) | ||
if os.path.exists(parallel_compile_workdir): | ||
shutil.rmtree(parallel_compile_workdir) | ||
os.makedirs(parallel_compile_workdir, exist_ok=True) | ||
|
||
# Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by | ||
# using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence. | ||
explicit_cache_dir = None | ||
if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"): | ||
if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags): | ||
explicit_cache_dir = s.group(1) | ||
|
||
parallel_compile( | ||
parallel_compile_workdir, | ||
CacheUrl.get_cache_url(explicit_cache_dir), | ||
) | ||
|
||
|
||
class _TorchAwsNeuronXLABackend(Backend): | ||
unique_run_id: str = str(uuid.uuid4()) | ||
|
||
def on_start(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): | ||
"""Logic ran right before training is started.""" | ||
|
||
# On previous worker failure, we don't run graceful shutdown on workers. | ||
# This would leak any running xrt server. | ||
worker_group.execute(_kill_xrt_server) | ||
|
||
# Get master address and port from the first worker. | ||
master_addr, master_port = worker_group.execute_single(0, get_address_and_port) | ||
|
||
def set_env_vars(addr, port): | ||
os.environ["MASTER_ADDR"] = addr | ||
os.environ["MASTER_PORT"] = str(port) | ||
# To trigger the xrt server | ||
os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id | ||
|
||
# Set the env vars on all workers. | ||
worker_group.execute(set_env_vars, addr=master_addr, port=master_port) | ||
|
||
# Set up env vars for neuron parallel compilation graph extraction | ||
if backend_config.neuron_parallel_compile: | ||
logger.info("Extracting graphs for Neuron parallel compilation") | ||
worker_group.execute(_set_neuron_parallel_compile_env_vars) | ||
|
||
def on_training_start( | ||
self, worker_group: WorkerGroup, backend_config: TorchXLAConfig | ||
): | ||
""" | ||
Configure the environment variables for the worker group. | ||
And initialize the xla distributed process group. | ||
TODO: Current setup only supports homogenous cluster with | ||
neuron_cores accelerator and xrt runtime. | ||
""" | ||
worker_group.execute(_set_xla_env_vars) | ||
worker_group.execute(_setup_xla_torch_process_group) | ||
|
||
def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): | ||
""" | ||
Logic ran right after training is finished. | ||
This is a sanity cleanup to kill xrt server, and to optionally | ||
run neuron parallel graph compilation | ||
""" | ||
worker_group.execute(_kill_xrt_server) | ||
|
||
# Compile the extracted graphs. This must run at end of training. | ||
if backend_config.neuron_parallel_compile: | ||
worker_group.execute(_neuron_compile_extracted_graphs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters