Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPI example #270

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_suite_ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- name: Install an MPI distribution
run: |
sudo apt update
sudo apt install mpich
sudo apt install openmpi-bin openmpi-common libopenmpi-dev

- name: Install pFUnit
run: |
Expand Down
68 changes: 68 additions & 0 deletions examples/7_MPI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
cmake_minimum_required(VERSION 3.15...3.31)
# policy CMP0076 - target_sources source files are relative to file where
# target_sources is run
cmake_policy(SET CMP0076 NEW)

set(PROJECT_NAME SimpleNetExample)

project(${PROJECT_NAME} LANGUAGES Fortran)

# Build in Debug mode if not specified
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE
Debug
CACHE STRING "" FORCE)
endif()

find_package(FTorch)
find_package(MPI REQUIRED)
message(STATUS "Building with Fortran PyTorch coupling")

# Fortran example
add_executable(mpi_infer_fortran mpi_infer_fortran.f90)
target_link_libraries(mpi_infer_fortran PRIVATE FTorch::ftorch)
target_link_libraries(mpi_infer_fortran PRIVATE MPI::MPI_Fortran)

# Integration testing
if(CMAKE_BUILD_TESTS)
include(CTest)

# 1. Check the PyTorch model runs and its outputs meet expectations
add_test(NAME simplenet COMMAND ${Python_EXECUTABLE}
${PROJECT_SOURCE_DIR}/simplenet.py)

# 2. Check the model is saved to file in the expected location with the
# pt2ts.py script
add_test(
NAME pt2ts
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py
${PROJECT_BINARY_DIR} # Command line argument: filepath for saving
# the model
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})

# 3. Check the model can be loaded from file and run with MPI in Python and
# that its outputs meet expectations
add_test(
NAME mpi_infer_python
COMMAND
${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} 2 ${Python_EXECUTABLE}
${PROJECT_SOURCE_DIR}/mpi_infer_python.py
${PROJECT_BINARY_DIR} # Command line argument: filepath to find the model
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
set_tests_properties(
mpi_infer_python PROPERTIES PASS_REGULAR_EXPRESSION
"MPI Python example ran successfully")

# 4. Check the model can be loaded from file and run with MPI in Fortran and
# that its outputs meet expectations
add_test(
NAME mpi_infer_fortran
COMMAND
${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} 2 ./mpi_infer_fortran
${PROJECT_BINARY_DIR}/saved_simplenet_model_cpu.pt
# Command line argument: model file
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
set_tests_properties(
mpi_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION
"MPI Fortran example ran successfully")
endif()
3 changes: 3 additions & 0 deletions examples/7_MPI/README.md
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs completing.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Example 7 - MPI

TODO
124 changes: 124 additions & 0 deletions examples/7_MPI/mpi_infer_fortran.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
program inference

! Import precision info from iso
use, intrinsic :: iso_fortran_env, only : sp => real32

! Import our library for interfacing with PyTorch
use ftorch, only : torch_model, torch_tensor, torch_kCPU, torch_delete, &
torch_tensor_from_array, torch_model_load, torch_model_forward

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_allclose

! Import MPI
use mpi, only : mpi_comm_rank, mpi_comm_size, mpi_comm_world, mpi_finalize, mpi_float, &
mpi_gather, mpi_init

implicit none

! Set working precision for reals
integer, parameter :: wp = sp

integer :: num_args, ix
character(len=128), dimension(:), allocatable :: args

! Set up Fortran data structures
real(wp), dimension(5), target :: in_data
real(wp), dimension(5), target :: out_data
real(wp), dimension(5), target :: expected
integer, parameter :: tensor_layout(1) = [1]

! Set up Torch data structures
! The net, a vector of input tensors (in this case we only have one), and the output tensor
type(torch_model) :: model
type(torch_tensor), dimension(1) :: in_tensors
type(torch_tensor), dimension(1) :: out_tensors

! Flag for testing
logical :: test_pass

! MPI configuration
integer :: rank, size, ierr, i

! Variables for testing
real(wp), allocatable, dimension(:,:) :: recvbuf
real(wp), dimension(5) :: result_chk
integer :: rank_chk

call mpi_init(ierr)
call mpi_comm_rank(mpi_comm_world, rank, ierr)
call mpi_comm_size(mpi_comm_world, size, ierr)

! Check MPI was configured correctly
if (size == 1) then
write(*,*) "MPI communicator size is 1, indicating that it is not configured correctly"
write(*,*) "(assuming you specified more than one rank)"
call clean_up()
stop 999
end if

! Get TorchScript model file as a command line argument
num_args = command_argument_count()
allocate(args(num_args))
do ix = 1, num_args
call get_command_argument(ix,args(ix))
end do

! Initialise data and print the values used on each MPI rank
in_data = [(rank + i, i = 0, 4)]
write(unit=6, fmt="('input on rank ',i1,': ')", advance="no") rank
write(unit=6, fmt=100) in_data(:)
100 format('[',4(f5.1,','),f5.1,']')

! Create Torch input/output tensors from the above arrays
call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, torch_kCPU)
call torch_tensor_from_array(out_tensors(1), out_data, tensor_layout, torch_kCPU)

! Load ML model
call torch_model_load(model, args(1), torch_kCPU)

! Run inference on each MPI rank
call torch_model_forward(model, in_tensors, out_tensors)

! Print the values computed on each MPI rank
write(unit=6, fmt="('output on rank ',i1,': ')", advance="no")
write(unit=6, fmt=100) out_data(:)

! Gather the outputs onto rank 0
allocate(recvbuf(5,size))
call mpi_gather(out_data, 5, mpi_float, recvbuf, 5, mpi_float, 0, mpi_comm_world, ierr)

! Check that the correct values were attained
if (rank == 0) then

! Check output tensor matches expected value
do rank_chk = 0, size-1
expected = [(2 * (rank_chk + i), i = 0, 4)]
result_chk(:) = recvbuf(:,rank_chk+1)
test_pass = assert_allclose(result_chk, expected, test_name="MPI")
if (.not. test_pass) then
write(unit=6, fmt="('rank ',i1,' result: ')") rank_chk
write(unit=6, fmt=100) result_chk(:)
write(unit=6, fmt="('does not match expected value')")
write(unit=6, fmt=100) expected(:)
call clean_up()
stop 999
end if
end do

write (*,*) "MPI Fortran example ran successfully"
end if

call clean_up()

contains

subroutine clean_up()
call torch_delete(model)
call torch_delete(in_tensors)
call torch_delete(out_tensors)
call mpi_finalize(ierr)
deallocate(recvbuf)
end subroutine clean_up

end program inference
91 changes: 91 additions & 0 deletions examples/7_MPI/mpi_infer_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Load saved SimpleNet to TorchScript and run inference example."""

import os
import sys

import torch
from mpi4py import MPI


def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
"""
Load TorchScript SimpleNet and run inference with example Tensor.

Parameters
----------
saved_model : str
location of SimpleNet model saved to Torchscript
device : str
Torch device to run model on, 'cpu' or 'cuda'
batch_size : int
batch size to run (default 1)

Returns
-------
output : torch.Tensor
result of running inference on model with example Tensor input
"""
input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]).repeat(batch_size, 1)

# Add the rank (device index) to each tensor to make them differ
input_tensor += MPI.COMM_WORLD.rank

if device == "cpu":
# Load saved TorchScript model
model = torch.jit.load(saved_model)
# Inference
output = model.forward(input_tensor)

elif device == "cuda":
# All previously saved modules, no matter their device, are first
# loaded onto CPU, and then are moved to the devices they were saved
# from, so we don't need to manually transfer the model to the GPU
model = torch.jit.load(saved_model)
input_tensor_gpu = input_tensor.to(torch.device("cuda"))
output_gpu = model.forward(input_tensor_gpu)
output = output_gpu.to(torch.device("cpu"))

else:
device_error = f"Device '{device}' not recognised."
raise ValueError(device_error)

return output


if __name__ == "__main__":
filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
saved_model_file = os.path.join(filepath, "saved_simplenet_model_cpu.pt")

comm = MPI.COMM_WORLD
rank = comm.rank
device_to_run = "cpu"
if comm.size == 1:
size_error = (
"MPI communicator size is 1, indicating that it is not configured correctly"
" (assuming you specified more than one rank)"
)
raise ValueError(size_error)

batch_size_to_run = 1

# Run inference on each rank
with torch.no_grad():
result = deploy(saved_model_file, device_to_run, batch_size_to_run)
print(f"rank {rank}: result:\n{result}")

# Gather the outputs onto rank 0
recvbuf = torch.empty([comm.size, 5], dtype=torch.float32) if rank == 0 else None
comm.Gather(result, recvbuf, root=0)

# Check that the correct values were attained
if rank == 0:
for rank_chk, result_chk in enumerate(recvbuf):
expected = torch.Tensor([2 * (i + rank_chk) for i in range(5)])
if not torch.allclose(expected, result_chk):
result_error = (
f"rank {rank_chk}: result:\n{result_chk}\n"
f"does not match expected value:\n{expected}"
)
raise ValueError(result_error)

print("MPI Python example ran successfully")
Loading
Loading