Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Add TorchAwsNeuronXLABackend and XLAConfig #39130

Merged
merged 33 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8f9753f
Add TorchAwsNeuronXLABackend and XLAConfig - refactor
chappidim Sep 11, 2023
cf591d1
Merge branch 'master' into feat-train-config
chappidim Sep 12, 2023
1e45e77
Refactor changes with fixes
chappidim Sep 21, 2023
4fcf6ac
Merge branch 'master' into feat-train-config
chappidim Sep 21, 2023
d7aa259
Refactor changes - format
chappidim Sep 21, 2023
5418191
Merge branch 'master' into feat-train-config
chappidim Oct 13, 2023
206cbce
Remove config
chappidim Oct 13, 2023
1c80fad
Merge branch 'master' into feat-train-config
chappidim Oct 16, 2023
2c66d64
add release test for Trainium integration
woshiyyya Oct 19, 2023
211ffcf
fix typo
woshiyyya Oct 19, 2023
e805f68
update prepare_model
woshiyyya Oct 20, 2023
c888e5d
update multi-node release test
woshiyyya Nov 14, 2023
4e836eb
Merge remote-tracking branch 'upstream/master' into feat-train-config
woshiyyya Nov 14, 2023
1ba3b4d
fix byod.sh
woshiyyya Nov 15, 2023
8e9148c
chmod 755 byod_train_trainium.sh
woshiyyya Nov 15, 2023
de9b620
update efa configs and vpn cld_id
woshiyyya Nov 15, 2023
553653e
fix lint
woshiyyya Nov 15, 2023
c319a75
fix compute config
woshiyyya Nov 16, 2023
717b3a8
change region
woshiyyya Nov 16, 2023
fc63097
try to put Network interfaces into new aws entry
woshiyyya Nov 16, 2023
34620f0
fix shell lint
woshiyyya Nov 17, 2023
c88cdff
Merge remote-tracking branch 'upstream/master' into feat-train-config
woshiyyya Nov 17, 2023
74251f4
Merge branch 'master' into feat-train-config
woshiyyya Jan 4, 2024
beece14
Avoid runtime warning when NEURON_RT_VISIBLE_CORES is empty
5cp Jan 10, 2024
1164d73
Add support for neuron_parallel_compile to pre-populate Neuron cache
5cp Jan 16, 2024
8b1a416
Clean up parallel_compile_workdir before parallel compilation
5cp Jan 16, 2024
6d995dd
Minor updates to address PR comments
5cp Jan 16, 2024
39eb872
Merge pull request #1 from 5cp/feat-train-config
chappidim Jan 17, 2024
bb28b81
Merge branch 'ray-project:master' into feat-train-config
chappidim Jan 17, 2024
f514f07
Merge branch 'master' into feat-train-config
woshiyyya Mar 19, 2024
66264f9
remove release test
woshiyyya Mar 20, 2024
b340222
fix ci test
woshiyyya Apr 3, 2024
4eeb74d
Merge branch 'master' into feat-train-config
woshiyyya Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ PyTorch Ecosystem

~train.torch.TorchTrainer
~train.torch.TorchConfig
~train.torch.xla.TorchXLAConfig

.. _train-pytorch-integration:

Expand Down Expand Up @@ -145,7 +146,7 @@ Ray Train Utilities
.. autosummary::
:nosignatures:
:toctree: doc/

~train.get_checkpoint
~train.get_context
~train.get_dataset_shard
Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/torch/xla/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ray.train.torch.xla.config import TorchXLAConfig

__all__ = [
"TorchXLAConfig",
]
169 changes: 169 additions & 0 deletions python/ray/train/torch/xla/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
import os
import re
import shutil
import uuid
from dataclasses import dataclass

import ray
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend
from ray.train.torch import TorchConfig
from ray.util import PublicAPI

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
@dataclass
class TorchXLAConfig(TorchConfig):
"""
Configuration for torch XLA setup.
See https://pytorch.org/xla/release/1.13/index.html for more info.
Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend)
is supported with xrt runtime.
"""

neuron_parallel_compile: bool = False

@property
def backend_cls(self):
return _TorchAwsNeuronXLABackend


def _kill_xrt_server():
import subprocess

subprocess.call(["pkill", "-f", "xrt_run_server"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a python API we can use to kill the server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no. I'm open for options if there are any alternatives.



def _set_xla_env_vars():
# https://pytorch.org/docs/1.13/elastic/run.html#environment-variables
context = ray.train.get_context()

os.environ["LOCAL_RANK"] = str(context.get_local_rank())
os.environ["RANK"] = str(context.get_world_rank())
os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
os.environ["WORLD_SIZE"] = str(context.get_world_size())
os.environ["GROUP_RANK"] = str(context.get_node_rank())
os.environ["GROUP_WORLD_SIZE"] = str(
context.get_world_size() / context.get_local_world_size()
)
Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider the case of heterogenous clusters? Should we pass this information in from the Backend instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen any better documentation to be honest and my knowledge is slim in this area.

Also, do we have a way to pass addition work/actions before the training starts without extending the class? For example, I want to add a new EFA configuration as env_var for this train/run. More like a function to on_start which will be executed on all workers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay can we just add a TODO here. If we need to support heterogenous cluster we can add more robust logic in on_training_start to fetch the topology of the cluster and then set it.

For the question on on_start, just capturing here the conversation we had offline - this can be defined by the user in in train_loop_per_worker.

os.environ["ROLE_RANK"] = str(context.get_world_rank())
os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank())
os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size())

# EFA and XLA setup
# https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c
# https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa
os.environ["FI_PROVIDER"] = "efa"
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
os.environ["FI_EFA_FORK_SAFE"] = "1"
os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1"
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"


def _setup_xla_torch_process_group():
try:
import torch.distributed as dist
import torch_xla.core.xla_model as xm # noqa F401
import torch_xla.distributed.xla_backend # noqa F401

dist.init_process_group("xla")
except ImportError:
raise ImportError("torch_xla must be installed to use torch_xla backend.")


# The following env vars enable Neuron graph extraction for parallel compilation
# Note: model outputs are invalid and should be ignored while these env vars are set
def _set_neuron_parallel_compile_env_vars():
os.environ["NEURON_PARALLEL_COMPILE"] = "1"
os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1"
os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1"


# Compile previously extracted Neuron graphs
def _neuron_compile_extracted_graphs():
try:
from libneuronxla.neuron_cc_cache import CacheUrl
from libneuronxla.neuron_parallel_compile import parallel_compile
except ImportError:
raise ImportError(
"libneuronxla must be installed to use Neuron parallel compilation."
)

# Only 1 worker per node should run parallel_compile()
if os.environ.get("LOCAL_RANK") == "0":
logger.info("Compiling extracted graphs on local rank0 worker")

parallel_compile_workdir = (
f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/"
)
if os.path.exists(parallel_compile_workdir):
shutil.rmtree(parallel_compile_workdir)
os.makedirs(parallel_compile_workdir, exist_ok=True)

# Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by
# using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence.
explicit_cache_dir = None
if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"):
if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags):
explicit_cache_dir = s.group(1)

parallel_compile(
parallel_compile_workdir,
CacheUrl.get_cache_url(explicit_cache_dir),
)


class _TorchAwsNeuronXLABackend(Backend):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this subclass TorchBackend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's private class, kept to Backend for now

unique_run_id: str = str(uuid.uuid4())

def on_start(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig):
"""Logic ran right before training is started."""

# On previous worker failure, we don't run graceful shutdown on workers.
# This would leak any running xrt server.
worker_group.execute(_kill_xrt_server)

# Get master address and port from the first worker.
master_addr, master_port = worker_group.execute_single(0, get_address_and_port)

def set_env_vars(addr, port):
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = str(port)
# To trigger the xrt server
os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id

# Set the env vars on all workers.
worker_group.execute(set_env_vars, addr=master_addr, port=master_port)

# Set up env vars for neuron parallel compilation graph extraction
if backend_config.neuron_parallel_compile:
logger.info("Extracting graphs for Neuron parallel compilation")
worker_group.execute(_set_neuron_parallel_compile_env_vars)

def on_training_start(
self, worker_group: WorkerGroup, backend_config: TorchXLAConfig
):
"""
Configure the environment variables for the worker group.
And initialize the xla distributed process group.
TODO: Current setup only supports homogenous cluster with
neuron_cores accelerator and xrt runtime.
"""
worker_group.execute(_set_xla_env_vars)
worker_group.execute(_setup_xla_torch_process_group)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig):
"""
Logic ran right after training is finished.
This is a sanity cleanup to kill xrt server, and to optionally
run neuron parallel graph compilation
"""
worker_group.execute(_kill_xrt_server)

# Compile the extracted graphs. This must run at end of training.
if backend_config.neuron_parallel_compile:
worker_group.execute(_neuron_compile_extracted_graphs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we compiling the graph at the end? How would the people use the compiled graph?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code only runs during a pre-compilation step which happens when neuron_parallel_compile is set to true. _neuron_compile_extracted_graphs() must run at the end of the job (after a short number of training iterations) when all the graphs have been encountered.

See: 1 and 2 and 3.

After precompilation, the user simply runs without neuron_parallel_compile to use the cached graphs and this avoids _neuron_compile_extracted_graphs()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@woshiyyya does the above make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that makes sense to me. Can we add a docstring for neuron_parallel_compile in TorchConfig to explain the behavior?

1 change: 1 addition & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,7 @@
cluster:
cluster_compute: rte_gce_small.yaml


- name: runtime_env_wheel_urls
group: Runtime env tests
working_dir: runtime_env_tests
Expand Down
Loading