diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 7aede51dda..00cf2f2103 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -13,7 +13,7 @@ tqdm==4.66.1 numpy==1.24.4 matplotlib librosa -torch==2.3 +torch==2.4 torchvision torchtext torchdata @@ -28,9 +28,9 @@ tensorboard jinja2==3.1.3 pytorch-lightning torchx -torchrl==0.3.0 -tensordict==0.3.0 -ax-platform +torchrl==0.5.0 +tensordict==0.5.0 +ax-platform>==0.4.0 nbformat>==5.9.2 datasets transformers @@ -68,4 +68,4 @@ pygame==2.1.2 pycocotools semilearn==0.3.2 torchao==0.0.3 -segment_anything==1.0 +segment_anything==1.0 \ No newline at end of file diff --git a/.jenkins/build.sh b/.jenkins/build.sh index dbec2d2552..8830c4259a 100755 --- a/.jenkins/build.sh +++ b/.jenkins/build.sh @@ -22,8 +22,8 @@ sudo apt-get install -y pandoc #Install PyTorch Nightly for test. # Nightly - pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html # Install 2.4 to merge all 2.4 PRs - uncomment to install nightly binaries (update the version as needed). -pip uninstall -y torch torchvision torchaudio torchtext torchdata -pip3 install torch==2.4.0 torchvision torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu124 +# pip uninstall -y torch torchvision torchaudio torchtext torchdata +# pip3 install torch==2.4.0 torchvision torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu124 # Install two language tokenizers for Translation with TorchText tutorial python -m spacy download en_core_web_sm diff --git a/_static/img/pinmem/pinmem.png b/_static/img/pinmem/pinmem.png new file mode 100644 index 0000000000..9d84e9d229 Binary files /dev/null and b/_static/img/pinmem/pinmem.png differ diff --git a/_static/img/pinmem/trace_streamed0_pinned0.png b/_static/img/pinmem/trace_streamed0_pinned0.png new file mode 100644 index 0000000000..dedac997b0 Binary files /dev/null and b/_static/img/pinmem/trace_streamed0_pinned0.png differ diff --git a/_static/img/pinmem/trace_streamed0_pinned1.png b/_static/img/pinmem/trace_streamed0_pinned1.png new file mode 100644 index 0000000000..2d5ff462e1 Binary files /dev/null and b/_static/img/pinmem/trace_streamed0_pinned1.png differ diff --git a/_static/img/pinmem/trace_streamed1_pinned0.png b/_static/img/pinmem/trace_streamed1_pinned0.png new file mode 100644 index 0000000000..130182a197 Binary files /dev/null and b/_static/img/pinmem/trace_streamed1_pinned0.png differ diff --git a/_static/img/pinmem/trace_streamed1_pinned1.png b/_static/img/pinmem/trace_streamed1_pinned1.png new file mode 100644 index 0000000000..c596fcdb69 Binary files /dev/null and b/_static/img/pinmem/trace_streamed1_pinned1.png differ diff --git a/advanced_source/coding_ddpg.py b/advanced_source/coding_ddpg.py index 7dd3acf238..c634932971 100644 --- a/advanced_source/coding_ddpg.py +++ b/advanced_source/coding_ddpg.py @@ -182,7 +182,7 @@ # Later, we will see how the target parameters should be updated in TorchRL. # -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential def _init( @@ -290,12 +290,11 @@ def _loss_actor( ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys) # Get an action from the actor network: since we made it functional, we need to pass the params - td_copy = self.actor_network(td_copy, params=self.actor_network_params) + with self.actor_network_params.to_module(self.actor_network): + td_copy = self.actor_network(td_copy) # get the value associated with that action - td_copy = self.value_network( - td_copy, - params=self.value_network_params.detach(), - ) + with self.value_network_params.detach().to_module(self.value_network): + td_copy = self.value_network(td_copy) return -td_copy.get("state_action_value") @@ -317,7 +316,8 @@ def _loss_value( td_copy = tensordict.clone() # V(s, a) - self.value_network(td_copy, params=self.value_network_params) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get("state_action_value").squeeze(-1) # we manually reconstruct the parameters of the actor-critic, where the first @@ -332,9 +332,8 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - target_value = self.value_estimator.value_estimate( - tensordict, target_params=target_params - ).squeeze(-1) + with target_params.to_module(self.actor_critic): + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) @@ -717,7 +716,7 @@ def get_env_stats(): ActorCriticWrapper, DdpgMlpActor, DdpgMlpQNet, - OrnsteinUhlenbeckProcessWrapper, + OrnsteinUhlenbeckProcessModule, ProbabilisticActor, TanhDelta, ValueOperator, @@ -776,15 +775,18 @@ def make_ddpg_actor( # Exploration # ~~~~~~~~~~~ # -# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` +# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule` # exploration module, as suggested in the original paper. # Let's define the number of frames before OU noise reaches its minimum value annealing_frames = 1_000_000 -actor_model_explore = OrnsteinUhlenbeckProcessWrapper( +actor_model_explore = TensorDictSequential( actor, - annealing_num_steps=annealing_frames, -).to(device) + OrnsteinUhlenbeckProcessModule( + spec=actor.spec.clone(), + annealing_num_steps=annealing_frames, + ).to(device), +) if device == torch.device("cpu"): actor_model_explore.share_memory() @@ -1168,7 +1170,7 @@ def ceil_div(x, y): ) # update the exploration strategy - actor_model_explore.step(current_frames) + actor_model_explore[1].step(current_frames) collector.shutdown() del collector diff --git a/beginner_source/introyt/tensors_deeper_tutorial.py b/beginner_source/introyt/tensors_deeper_tutorial.py index b5f9dc0bc9..d7293dfe29 100644 --- a/beginner_source/introyt/tensors_deeper_tutorial.py +++ b/beginner_source/introyt/tensors_deeper_tutorial.py @@ -448,17 +448,19 @@ m2 = torch.tensor([[3., 0.], [0., 3.]]) # three times identity matrix print('\nVectors & Matrices:') -print(torch.cross(v2, v1)) # negative of z unit vector (v1 x v2 == -v2 x v1) +print(torch.linalg.cross(v2, v1)) # negative of z unit vector (v1 x v2 == -v2 x v1) print(m1) -m3 = torch.matmul(m1, m2) +m3 = torch.linalg.matmul(m1, m2) print(m3) # 3 times m1 -print(torch.svd(m3)) # singular value decomposition +print(torch.linalg.svd(m3)) # singular value decomposition ################################################################################## # This is a small sample of operations. For more details and the full inventory of # math functions, have a look at the # `documentation `__. +# For more details and the full inventory of linear algebra operations, have a +# look at this `documentation `__. # # Altering Tensors in Place # ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/beginner_source/knowledge_distillation_tutorial.py b/beginner_source/knowledge_distillation_tutorial.py index 4601352ff0..49ab9a134d 100644 --- a/beginner_source/knowledge_distillation_tutorial.py +++ b/beginner_source/knowledge_distillation_tutorial.py @@ -352,7 +352,7 @@ def train_knowledge_distillation(teacher, student, train_loader, epochs, learnin # Cosine loss minimization run # ---------------------------- # Feel free to play around with the temperature parameter that controls the softness of the softmax function and the loss coefficients. -# In neural networks, it is easy to include to include additional loss functions to the main objectives to achieve goals like better generalization. +# In neural networks, it is easy to include additional loss functions to the main objectives to achieve goals like better generalization. # Let's try including an objective for the student, but now let's focus on their hidden states rather than their output layers. # Our goal is to convey information from the teacher's representation to the student by including a naive loss function, # whose minimization implies that the flattened vectors that are subsequently passed to the classifiers have become more *similar* as the loss decreases. diff --git a/en-wordlist.txt b/en-wordlist.txt index f4b5d6d3bc..56951442db 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -1,3 +1,4 @@ + ACL ADI AOT @@ -50,6 +51,7 @@ DDP DDPG DDQN DLRM +DMA DNN DQN DataLoaders @@ -68,6 +70,8 @@ Ecker ExportDB FC FGSM +tensordict +DataLoader's FLAVA FSDP FX @@ -139,6 +143,7 @@ MKLDNN MLP MLPs MNIST +MPS MUC MacBook MacOS @@ -222,6 +227,7 @@ STR SVE SciPy Sequentials +Sharding Sigmoid SoTA Sohn @@ -257,6 +263,7 @@ VLDB VQA VS Code ViT +Volterra WMT WSI WSIs @@ -339,11 +346,11 @@ dataset’s deallocation decompositions decorrelated -devicemesh deserialize deserialized desynchronization deterministically +devicemesh dimensionality dir discontiguous @@ -388,6 +395,7 @@ hessian hessians histoencoder histologically +homonymous hotspot hvp hyperparameter @@ -463,6 +471,7 @@ optimizer's optimizers otsu overfitting +pageable parallelizable parallelization parametrization @@ -527,7 +536,6 @@ runtime runtimes scalable sharded -Sharding softmax sparsified sparsifier @@ -616,4 +624,4 @@ warmstarting warmup webp wsi -wsis +wsis \ No newline at end of file diff --git a/index.rst b/index.rst index 1ddba17d1b..91517834fd 100644 --- a/index.rst +++ b/index.rst @@ -3,11 +3,12 @@ Welcome to PyTorch Tutorials **What's new in PyTorch tutorials?** -* `Using User-Defined Triton Kernels with torch.compile `__ -* `Large Scale Transformer model training with Tensor Parallel (TP) `__ -* `Accelerating BERT with semi-structured (2:4) sparsity `__ -* `torch.export Tutorial with torch.export.Dim `__ -* `Extension points in nn.Module for load_state_dict and tensor subclasses `__ +* `A guide on good usage of non_blocking and pin_memory() in PyTorch `__ +* `Introduction to Distributed Pipeline Parallelism `__ +* `Introduction to Libuv TCPStore Backend `__ +* `Asynchronous Saving with Distributed Checkpoint (DCP) `__ +* `Python Custom Operators `__ +* Updated `Getting Started with DeviceMesh `__ .. raw:: html @@ -93,6 +94,13 @@ Welcome to PyTorch Tutorials :link: intermediate/tensorboard_tutorial.html :tags: Interpretability,Getting-Started,TensorBoard +.. customcarditem:: + :header: Good usage of `non_blocking` and `pin_memory()` in PyTorch + :card_description: A guide on best practices to copy data from CPU to GPU. + :image: _static/img/pinmem.png + :link: intermediate/pinmem_nonblock.html + :tags: Getting-Started + .. Image/Video .. customcarditem:: @@ -969,6 +977,7 @@ Additional Resources beginner/pytorch_with_examples beginner/nn_tutorial intermediate/tensorboard_tutorial + intermediate/pinmem_nonblock .. toctree:: :maxdepth: 2 diff --git a/intermediate_source/FSDP_adavnced_tutorial.rst b/intermediate_source/FSDP_adavnced_tutorial.rst index 5a0cb5376d..f7ee1e7de1 100644 --- a/intermediate_source/FSDP_adavnced_tutorial.rst +++ b/intermediate_source/FSDP_adavnced_tutorial.rst @@ -502,7 +502,7 @@ layer class (holding MHSA and FFN). model = FSDP(model, - fsdp_auto_wrap_policy=t5_auto_wrap_policy) + auto_wrap_policy=t5_auto_wrap_policy) To see the wrapped model, you can easily print the model and visually inspect the sharding and FSDP units as well. diff --git a/intermediate_source/FSDP_tutorial.rst b/intermediate_source/FSDP_tutorial.rst index 034225ec46..9b9845667f 100644 --- a/intermediate_source/FSDP_tutorial.rst +++ b/intermediate_source/FSDP_tutorial.rst @@ -70,7 +70,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”. 1.2 Import necessary packages .. note:: - This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy`. + This tutorial is intended for PyTorch versions 1.12 and later. If you are using an earlier version, replace all instances of `size_based_auto_wrap_policy` with `default_auto_wrap_policy` and `fsdp_auto_wrap_policy` with `auto_wrap_policy`. .. code-block:: python @@ -308,7 +308,7 @@ We have recorded cuda events to measure the time of FSDP model specifics. The CU CUDA event elapsed time on training loop 40.67462890625sec Wrapping the model with FSDP, the model will look as follows, we can see the model has been wrapped in one FSDP unit. -Alternatively, we will look at adding the fsdp_auto_wrap_policy next and will discuss the differences. +Alternatively, we will look at adding the auto_wrap_policy next and will discuss the differences. .. code-block:: bash @@ -335,12 +335,12 @@ The following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarg FSDP Peak Memory Usage -Applying *fsdp_auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency. +Applying *auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency. The way it works is that, suppose your model contains 100 Linear layers. If you do FSDP(model), there will only be one FSDP unit which wraps the entire model. In that case, the allgather would collect the full parameters for all 100 linear layers, and hence won't save CUDA memory for parameter sharding. Also, there is only one blocking allgather call for the all 100 linear layers, there will not be communication and computation overlapping between layers. -To avoid that, you can pass in an fsdp_auto_wrap_policy, which will seal the current FSDP unit and start a new one automatically when the specified condition is met (e.g., size limit). +To avoid that, you can pass in an auto_wrap_policy, which will seal the current FSDP unit and start a new one automatically when the specified condition is met (e.g., size limit). In that way you will have multiple FSDP units, and only one FSDP unit needs to collect full parameters at a time. E.g., suppose you have 5 FSDP units, and each wraps 20 linear layers. Then, in the forward, the 1st FSDP unit will allgather parameters for the first 20 linear layers, do computation, discard the parameters and then move on to the next 20 linear layers. So, at any point in time, each rank only materializes parameters/grads for 20 linear layers instead of 100. @@ -358,9 +358,9 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning model = Net().to(rank) model = FSDP(model, - fsdp_auto_wrap_policy=my_auto_wrap_policy) + auto_wrap_policy=my_auto_wrap_policy) -Applying the fsdp_auto_wrap_policy, the model would be as follows: +Applying the auto_wrap_policy, the model would be as follows: .. code-block:: bash @@ -411,7 +411,7 @@ In 2.4 we just add it to the FSDP wrapper .. code-block:: python model = FSDP(model, - fsdp_auto_wrap_policy=my_auto_wrap_policy, + auto_wrap_policy=my_auto_wrap_policy, cpu_offload=CPUOffload(offload_params=True)) diff --git a/intermediate_source/TCPStore_libuv_backend.rst b/intermediate_source/TCPStore_libuv_backend.rst index 34037b5be7..1e285eba7c 100644 --- a/intermediate_source/TCPStore_libuv_backend.rst +++ b/intermediate_source/TCPStore_libuv_backend.rst @@ -8,7 +8,8 @@ Introduction to Libuv TCPStore Backend .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn - :class-card: card-prerequisites + :class-card: card-prerequisites + * What is the new TCPStore backend * Compare the new libuv backend against the legacy backend * How to enable to use the legacy backend diff --git a/intermediate_source/ax_multiobjective_nas_tutorial.py b/intermediate_source/ax_multiobjective_nas_tutorial.py index 79b096b9e6..0f1ae21a55 100644 --- a/intermediate_source/ax_multiobjective_nas_tutorial.py +++ b/intermediate_source/ax_multiobjective_nas_tutorial.py @@ -232,21 +232,21 @@ def trainer( # we get the logic to read and parse the TensorBoard logs for free. # -from ax.metrics.tensorboard import TensorboardCurveMetric +from ax.metrics.tensorboard import TensorboardMetric +from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer - -class MyTensorboardMetric(TensorboardCurveMetric): +class MyTensorboardMetric(TensorboardMetric): # NOTE: We need to tell the new TensorBoard metric how to get the id / # file handle for the TensorBoard logs from a trial. In this case # our convention is to just save a separate file per trial in # the prespecified log dir. - @classmethod - def get_ids_from_trials(cls, trials): - return { - trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix() - for trial in trials - } + def _get_event_multiplexer_for_trial(self, trial): + mul = event_multiplexer.EventMultiplexer(max_reload_threads=20) + mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None) + mul.Reload() + + return mul # This indicates whether the metric is queryable while the trial is # still running. We don't use this in the current tutorial, but Ax @@ -266,12 +266,12 @@ def is_available_while_running(cls): val_acc = MyTensorboardMetric( name="val_acc", - curve_name="val_acc", + tag="val_acc", lower_is_better=False, ) model_num_params = MyTensorboardMetric( name="num_params", - curve_name="num_params", + tag="num_params", lower_is_better=True, ) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index 991a0ff8bd..6ea0955939 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -298,7 +298,7 @@ # either by passing a string or an action-spec. This allows us to use # Categorical (sometimes called "sparse") encoding or the one-hot version of it. # -qval = QValueModule(action_space=env.action_spec) +qval = QValueModule(spec=env.action_spec) ###################################################################### # .. note:: diff --git a/intermediate_source/pinmem_nonblock.py b/intermediate_source/pinmem_nonblock.py new file mode 100644 index 0000000000..fa69507a0e --- /dev/null +++ b/intermediate_source/pinmem_nonblock.py @@ -0,0 +1,728 @@ +# -*- coding: utf-8 -*- +""" +A guide on good usage of ``non_blocking`` and ``pin_memory()`` in PyTorch +========================================================================= + +**Author**: `Vincent Moens `_ + +Introduction +------------ + +Transferring data from the CPU to the GPU is fundamental in many PyTorch applications. +It's crucial for users to understand the most effective tools and options available for moving data between devices. +This tutorial examines two key methods for device-to-device data transfer in PyTorch: +:meth:`~torch.Tensor.pin_memory` and :meth:`~torch.Tensor.to` with the ``non_blocking=True`` option. + +What you will learn +~~~~~~~~~~~~~~~~~~~ + +Optimizing the transfer of tensors from the CPU to the GPU can be achieved through asynchronous transfers and memory +pinning. However, there are important considerations: + +- Using ``tensor.pin_memory().to(device, non_blocking=True)`` can be up to twice as slow as a straightforward ``tensor.to(device)``. +- Generally, ``tensor.to(device, non_blocking=True)`` is an effective choice for enhancing transfer speed. +- While ``cpu_tensor.to("cuda", non_blocking=True).mean()`` executes correctly, attempting + ``cuda_tensor.to("cpu", non_blocking=True).mean()`` will result in erroneous outputs. + +Preamble +~~~~~~~~ + +The performance reported in this tutorial are conditioned on the system used to build the tutorial. +Although the conclusions are applicable across different systems, the specific observations may vary slightly +depending on the hardware available, especially on older hardware. +The primary objective of this tutorial is to offer a theoretical framework for understanding CPU to GPU data transfers. +However, any design decisions should be tailored to individual cases and guided by benchmarked throughput measurements, +as well as the specific requirements of the task at hand. + +""" + +import torch + +assert torch.cuda.is_available(), "A cuda device is required to run this tutorial" + + +###################################################################### +# +# This tutorial requires tensordict to be installed. If you don't have tensordict in your environment yet, install it +# by running the following command in a separate cell: +# +# .. code-block:: bash +# +# # Install tensordict with the following command +# !pip3 install tensordict +# +# We start by outlining the theory surrounding these concepts, and then move to concrete test examples of the features. +# +# +# Background +# ---------- +# +# .. _pinned_memory_background: +# +# Memory management basics +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# .. _pinned_memory_memory: +# +# When one creates a CPU tensor in PyTorch, the content of this tensor needs to be placed +# in memory. The memory we talk about here is a rather complex concept worth looking at carefully. +# We distinguish two types of memory that are handled by the Memory Management Unit: the RAM (for simplicity) +# and the swap space on disk (which may or may not be the hard drive). Together, the available space in disk and RAM (physical memory) +# make up the virtual memory, which is an abstraction of the total resources available. +# In short, the virtual memory makes it so that the available space is larger than what can be found on RAM in isolation +# and creates the illusion that the main memory is larger than it actually is. +# +# In normal circumstances, a regular CPU tensor is pageable which means that it is divided in blocks called pages that +# can live anywhere in the virtual memory (both in RAM or on disk). As mentioned earlier, this has the advantage that +# the memory seems larger than what the main memory actually is. +# +# Typically, when a program accesses a page that is not in RAM, a "page fault" occurs and the operating system (OS) then brings +# back this page into RAM ("swap in" or "page in"). +# In turn, the OS may have to swap out (or "page out") another page to make room for the new page. +# +# In contrast to pageable memory, a pinned (or page-locked or non-pageable) memory is a type of memory that cannot +# be swapped out to disk. +# It allows for faster and more predictable access times, but has the downside that it is more limited than the +# pageable memory (aka the main memory). +# +# .. figure:: /_static/img/pinmem/pinmem.png +# :alt: +# +# CUDA and (non-)pageable memory +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# .. _pinned_memory_cuda_pageable_memory: +# +# To understand how CUDA copies a tensor from CPU to CUDA, let's consider the two scenarios above: +# +# - If the memory is page-locked, the device can access the memory directly in the main memory. The memory addresses are well +# defined and functions that need to read these data can be significantly accelerated. +# - If the memory is pageable, all the pages will have to be brought to the main memory before being sent to the GPU. +# This operation may take time and is less predictable than when executed on page-locked tensors. +# +# More precisely, when CUDA sends pageable data from CPU to GPU, it must first create a page-locked copy of that data +# before making the transfer. +# +# Asynchronous vs. Synchronous Operations with ``non_blocking=True`` (CUDA ``cudaMemcpyAsync``) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# .. _pinned_memory_async_sync: +# +# When executing a copy from a host (e.g., CPU) to a device (e.g., GPU), the CUDA toolkit offers modalities to do these +# operations synchronously or asynchronously with respect to the host. +# +# In practice, when calling :meth:`~torch.Tensor.to`, PyTorch always makes a call to +# `cudaMemcpyAsync `_. +# If ``non_blocking=False`` (default), a ``cudaStreamSynchronize`` will be called after each and every ``cudaMemcpyAsync``, making +# the call to :meth:`~torch.Tensor.to` blocking in the main thread. +# If ``non_blocking=True``, no synchronization is triggered, and the main thread on the host is not blocked. +# Therefore, from the host perspective, multiple tensors can be sent to the device simultaneously, +# as the thread does not need to wait for one transfer to be completed to initiate the other. +# +# .. note:: In general, the transfer is blocking on the device side (even if it isn't on the host side): +# the copy on the device cannot occur while another operation is being executed. +# However, in some advanced scenarios, a copy and a kernel execution can be done simultaneously on the GPU side. +# As the following example will show, three requirements must be met to enable this: +# +# 1. The device must have at least one free DMA (Direct Memory Access) engine. Modern GPU architectures such as Volterra, +# Tesla, or H100 devices have more than one DMA engine. +# +# 2. The transfer must be done on a separate, non-default cuda stream. In PyTorch, cuda streams can be handles using +# :class:`~torch.cuda.Stream`. +# +# 3. The source data must be in pinned memory. +# +# We demonstrate this by running profiles on the following script. +# + +import contextlib + +from torch.cuda import Stream + + +s = Stream() + +torch.manual_seed(42) +t1_cpu_pinned = torch.randn(1024**2 * 5, pin_memory=True) +t2_cpu_paged = torch.randn(1024**2 * 5, pin_memory=False) +t3_cuda = torch.randn(1024**2 * 5, device="cuda:0") + +assert torch.cuda.is_available() +device = torch.device("cuda", torch.cuda.current_device()) + + +# The function we want to profile +def inner(pinned: bool, streamed: bool): + with torch.cuda.stream(s) if streamed else contextlib.nullcontext(): + if pinned: + t1_cuda = t1_cpu_pinned.to(device, non_blocking=True) + else: + t2_cuda = t2_cpu_paged.to(device, non_blocking=True) + t_star_cuda_h2d_event = s.record_event() + # This operation can be executed during the CPU to GPU copy if and only if the tensor is pinned and the copy is + # done in the other stream + t3_cuda_mul = t3_cuda * t3_cuda * t3_cuda + t3_cuda_h2d_event = torch.cuda.current_stream().record_event() + t_star_cuda_h2d_event.synchronize() + t3_cuda_h2d_event.synchronize() + + +# Our profiler: profiles the `inner` function and stores the results in a .json file +def benchmark_with_profiler( + pinned, + streamed, +) -> None: + torch._C._profiler._set_cuda_sync_enabled_val(True) + wait, warmup, active = 1, 1, 2 + num_steps = wait + warmup + active + rank = 0 + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=1, skip_first=1 + ), + ) as prof: + for step_idx in range(1, num_steps + 1): + inner(streamed=streamed, pinned=pinned) + if rank is None or rank == 0: + prof.step() + prof.export_chrome_trace(f"trace_streamed{int(streamed)}_pinned{int(pinned)}.json") + + +###################################################################### +# Loading these profile traces in chrome (``chrome://tracing``) shows the following results: first, let's see +# what happens if both the arithmetic operation on ``t3_cuda`` is executed after the pageable tensor is sent to GPU +# in the main stream: +# + +benchmark_with_profiler(streamed=False, pinned=False) + +###################################################################### +# .. figure:: /_static/img/pinmem/trace_streamed0_pinned0.png +# :alt: +# +# Using a pinned tensor doesn't change the trace much, both operations are still executed consecutively: + +benchmark_with_profiler(streamed=False, pinned=True) + +###################################################################### +# +# .. figure:: /_static/img/pinmem/trace_streamed0_pinned1.png +# :alt: +# +# Sending a pageable tensor to GPU on a separate stream is also a blocking operation: + +benchmark_with_profiler(streamed=True, pinned=False) + +###################################################################### +# +# .. figure:: /_static/img/pinmem/trace_streamed1_pinned0.png +# :alt: +# +# Only pinned tensors copies to GPU on a separate stream overlap with another cuda kernel executed on +# the main stream: + +benchmark_with_profiler(streamed=True, pinned=True) + +###################################################################### +# +# .. figure:: /_static/img/pinmem/trace_streamed1_pinned1.png +# :alt: +# +# A PyTorch perspective +# --------------------- +# +# .. _pinned_memory_pt_perspective: +# +# ``pin_memory()`` +# ~~~~~~~~~~~~~~~~ +# +# .. _pinned_memory_pinned: +# +# PyTorch offers the possibility to create and send tensors to page-locked memory through the +# :meth:`~torch.Tensor.pin_memory` method and constructor arguments. +# CPU tensors on a machine where CUDA is initialized can be cast to pinned memory through the :meth:`~torch.Tensor.pin_memory` +# method. Importantly, ``pin_memory`` is blocking on the main thread of the host: it will wait for the tensor to be copied to +# page-locked memory before executing the next operation. +# New tensors can be directly created in pinned memory with functions like :func:`~torch.zeros`, :func:`~torch.ones` and other +# constructors. +# +# Let us check the speed of pinning memory and sending tensors to CUDA: + + +import torch +import gc +from torch.utils.benchmark import Timer +import matplotlib.pyplot as plt + + +def timer(cmd): + median = ( + Timer(cmd, globals=globals()) + .adaptive_autorange(min_run_time=1.0, max_run_time=20.0) + .median + * 1000 + ) + print(f"{cmd}: {median: 4.4f} ms") + return median + + +# A tensor in pageable memory +pageable_tensor = torch.randn(1_000_000) + +# A tensor in page-locked (pinned) memory +pinned_tensor = torch.randn(1_000_000, pin_memory=True) + +# Runtimes: +pageable_to_device = timer("pageable_tensor.to('cuda:0')") +pinned_to_device = timer("pinned_tensor.to('cuda:0')") +pin_mem = timer("pageable_tensor.pin_memory()") +pin_mem_to_device = timer("pageable_tensor.pin_memory().to('cuda:0')") + +# Ratios: +r1 = pinned_to_device / pageable_to_device +r2 = pin_mem_to_device / pageable_to_device + +# Create a figure with the results +fig, ax = plt.subplots() + +xlabels = [0, 1, 2] +bar_labels = [ + "pageable_tensor.to(device) (1x)", + f"pinned_tensor.to(device) ({r1:4.2f}x)", + f"pageable_tensor.pin_memory().to(device) ({r2:4.2f}x)" + f"\npin_memory()={100*pin_mem/pin_mem_to_device:.2f}% of runtime.", +] +values = [pageable_to_device, pinned_to_device, pin_mem_to_device] +colors = ["tab:blue", "tab:red", "tab:orange"] +ax.bar(xlabels, values, label=bar_labels, color=colors) + +ax.set_ylabel("Runtime (ms)") +ax.set_title("Device casting runtime (pin-memory)") +ax.set_xticks([]) +ax.legend() + +plt.show() + +# Clear tensors +del pageable_tensor, pinned_tensor +_ = gc.collect() + +###################################################################### +# +# We can observe that casting a pinned-memory tensor to GPU is indeed much faster than a pageable tensor, because under +# the hood, a pageable tensor must be copied to pinned memory before being sent to GPU. +# +# However, contrary to a somewhat common belief, calling :meth:`~torch.Tensor.pin_memory()` on a pageable tensor before +# casting it to GPU should not bring any significant speed-up, on the contrary this call is usually slower than just +# executing the transfer. This makes sense, since we're actually asking Python to execute an operation that CUDA will +# perform anyway before copying the data from host to device. +# +# .. note:: The PyTorch implementation of +# `pin_memory `_ +# which relies on creating a brand new storage in pinned memory through `cudaHostAlloc `_ +# could be, in rare cases, faster than transitioning data in chunks as ``cudaMemcpy`` does. +# Here too, the observation may vary depending on the available hardware, the size of the tensors being sent or +# the amount of available RAM. +# +# ``non_blocking=True`` +# ~~~~~~~~~~~~~~~~~~~~~ +# +# .. _pinned_memory_non_blocking: +# +# As mentioned earlier, many PyTorch operations have the option of being executed asynchronously with respect to the host +# through the ``non_blocking`` argument. +# +# Here, to account accurately of the benefits of using ``non_blocking``, we will design a slightly more complex +# experiment since we want to assess how fast it is to send multiple tensors to GPU with and without calling +# ``non_blocking``. +# + + +# A simple loop that copies all tensors to cuda +def copy_to_device(*tensors): + result = [] + for tensor in tensors: + result.append(tensor.to("cuda:0")) + return result + + +# A loop that copies all tensors to cuda asynchronously +def copy_to_device_nonblocking(*tensors): + result = [] + for tensor in tensors: + result.append(tensor.to("cuda:0", non_blocking=True)) + # We need to synchronize + torch.cuda.synchronize() + return result + + +# Create a list of tensors +tensors = [torch.randn(1000) for _ in range(1000)] +to_device = timer("copy_to_device(*tensors)") +to_device_nonblocking = timer("copy_to_device_nonblocking(*tensors)") + +# Ratio +r1 = to_device_nonblocking / to_device + +# Plot the results +fig, ax = plt.subplots() + +xlabels = [0, 1] +bar_labels = [f"to(device) (1x)", f"to(device, non_blocking=True) ({r1:4.2f}x)"] +colors = ["tab:blue", "tab:red"] +values = [to_device, to_device_nonblocking] + +ax.bar(xlabels, values, label=bar_labels, color=colors) + +ax.set_ylabel("Runtime (ms)") +ax.set_title("Device casting runtime (non-blocking)") +ax.set_xticks([]) +ax.legend() + +plt.show() + + +###################################################################### +# To get a better sense of what is happening here, let us profile these two functions: + + +from torch.profiler import profile, ProfilerActivity + + +def profile_mem(cmd): + with profile(activities=[ProfilerActivity.CPU]) as prof: + exec(cmd) + print(cmd) + print(prof.key_averages().table(row_limit=10)) + + +###################################################################### +# Let's see the call stack with a regular ``to(device)`` first: +# + +print("Call to `to(device)`", profile_mem("copy_to_device(*tensors)")) + +###################################################################### +# and now the ``non_blocking`` version: +# + +print( + "Call to `to(device, non_blocking=True)`", + profile_mem("copy_to_device_nonblocking(*tensors)"), +) + + +###################################################################### +# The results are without any doubt better when using ``non_blocking=True``, as all transfers are initiated simultaneously +# on the host side and only one synchronization is done. +# +# The benefit will vary depending on the number and the size of the tensors as well as depending on the hardware being +# used. +# +# .. note:: Interestingly, the blocking ``to("cuda")`` actually performs the same asynchronous device casting operation +# (``cudaMemcpyAsync``) as the one with ``non_blocking=True`` with a synchronization point after each copy. +# +# Synergies +# ~~~~~~~~~ +# +# .. _pinned_memory_synergies: +# +# Now that we have made the point that data transfer of tensors already in pinned memory to GPU is faster than from +# pageable memory, and that we know that doing these transfers asynchronously is also faster than synchronously, we can +# benchmark combinations of these approaches. First, let's write a couple of new functions that will call ``pin_memory`` +# and ``to(device)`` on each tensor: +# + + +def pin_copy_to_device(*tensors): + result = [] + for tensor in tensors: + result.append(tensor.pin_memory().to("cuda:0")) + return result + + +def pin_copy_to_device_nonblocking(*tensors): + result = [] + for tensor in tensors: + result.append(tensor.pin_memory().to("cuda:0", non_blocking=True)) + # We need to synchronize + torch.cuda.synchronize() + return result + + +###################################################################### +# The benefits of using :meth:`~torch.Tensor.pin_memory` are more pronounced for +# somewhat large batches of large tensors: +# + +tensors = [torch.randn(1_000_000) for _ in range(1000)] +page_copy = timer("copy_to_device(*tensors)") +page_copy_nb = timer("copy_to_device_nonblocking(*tensors)") + +tensors_pinned = [torch.randn(1_000_000, pin_memory=True) for _ in range(1000)] +pinned_copy = timer("copy_to_device(*tensors_pinned)") +pinned_copy_nb = timer("copy_to_device_nonblocking(*tensors_pinned)") + +pin_and_copy = timer("pin_copy_to_device(*tensors)") +pin_and_copy_nb = timer("pin_copy_to_device_nonblocking(*tensors)") + +# Plot +strategies = ("pageable copy", "pinned copy", "pin and copy") +blocking = { + "blocking": [page_copy, pinned_copy, pin_and_copy], + "non-blocking": [page_copy_nb, pinned_copy_nb, pin_and_copy_nb], +} + +x = torch.arange(3) +width = 0.25 +multiplier = 0 + + +fig, ax = plt.subplots(layout="constrained") + +for attribute, runtimes in blocking.items(): + offset = width * multiplier + rects = ax.bar(x + offset, runtimes, width, label=attribute) + ax.bar_label(rects, padding=3, fmt="%.2f") + multiplier += 1 + +# Add some text for labels, title and custom x-axis tick labels, etc. +ax.set_ylabel("Runtime (ms)") +ax.set_title("Runtime (pin-mem and non-blocking)") +ax.set_xticks([0, 1, 2]) +ax.set_xticklabels(strategies) +plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") +ax.legend(loc="upper left", ncols=3) + +plt.show() + +del tensors, tensors_pinned +_ = gc.collect() + + +###################################################################### +# Other copy directions (GPU -> CPU, CPU -> MPS) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# .. _pinned_memory_other_direction: +# +# Until now, we have operated under the assumption that asynchronous copies from the CPU to the GPU are safe. +# This is generally true because CUDA automatically handles synchronization to ensure that the data being accessed is +# valid at read time. +# However, this guarantee does not extend to transfers in the opposite direction, from GPU to CPU. +# Without explicit synchronization, these transfers offer no assurance that the copy will be complete at the time of +# data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage: +# + + +tensor = ( + torch.arange(1, 1_000_000, dtype=torch.double, device="cuda") + .expand(100, 999999) + .clone() +) +torch.testing.assert_close( + tensor.mean(), torch.tensor(500_000, dtype=torch.double, device="cuda") +), tensor.mean() +try: + i = -1 + for i in range(100): + cpu_tensor = tensor.to("cpu", non_blocking=True) + torch.testing.assert_close( + cpu_tensor.mean(), torch.tensor(500_000, dtype=torch.double) + ) + print("No test failed with non_blocking") +except AssertionError: + print(f"{i}th test failed with non_blocking. Skipping remaining tests") +try: + i = -1 + for i in range(100): + cpu_tensor = tensor.to("cpu", non_blocking=True) + torch.cuda.synchronize() + torch.testing.assert_close( + cpu_tensor.mean(), torch.tensor(500_000, dtype=torch.double) + ) + print("No test failed with synchronize") +except AssertionError: + print(f"One test failed with synchronize: {i}th assertion!") + + +###################################################################### +# The same considerations apply to copies from the CPU to non-CUDA devices, such as MPS. +# Generally, asynchronous copies to a device are safe without explicit synchronization only when the target is a +# CUDA-enabled device. +# +# In summary, copying data from CPU to GPU is safe when using ``non_blocking=True``, but for any other direction, +# ``non_blocking=True`` can still be used but the user must make sure that a device synchronization is executed before +# the data is accessed. +# +# Practical recommendations +# ------------------------- +# +# .. _pinned_memory_recommendations: +# +# We can now wrap up some early recommendations based on our observations: +# +# In general, ``non_blocking=True`` will provide good throughput, regardless of whether the original tensor is or +# isn't in pinned memory. +# If the tensor is already in pinned memory, the transfer can be accelerated, but sending it to +# pin memory manually from python main thread is a blocking operation on the host, and hence will annihilate much of +# the benefit of using ``non_blocking=True`` (as CUDA does the `pin_memory` transfer anyway). +# +# One might now legitimately ask what use there is for the :meth:`~torch.Tensor.pin_memory` method. +# In the following section, we will explore further how this can be used to accelerate the data transfer even more. +# +# Additional considerations +# ------------------------- +# +# .. _pinned_memory_considerations: +# +# PyTorch notoriously provides a :class:`~torch.utils.data.DataLoader` class whose constructor accepts a +# ``pin_memory`` argument. +# Considering our previous discussion on ``pin_memory``, you might wonder how the ``DataLoader`` manages to +# accelerate data transfers if memory pinning is inherently blocking. +# +# The key lies in the DataLoader's use of a separate thread to handle the transfer of data from pageable to pinned +# memory, thus preventing any blockage in the main thread. +# +# To illustrate this, we will use the TensorDict primitive from the homonymous library. +# When invoking :meth:`~tensordict.TensorDict.to`, the default behavior is to send tensors to the device asynchronously, +# followed by a single call to ``torch.device.synchronize()`` afterwards. +# +# Additionally, ``TensorDict.to()`` includes a ``non_blocking_pin`` option which initiates multiple threads to execute +# ``pin_memory()`` before proceeding with to ``to(device)``. +# This approach can further accelerate data transfers, as demonstrated in the following example. +# +# + +from tensordict import TensorDict +import torch +from torch.utils.benchmark import Timer +import matplotlib.pyplot as plt + +# Create the dataset +td = TensorDict({str(i): torch.randn(1_000_000) for i in range(1000)}) + +# Runtimes +copy_blocking = timer("td.to('cuda:0', non_blocking=False)") +copy_non_blocking = timer("td.to('cuda:0')") +copy_pin_nb = timer("td.to('cuda:0', non_blocking_pin=True, num_threads=0)") +copy_pin_multithread_nb = timer("td.to('cuda:0', non_blocking_pin=True, num_threads=4)") + +# Rations +r1 = copy_non_blocking / copy_blocking +r2 = copy_pin_nb / copy_blocking +r3 = copy_pin_multithread_nb / copy_blocking + +# Figure +fig, ax = plt.subplots() + +xlabels = [0, 1, 2, 3] +bar_labels = [ + "Blocking copy (1x)", + f"Non-blocking copy ({r1:4.2f}x)", + f"Blocking pin, non-blocking copy ({r2:4.2f}x)", + f"Non-blocking pin, non-blocking copy ({r3:4.2f}x)", +] +values = [copy_blocking, copy_non_blocking, copy_pin_nb, copy_pin_multithread_nb] +colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"] + +ax.bar(xlabels, values, label=bar_labels, color=colors) + +ax.set_ylabel("Runtime (ms)") +ax.set_title("Device casting runtime") +ax.set_xticks([]) +ax.legend() + +plt.show() + +###################################################################### +# In this example, we are transferring many large tensors from the CPU to the GPU. +# This scenario is ideal for utilizing multithreaded ``pin_memory()``, which can significantly enhance performance. +# However, if the tensors are small, the overhead associated with multithreading may outweigh the benefits. +# Similarly, if there are only a few tensors, the advantages of pinning tensors on separate threads become limited. +# +# As an additional note, while it might seem advantageous to create permanent buffers in pinned memory to shuttle +# tensors from pageable memory before transferring them to the GPU, this strategy does not necessarily expedite +# computation. The inherent bottleneck caused by copying data into pinned memory remains a limiting factor. +# +# Moreover, transferring data that resides on disk (whether in shared memory or files) to the GPU typically requires an +# intermediate step of copying the data into pinned memory (located in RAM). +# Utilizing non_blocking for large data transfers in this context can significantly increase RAM consumption, +# potentially leading to adverse effects. +# +# In practice, there is no one-size-fits-all solution. +# The effectiveness of using multithreaded ``pin_memory`` combined with ``non_blocking`` transfers depends on a +# variety of factors, including the specific system, operating system, hardware, and the nature of the tasks +# being executed. +# Here is a list of factors to check when trying to speed-up data transfers between CPU and GPU, or comparing +# throughput's across scenarios: +# +# - **Number of available cores** +# +# How many CPU cores are available? Is the system shared with other users or processes that might compete for +# resources? +# +# - **Core utilization** +# +# Are the CPU cores heavily utilized by other processes? Does the application perform other CPU-intensive tasks +# concurrently with data transfers? +# +# - **Memory utilization** +# +# How much pageable and page-locked memory is currently being used? Is there sufficient free memory to allocate +# additional pinned memory without affecting system performance? Remember that nothing comes for free, for instance +# ``pin_memory`` will consume RAM and may impact other tasks. +# +# - **CUDA Device Capabilities** +# +# Does the GPU support multiple DMA engines for concurrent data transfers? What are the specific capabilities and +# limitations of the CUDA device being used? +# +# - **Number of tensors to be sent** +# +# How many tensors are transferred in a typical operation? +# +# - **Size of the tensors to be sent** +# +# What is the size of the tensors being transferred? A few large tensors or many small tensors may not benefit from +# the same transfer program. +# +# - **System Architecture** +# +# How is the system's architecture influencing data transfer speeds (for example, bus speeds, network latency)? +# +# Additionally, allocating a large number of tensors or sizable tensors in pinned memory can monopolize a substantial +# portion of RAM. +# This reduces the available memory for other critical operations, such as paging, which can negatively impact the +# overall performance of an algorithm. +# +# Conclusion +# ---------- +# +# .. _pinned_memory_conclusion: +# +# Throughout this tutorial, we have explored several critical factors that influence transfer speeds and memory +# management when sending tensors from the host to the device. We've learned that using ``non_blocking=True`` generally +# accelerates data transfers, and that :meth:`~torch.Tensor.pin_memory` can also enhance performance if implemented +# correctly. However, these techniques require careful design and calibration to be effective. +# +# Remember that profiling your code and keeping an eye on the memory consumption are essential to optimize resource +# usage and achieve the best possible performance. +# +# Additional resources +# -------------------- +# +# .. _pinned_memory_resources: +# +# If you are dealing with issues with memory copies when using CUDA devices or want to learn more about +# what was discussed in this tutorial, check the following references: +# +# - `CUDA toolkit memory management doc `_; +# - `CUDA pin-memory note `_; +# - `How to Optimize Data Transfers in CUDA C/C++ `_; +# - `tensordict doc `_ and `repo `_. +# diff --git a/intermediate_source/pipelining_tutorial.rst b/intermediate_source/pipelining_tutorial.rst index 3d6533cef2..0c6fc79846 100644 --- a/intermediate_source/pipelining_tutorial.rst +++ b/intermediate_source/pipelining_tutorial.rst @@ -12,6 +12,7 @@ APIs. .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites * How to use ``torch.distributed.pipelining`` APIs * How to apply pipeline parallelism to a transformer model @@ -19,6 +20,7 @@ APIs. .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites * Familiarity with `basic distributed training `__ in PyTorch diff --git a/prototype_source/pt2e_quantizer.rst b/prototype_source/pt2e_quantizer.rst index df666b1f6a..be6d6949ed 100644 --- a/prototype_source/pt2e_quantizer.rst +++ b/prototype_source/pt2e_quantizer.rst @@ -8,7 +8,7 @@ Prerequisites: Required: -- `Torchdynamo concepts in PyTorch `__ +- `Torchdynamo concepts in PyTorch `__ - `Quantization concepts in PyTorch `__ diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index 32666e5a3a..2467db878e 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -333,6 +333,7 @@ Since this can be an issue when users wish to share models with users used to th to their applications. For this case, we provide the ``format_utils`` module in ``torch.distributed.checkpoint.format_utils``. A command line utility is provided for the users convenience, which follows the following format: + .. code-block:: bash python -m torch.distributed.checkpoint.format_utils -m @@ -341,6 +342,7 @@ In the command above, ``mode`` is one of ``torch_to_dcp``` or ``dcp_to_torch``. Alternatively, methods are also provided for users who may wish to convert checkpoints directly. + .. code-block:: python import os diff --git a/recipes_source/intel_extension_for_pytorch.rst b/recipes_source/intel_extension_for_pytorch.rst index 03416102d2..7632ee73f3 100644 --- a/recipes_source/intel_extension_for_pytorch.rst +++ b/recipes_source/intel_extension_for_pytorch.rst @@ -12,8 +12,8 @@ easy GPU acceleration for Intel discrete GPUs with PyTorch*. Intel® Extension for PyTorch* has been released as an open–source project at `Github `_. -- Source code for CPU is available at `master branch `_. -- Source code for GPU is available at `xpu-master branch `_. +- Source code for CPU is available at `main branch `_. +- Source code for GPU is available at `xpu-main branch `_. Features -------- diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index c9aa2947a7..8959ea98a3 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -275,10 +275,10 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :tags: Model-Optimization .. customcarditem:: - :header: CPU launcher script for optimal performance on Intel® Xeon - :card_description: How to use launcher script for optimal runtime configurations on Intel® Xeon CPUs. + :header: Optimizing CPU Performance on Intel® Xeon® with run_cpu Script + :card_description: How to use run_cpu script for optimal runtime configurations on Intel® Xeon CPUs. :image: ../_static/img/thumbnails/cropped/profiler.png - :link: ../recipes/recipes/xeon_run_cpu.html + :link: ../recipes/xeon_run_cpu.html :tags: Model-Optimization .. customcarditem:: diff --git a/recipes_source/xeon_run_cpu.rst b/recipes_source/xeon_run_cpu.rst index fcf96a2ee8..6426bc5781 100644 --- a/recipes_source/xeon_run_cpu.rst +++ b/recipes_source/xeon_run_cpu.rst @@ -1,4 +1,4 @@ -Optimizing PyTorch Inference with Intel® Xeon® Scalable Processors +Optimizing CPU Performance on Intel® Xeon® with run_cpu Script ====================================================================== There are several configuration options that can impact the performance of PyTorch inference when executed on Intel® Xeon® Scalable Processors.