From 5ef25d8341815842cc69defb950e82b2b79f3540 Mon Sep 17 00:00:00 2001 From: Semih Akkurt Date: Thu, 5 Dec 2024 19:35:54 +0000 Subject: [PATCH] Improve performance of transeq on OpenMP backend. --- src/omp/exec_dist.f90 | 34 ++++------- src/omp/kernels/distributed.f90 | 100 ++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 23 deletions(-) diff --git a/src/omp/exec_dist.f90 b/src/omp/exec_dist.f90 index f33a844b2..b6fdf93e6 100644 --- a/src/omp/exec_dist.f90 +++ b/src/omp/exec_dist.f90 @@ -3,7 +3,8 @@ module m_omp_exec_dist use m_common, only: dp, VERT use m_omp_common, only: SZ - use m_omp_kernels_dist, only: der_univ_dist, der_univ_subs + use m_omp_kernels_dist, only: der_univ_dist, der_univ_subs, & + der_univ_fused_subs use m_tdsops, only: tdsops_t use m_omp_sendrecv, only: sendrecv_fields @@ -164,28 +165,15 @@ subroutine exec_dist_transeq_compact( & !$omp parallel do do k = 1, n_groups - call der_univ_subs(rhs_du(:, :, k), & - du_recv_s(:, :, k), du_recv_e(:, :, k), & - n, tdsops_du%dist_sa, tdsops_du%dist_sc) - - call der_univ_subs(dud(:, :, k), & - dud_recv_s(:, :, k), dud_recv_e(:, :, k), & - n, tdsops_dud%dist_sa, tdsops_dud%dist_sc) - - call der_univ_subs(d2u(:, :, k), & - d2u_recv_s(:, :, k), d2u_recv_e(:, :, k), & - n, tdsops_d2u%dist_sa, tdsops_d2u%dist_sc) - - do j = 1, n - !$omp simd - do i = 1, SZ - rhs_du(i, j, k) = -0.5_dp*(v(i, j, k)*rhs_du(i, j, k) & - + dud(i, j, k)) & - + nu*d2u(i, j, k) - end do - !$omp end simd - end do - + call der_univ_fused_subs( & + rhs_du(:, :, k), dud(:, :, k), d2u(:, :, k), v(:, :, k), & + du_recv_s(:, :, k), du_recv_e(:, :, k), & + dud_recv_s(:, :, k), dud_recv_e(:, :, k), & + d2u_recv_s(:, :, k), d2u_recv_e(:, :, k), & + nu, n, tdsops_du%dist_sa, tdsops_du%dist_sc, & + tdsops_dud%dist_sa, tdsops_dud%dist_sc, & + tdsops_d2u%dist_sa, tdsops_d2u%dist_sc & + ) end do !$omp end parallel do diff --git a/src/omp/kernels/distributed.f90 b/src/omp/kernels/distributed.f90 index c07b5a1b2..f800405fb 100644 --- a/src/omp/kernels/distributed.f90 +++ b/src/omp/kernels/distributed.f90 @@ -227,4 +227,104 @@ subroutine der_univ_subs(du, recv_u_s, recv_u_e, n, dist_sa, dist_sc) end subroutine der_univ_subs + subroutine der_univ_fused_subs(rhs_du, dud, d2u, v, & + du_recv_s, du_recv_e, & + dud_recv_s, dud_recv_e, & + d2u_recv_s, d2u_recv_e, & + nu, n, du_dist_sa, du_dist_sc, & + dud_dist_sa, dud_dist_sc, & + d2u_dist_sa, d2u_dist_sc) + implicit none + + ! Arguments + real(dp), intent(inout), dimension(:, :) :: rhs_du + real(dp), intent(in), dimension(:, :) :: dud, d2u, v + real(dp), intent(in), dimension(:, :) :: du_recv_s, du_recv_e + real(dp), intent(in), dimension(:, :) :: dud_recv_s, dud_recv_e + real(dp), intent(in), dimension(:, :) :: d2u_recv_s, d2u_recv_e + real(dp), intent(in), dimension(:) :: du_dist_sa, du_dist_sc + real(dp), intent(in), dimension(:) :: dud_dist_sa, dud_dist_sc + real(dp), intent(in), dimension(:) :: d2u_dist_sa, d2u_dist_sc + real(dp), intent(in) :: nu + integer, intent(in) :: n + + ! Local variables + integer :: i, j + real(dp) :: ur, bl, recp + real(dp), dimension(SZ) :: du_s, du_e, dud_s, dud_e, d2u_s, d2u_e, & + temp_du, temp_dud, temp_d2u + + !$omp simd + do i = 1, SZ + ! A small trick we do here is valid for symmetric Toeplitz matrices. + ! In our case our matrices satisfy this criteria in the (5:n-4) region + ! and as long as a rank has around at least 20 entries the assumptions + ! we make here are perfectly valid. + + ! bl is the bottom left entry in the 2x2 matrix + ! ur is the upper right entry in the 2x2 matrix + + ! Start + ! At the start we have the 'bl', and assume 'ur' + bl = du_dist_sa(1) + ur = du_dist_sa(1) + recp = 1._dp/(1._dp - ur*bl) + du_s(i) = recp*(rhs_du(i, 1) - bl*du_recv_s(i, 1)) + + bl = dud_dist_sa(1) + ur = dud_dist_sa(1) + recp = 1._dp/(1._dp - ur*bl) + dud_s(i) = recp*(dud(i, 1) - bl*dud_recv_s(i, 1)) + + bl = d2u_dist_sa(1) + ur = d2u_dist_sa(1) + recp = 1._dp/(1._dp - ur*bl) + d2u_s(i) = recp*(d2u(i, 1) - bl*d2u_recv_s(i, 1)) + + ! End + ! At the end we have the 'ur', and assume 'bl' + bl = du_dist_sc(n) + ur = du_dist_sc(n) + recp = 1._dp/(1._dp - ur*bl) + du_e(i) = recp*(rhs_du(i, n) - ur*du_recv_e(i, 1)) + + bl = dud_dist_sc(n) + ur = dud_dist_sc(n) + recp = 1._dp/(1._dp - ur*bl) + dud_e(i) = recp*(dud(i, n) - ur*dud_recv_e(i, 1)) + + bl = d2u_dist_sc(n) + ur = d2u_dist_sc(n) + recp = 1._dp/(1._dp - ur*bl) + d2u_e(i) = recp*(d2u(i, n) - ur*d2u_recv_e(i, 1)) + end do + !$omp end simd + + !$omp simd + do i = 1, SZ + rhs_du(i, 1) = -0.5_dp*(v(i, 1)*du_s(i) + dud_s(i)) + nu*d2u_s(i) + end do + !$omp end simd + do j = 2, n - 1 + !$omp simd + do i = 1, SZ + temp_du(i) = rhs_du(i, j) & + - du_dist_sa(j)*du_s(i) - du_dist_sc(j)*du_e(i) + temp_dud(i) = dud(i, j) & + - dud_dist_sa(j)*dud_s(i) - dud_dist_sc(j)*dud_e(i) + temp_d2u(i) = d2u(i, j) & + - d2u_dist_sa(j)*d2u_s(i) - d2u_dist_sc(j)*d2u_e(i) + rhs_du(i, j) = -0.5_dp*(v(i, j)*temp_du(i) + temp_dud(i)) & + + nu*temp_d2u(i) + end do + !$omp end simd + end do + !$omp simd + do i = 1, SZ + rhs_du(i, n) = -0.5_dp*(v(i, n)*du_e(i) + dud_e(i)) + nu*d2u_e(i) + end do + !$omp end simd + + end subroutine der_univ_fused_subs + end module m_omp_kernels_dist