Skip to content

Commit

Permalink
Add checks for comm size
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 30, 2025
1 parent f99f704 commit 7095302
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
8 changes: 4 additions & 4 deletions examples/7_MPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ if(CMAKE_BUILD_TESTS)
# the model
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})

# 3. Check the model can be loaded from file and run in Python and that its
# outputs meet expectations
# 3. Check the model can be loaded from file and run with MPI in Python and
# that its outputs meet expectations
add_test(
NAME mpi_infer_python
COMMAND
Expand All @@ -53,8 +53,8 @@ if(CMAKE_BUILD_TESTS)
mpi_infer_python PROPERTIES PASS_REGULAR_EXPRESSION
"MPI Python example ran successfully")

# 4. Check the model can be loaded from file and run in Fortran and that its
# outputs meet expectations
# 4. Check the model can be loaded from file and run with MPI in Fortran and
# that its outputs meet expectations
add_test(
NAME mpi_infer_fortran
COMMAND
Expand Down
10 changes: 9 additions & 1 deletion examples/7_MPI/mpi_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ program inference

call mpi_init(ierr)
call mpi_comm_rank(mpi_comm_world, rank, ierr)
call mpi_comm_size(mpi_comm_world, size, ierr)

! Check MPI was configured correctly
if (size == 1) then
write(*,*) "MPI communicator size is 1, indicating that it is not configured correctly"
write(*,*) "(assuming you specified more than one rank)"
call clean_up()
stop 999
end if

! Get TorchScript model file as a command line argument
num_args = command_argument_count()
Expand Down Expand Up @@ -76,7 +85,6 @@ program inference
write(unit=6, fmt=100) out_data(:)

! Gather the outputs onto rank 0
call mpi_comm_size(mpi_comm_world, size, ierr)
allocate(recvbuf(5,size))
call mpi_gather(out_data, 5, mpi_float, recvbuf, 5, mpi_float, 0, mpi_comm_world, ierr)

Expand Down
6 changes: 6 additions & 0 deletions examples/7_MPI/mpi_infer_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
comm = MPI.COMM_WORLD
rank = comm.rank
device_to_run = "cpu"
if comm.size == 1:
size_error = (
"MPI communicator size is 1, indicating that it is not configured correctly"
" (assuming you specified more than one rank)"
)
raise ValueError(size_error)

batch_size_to_run = 1

Expand Down

0 comments on commit 7095302

Please sign in to comment.