Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove MPI from multi-GPU example #268

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/3_MultiGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ if(NOT CMAKE_BUILD_TYPE)
endif()

find_package(FTorch)
find_package(MPI REQUIRED)
message(STATUS "Building with Fortran PyTorch coupling")

include(CheckLanguage)
Expand All @@ -29,15 +28,14 @@ endif()
# Fortran example
add_executable(multigpu_infer_fortran multigpu_infer_fortran.f90)
target_link_libraries(multigpu_infer_fortran PRIVATE FTorch::ftorch)
target_link_libraries(multigpu_infer_fortran PRIVATE MPI::MPI_Fortran)

# Integration testing
if (CMAKE_BUILD_TESTS)
include(CTest)

# 1. Check the PyTorch model runs and its outputs meet expectations
add_test(NAME multigpu COMMAND ${Python_EXECUTABLE}
${PROJECT_SOURCE_DIR}/multigpu.py)
add_test(NAME simplenet COMMAND ${Python_EXECUTABLE}
${PROJECT_SOURCE_DIR}/simplenet.py)

# 2. Check the model is saved to file in the expected location with the
# pt2ts.py script
Expand Down
69 changes: 31 additions & 38 deletions examples/3_MultiGPU/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,22 @@ multiple GPU devices.

## Description

The Python file `multigpu.py` is used, which is similar to the `simplenet.py`
from the earlier example.
Recall that it defines a very simple PyTorch network that takes an input of length 5
The same Python file `simplenet.py` is used from the earlier example. Recall
that it defines a very simple PyTorch network that takes an input of length 5
and applies a single `Linear` layer to multiply it by 2.

The same `pt2ts.py` tool is used to save the simple network to TorchScript.

A series of files `multigpu_infer_<LANG>` then bind from other languages to run the
TorchScript model in inference mode.
A series of files `multigpu_infer_<LANG>` then bind from other languages to run
the TorchScript model in inference mode.

## Dependencies

To run this example requires:

- CMake
- An MPI installation.
- mpif90
- FTorch (installed as described in main package)
- Two GPU devices that support CUDA and have it installed.
- FTorch (installed with CUDA enabled as described in main package)
- Python 3

## Running
Expand All @@ -36,47 +34,47 @@ source venv/bin/activate
pip install -r requirements.txt
```

You can check that everything is working by running `multigpu.py`:
You can check that everything is working by running `simplenet.py`:
```
python3 multigpu.py
python3 simplenet.py
```
As before, this defines the network and runs it with an input tensor
[0.0, 1.0, 2.0, 3.0, 4.0] to produce the result:
[0.0, 1.0, 2.0, 3.0, 4.0]. The difference is that the code will make use of the
default CUDA device (index 0) to produce the result:
```
SimpleNet forward pass on CUDA device 0
tensor([[0, 2, 4, 6, 8]])
```

To save the MultiGPUNet model to TorchScript run the modified version of the `pt2ts.py`
tool:
To save the `SimpleNet` model to TorchScript run the modified version of the
`pt2ts.py` tool:
```
python3 pt2ts.py
```
which will generate `saved_multigpu_model_cuda.pt` - the TorchScript instance of the
network. The only difference with the earlier example is that the model is built to
be run using CUDA rather than on CPU.
which will generate `saved_multigpu_model_cuda.pt` - the TorchScript instance
of the network. The only difference with the earlier example is that the model
is built to be run using CUDA rather than on CPU.

You can check that everything is working by running the `multigpu_infer_python.py`
script. It's set up with MPI such that a different GPU device is associated with each
MPI rank. You should substitute `<NP>` with the number of GPUs you wish to run with:
You can check that everything is working by running the
`multigpu_infer_python.py` script. It's set up such that it loops over two GPU
devices. Run with:
```
mpiexec -np <NP> python3 multigpu_infer_python.py
python3 multigpu_infer_python.py
```
This reads the model in from the TorchScript file and runs it with an different input
tensor on each GPU device: [0.0, 1.0, 2.0, 3.0, 4.0], plus the device index in each
entry. The result should be (some permutation of):
entry. The result should be:
```
0: tensor([[0., 2., 4., 6., 8.]])
1: tensor([[ 2., 4., 6., 8., 10.]])
2: tensor([[ 4., 6., 8., 10., 12.]])
3: tensor([[ 6., 8., 10., 12., 14.]])
Output on device 0: tensor([[0., 2., 4., 6., 8.]])
Output on device 1: tensor([[ 2., 4., 6., 8., 10.]])
```

At this point we no longer require Python, so can deactivate the virtual environment:
```
deactivate
```

To call the saved MultiGPUNet model from Fortran we need to compile the
To call the saved `SimpleNet` model from Fortran we need to compile the
`multigpu_infer_fortran.f90` file. This can be done using the included
`CMakeLists.txt` as follows, noting that we need to use an MPI-enabled Fortran
compiler:
Expand All @@ -90,24 +88,19 @@ cmake --build .
(Note that the Fortran compiler can be chosen explicitly with the `-DCMAKE_Fortran_COMPILER` flag,
and should match the compiler that was used to locally build FTorch.)

To run the compiled code calling the saved MultiGPUNet TorchScript from Fortran, run the
executable with an argument of the saved model file. Again, specify the number of MPI
processes according to the desired number of GPUs:
To run the compiled code calling the saved `SimpleNet` TorchScript from
Fortran, run the executable with an argument of the saved model file:
```
mpiexec -np <NP> ./multigpu_infer_fortran ../saved_multigpu_model_cuda.pt
./multigpu_infer_fortran ../saved_multigpu_model_cuda.pt
```

This runs the model with the same inputs as described above and should produce (some
permutation of) the output:
```
input on rank0: [ 0.0, 1.0, 2.0, 3.0, 4.0]
input on rank1: [ 1.0, 2.0, 3.0, 4.0, 5.0]
input on rank2: [ 2.0, 3.0, 4.0, 5.0, 6.0]
input on rank3: [ 3.0, 4.0, 5.0, 6.0, 7.0]
output on rank0: [ 0.0, 2.0, 4.0, 6.0, 8.0]
output on rank1: [ 2.0, 4.0, 6.0, 8.0, 10.0]
output on rank2: [ 4.0, 6.0, 8.0, 10.0, 12.0]
output on rank3: [ 6.0, 8.0, 10.0, 12.0, 14.0]
input on device 0: [ 0.0, 1.0, 2.0, 3.0, 4.0]
input on device 1: [ 1.0, 2.0, 3.0, 4.0, 5.0]
output on device 0: [ 0.0, 2.0, 4.0, 6.0, 8.0]
output on device 1: [ 2.0, 4.0, 6.0, 8.0, 10.0]
```

Alternatively, we can use `make`, instead of CMake, copying the Makefile over from the
Expand Down
80 changes: 39 additions & 41 deletions examples/3_MultiGPU/multigpu_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ program inference
torch_tensor_from_array, torch_model_load, torch_model_forward, &
torch_delete

! Import MPI
use mpi, only : mpi_init, mpi_finalize, mpi_comm_world, mpi_comm_rank

implicit none

! Set precision for reals
Expand All @@ -29,51 +26,52 @@ program inference
type(torch_tensor), dimension(1) :: in_tensors
type(torch_tensor), dimension(1) :: out_tensors

! MPI configuration
integer :: rank, ierr, i

call mpi_init(ierr)
call mpi_comm_rank(mpi_comm_world, rank, ierr)
! Variables for multi-GPU setup
integer, parameter :: num_devices = 2
integer :: device_index, i

! Get TorchScript model file as a command line argument
num_args = command_argument_count()
allocate(args(num_args))
do ix = 1, num_args
call get_command_argument(ix,args(ix))
call get_command_argument(ix,args(ix))
end do

! Initialise data and print the values used on each MPI rank.
in_data = [(rank + i, i = 0, 4)]
write (6, 100) rank, in_data(:)
100 format("input on rank ", i1,": [", 4(f5.1,","), f5.1,"]")

! Create Torch input tensor from the above array and assign it to the first (and only)
! element in the array of input tensors.
! We use the torch_kCUDA device type with device index corresponding to the MPI rank.
call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, &
torch_kCUDA, device_index=rank)

! Create Torch output tensor from the above array.
! Here we use the torch_kCPU device type since the tensor is for output only
! i.e. to be subsequently used by Fortran on CPU.
call torch_tensor_from_array(out_tensors(1), out_data, tensor_layout, torch_kCPU)

! Load ML model. Ensure that the same device type and device index are used
! as for the input data.
call torch_model_load(model, args(1), torch_kCUDA, device_index=rank)

! Infer
call torch_model_forward(model, in_tensors, out_tensors)

! Print the values computed on each MPI rank.
write (6, 200) rank, out_data(:)
200 format("output on rank ", i1,": [", 4(f5.1,","), f5.1,"]")

! Cleanup
call torch_delete(model)
call torch_delete(in_tensors)
call torch_delete(out_tensors)
call mpi_finalize(ierr)
do device_index = 0, num_devices-1

! Initialise data and print the values used
in_data = [(device_index + i, i = 0, 4)]
write (6, 100) device_index, in_data(:)
100 format("input on device ", i1,": [", 4(f5.1,","), f5.1,"]")

! Create Torch input tensor from the above array and assign it to the first (and only)
! element in the array of input tensors.
! We use the torch_kCUDA device type with the given device index
call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, torch_kCUDA, &
device_index=device_index)

! Create Torch output tensor from the above array.
! Here we use the torch_kCPU device type since the tensor is for output only
! i.e. to be subsequently used by Fortran on CPU.
call torch_tensor_from_array(out_tensors(1), out_data, tensor_layout, torch_kCPU)

! Load ML model. Ensure that the same device type and device index are used
! as for the input data.
call torch_model_load(model, args(1), torch_kCUDA, device_index=device_index)

! Infer
call torch_model_forward(model, in_tensors, out_tensors)

! Print the values computed on the current device.
write (6, 200) device_index, out_data(:)
200 format("output on device ", i1,": [", 4(f5.1,","), f5.1,"]")

! Cleanup
call torch_delete(model)
call torch_delete(in_tensors)
call torch_delete(out_tensors)

end do

write (*,*) "MultiGPU example ran successfully"

Expand Down
23 changes: 12 additions & 11 deletions examples/3_MultiGPU/multigpu_infer_python.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Load saved MultiGPUNet to TorchScript and run inference example."""
"""Load saved SimpleNet to TorchScript and run inference example."""

import torch
from mpi4py import MPI


def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
Expand All @@ -13,7 +12,8 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
saved_model : str
location of SimpleNet model saved to Torchscript
device : str
Torch device to run model on, 'cpu' or 'cuda'
Torch device to run model on, 'cpu' or 'cuda'. May be followed by a colon and
then a device index, e.g., 'cuda:0' for the 0th CUDA device.
batch_size : int
batch size to run (default 1)

Expand All @@ -24,16 +24,16 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
"""
input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]).repeat(batch_size, 1)

# Add the rank (device index) to each tensor to make them differ
input_tensor += MPI.COMM_WORLD.rank

if device == "cpu":
# Load saved TorchScript model
model = torch.jit.load(saved_model)
# Inference
output = model.forward(input_tensor)

elif device.startswith("cuda"):
# Add the device index to each tensor to make them differ
input_tensor += int(device.split(":")[-1] or 0)

# All previously saved modules, no matter their device, are first
# loaded onto CPU, and then are moved to the devices they were saved
# from, so we don't need to manually transfer the model to the GPU
Expand All @@ -53,11 +53,12 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
if __name__ == "__main__":
saved_model_file = "saved_multigpu_model_cuda.pt"

device_to_run = f"cuda:{MPI.COMM_WORLD.rank}"
for device_index in range(2):
device_to_run = f"cuda:{device_index}"

batch_size_to_run = 1
batch_size_to_run = 1

with torch.no_grad():
result = deploy(saved_model_file, device_to_run, batch_size_to_run)
with torch.no_grad():
result = deploy(saved_model_file, device_to_run, batch_size_to_run)

print(f"Output on device {device_to_run}: {result}")
print(f"Output on device {device_to_run}: {result}")
1 change: 0 additions & 1 deletion examples/3_MultiGPU/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
mpi4py
torch
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn


class MultiGPUNet(nn.Module):
class SimpleNet(nn.Module):
"""PyTorch module multiplying an input vector by 2."""

def __init__(
Expand Down Expand Up @@ -42,7 +42,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:


if __name__ == "__main__":
model = MultiGPUNet().to(torch.device("cuda"))
model = SimpleNet().to(torch.device("cuda"))
model.eval()

input_tensor = torch.Tensor([0.0, 1.0, 2.0, 3.0, 4.0])
Expand Down