diff --git a/src/solve_helmholtz.f90 b/src/solve_helmholtz.f90 index e111589f..e935fc1e 100644 --- a/src/solve_helmholtz.f90 +++ b/src/solve_helmholtz.f90 @@ -25,8 +25,7 @@ module mod_solve_helmholtz end type rhs_bound public solve_helmholtz,rhs_bound contains - subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha, & - lambdaxyi,ai,bi,ci,rhsbxi,rhsbyi,rhsbzi,is_bound,cbc,c_or_f,p) + subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha,lambdaxy,a,b,c,rhsbx,rhsby,rhsbz,is_bound,cbc,c_or_f,p) ! ! this is a wrapper subroutine to solve 1D/3D helmholtz problems: p/alpha + lap(p) = rhs ! @@ -38,16 +37,15 @@ subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha, & #endif real(rp), intent(in ), optional :: normfft real(rp), intent(in ) :: alpha - real(rp), intent(in ), dimension(:,:), optional :: lambdaxyi - real(rp), intent(in ), dimension(:) :: ai,bi,ci - real(rp), intent(in ), dimension(:,:,0:), optional :: rhsbxi,rhsbyi,rhsbzi + real(rp), intent(in ), dimension(:,:), optional :: lambdaxy + real(rp), intent(in ), dimension(:) :: a,b,c + real(rp), intent(in ), dimension(:,:,0:), optional :: rhsbx,rhsby,rhsbz logical , intent(in ), dimension(2,3) :: is_bound character(len=1), intent(in), dimension(0:1,3) :: cbc character(len=1), intent(in), dimension(3) :: c_or_f real(rp), intent(inout), dimension(:,:,:) :: p - real(rp), allocatable, dimension(:,:) , save :: lambdaxy - real(rp), allocatable, dimension(:) , save :: a,b,c - real(rp), allocatable, dimension(:,:,:), save :: rhsbx,rhsby,rhsbz + real(rp), allocatable, dimension(:), save :: bb + real(rp) :: alphai ! logical, save :: is_first = .true. ! @@ -55,50 +53,21 @@ subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha, & ! if(is_first) then ! leverage save attribute to allocate these arrays on the device only once is_first = .false. - if(present(lambdaxyi)) allocate(lambdaxy,mold=lambdaxyi) - allocate(a,mold=ai) - allocate(b,mold=bi) - allocate(c,mold=ci) - if(present(rhsbxi)) allocate(rhsbx(n(2),n(3),0:1)) ! allocate(rhsbx,mold=rhsbxi) ! gfortran 11.4.0 bug - if(present(rhsbyi)) allocate(rhsby(n(1),n(3),0:1)) ! allocate(rhsby,mold=rhsbyi) ! gfortran 11.4.0 bug - if(present(rhsbzi)) allocate(rhsbz(n(1),n(2),0:1)) ! allocate(rhsbz,mold=rhsbzi) ! gfortran 11.4.0 bug - !$acc enter data create(lambdaxy,a,b,c,rhsbx,rhsby,rhsbz) async(1) + allocate(bb,mold=b) + !$acc enter data create(bb) async(1) end if ! - if(.not.is_impdiff_1d) then - !$acc kernels default(present) async(1) - !$OMP PARALLEL WORKSHARE - rhsbx(:,:,0:1) = rhsbxi(:,:,0:1)*alpha - rhsby(:,:,0:1) = rhsbyi(:,:,0:1)*alpha - rhsbz(:,:,0:1) = rhsbzi(:,:,0:1)*alpha - !$OMP END PARALLEL WORKSHARE - !$acc end kernels - else - !$acc kernels default(present) async(1) - !$OMP PARALLEL WORKSHARE - rhsbz(:,:,0:1) = rhsbzi(:,:,0:1)*alpha - !$OMP END PARALLEL WORKSHARE - !$acc end kernels - end if call updt_rhs_b(c_or_f,cbc,n,is_bound,rhsbx,rhsby,rhsbz,p) + alphai = alpha**(-1) !$acc kernels default(present) async(1) !$OMP PARALLEL WORKSHARE - a(:) = ai(:)*alpha - b(:) = bi(:)*alpha + 1. - c(:) = ci(:)*alpha + bb(:) = b(:) + alphai !$OMP END PARALLEL WORKSHARE !$acc end kernels if(.not.is_impdiff_1d) then - !$acc kernels default(present) async(1) - !$OMP PARALLEL WORKSHARE - lambdaxy(:,:) = lambdaxyi(:,:)*alpha - !$OMP END PARALLEL WORKSHARE - !$acc end kernels - end if - if(.not.is_impdiff_1d) then - call solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,cbc,c_or_f,p) + call solver(n,ng,arrplan,normfft*alphai,lambdaxy,a,bb,c,cbc,c_or_f,p) else - call solver_gaussel_z(n,ng,hi,a,b,c,cbc(:,3),c_or_f,p) + call solver_gaussel_z(n,ng,hi,a,bb,c,cbc(:,3),c_or_f,alphai,p) end if end subroutine solve_helmholtz end module mod_solve_helmholtz diff --git a/src/solver.f90 b/src/solver.f90 index a21b1ee0..6416abe4 100644 --- a/src/solver.f90 +++ b/src/solver.f90 @@ -35,6 +35,9 @@ subroutine solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_updat logical :: is_periodic_z integer, dimension(3) :: n_z,hi_z logical :: is_ptdma_update_ + real(rp) :: norm + ! + norm = normfft ! is_ptdma_update_ = .true. if(present(is_ptdma_update)) is_ptdma_update_ = is_ptdma_update @@ -76,11 +79,11 @@ subroutine solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_updat if(.not.is_poisson_pcr_tdma) then call transpose_y_to_z(py,pz) ! - call gaussel(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,pz,lambdaxy) + call gaussel(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,norm,pz,lambdaxy) ! call transpose_z_to_y(pz,py) else - call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,py,lambdaxy,is_ptdma_update_,aa_z,cc_z) + call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,norm,py,lambdaxy,is_ptdma_update_,aa_z,cc_z) if(present(is_ptdma_update)) is_ptdma_update = is_ptdma_update_ end if call fft(arrplan(2,2),py) ! bwd transform in y @@ -91,29 +94,30 @@ subroutine solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_updat select case(ipencil_axis) case(1) !$OMP PARALLEL WORKSHARE - p(1:n(1),1:n(2),1:n(3)) = px(:,:,:)*normfft + p(1:n(1),1:n(2),1:n(3)) = px(:,:,:) !$OMP END PARALLEL WORKSHARE case(2) call transpose_x_to_y(px,py) !$OMP PARALLEL WORKSHARE - p(1:n(1),1:n(2),1:n(3)) = py(:,:,:)*normfft + p(1:n(1),1:n(2),1:n(3)) = py(:,:,:) !$OMP END PARALLEL WORKSHARE case(3) !call transpose_x_to_z(px,pz) call transpose_x_to_y(px,py) call transpose_y_to_z(py,pz) !$OMP PARALLEL WORKSHARE - p(1:n(1),1:n(2),1:n(3)) = pz(:,:,:)*normfft + p(1:n(1),1:n(2),1:n(3)) = pz(:,:,:) !$OMP END PARALLEL WORKSHARE end select end subroutine solver ! - subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy) + subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,norm,p,lambdaxy) use mod_param, only: eps implicit none integer , intent(in) :: nx,ny,n,nh real(rp), intent(in), dimension(:) :: a,b,c logical , intent(in) :: is_periodic + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p real(rp), intent(in), dimension(nx,ny), optional :: lambdaxy real(rp), dimension(n) :: bb,p2 @@ -129,7 +133,7 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy) do j=1,ny do i=1,nx bb(:) = b(:) + lambdaxy(i,j) - call dgtsv_homebrewed(nn,a,bb,c,p(i,j,1:nn)) + call dgtsv_homebrewed(nn,a,bb,c,norm,p(i,j,1:nn)) end do end do !$OMP END PARALLEL @@ -138,7 +142,7 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy) !$OMP DO COLLAPSE(2) do j=1,ny do i=1,nx - call dgtsv_homebrewed(nn,a,b,c,p(i,j,1:nn)) + call dgtsv_homebrewed(nn,a,b,c,norm,p(i,j,1:nn)) end do end do !$OMP END PARALLEL @@ -153,9 +157,9 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy) p2(1 ) = -a(1 ) p2(nn) = p2(nn) - c(nn) bb(:) = b(:) + lambdaxy(i,j) - call dgtsv_homebrewed(nn,a,bb,c,p2(1:nn)) - p(i,j,nn+1) = (p(i,j,nn+1) - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / & - (bb( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps) + call dgtsv_homebrewed(nn,a,bb,c,1._rp,p2(1:nn)) + p(i,j,nn+1) = (p(i,j,nn+1)*norm - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / & + (bb( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps) p(i,j,1:nn) = p(i,j,1:nn) + p2(1:nn)*p(i,j,nn+1) end do end do @@ -168,9 +172,9 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy) p2(:) = 0. p2(1 ) = -a(1 ) p2(nn) = p2(nn) - c(nn) - call dgtsv_homebrewed(nn,a,b,c,p2(1:nn)) - p(i,j,nn+1) = (p(i,j,nn+1) - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / & - (b( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps) + call dgtsv_homebrewed(nn,a,b,c,1._rp,p2(1:nn)) + p(i,j,nn+1) = (p(i,j,nn+1)*norm - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / & + (b( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps) p(i,j,1:nn) = p(i,j,1:nn) + p2(1:nn)*p(i,j,nn+1) end do end do @@ -179,7 +183,7 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy) end if end subroutine gaussel ! - subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_save,cc_z_save) + subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,norm,p,lambdaxy,is_update,aa_z_save,cc_z_save) ! ! distributed TDMA solver ! @@ -190,6 +194,7 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_ integer , intent(in) :: nx,ny,n,nh real(rp), intent(in), dimension(:) :: a,b,c logical , intent(in) :: is_periodic + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p real(rp), intent(in), dimension(:,:), optional :: lambdaxy logical , intent(inout), optional :: is_update @@ -227,13 +232,13 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_ zz(:) = 1./(bb(1:2)+eps) aa(i,j,1:2) = a(1:2)*zz(:) cc(i,j,1:2) = c(1:2)*zz(:) - p( i,j,1:2) = p(i,j,1:2)*zz(:) + p( i,j,1:2) = p(i,j,1:2)*norm*zz(:) ! ! elimination of lower diagonals ! do k=3,n z = 1./(bb(k)-a(k)*cc(i,j,k-1)+eps) - p(i,j,k) = (p(i,j,k)-a(k)*p(i,j,k-1))*z + p(i,j,k) = (p(i,j,k)*norm-a(k)*p(i,j,k-1))*z aa(i,j,k) = -a(k)*aa(i,j,k-1)*z cc(i,j,k) = c(k)*z end do @@ -266,13 +271,13 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_ zz(:) = 1./(b(1:2)+eps) aa(i,j,1:2) = a(1:2)*zz(:) cc(i,j,1:2) = c(1:2)*zz(:) - p( i,j,1:2) = p(i,j,1:2)*zz(:) + p( i,j,1:2) = p(i,j,1:2)*norm*zz(:) ! ! elimination of lower diagonals ! do k=3,n z = 1./(b(k)-a(k)*cc(i,j,k-1)+eps) - p(i,j,k) = (p(i,j,k)-a(k)*p(i,j,k-1))*z + p(i,j,k) = (p(i,j,k)*norm-a(k)*p(i,j,k-1))*z aa(i,j,k) = -a(k)*aa(i,j,k-1)*z cc(i,j,k) = c(k)*z end do @@ -389,11 +394,12 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_ !$OMP END PARALLEL end subroutine gaussel_ptdma ! - subroutine dgtsv_homebrewed(n,a,b,c,p) + subroutine dgtsv_homebrewed(n,a,b,c,norm,p) use mod_param, only: eps implicit none integer , intent(in) :: n real(rp), intent(in ), dimension(:) :: a,b,c + real(rp), intent(in ) :: norm real(rp), intent(inout), dimension(:) :: p real(rp), dimension(n) :: d real(rp) :: z @@ -403,11 +409,11 @@ subroutine dgtsv_homebrewed(n,a,b,c,p) ! z = 1./(b(1)+eps) d(1) = c(1)*z - p(1) = p(1)*z + p(1) = p(1)*norm*z do l=2,n z = 1./(b(l)-a(l)*d(l-1)+eps) d(l) = c(l)*z - p(l) = (p(l)-a(l)*p(l-1))*z + p(l) = (p(l)*norm-a(l)*p(l-1))*z end do ! ! backward substitution @@ -417,12 +423,13 @@ subroutine dgtsv_homebrewed(n,a,b,c,p) end do end subroutine dgtsv_homebrewed ! - subroutine solver_gaussel_z(n,ng,hi,a,b,c,bcz,c_or_f,p) + subroutine solver_gaussel_z(n,ng,hi,a,b,c,bcz,c_or_f,norm,p) implicit none integer , intent(in), dimension(3) :: n,ng,hi real(rp), intent(in), dimension(:) :: a,b,c character(len=1), dimension(0:1), intent(in) :: bcz character(len=1), intent(in), dimension(3) :: c_or_f + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(0:,0:,0:) :: p real(rp), allocatable, dimension(:,:,:) :: px,py,pz integer :: q @@ -461,12 +468,12 @@ subroutine solver_gaussel_z(n,ng,hi,a,b,c,bcz,c_or_f,p) is_periodic_z = bcz(0)//bcz(1) == 'PP' if(.not.is_no_decomp_z) then if(.not.is_poisson_pcr_tdma) then - call gaussel( n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,pz) + call gaussel( n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,norm,pz) else - call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,1,a,b,c,is_periodic_z,p) + call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,1,a,b,c,is_periodic_z,norm,p) end if else - call gaussel(n(1),n(2),n(3)-q,1,a,b,c,is_periodic_z,p) + call gaussel(n(1),n(2),n(3)-q,1,a,b,c,is_periodic_z,norm,p) end if ! if(.not.is_poisson_pcr_tdma .and. .not.is_no_decomp_z) then diff --git a/src/solver_gpu.f90 b/src/solver_gpu.f90 index 8ce6e276..71f9b381 100644 --- a/src/solver_gpu.f90 +++ b/src/solver_gpu.f90 @@ -43,6 +43,9 @@ subroutine solver_gpu(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_u integer, dimension(3) :: n_x,n_y,n_z,n_z_0,lo_z_0,hi_z_0 integer :: istat logical :: is_ptdma_update_ + real(rp) :: norm + ! + norm = normfft ! is_ptdma_update_ = .true. if(present(is_ptdma_update)) is_ptdma_update_ = is_ptdma_update @@ -121,7 +124,7 @@ subroutine solver_gpu(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_u istat = cudecompTransposeYtoZ(ch,gd,py,pz,work,dtype_rp,stream=istream) !$acc end host_data ! - call gaussel_gpu(n_z_0(1),n_z_0(2),n_z_0(3)-q,0,a,b,c,is_periodic_z,pz,work,pz_aux_1,lambdaxy) + call gaussel_gpu(n_z_0(1),n_z_0(2),n_z_0(3)-q,0,a,b,c,is_periodic_z,norm,pz,work,pz_aux_1,lambdaxy) ! !$acc host_data use_device(pz,py,work) istat = cudecompTransposeZtoY(ch,gd,pz,py,work,dtype_rp,stream=istream) @@ -148,7 +151,7 @@ subroutine solver_gpu(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_u end do end block ! - call gaussel_ptdma_gpu(n_z_0(1),n_z_0(2),n_z_0(3)-q,lo_z_0(3),0,a,b,c,is_periodic_z,pz,work,pz_aux_1,is_ptdma_update_,lambdaxy,aa_z,cc_z) + call gaussel_ptdma_gpu(n_z_0(1),n_z_0(2),n_z_0(3)-q,lo_z_0(3),0,a,b,c,is_periodic_z,norm,pz,work,pz_aux_1,is_ptdma_update_,lambdaxy,aa_z,cc_z) if(present(is_ptdma_update)) is_ptdma_update = is_ptdma_update_ ! block @@ -189,7 +192,7 @@ subroutine solver_gpu(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_u case(1) !$acc kernels default(present) async(1) !$OMP PARALLEL WORKSHARE - p(1:n(1),1:n(2),1:n(3)) = px(1:n(1),1:n(2),1:n(3))*normfft + p(1:n(1),1:n(2),1:n(3)) = px(1:n(1),1:n(2),1:n(3)) !$OMP END PARALLEL WORKSHARE !$acc end kernels case(2) @@ -206,7 +209,7 @@ subroutine solver_gpu(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_u do k=1,n(3) do j=1,n(2) do i=1,n(1) - p(i,j,k) = py(j,k,i)*normfft + p(i,j,k) = py(j,k,i) end do end do end do @@ -218,18 +221,19 @@ subroutine solver_gpu(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_u !$acc end host_data !$acc kernels default(present) async(1) !$OMP PARALLEL WORKSHARE - p(1:n(1),1:n(2),1:n(3)) = pz(1:n(1),1:n(2),1:n(3))*normfft + p(1:n(1),1:n(2),1:n(3)) = pz(1:n(1),1:n(2),1:n(3)) !$OMP END PARALLEL WORKSHARE !$acc end kernels end select end subroutine solver_gpu ! - subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,p,d,p2,lambdaxy) + subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,norm,p,d,p2,lambdaxy) use mod_param, only: eps implicit none integer , intent(in) :: nx,ny,n,nh real(rp), intent(in), dimension(:) :: a,b,c logical , intent(in) :: is_periodic + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p real(rp), dimension(nx,ny,n) :: d,p2 real(rp), intent(in), dimension(:,:), optional :: lambdaxy @@ -249,11 +253,11 @@ subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,p,d,p2,lambdaxy) ! z = 1./(b(1)+lxy+eps) d(i,j,1) = c(1)*z - p(i,j,1) = p(i,j,1)*z + p(i,j,1) = p(i,j,1)*norm*z !$acc loop seq do k=2,nn z = 1./(b(k)+lxy-a(k)*d(i,j,k-1)+eps) - p(i,j,k) = (p(i,j,k)-a(k)*p(i,j,k-1))*z + p(i,j,k) = (p(i,j,k)*norm-a(k)*p(i,j,k-1))*z d(i,j,k) = c(k)*z end do ! @@ -291,8 +295,8 @@ subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,p,d,p2,lambdaxy) p2(i,j,k) = p2(i,j,k) - d(i,j,k)*p2(i,j,k+1) end do ! - p(i,j,nn+1) = (p(i,j,nn+1) - c(nn+1)*p( i,j,1) - a(nn+1)*p( i,j,nn)) / & - (b( nn+1) + lxy + c(nn+1)*p2(i,j,1) + a(nn+1)*p2(i,j,nn)+eps) + p(i,j,nn+1) = (p(i,j,nn+1)*norm - c(nn+1)*p( i,j,1) - a(nn+1)*p( i,j,nn)) / & + (b( nn+1) + lxy + c(nn+1)*p2(i,j,1) + a(nn+1)*p2(i,j,nn)+eps) !$acc loop seq do k=1,nn p(i,j,k) = p(i,j,k) + p2(i,j,k)*p(i,j,nn+1) @@ -310,11 +314,11 @@ subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,p,d,p2,lambdaxy) do i=1,nx z = 1./(b(1)+eps) dd(1) = c(1)*z - p(i,j,1) = p(i,j,1)*z + p(i,j,1) = p(i,j,1)*norm*z !$acc loop seq do k=2,nn z = 1./(b(k)-a(k)*dd(k-1)+eps) - p(i,j,k) = (p(i,j,k)-a(k)*p(i,j,k-1))*z + p(i,j,k) = (p(i,j,k)*norm-a(k)*p(i,j,k-1))*z dd(k) = c(k)*z end do ! @@ -354,8 +358,8 @@ subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,p,d,p2,lambdaxy) pp2(k) = pp2(k) - dd(k)*pp2(k+1) end do ! - p(i,j,nn+1) = (p(i,j,nn+1) - c(nn+1)*p( i,j,1) - a(nn+1)*p( i,j,nn)) / & - (b( nn+1) + lxy + c(nn+1)*pp2( 1) + a(nn+1)*pp2( nn)+eps) + p(i,j,nn+1) = (p(i,j,nn+1)*norm - c(nn+1)*p( i,j,1) - a(nn+1)*p( i,j,nn)) / & + (b( nn+1) + lxy + c(nn+1)*pp2( 1) + a(nn+1)*pp2( nn)+eps) !$acc loop seq do k=1,nn-1 p(i,j,k) = p(i,j,k) + pp2(k)*p(i,j,nn+1) @@ -366,7 +370,7 @@ subroutine gaussel_gpu(nx,ny,n,nh,a,b,c,is_periodic,p,d,p2,lambdaxy) end if end subroutine gaussel_gpu ! - subroutine gaussel_ptdma_gpu(nx,ny,n,lo,nh,a,b,c,is_periodic,p,aa,cc,is_update,lambdaxy,aa_z_save,cc_z_save) + subroutine gaussel_ptdma_gpu(nx,ny,n,lo,nh,a,b,c,is_periodic,norm,p,aa,cc,is_update,lambdaxy,aa_z_save,cc_z_save) ! ! distributed TDMA solver ! @@ -377,6 +381,7 @@ subroutine gaussel_ptdma_gpu(nx,ny,n,lo,nh,a,b,c,is_periodic,p,aa,cc,is_update,l integer , intent(in) :: nx,ny,n,lo,nh real(rp), intent(in), dimension(:) :: a,b,c logical , intent(in) :: is_periodic + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p real(rp), dimension(nx,ny,n) :: aa,cc logical , intent(inout), optional :: is_update @@ -421,15 +426,15 @@ subroutine gaussel_ptdma_gpu(nx,ny,n,lo,nh,a,b,c,is_periodic,p,aa,cc,is_update,l aa(i,j,2) = a(2+dk_g)*z2 cc(i,j,1) = c(1+dk_g)*z1 cc(i,j,2) = c(2+dk_g)*z2 - p(i,j,1) = p(i,j,1)*z1 - p(i,j,2) = p(i,j,2)*z2 + p(i,j,1) = p(i,j,1)*norm*z1 + p(i,j,2) = p(i,j,2)*norm*z2 ! ! elimination of lower diagonals ! !$acc loop seq do k=3,n z = 1./(b(k+dk_g)+lxy-a(k+dk_g)*cc(i,j,k-1)+eps) - p(i,j,k) = (p(i,j,k)-a(k+dk_g)*p(i,j,k-1))*z + p(i,j,k) = (p(i,j,k)*norm-a(k+dk_g)*p(i,j,k-1))*z aa(i,j,k) = -a(k+dk_g)*aa(i,j,k-1)*z cc(i,j,k) = c(k+dk_g)*z end do @@ -467,15 +472,15 @@ subroutine gaussel_ptdma_gpu(nx,ny,n,lo,nh,a,b,c,is_periodic,p,aa,cc,is_update,l aa(i,j,2) = a(2+dk_g)*z2 cc(i,j,1) = c(1+dk_g)*z1 cc(i,j,2) = c(2+dk_g)*z2 - p(i,j,1) = p(i,j,1)*z1 - p(i,j,2) = p(i,j,2)*z2 + p(i,j,1) = p(i,j,1)*norm*z1 + p(i,j,2) = p(i,j,2)*norm*z2 ! ! elimination of lower diagonals ! !$acc loop seq do k=3,n z = 1./(b(k+dk_g)-a(k+dk_g)*cc(i,j,k-1)+eps) - p(i,j,k) = (p(i,j,k)-a(k+dk_g)*p(i,j,k-1))*z + p(i,j,k) = (p(i,j,k)*norm-a(k+dk_g)*p(i,j,k-1))*z aa(i,j,k) = -a(k+dk_g)*aa(i,j,k-1)*z cc(i,j,k) = c(k+dk_g)*z end do @@ -607,7 +612,7 @@ subroutine gaussel_ptdma_gpu(nx,ny,n,lo,nh,a,b,c,is_periodic,p,aa,cc,is_update,l end do end subroutine gaussel_ptdma_gpu ! - subroutine gaussel_ptdma_gpu_fast_1d(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p) + subroutine gaussel_ptdma_gpu_fast_1d(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,norm,p) ! ! distributed TDMA solver for many 1D systems on GPUs ! @@ -623,6 +628,7 @@ subroutine gaussel_ptdma_gpu_fast_1d(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p) integer , intent(in) :: nx,ny,n,lo,nh real(rp), intent(in), dimension(:) :: a_g,b_g,c_g logical , intent(in) :: is_periodic + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p real(rp), pointer , dimension(:,:,:) :: pp_x,pp_y real(rp), allocatable, dimension(:,:,:), save :: pp_z @@ -759,13 +765,15 @@ subroutine gaussel_ptdma_gpu_fast_1d(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p) !$acc parallel loop gang vector collapse(2) default(present) async(1) do j=1,ny do i=1,nx + p(i,j,1) = p(i,j,1)*norm + p(i,j,2) = p(i,j,2)*norm !$acc loop seq do k=3,n - p(i,j,k) = p(i,j,k) - a_g(k+dk_g)*bb(k-1)*p(i,j,k-1) + p(i,j,k) = p(i,j,k)*norm - a_g(k+dk_g)*bb(k-1)*p(i,j,k-1) end do !$acc loop seq do k=n-2,1,-1 - p(i,j,k) = p(i,j,k) - c_g(k+dk_g)*bb(k+1)*p(i,j,k+1) + p(i,j,k) = p(i,j,k) - c_g(k+dk_g)*bb(k+1)*p(i,j,k+1) end do end do end do @@ -835,13 +843,14 @@ subroutine gaussel_ptdma_gpu_fast_1d(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p) end do end do end subroutine gaussel_ptdma_gpu_fast_1d - subroutine solver_gaussel_z_gpu(n,ng,hi,a,b,c,bcz,c_or_f,p) + subroutine solver_gaussel_z_gpu(n,ng,hi,a,b,c,bcz,c_or_f,norm,p) use mod_param, only: eps implicit none integer , intent(in), dimension(3) :: ng,n,hi real(rp), intent(in), dimension(:) :: a,b,c character(len=1), dimension(0:1), intent(in) :: bcz character(len=1), intent(in), dimension(3) :: c_or_f + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(0:,0:,0:) :: p real(rp), pointer, contiguous, dimension(:,:,:) :: px,py,pz integer :: q @@ -909,12 +918,12 @@ subroutine solver_gaussel_z_gpu(n,ng,hi,a,b,c,bcz,c_or_f,p) is_periodic_z = bcz(0)//bcz(1) == 'PP' if(ipencil_axis /= 3) then if(.not.is_poisson_pcr_tdma) then - call gaussel_gpu(n_z_0(1),n_z_0(2),n_z_0(3)-q,0,a,b,c,is_periodic_z,pz,work,pz_aux_1) + call gaussel_gpu(n_z_0(1),n_z_0(2),n_z_0(3)-q,0,a,b,c,is_periodic_z,norm,pz,work,pz_aux_1) else - call gaussel_ptdma_gpu_fast_1d(n(1),n(2),n(3)-q,lo_z,1,a,b,c,is_periodic_z,p) + call gaussel_ptdma_gpu_fast_1d(n(1),n(2),n(3)-q,lo_z,1,a,b,c,is_periodic_z,norm,p) end if else - call gaussel_gpu(n(1),n(2),n(3)-q,1,a,b,c,is_periodic_z,p,work,pz_aux_1) + call gaussel_gpu(n(1),n(2),n(3)-q,1,a,b,c,is_periodic_z,norm,p,work,pz_aux_1) end if ! if(.not.is_poisson_pcr_tdma .and. .not.is_no_decomp_z) then @@ -953,7 +962,7 @@ subroutine solver_gaussel_z_gpu(n,ng,hi,a,b,c,bcz,c_or_f,p) end if end subroutine solver_gaussel_z_gpu #if 0 - subroutine gaussel_ptdma_gpu_fast(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p,is_update,lambdaxy,aa,bb,cc,aa_z,bb_z,cc_z,pp_z_2) + subroutine gaussel_ptdma_gpu_fast(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,norm,p,is_update,lambdaxy,aa,bb,cc,aa_z,bb_z,cc_z,pp_z_2) ! ! distributed TDMA solver using pre-computed coefficients ! @@ -966,6 +975,7 @@ subroutine gaussel_ptdma_gpu_fast(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p,is_upd integer , intent(in) :: nx,ny,n,lo,nh real(rp), intent(in), dimension(:) :: a_g,b_g,c_g logical , intent(in) :: is_periodic + real(rp), intent(in) :: norm real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p logical , intent(inout) :: is_update real(rp), intent(in), dimension(:,:), optional :: lambdaxy @@ -1107,13 +1117,15 @@ subroutine gaussel_ptdma_gpu_fast(nx,ny,n,lo,nh,a_g,b_g,c_g,is_periodic,p,is_upd !$acc parallel loop gang vector collapse(2) default(present) async(1) do j=1,ny do i=1,nx + p(i,j,1) = p(i,j,1)*norm + p(i,j,2) = p(i,j,2)*norm !$acc loop seq do k=3,n - p(i,j,k) = p(i,j,k) - a_g(k+dk_g)*bb(i,j,k-1)*p(i,j,k-1) + p(i,j,k) = p(i,j,k)*norm - a_g(k+dk_g)*bb(i,j,k-1)*p(i,j,k-1) end do !$acc loop seq do k=n-2,1,-1 - p(i,j,k) = p(i,j,k) - c_g(k+dk_g)*bb(i,j,k+1)*p(i,j,k+1) + p(i,j,k) = p(i,j,k) - c_g(k+dk_g)*bb(i,j,k+1)*p(i,j,k+1) end do end do end do