-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Add TorchAwsNeuronXLABackend and XLAConfig #39130
Changes from all commits
8f9753f
cf591d1
1e45e77
4fcf6ac
d7aa259
5418191
206cbce
1c80fad
2c66d64
211ffcf
e805f68
c888e5d
4e836eb
1ba3b4d
8e9148c
de9b620
553653e
c319a75
717b3a8
fc63097
34620f0
c88cdff
74251f4
beece14
1164d73
8b1a416
6d995dd
39eb872
bb28b81
f514f07
66264f9
b340222
4eeb74d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ray.train.torch.xla.config import TorchXLAConfig | ||
|
||
__all__ = [ | ||
"TorchXLAConfig", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import logging | ||
import os | ||
import re | ||
import shutil | ||
import uuid | ||
from dataclasses import dataclass | ||
|
||
import ray | ||
from ray.train._internal.utils import get_address_and_port | ||
from ray.train._internal.worker_group import WorkerGroup | ||
from ray.train.backend import Backend | ||
from ray.train.torch import TorchConfig | ||
from ray.util import PublicAPI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
@dataclass | ||
class TorchXLAConfig(TorchConfig): | ||
""" | ||
Configuration for torch XLA setup. | ||
See https://pytorch.org/xla/release/1.13/index.html for more info. | ||
Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend) | ||
is supported with xrt runtime. | ||
""" | ||
|
||
neuron_parallel_compile: bool = False | ||
|
||
@property | ||
def backend_cls(self): | ||
return _TorchAwsNeuronXLABackend | ||
|
||
|
||
def _kill_xrt_server(): | ||
import subprocess | ||
|
||
subprocess.call(["pkill", "-f", "xrt_run_server"]) | ||
|
||
|
||
def _set_xla_env_vars(): | ||
# https://pytorch.org/docs/1.13/elastic/run.html#environment-variables | ||
context = ray.train.get_context() | ||
|
||
os.environ["LOCAL_RANK"] = str(context.get_local_rank()) | ||
os.environ["RANK"] = str(context.get_world_rank()) | ||
os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size()) | ||
os.environ["WORLD_SIZE"] = str(context.get_world_size()) | ||
os.environ["GROUP_RANK"] = str(context.get_node_rank()) | ||
os.environ["GROUP_WORLD_SIZE"] = str( | ||
context.get_world_size() / context.get_local_world_size() | ||
) | ||
Comment on lines
+50
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to consider the case of heterogenous clusters? Should we pass this information in from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen any better documentation to be honest and my knowledge is slim in this area. Also, do we have a way to pass addition work/actions before the training starts without extending the class? For example, I want to add a new EFA configuration as env_var for this train/run. More like a function to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay can we just add a TODO here. If we need to support heterogenous cluster we can add more robust logic in For the question on |
||
os.environ["ROLE_RANK"] = str(context.get_world_rank()) | ||
os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank()) | ||
os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size()) | ||
|
||
# EFA and XLA setup | ||
# https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c | ||
# https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa | ||
os.environ["FI_PROVIDER"] = "efa" | ||
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" | ||
os.environ["FI_EFA_FORK_SAFE"] = "1" | ||
os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1" | ||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" | ||
|
||
|
||
def _setup_xla_torch_process_group(): | ||
try: | ||
import torch.distributed as dist | ||
import torch_xla.core.xla_model as xm # noqa F401 | ||
import torch_xla.distributed.xla_backend # noqa F401 | ||
|
||
dist.init_process_group("xla") | ||
except ImportError: | ||
raise ImportError("torch_xla must be installed to use torch_xla backend.") | ||
|
||
|
||
# The following env vars enable Neuron graph extraction for parallel compilation | ||
# Note: model outputs are invalid and should be ignored while these env vars are set | ||
def _set_neuron_parallel_compile_env_vars(): | ||
os.environ["NEURON_PARALLEL_COMPILE"] = "1" | ||
os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1" | ||
os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1" | ||
|
||
|
||
# Compile previously extracted Neuron graphs | ||
def _neuron_compile_extracted_graphs(): | ||
try: | ||
from libneuronxla.neuron_cc_cache import CacheUrl | ||
from libneuronxla.neuron_parallel_compile import parallel_compile | ||
except ImportError: | ||
raise ImportError( | ||
"libneuronxla must be installed to use Neuron parallel compilation." | ||
) | ||
|
||
# Only 1 worker per node should run parallel_compile() | ||
if os.environ.get("LOCAL_RANK") == "0": | ||
logger.info("Compiling extracted graphs on local rank0 worker") | ||
|
||
parallel_compile_workdir = ( | ||
f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/" | ||
) | ||
if os.path.exists(parallel_compile_workdir): | ||
shutil.rmtree(parallel_compile_workdir) | ||
os.makedirs(parallel_compile_workdir, exist_ok=True) | ||
|
||
# Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by | ||
# using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence. | ||
explicit_cache_dir = None | ||
if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"): | ||
if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags): | ||
explicit_cache_dir = s.group(1) | ||
|
||
parallel_compile( | ||
parallel_compile_workdir, | ||
CacheUrl.get_cache_url(explicit_cache_dir), | ||
) | ||
|
||
|
||
class _TorchAwsNeuronXLABackend(Backend): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this subclass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's private class, kept to Backend for now |
||
unique_run_id: str = str(uuid.uuid4()) | ||
|
||
def on_start(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): | ||
"""Logic ran right before training is started.""" | ||
|
||
# On previous worker failure, we don't run graceful shutdown on workers. | ||
# This would leak any running xrt server. | ||
worker_group.execute(_kill_xrt_server) | ||
|
||
# Get master address and port from the first worker. | ||
master_addr, master_port = worker_group.execute_single(0, get_address_and_port) | ||
|
||
def set_env_vars(addr, port): | ||
os.environ["MASTER_ADDR"] = addr | ||
os.environ["MASTER_PORT"] = str(port) | ||
# To trigger the xrt server | ||
os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id | ||
|
||
# Set the env vars on all workers. | ||
worker_group.execute(set_env_vars, addr=master_addr, port=master_port) | ||
|
||
# Set up env vars for neuron parallel compilation graph extraction | ||
if backend_config.neuron_parallel_compile: | ||
logger.info("Extracting graphs for Neuron parallel compilation") | ||
worker_group.execute(_set_neuron_parallel_compile_env_vars) | ||
|
||
def on_training_start( | ||
self, worker_group: WorkerGroup, backend_config: TorchXLAConfig | ||
): | ||
""" | ||
Configure the environment variables for the worker group. | ||
And initialize the xla distributed process group. | ||
TODO: Current setup only supports homogenous cluster with | ||
neuron_cores accelerator and xrt runtime. | ||
""" | ||
worker_group.execute(_set_xla_env_vars) | ||
worker_group.execute(_setup_xla_torch_process_group) | ||
|
||
def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): | ||
""" | ||
Logic ran right after training is finished. | ||
This is a sanity cleanup to kill xrt server, and to optionally | ||
run neuron parallel graph compilation | ||
""" | ||
worker_group.execute(_kill_xrt_server) | ||
|
||
# Compile the extracted graphs. This must run at end of training. | ||
if backend_config.neuron_parallel_compile: | ||
worker_group.execute(_neuron_compile_extracted_graphs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we compiling the graph at the end? How would the people use the compiled graph? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code only runs during a pre-compilation step which happens when neuron_parallel_compile is set to true. After precompilation, the user simply runs without neuron_parallel_compile to use the cached graphs and this avoids There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @woshiyyya does the above make sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok that makes sense to me. Can we add a docstring for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a python API we can use to kill the server?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, no. I'm open for options if there are any alternatives.