Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of transeq on OpenMP backend #132

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions src/omp/exec_dist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module m_omp_exec_dist

use m_common, only: dp, VERT
use m_omp_common, only: SZ
use m_omp_kernels_dist, only: der_univ_dist, der_univ_subs
use m_omp_kernels_dist, only: der_univ_dist, der_univ_subs, &
der_univ_fused_subs
use m_tdsops, only: tdsops_t
use m_omp_sendrecv, only: sendrecv_fields

Expand Down Expand Up @@ -164,28 +165,15 @@ subroutine exec_dist_transeq_compact( &

!$omp parallel do
do k = 1, n_groups
call der_univ_subs(rhs_du(:, :, k), &
du_recv_s(:, :, k), du_recv_e(:, :, k), &
n, tdsops_du%dist_sa, tdsops_du%dist_sc)

call der_univ_subs(dud(:, :, k), &
dud_recv_s(:, :, k), dud_recv_e(:, :, k), &
n, tdsops_dud%dist_sa, tdsops_dud%dist_sc)

call der_univ_subs(d2u(:, :, k), &
d2u_recv_s(:, :, k), d2u_recv_e(:, :, k), &
n, tdsops_d2u%dist_sa, tdsops_d2u%dist_sc)

do j = 1, n
!$omp simd
do i = 1, SZ
rhs_du(i, j, k) = -0.5_dp*(v(i, j, k)*rhs_du(i, j, k) &
+ dud(i, j, k)) &
+ nu*d2u(i, j, k)
end do
!$omp end simd
end do

call der_univ_fused_subs( &
rhs_du(:, :, k), dud(:, :, k), d2u(:, :, k), v(:, :, k), &
du_recv_s(:, :, k), du_recv_e(:, :, k), &
dud_recv_s(:, :, k), dud_recv_e(:, :, k), &
d2u_recv_s(:, :, k), d2u_recv_e(:, :, k), &
nu, n, tdsops_du%dist_sa, tdsops_du%dist_sc, &
tdsops_dud%dist_sa, tdsops_dud%dist_sc, &
tdsops_d2u%dist_sa, tdsops_d2u%dist_sc &
)
end do
!$omp end parallel do

Expand Down
100 changes: 100 additions & 0 deletions src/omp/kernels/distributed.f90
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,104 @@ subroutine der_univ_subs(du, recv_u_s, recv_u_e, n, dist_sa, dist_sc)

end subroutine der_univ_subs

subroutine der_univ_fused_subs(rhs_du, dud, d2u, v, &
du_recv_s, du_recv_e, &
dud_recv_s, dud_recv_e, &
d2u_recv_s, d2u_recv_e, &
nu, n, du_dist_sa, du_dist_sc, &
dud_dist_sa, dud_dist_sc, &
d2u_dist_sa, d2u_dist_sc)
implicit none

! Arguments
real(dp), intent(inout), dimension(:, :) :: rhs_du
real(dp), intent(in), dimension(:, :) :: dud, d2u, v
real(dp), intent(in), dimension(:, :) :: du_recv_s, du_recv_e
real(dp), intent(in), dimension(:, :) :: dud_recv_s, dud_recv_e
real(dp), intent(in), dimension(:, :) :: d2u_recv_s, d2u_recv_e
real(dp), intent(in), dimension(:) :: du_dist_sa, du_dist_sc
real(dp), intent(in), dimension(:) :: dud_dist_sa, dud_dist_sc
real(dp), intent(in), dimension(:) :: d2u_dist_sa, d2u_dist_sc
real(dp), intent(in) :: nu
integer, intent(in) :: n

! Local variables
integer :: i, j
real(dp) :: ur, bl, recp
real(dp), dimension(SZ) :: du_s, du_e, dud_s, dud_e, d2u_s, d2u_e, &
temp_du, temp_dud, temp_d2u

!$omp simd
do i = 1, SZ
! A small trick we do here is valid for symmetric Toeplitz matrices.
! In our case our matrices satisfy this criteria in the (5:n-4) region
! and as long as a rank has around at least 20 entries the assumptions
! we make here are perfectly valid.
Nanoseb marked this conversation as resolved.
Show resolved Hide resolved

! bl is the bottom left entry in the 2x2 matrix
! ur is the upper right entry in the 2x2 matrix

! Start
! At the start we have the 'bl', and assume 'ur'
bl = du_dist_sa(1)
ur = du_dist_sa(1)
recp = 1._dp/(1._dp - ur*bl)
du_s(i) = recp*(rhs_du(i, 1) - bl*du_recv_s(i, 1))

bl = dud_dist_sa(1)
ur = dud_dist_sa(1)
recp = 1._dp/(1._dp - ur*bl)
dud_s(i) = recp*(dud(i, 1) - bl*dud_recv_s(i, 1))

bl = d2u_dist_sa(1)
ur = d2u_dist_sa(1)
recp = 1._dp/(1._dp - ur*bl)
d2u_s(i) = recp*(d2u(i, 1) - bl*d2u_recv_s(i, 1))

! End
! At the end we have the 'ur', and assume 'bl'
bl = du_dist_sc(n)
ur = du_dist_sc(n)
recp = 1._dp/(1._dp - ur*bl)
du_e(i) = recp*(rhs_du(i, n) - ur*du_recv_e(i, 1))

bl = dud_dist_sc(n)
ur = dud_dist_sc(n)
recp = 1._dp/(1._dp - ur*bl)
dud_e(i) = recp*(dud(i, n) - ur*dud_recv_e(i, 1))

bl = d2u_dist_sc(n)
ur = d2u_dist_sc(n)
recp = 1._dp/(1._dp - ur*bl)
d2u_e(i) = recp*(d2u(i, n) - ur*d2u_recv_e(i, 1))
end do
!$omp end simd

!$omp simd
do i = 1, SZ
rhs_du(i, 1) = -0.5_dp*(v(i, 1)*du_s(i) + dud_s(i)) + nu*d2u_s(i)
end do
!$omp end simd
do j = 2, n - 1
!$omp simd
do i = 1, SZ
temp_du(i) = rhs_du(i, j) &
- du_dist_sa(j)*du_s(i) - du_dist_sc(j)*du_e(i)
temp_dud(i) = dud(i, j) &
- dud_dist_sa(j)*dud_s(i) - dud_dist_sc(j)*dud_e(i)
temp_d2u(i) = d2u(i, j) &
- d2u_dist_sa(j)*d2u_s(i) - d2u_dist_sc(j)*d2u_e(i)
rhs_du(i, j) = -0.5_dp*(v(i, j)*temp_du(i) + temp_dud(i)) &
+ nu*temp_d2u(i)
end do
!$omp end simd
end do
!$omp simd
do i = 1, SZ
rhs_du(i, n) = -0.5_dp*(v(i, n)*du_e(i) + dud_e(i)) + nu*d2u_e(i)
end do
!$omp end simd

end subroutine der_univ_fused_subs

end module m_omp_kernels_dist
Loading