From 8f9753f9f91b469cbabe06a6fd3c594bd5c9cd18 Mon Sep 17 00:00:00 2001 From: maheedhar reddy chappidi Date: Mon, 11 Sep 2023 12:01:00 -0700 Subject: [PATCH 01/22] Add TorchAwsNeuronXLABackend and XLAConfig - refactor Signed-off-by: maheedhar reddy chappidi --- doc/source/train/api/api.rst | 3 +- python/ray/train/torch/xla/__init__.py | 5 + python/ray/train/torch/xla/config.py | 126 +++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 python/ray/train/torch/xla/__init__.py create mode 100644 python/ray/train/torch/xla/config.py diff --git a/doc/source/train/api/api.rst b/doc/source/train/api/api.rst index 59dc83c5a514..7df62e1f3422 100644 --- a/doc/source/train/api/api.rst +++ b/doc/source/train/api/api.rst @@ -17,6 +17,7 @@ PyTorch Ecosystem ~train.torch.TorchTrainer ~train.torch.TorchConfig + ~train.torch.xla.TorchXLAConfig .. _train-pytorch-integration: @@ -132,7 +133,7 @@ Ray Train Utilities .. autosummary:: :toctree: doc/ - + ~train.get_checkpoint ~train.get_context ~train.get_dataset_shard diff --git a/python/ray/train/torch/xla/__init__.py b/python/ray/train/torch/xla/__init__.py new file mode 100644 index 000000000000..ea32abc8c9d7 --- /dev/null +++ b/python/ray/train/torch/xla/__init__.py @@ -0,0 +1,5 @@ +from ray.train.torch.xla.config import TorchXLAConfig + +__all__ = [ + "TorchXLAConfig", +] diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py new file mode 100644 index 000000000000..7da15d66009d --- /dev/null +++ b/python/ray/train/torch/xla/config.py @@ -0,0 +1,126 @@ +import os +import uuid +from dataclasses import dataclass + +import ray +from ray._private.ray_constants import NEURON_CORES +from ray.train import BackendConfig +from ray.train._internal.utils import get_address_and_port +from ray.train._internal.worker_group import WorkerGroup +from ray.train.backend import Backend +from ray.util import PublicAPI + + +@PublicAPI(stability="alpha") +@dataclass +class TorchXLAConfig(BackendConfig): + """ + Configuration for torch XLA setup. + See https://pytorch.org/xla/release/1.13/index.html for more info. + + Args: + runtime: Runtime to use for training. Supported values are "xrt", "pjrt". + Currently, only "xrt" is supported. + accelerator_type: The accelerator type used to differentiate the XLA backend. + Currently, only "neuron_cores" is supported. + + """ + + runtime: str = "xrt" + accelerator_type: str = NEURON_CORES + + @property + def backend_cls(self): + if self.accelerator_type == NEURON_CORES: + return _TorchAwsNeuronXLABackend + else: + raise NotImplementedError + + +def _kill_xrt_server(): + import subprocess + + subprocess.call(["pkill", "-f", "xrt_run_server"]) + + +def _set_xla_env_vars(): + # https://pytorch.org/docs/1.13/elastic/run.html#environment-variables + context = ray.train.get_context() + + os.environ["LOCAL_RANK"] = str(context.get_local_rank()) + os.environ["RANK"] = str(context.get_world_rank()) + os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size()) + os.environ["WORLD_SIZE"] = str(context.get_world_size()) + os.environ["GROUP_RANK"] = str(context.get_node_rank()) + os.environ["GROUP_WORLD_SIZE"] = str( + context.get_world_size() / context.get_local_world_size() + ) + os.environ["ROLE_RANK"] = str(context.get_world_rank()) + os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank()) + os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size()) + + # EFA and XLA setup + # https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c + # https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa + os.environ["FI_PROVIDER"] = "efa" + os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" + os.environ["FI_EFA_FORK_SAFE"] = "1" + os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1" + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + + +def _setup_xla_torch_process_group(): + try: + import torch_xla.core.xla_model as xm # noqa + import torch_xla.distributed.xla_backend # noqa + import torch.distributed as dist + + dist.init_process_group("xla") + except ImportError: + raise ImportError("torch_xla must be installed to use torch_xla backend.") + + +class _TorchAwsNeuronXLABackend(Backend): + unique_run_id: str = str(uuid.uuid4()) + + def on_start(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): + """Logic ran right before training is started.""" + + if backend_config.runtime != "xrt": + # pjrt is not yet supported in torch-neuronx. + raise ValueError( + f"Expected runtime to be 'xrt', but got '{backend_config.runtime}'." + ) + + # On previous worker failure, we don't run graceful shutdown on workers. + # This would leak any running xrt server. + worker_group.execute(_kill_xrt_server) + + # Get master address and port from the first worker. + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) + + def set_env_vars(addr, port): + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + # To trigger the xrt server + os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id + + # Set the env vars on all workers. + worker_group.execute(set_env_vars, addr=master_addr, port=master_port) + + def on_training_start( + self, worker_group: WorkerGroup, backend_config: TorchXLAConfig + ): + """ + Configure the environment variables for the worker group. + And initialize the xla distributed process group. + """ + worker_group.execute(_set_xla_env_vars) + worker_group.execute(_setup_xla_torch_process_group) + + def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): + """ + Logic ran right before training is started. + This is a sanity cleanup to kill xrt server. + """ + worker_group.execute(_kill_xrt_server) From 1e45e77d03ffdd25c81c4fc6069df130493981ac Mon Sep 17 00:00:00 2001 From: maheedhar reddy chappidi Date: Thu, 21 Sep 2023 11:33:00 -0700 Subject: [PATCH 02/22] Refactor changes with fixes Signed-off-by: maheedhar reddy chappidi --- python/ray/train/torch/xla/config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index 7da15d66009d..3f1fc03e42de 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -1,19 +1,20 @@ + import os import uuid from dataclasses import dataclass import ray from ray._private.ray_constants import NEURON_CORES -from ray.train import BackendConfig from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup from ray.train.backend import Backend +from ray.train.torch import TorchConfig from ray.util import PublicAPI @PublicAPI(stability="alpha") @dataclass -class TorchXLAConfig(BackendConfig): +class TorchXLAConfig(TorchConfig): """ Configuration for torch XLA setup. See https://pytorch.org/xla/release/1.13/index.html for more info. @@ -71,8 +72,8 @@ def _set_xla_env_vars(): def _setup_xla_torch_process_group(): try: - import torch_xla.core.xla_model as xm # noqa - import torch_xla.distributed.xla_backend # noqa + import torch_xla.core.xla_model as xm # noqa F401 + import torch_xla.distributed.xla_backend # noqa F401 import torch.distributed as dist dist.init_process_group("xla") @@ -120,7 +121,7 @@ def on_training_start( def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): """ - Logic ran right before training is started. + Logic ran right after training is finished. This is a sanity cleanup to kill xrt server. """ worker_group.execute(_kill_xrt_server) From d7aa2596c4b6dfb1f52c611191282923f26e2c49 Mon Sep 17 00:00:00 2001 From: maheedhar reddy chappidi Date: Thu, 21 Sep 2023 12:08:17 -0700 Subject: [PATCH 03/22] Refactor changes - format Signed-off-by: maheedhar reddy chappidi --- python/ray/train/torch/xla/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index 3f1fc03e42de..632be0053e3a 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -1,4 +1,3 @@ - import os import uuid from dataclasses import dataclass From 206cbceaf6d0d0c002b329c75b91062a38eb936c Mon Sep 17 00:00:00 2001 From: maheedhar reddy chappidi Date: Fri, 13 Oct 2023 16:18:31 -0700 Subject: [PATCH 04/22] Remove config Signed-off-by: maheedhar reddy chappidi --- python/ray/train/torch/xla/config.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index 632be0053e3a..f43069d2a811 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -3,7 +3,6 @@ from dataclasses import dataclass import ray -from ray._private.ray_constants import NEURON_CORES from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup from ray.train.backend import Backend @@ -17,24 +16,13 @@ class TorchXLAConfig(TorchConfig): """ Configuration for torch XLA setup. See https://pytorch.org/xla/release/1.13/index.html for more info. - - Args: - runtime: Runtime to use for training. Supported values are "xrt", "pjrt". - Currently, only "xrt" is supported. - accelerator_type: The accelerator type used to differentiate the XLA backend. - Currently, only "neuron_cores" is supported. - + Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend) + is supported with xrt runtime. """ - runtime: str = "xrt" - accelerator_type: str = NEURON_CORES - @property def backend_cls(self): - if self.accelerator_type == NEURON_CORES: - return _TorchAwsNeuronXLABackend - else: - raise NotImplementedError + return _TorchAwsNeuronXLABackend def _kill_xrt_server(): @@ -71,9 +59,9 @@ def _set_xla_env_vars(): def _setup_xla_torch_process_group(): try: + import torch.distributed as dist import torch_xla.core.xla_model as xm # noqa F401 import torch_xla.distributed.xla_backend # noqa F401 - import torch.distributed as dist dist.init_process_group("xla") except ImportError: @@ -86,12 +74,6 @@ class _TorchAwsNeuronXLABackend(Backend): def on_start(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): """Logic ran right before training is started.""" - if backend_config.runtime != "xrt": - # pjrt is not yet supported in torch-neuronx. - raise ValueError( - f"Expected runtime to be 'xrt', but got '{backend_config.runtime}'." - ) - # On previous worker failure, we don't run graceful shutdown on workers. # This would leak any running xrt server. worker_group.execute(_kill_xrt_server) @@ -114,6 +96,8 @@ def on_training_start( """ Configure the environment variables for the worker group. And initialize the xla distributed process group. + TODO: Current setup only supports homogenous cluster with + neuron_cores accelerator and xrt runtime. """ worker_group.execute(_set_xla_env_vars) worker_group.execute(_setup_xla_torch_process_group) From 2c66d643bd8c1db374618f9a4fc3838a0a33e745 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Thu, 19 Oct 2023 15:40:20 -0700 Subject: [PATCH 05/22] add release test for Trainium integration Signed-off-by: woshiyyya --- .../ray_release/byod/byod_train_trainium.sh | 22 ++++++++ release/release_tests.yaml | 17 ++++++ release/train_tests/trainium/compute_aws.yml | 17 ++++++ release/train_tests/trainium/test_trainium.py | 54 +++++++++++++++++++ 4 files changed, 110 insertions(+) create mode 100755 release/ray_release/byod/byod_train_trainium.sh create mode 100644 release/train_tests/trainium/compute_aws.yml create mode 100644 release/train_tests/trainium/test_trainium.py diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh new file mode 100755 index 000000000000..029ad132d9cd --- /dev/null +++ b/release/ray_release/byod/byod_train_trainium.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Configure Linux for Neuron repository updates +. /etc/os-release && \ +echo "deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main" | \ +sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null + +wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - + +# Update OS packages +sudo apt-get update -y + +# Install Neuron Runtime +sudo apt-get install aws-neuronx-collectives=2.* -y +sudo apt-get install aws-neuronx-runtime-lib=2.* -y + +# Install Neuron Tools +sudo apt-get install aws-neuronx-tools=2.* -y + +# Install neuronx and torch_xla +pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com +pip install neuronx-cc==2.* torch-neuronx torchvision diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 6f48c4552b9b..91fee2d593ec 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -3280,6 +3280,23 @@ alert: default +- name: train_torch_trainium_integration + group: Train tests + working_dir: train_tests/trainium + + frequency: nightly + team: ml + + cluster: + byod: + post_build_script: byod_train_trainium.sh + cluster_compute: compute_aws.yaml + + run: + timeout: 1000 + script: python test_trainium.py + + alert: default ######################## # Alpa tests diff --git a/release/train_tests/trainium/compute_aws.yml b/release/train_tests/trainium/compute_aws.yml new file mode 100644 index 000000000000..5995662e9cec --- /dev/null +++ b/release/train_tests/trainium/compute_aws.yml @@ -0,0 +1,17 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 0 + +head_node_type: + name: head_node + instance_type: trn1.2xlarge + +worker_node_types: [] + +aws: + TagSpecifications: + - ResourceType: "instance" + Tags: + - Key: ttl-hours + Value: '24' diff --git a/release/train_tests/trainium/test_trainium.py b/release/train_tests/trainium/test_trainium.py new file mode 100644 index 000000000000..1f31bfb06c97 --- /dev/null +++ b/release/train_tests/trainium/test_trainium.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.optim as optim +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_backend # noqa: F401 + +from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer +from ray.train.torch.xla import TorchXLAConfig + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def train_func(): + device = xm.xla_device() + rank = xm.get_ordinal() + + # Create the model and move to device + model = Model().to(device) + ddp_model = DDP(model, gradient_as_bucket_view=True) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + for step in range(5): + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + xm.mark_step() + if rank == 0: + print(f"Loss after step {step}: {loss.cpu()}") + + +trainer = TorchTrainer( + train_loop_per_worker=train_func, + torch_config=TorchXLAConfig(), + scaling_config=ScalingConfig( + num_workers=2, resources_per_worker={"neuron_cores": 1} + ), +) +result = trainer.fit() +print(result) From 211ffcf404c579165a22654f31ab9c8677b445c7 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Thu, 19 Oct 2023 15:49:53 -0700 Subject: [PATCH 06/22] fix typo Signed-off-by: woshiyyya --- .../train_tests/trainium/{compute_aws.yml => compute_aws.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename release/train_tests/trainium/{compute_aws.yml => compute_aws.yaml} (100%) diff --git a/release/train_tests/trainium/compute_aws.yml b/release/train_tests/trainium/compute_aws.yaml similarity index 100% rename from release/train_tests/trainium/compute_aws.yml rename to release/train_tests/trainium/compute_aws.yaml From e805f68b3c0649e15af1539889fd5ae105cd2a6f Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Thu, 19 Oct 2023 18:03:22 -0700 Subject: [PATCH 07/22] update prepare_model Signed-off-by: woshiyyya --- release/train_tests/trainium/test_trainium.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/release/train_tests/trainium/test_trainium.py b/release/train_tests/trainium/test_trainium.py index 1f31bfb06c97..be6ea8d79ccc 100644 --- a/release/train_tests/trainium/test_trainium.py +++ b/release/train_tests/trainium/test_trainium.py @@ -1,12 +1,11 @@ import torch import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP import torch.optim as optim import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_backend # noqa: F401 from ray.train import ScalingConfig -from ray.train.torch import TorchTrainer +from ray.train.torch import TorchTrainer, prepare_model from ray.train.torch.xla import TorchXLAConfig @@ -27,7 +26,11 @@ def train_func(): # Create the model and move to device model = Model().to(device) - ddp_model = DDP(model, gradient_as_bucket_view=True) + ddp_model = prepare_model( + model, + move_to_device=False, + parallel_strategy_kwargs={"gradient_as_bucket_view": True}, + ) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) From c888e5df9a1c5e6eeb427fe84e4ea2dec1ff9118 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Tue, 14 Nov 2023 11:54:25 -0800 Subject: [PATCH 08/22] update multi-node release test Signed-off-by: woshiyyya --- .../ray_release/byod/byod_train_trainium.sh | 24 +++- release/train_tests/trainium/compute_aws.yaml | 11 +- release/train_tests/trainium/test_trainium.py | 130 +++++++++++++----- 3 files changed, 122 insertions(+), 43 deletions(-) mode change 100755 => 100644 release/ray_release/byod/byod_train_trainium.sh diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh old mode 100755 new mode 100644 index 029ad132d9cd..bf23b2470951 --- a/release/ray_release/byod/byod_train_trainium.sh +++ b/release/ray_release/byod/byod_train_trainium.sh @@ -1,10 +1,19 @@ -#!/bin/bash +# Configure EFA installer +sudo apt-get install curl +curl -O https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz +wget https://efa-installer.amazonaws.com/aws-efa-installer.key && gpg --import aws-efa-installer.key +cat aws-efa-installer.key | gpg --fingerprint +wget https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz.sig && gpg --verify ./aws-efa-installer-latest.tar.gz.sig +tar -xvf aws-efa-installer-latest.tar.gz +cd aws-efa-installer && sudo bash efa_installer.sh --yes --skip-kmod +cd +sudo rm -rf aws-efa-installer-latest.tar.gz aws-efa-installer # Configure Linux for Neuron repository updates -. /etc/os-release && \ -echo "deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main" | \ -sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null - +. /etc/os-release +sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null < None: + super().__init__() + self.layers = nn.Sequential( + nn.Linear(28 * 28, 128), + nn.ReLU(), + nn.Linear(128, 10), + nn.ReLU(), + ) def forward(self, x): - return self.net2(self.relu(self.net1(x))) + return self.layers(x) def train_func(): - device = xm.xla_device() - rank = xm.get_ordinal() - - # Create the model and move to device - model = Model().to(device) - ddp_model = prepare_model( - model, - move_to_device=False, - parallel_strategy_kwargs={"gradient_as_bucket_view": True}, + # Load MNIST train dataset + if not xm.is_master_ordinal(): + xm.rendezvous("dataset_download") + train_dataset = mnist.MNIST( + root="/tmp/MNIST_DATA_train", train=True, download=True, transform=ToTensor() ) + if xm.is_master_ordinal(): + xm.rendezvous("dataset_download") + + # XLA MP: get world size + world_size = xm.xrt_world_size() + # multi-processing: ensure each worker has same initial weights + torch.manual_seed(0) + + # Move model to device and declare optimizer and loss function + device = "xla" + model = MLP().to(device) + # For multiprocessing, scale up learning rate + optimizer = torch.optim.SGD(model.parameters(), lr=1e-5) + loss_fn = torch.nn.NLLLoss() + + # Prepare data loader + train_sampler = None + if world_size > 1: + train_sampler = DistributedSampler( + train_dataset, num_replicas=world_size, rank=xm.get_ordinal(), shuffle=True + ) + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + sampler=train_sampler, + shuffle=False if train_sampler else True, + ) + # XLA MP: use MpDeviceLoader from torch_xla.distributed + train_device_loader = pl.MpDeviceLoader(train_loader, device) + + # Run the training loop + print("----------Training ---------------") + model.train() + for epoch in range(EPOCHS): + start = time.time() + for idx, (train_x, train_label) in enumerate(train_device_loader): + optimizer.zero_grad() + train_x = train_x.view(train_x.size(0), -1) + output = model(train_x) + loss = loss_fn(output, train_label) + loss.backward() + xm.optimizer_step( + optimizer + ) # XLA MP: performs grad allreduce and optimizer step + if idx < WARMUP_STEPS: # skip warmup iterations + start = time.time() + + # Compute statistics for the last epoch + interval = len(train_device_loader) - WARMUP_STEPS # skip warmup iterations + throughput = interval / (time.time() - start) + print("Train throughput (iter/sec): {}".format(throughput)) + print("Final loss is {:0.4f}".format(loss.detach().to("cpu"))) + + # Save checkpoint for evaluation (xm.save ensures only one process save) + os.makedirs("checkpoints", exist_ok=True) + checkpoint = {"state_dict": model.state_dict()} + xm.save(checkpoint, "checkpoints/checkpoint.pt") - loss_fn = nn.MSELoss() - optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) - for step in range(5): - optimizer.zero_grad() - outputs = ddp_model(torch.randn(20, 10).to(device)) - labels = torch.randn(20, 5).to(device) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() - xm.mark_step() - if rank == 0: - print(f"Loss after step {step}: {loss.cpu()}") + print("----------End Training ---------------") +# trn1.32xlarge -> 32 neuron_cores, 128 CPU +# 2x trn1.32xlarge trainer = TorchTrainer( train_loop_per_worker=train_func, torch_config=TorchXLAConfig(), scaling_config=ScalingConfig( - num_workers=2, resources_per_worker={"neuron_cores": 1} + num_workers=64, resources_per_worker={"neuron_cores": 1} ), ) result = trainer.fit() From 1ba3b4dc86581d1c2d646633c2bd52b634a49e1a Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Tue, 14 Nov 2023 17:14:52 -0800 Subject: [PATCH 09/22] fix byod.sh Signed-off-by: woshiyyya --- release/ray_release/byod/byod_train_trainium.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh index bf23b2470951..0cc64fd62b2d 100644 --- a/release/ray_release/byod/byod_train_trainium.sh +++ b/release/ray_release/byod/byod_train_trainium.sh @@ -10,10 +10,10 @@ cd sudo rm -rf aws-efa-installer-latest.tar.gz aws-efa-installer # Configure Linux for Neuron repository updates -. /etc/os-release -sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null < /dev/null + wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - # Update OS packages From 8e9148cf473e86a2371b112efbc877d2758bcb37 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Tue, 14 Nov 2023 22:33:15 -0800 Subject: [PATCH 10/22] chmod 755 byod_train_trainium.sh Signed-off-by: woshiyyya --- release/ray_release/byod/byod_train_trainium.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 release/ray_release/byod/byod_train_trainium.sh diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh old mode 100644 new mode 100755 From de9b620bd6828c5a0374839e4be2a76e10aec825 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 15 Nov 2023 11:23:51 -0800 Subject: [PATCH 11/22] update efa configs and vpn cld_id Signed-off-by: woshiyyya --- release/train_tests/trainium/compute_aws.yaml | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/release/train_tests/trainium/compute_aws.yaml b/release/train_tests/trainium/compute_aws.yaml index 230833269f08..8fe95d128dc4 100644 --- a/release/train_tests/trainium/compute_aws.yaml +++ b/release/train_tests/trainium/compute_aws.yaml @@ -1,4 +1,4 @@ -cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +cloud_id: cld_ywvPyX4hezrjv8RtrU28wTbr region: us-west-2 max_workers: 1 @@ -20,3 +20,52 @@ aws: Tags: - Key: ttl-hours Value: '24' + NetworkInterfaces: + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 0 + InterfaceType: efa + NetworkCardIndex: 0 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 1 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 2 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 3 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 4 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 5 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 6 + - Groups: + - sg-02dbb4c6957712ab8 + SubnetId: subnet-06425628175889cdb + DeviceIndex: 1 + InterfaceType: efa + NetworkCardIndex: 7 From 553653ecad83c0c7d3953fb7f0e791f67fcfb84f Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 15 Nov 2023 12:08:26 -0800 Subject: [PATCH 12/22] fix lint Signed-off-by: woshiyyya --- release/ray_release/byod/byod_train_trainium.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh index 0cc64fd62b2d..d03687d51124 100755 --- a/release/ray_release/byod/byod_train_trainium.sh +++ b/release/ray_release/byod/byod_train_trainium.sh @@ -1,4 +1,7 @@ -# Configure EFA installer +#!/bin/bash +# This script is used to build an extra layer on top of the base anyscale/ray image +# to run the train_torch_trainium_integration test. + sudo apt-get install curl curl -O https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz wget https://efa-installer.amazonaws.com/aws-efa-installer.key && gpg --import aws-efa-installer.key From c319a75bb10e8d70fdc508cb1ebbb68b3b5c7238 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 15 Nov 2023 17:37:10 -0800 Subject: [PATCH 13/22] fix compute config Signed-off-by: woshiyyya --- release/train_tests/trainium/compute_aws.yaml | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/release/train_tests/trainium/compute_aws.yaml b/release/train_tests/trainium/compute_aws.yaml index 8fe95d128dc4..9994671f7569 100644 --- a/release/train_tests/trainium/compute_aws.yaml +++ b/release/train_tests/trainium/compute_aws.yaml @@ -21,51 +21,51 @@ aws: - Key: ttl-hours Value: '24' NetworkInterfaces: - - Groups: - - sg-02dbb4c6957712ab8 - SubnetId: subnet-06425628175889cdb - DeviceIndex: 0 - InterfaceType: efa + - DeviceIndex: 0 NetworkCardIndex: 0 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 1 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 2 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 3 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 4 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 5 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 6 - - Groups: - - sg-02dbb4c6957712ab8 SubnetId: subnet-06425628175889cdb - DeviceIndex: 1 + Groups: + - sg-02dbb4c6957712ab8 InterfaceType: efa + - DeviceIndex: 1 NetworkCardIndex: 7 + SubnetId: subnet-06425628175889cdb + Groups: + - sg-02dbb4c6957712ab8 + InterfaceType: efa From 717b3a85c4b40e91c49619fa99832e1717f87f15 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 15 Nov 2023 21:09:34 -0800 Subject: [PATCH 14/22] change region Signed-off-by: woshiyyya --- release/train_tests/trainium/compute_aws.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/release/train_tests/trainium/compute_aws.yaml b/release/train_tests/trainium/compute_aws.yaml index 9994671f7569..1aa0de5ce08e 100644 --- a/release/train_tests/trainium/compute_aws.yaml +++ b/release/train_tests/trainium/compute_aws.yaml @@ -1,5 +1,5 @@ cloud_id: cld_ywvPyX4hezrjv8RtrU28wTbr -region: us-west-2 +region: us-east-2 max_workers: 1 From fc630975865886034d882c442dd79f8b9804aa31 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Thu, 16 Nov 2023 15:19:39 -0800 Subject: [PATCH 15/22] try to put Network interfaces into new aws entry Signed-off-by: woshiyyya --- release/train_tests/trainium/compute_aws.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/release/train_tests/trainium/compute_aws.yaml b/release/train_tests/trainium/compute_aws.yaml index 1aa0de5ce08e..5653288c09cd 100644 --- a/release/train_tests/trainium/compute_aws.yaml +++ b/release/train_tests/trainium/compute_aws.yaml @@ -20,6 +20,8 @@ aws: Tags: - Key: ttl-hours Value: '24' + +aws_advanced_configurations_json: NetworkInterfaces: - DeviceIndex: 0 NetworkCardIndex: 0 From 34620f0b3210e279fda4052fde418ab3ba7f50e1 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Fri, 17 Nov 2023 11:22:14 -0800 Subject: [PATCH 16/22] fix shell lint Signed-off-by: woshiyyya --- release/ray_release/byod/byod_train_trainium.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh index d03687d51124..14e12a177c3b 100755 --- a/release/ray_release/byod/byod_train_trainium.sh +++ b/release/ray_release/byod/byod_train_trainium.sh @@ -5,11 +5,11 @@ sudo apt-get install curl curl -O https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz wget https://efa-installer.amazonaws.com/aws-efa-installer.key && gpg --import aws-efa-installer.key -cat aws-efa-installer.key | gpg --fingerprint +gpg --fingerprint < aws-efa-installer.key wget https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz.sig && gpg --verify ./aws-efa-installer-latest.tar.gz.sig tar -xvf aws-efa-installer-latest.tar.gz cd aws-efa-installer && sudo bash efa_installer.sh --yes --skip-kmod -cd +cd || exit sudo rm -rf aws-efa-installer-latest.tar.gz aws-efa-installer # Configure Linux for Neuron repository updates From beece1475277e9adb04dbf49703435959ccedd56 Mon Sep 17 00:00:00 2001 From: Scott Perry <48838323+5cp@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:46:06 -0700 Subject: [PATCH 17/22] Avoid runtime warning when NEURON_RT_VISIBLE_CORES is empty --- python/ray/_private/accelerators/neuron.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/_private/accelerators/neuron.py b/python/ray/_private/accelerators/neuron.py index 7ba9eeb0666b..2f5b0327f9c3 100644 --- a/python/ray/_private/accelerators/neuron.py +++ b/python/ray/_private/accelerators/neuron.py @@ -108,6 +108,9 @@ def set_current_process_visible_accelerator_ids( if os.environ.get(NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR): return + if not os.environ.get(NeuronAcceleratorManager.get_visible_accelerator_ids_env_var()): + return + os.environ[ NeuronAcceleratorManager.get_visible_accelerator_ids_env_var() ] = ",".join([str(i) for i in visible_neuron_core_ids]) From 1164d73f1d3ea45e8910bd18a9118a0591cc80b9 Mon Sep 17 00:00:00 2001 From: Scott Perry <48838323+5cp@users.noreply.github.com> Date: Tue, 16 Jan 2024 09:41:30 +0000 Subject: [PATCH 18/22] Add support for neuron_parallel_compile to pre-populate Neuron cache --- python/ray/train/torch/xla/config.py | 55 +++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index f43069d2a811..401b902d37bb 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -1,6 +1,7 @@ import os import uuid from dataclasses import dataclass +import re import ray from ray.train._internal.utils import get_address_and_port @@ -20,6 +21,8 @@ class TorchXLAConfig(TorchConfig): is supported with xrt runtime. """ + neuron_parallel_compile: bool = False + @property def backend_cls(self): return _TorchAwsNeuronXLABackend @@ -68,6 +71,46 @@ def _setup_xla_torch_process_group(): raise ImportError("torch_xla must be installed to use torch_xla backend.") +# The following env vars enable Neuron graph extraction for parallel compilation +# Note: model outputs are invalid and should be ignored while these env vars are set +def _set_neuron_parallel_compile_env_vars(): + os.environ["NEURON_PARALLEL_COMPILE"] = "1" + os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1" + os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1" + + +# Compile previously extracted Neuron graphs +def _neuron_compile_extracted_graphs(): + try: + from libneuronxla.neuron_parallel_compile import parallel_compile + from libneuronxla.neuron_cc_cache import CacheUrl + except ImportError: + raise ImportError( + "libneuronxla must be installed to use Neuron parallel compilation." + ) + + # Only 1 worker per node should run parallel_compile() + if os.environ.get("LOCAL_RANK") == "0": + print("Compiling extracted graphs on local rank0 worker") + + parallel_compile_workdir = ( + f"/var/tmp/{os.environ.get('USER','no-user')}_parallel_compile_workdir/" + ) + os.makedirs(parallel_compile_workdir, exist_ok=True) + + # Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by + # using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence. + explicit_cache_dir = None + if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"): + if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags): + explicit_cache_dir = s.group(1) + + parallel_compile( + parallel_compile_workdir, + CacheUrl.get_cache_url(explicit_cache_dir), + ) + + class _TorchAwsNeuronXLABackend(Backend): unique_run_id: str = str(uuid.uuid4()) @@ -90,6 +133,11 @@ def set_env_vars(addr, port): # Set the env vars on all workers. worker_group.execute(set_env_vars, addr=master_addr, port=master_port) + # Set up env vars for neuron parallel compilation graph extraction + if backend_config.neuron_parallel_compile: + print("Extracting graphs for Neuron parallel compilation") + worker_group.execute(_set_neuron_parallel_compile_env_vars) + def on_training_start( self, worker_group: WorkerGroup, backend_config: TorchXLAConfig ): @@ -105,6 +153,11 @@ def on_training_start( def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig): """ Logic ran right after training is finished. - This is a sanity cleanup to kill xrt server. + This is a sanity cleanup to kill xrt server, and to optionally + run neuron parallel graph compilation """ worker_group.execute(_kill_xrt_server) + + # Compile the extracted graphs. This must run at end of training. + if backend_config.neuron_parallel_compile: + worker_group.execute(_neuron_compile_extracted_graphs) From 8b1a4167ecfbed85dcc186ef23e7fdc38e3fc0b6 Mon Sep 17 00:00:00 2001 From: Scott Perry <48838323+5cp@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:00:10 +0000 Subject: [PATCH 19/22] Clean up parallel_compile_workdir before parallel compilation --- python/ray/train/torch/xla/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index 401b902d37bb..4e9ff61ba6eb 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -2,6 +2,7 @@ import uuid from dataclasses import dataclass import re +import shutil import ray from ray.train._internal.utils import get_address_and_port @@ -96,6 +97,8 @@ def _neuron_compile_extracted_graphs(): parallel_compile_workdir = ( f"/var/tmp/{os.environ.get('USER','no-user')}_parallel_compile_workdir/" ) + if os.path.exists(parallel_compile_workdir): + shutil.rmtree(parallel_compile_workdir) os.makedirs(parallel_compile_workdir, exist_ok=True) # Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by From 6d995ddc5afc2ce31f9819029febd7970229fea6 Mon Sep 17 00:00:00 2001 From: Scott Perry <48838323+5cp@users.noreply.github.com> Date: Tue, 16 Jan 2024 23:08:04 +0000 Subject: [PATCH 20/22] Minor updates to address PR comments --- python/ray/train/torch/xla/config.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index 4e9ff61ba6eb..a01603134f1f 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import re import shutil +import logging import ray from ray.train._internal.utils import get_address_and_port @@ -11,6 +12,8 @@ from ray.train.torch import TorchConfig from ray.util import PublicAPI +logger = logging.getLogger(__name__) + @PublicAPI(stability="alpha") @dataclass @@ -92,10 +95,10 @@ def _neuron_compile_extracted_graphs(): # Only 1 worker per node should run parallel_compile() if os.environ.get("LOCAL_RANK") == "0": - print("Compiling extracted graphs on local rank0 worker") + logger.info("Compiling extracted graphs on local rank0 worker") parallel_compile_workdir = ( - f"/var/tmp/{os.environ.get('USER','no-user')}_parallel_compile_workdir/" + f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/" ) if os.path.exists(parallel_compile_workdir): shutil.rmtree(parallel_compile_workdir) @@ -138,7 +141,7 @@ def set_env_vars(addr, port): # Set up env vars for neuron parallel compilation graph extraction if backend_config.neuron_parallel_compile: - print("Extracting graphs for Neuron parallel compilation") + logger.info("Extracting graphs for Neuron parallel compilation") worker_group.execute(_set_neuron_parallel_compile_env_vars) def on_training_start( From 66264f9ffebf578826aa4d64719bcb9012076dc6 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 20 Mar 2024 13:29:31 -0700 Subject: [PATCH 21/22] remove release test Signed-off-by: woshiyyya --- python/ray/_private/accelerators/neuron.py | 4 +- python/ray/train/torch/xla/config.py | 8 +- .../ray_release/byod/byod_train_trainium.sh | 37 ------ release/release_tests.yaml | 19 --- release/train_tests/trainium/compute_aws.yaml | 73 ----------- release/train_tests/trainium/test_trainium.py | 119 ------------------ 6 files changed, 7 insertions(+), 253 deletions(-) delete mode 100755 release/ray_release/byod/byod_train_trainium.sh delete mode 100644 release/train_tests/trainium/compute_aws.yaml delete mode 100644 release/train_tests/trainium/test_trainium.py diff --git a/python/ray/_private/accelerators/neuron.py b/python/ray/_private/accelerators/neuron.py index 2f5b0327f9c3..ce0fcba266dd 100644 --- a/python/ray/_private/accelerators/neuron.py +++ b/python/ray/_private/accelerators/neuron.py @@ -108,7 +108,9 @@ def set_current_process_visible_accelerator_ids( if os.environ.get(NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR): return - if not os.environ.get(NeuronAcceleratorManager.get_visible_accelerator_ids_env_var()): + if not os.environ.get( + NeuronAcceleratorManager.get_visible_accelerator_ids_env_var() + ): return os.environ[ diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index a01603134f1f..e965f9fc269a 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -1,9 +1,9 @@ +import logging import os -import uuid -from dataclasses import dataclass import re import shutil -import logging +import uuid +from dataclasses import dataclass import ray from ray.train._internal.utils import get_address_and_port @@ -86,8 +86,8 @@ def _set_neuron_parallel_compile_env_vars(): # Compile previously extracted Neuron graphs def _neuron_compile_extracted_graphs(): try: - from libneuronxla.neuron_parallel_compile import parallel_compile from libneuronxla.neuron_cc_cache import CacheUrl + from libneuronxla.neuron_parallel_compile import parallel_compile except ImportError: raise ImportError( "libneuronxla must be installed to use Neuron parallel compilation." diff --git a/release/ray_release/byod/byod_train_trainium.sh b/release/ray_release/byod/byod_train_trainium.sh deleted file mode 100755 index 14e12a177c3b..000000000000 --- a/release/ray_release/byod/byod_train_trainium.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# This script is used to build an extra layer on top of the base anyscale/ray image -# to run the train_torch_trainium_integration test. - -sudo apt-get install curl -curl -O https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz -wget https://efa-installer.amazonaws.com/aws-efa-installer.key && gpg --import aws-efa-installer.key -gpg --fingerprint < aws-efa-installer.key -wget https://efa-installer.amazonaws.com/aws-efa-installer-latest.tar.gz.sig && gpg --verify ./aws-efa-installer-latest.tar.gz.sig -tar -xvf aws-efa-installer-latest.tar.gz -cd aws-efa-installer && sudo bash efa_installer.sh --yes --skip-kmod -cd || exit -sudo rm -rf aws-efa-installer-latest.tar.gz aws-efa-installer - -# Configure Linux for Neuron repository updates -. /etc/os-release && \ -echo "deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main" | \ -sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null - -wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - - -# Update OS packages -sudo apt-get update -y - -# Install Neuron Runtime -sudo apt-get install aws-neuronx-collectives=2.* -y -sudo apt-get install aws-neuronx-runtime-lib=2.* -y - -# Install Neuron Tools -sudo apt-get install aws-neuronx-tools=2.* -y - -pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com -pip install neuronx-cc==2.* torch-neuronx torchvision - -export FI_PROVIDER=efa -export FI_EFA_USE_DEVICE_RDMA=1 -export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 570f7f104966..6a80e55a1403 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -2264,25 +2264,6 @@ cluster: cluster_compute: rte_gce_small.yaml - alert: default - -- name: train_torch_trainium_integration - group: Train tests - working_dir: train_tests/trainium - - frequency: nightly - team: ml - - cluster: - byod: - post_build_script: byod_train_trainium.sh - cluster_compute: compute_aws.yaml - - run: - timeout: 1000 - script: python test_trainium.py - - alert: default - name: runtime_env_wheel_urls group: Runtime env tests diff --git a/release/train_tests/trainium/compute_aws.yaml b/release/train_tests/trainium/compute_aws.yaml deleted file mode 100644 index 5653288c09cd..000000000000 --- a/release/train_tests/trainium/compute_aws.yaml +++ /dev/null @@ -1,73 +0,0 @@ -cloud_id: cld_ywvPyX4hezrjv8RtrU28wTbr -region: us-east-2 - -max_workers: 1 - -head_node_type: - name: head_node - instance_type: trn1.32xlarge - -worker_node_types: - - name: worker_node - instance_type: trn1.32xlarge - max_workers: 1 - min_workers: 1 - use_spot: false - -aws: - TagSpecifications: - - ResourceType: "instance" - Tags: - - Key: ttl-hours - Value: '24' - -aws_advanced_configurations_json: - NetworkInterfaces: - - DeviceIndex: 0 - NetworkCardIndex: 0 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 1 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 2 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 3 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 4 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 5 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 6 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa - - DeviceIndex: 1 - NetworkCardIndex: 7 - SubnetId: subnet-06425628175889cdb - Groups: - - sg-02dbb4c6957712ab8 - InterfaceType: efa diff --git a/release/train_tests/trainium/test_trainium.py b/release/train_tests/trainium/test_trainium.py deleted file mode 100644 index ae88a9725143..000000000000 --- a/release/train_tests/trainium/test_trainium.py +++ /dev/null @@ -1,119 +0,0 @@ -import os -import time -import torch -from torch import nn - -from ray.train import ScalingConfig -from ray.train.torch import TorchTrainer -from ray.train.torch.xla import TorchXLAConfig - -from torchvision.datasets import mnist -from torch.utils.data import DataLoader -from torchvision.transforms import ToTensor - -# XLA imports -import torch_xla.core.xla_model as xm - -# XLA imports for parallel loader and multi-processing -import torch_xla.distributed.parallel_loader as pl -from torch.utils.data.distributed import DistributedSampler - -# Global constants -EPOCHS = 4 -WARMUP_STEPS = 2 -BATCH_SIZE = 32 - - -class MLP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.layers = nn.Sequential( - nn.Linear(28 * 28, 128), - nn.ReLU(), - nn.Linear(128, 10), - nn.ReLU(), - ) - - def forward(self, x): - return self.layers(x) - - -def train_func(): - # Load MNIST train dataset - if not xm.is_master_ordinal(): - xm.rendezvous("dataset_download") - train_dataset = mnist.MNIST( - root="/tmp/MNIST_DATA_train", train=True, download=True, transform=ToTensor() - ) - if xm.is_master_ordinal(): - xm.rendezvous("dataset_download") - - # XLA MP: get world size - world_size = xm.xrt_world_size() - # multi-processing: ensure each worker has same initial weights - torch.manual_seed(0) - - # Move model to device and declare optimizer and loss function - device = "xla" - model = MLP().to(device) - # For multiprocessing, scale up learning rate - optimizer = torch.optim.SGD(model.parameters(), lr=1e-5) - loss_fn = torch.nn.NLLLoss() - - # Prepare data loader - train_sampler = None - if world_size > 1: - train_sampler = DistributedSampler( - train_dataset, num_replicas=world_size, rank=xm.get_ordinal(), shuffle=True - ) - train_loader = DataLoader( - train_dataset, - batch_size=BATCH_SIZE, - sampler=train_sampler, - shuffle=False if train_sampler else True, - ) - # XLA MP: use MpDeviceLoader from torch_xla.distributed - train_device_loader = pl.MpDeviceLoader(train_loader, device) - - # Run the training loop - print("----------Training ---------------") - model.train() - for epoch in range(EPOCHS): - start = time.time() - for idx, (train_x, train_label) in enumerate(train_device_loader): - optimizer.zero_grad() - train_x = train_x.view(train_x.size(0), -1) - output = model(train_x) - loss = loss_fn(output, train_label) - loss.backward() - xm.optimizer_step( - optimizer - ) # XLA MP: performs grad allreduce and optimizer step - if idx < WARMUP_STEPS: # skip warmup iterations - start = time.time() - - # Compute statistics for the last epoch - interval = len(train_device_loader) - WARMUP_STEPS # skip warmup iterations - throughput = interval / (time.time() - start) - print("Train throughput (iter/sec): {}".format(throughput)) - print("Final loss is {:0.4f}".format(loss.detach().to("cpu"))) - - # Save checkpoint for evaluation (xm.save ensures only one process save) - os.makedirs("checkpoints", exist_ok=True) - checkpoint = {"state_dict": model.state_dict()} - xm.save(checkpoint, "checkpoints/checkpoint.pt") - - print("----------End Training ---------------") - - -# trn1.32xlarge -> 32 neuron_cores, 128 CPU -# 2x trn1.32xlarge -trainer = TorchTrainer( - train_loop_per_worker=train_func, - torch_config=TorchXLAConfig(), - scaling_config=ScalingConfig( - num_workers=64, resources_per_worker={"neuron_cores": 1} - ), -) -result = trainer.fit() -print(result) From b3402225de7bc5eecf85dabd4a974c3cdecd5e10 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 3 Apr 2024 12:43:36 -0700 Subject: [PATCH 22/22] fix ci test Signed-off-by: woshiyyya --- python/ray/_private/accelerators/neuron.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/ray/_private/accelerators/neuron.py b/python/ray/_private/accelerators/neuron.py index ce0fcba266dd..7ba9eeb0666b 100644 --- a/python/ray/_private/accelerators/neuron.py +++ b/python/ray/_private/accelerators/neuron.py @@ -108,11 +108,6 @@ def set_current_process_visible_accelerator_ids( if os.environ.get(NOSET_AWS_NEURON_RT_VISIBLE_CORES_ENV_VAR): return - if not os.environ.get( - NeuronAcceleratorManager.get_visible_accelerator_ids_env_var() - ): - return - os.environ[ NeuronAcceleratorManager.get_visible_accelerator_ids_env_var() ] = ",".join([str(i) for i in visible_neuron_core_ids])