Skip to content

Commit

Permalink
Allow for adding constant in TDMA/PCR-TDMA solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
p-costa committed Feb 16, 2025
1 parent 7a6e6d2 commit eaef752
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 100 deletions.
55 changes: 12 additions & 43 deletions src/solve_helmholtz.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ module mod_solve_helmholtz
end type rhs_bound
public solve_helmholtz,rhs_bound
contains
subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha, &
lambdaxyi,ai,bi,ci,rhsbxi,rhsbyi,rhsbzi,is_bound,cbc,c_or_f,p)
subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha,lambdaxy,a,b,c,rhsbx,rhsby,rhsbz,is_bound,cbc,c_or_f,p)
!
! this is a wrapper subroutine to solve 1D/3D helmholtz problems: p/alpha + lap(p) = rhs
!
Expand All @@ -38,67 +37,37 @@ subroutine solve_helmholtz(n,ng,hi,arrplan,normfft,alpha, &
#endif
real(rp), intent(in ), optional :: normfft
real(rp), intent(in ) :: alpha
real(rp), intent(in ), dimension(:,:), optional :: lambdaxyi
real(rp), intent(in ), dimension(:) :: ai,bi,ci
real(rp), intent(in ), dimension(:,:,0:), optional :: rhsbxi,rhsbyi,rhsbzi
real(rp), intent(in ), dimension(:,:), optional :: lambdaxy
real(rp), intent(in ), dimension(:) :: a,b,c
real(rp), intent(in ), dimension(:,:,0:), optional :: rhsbx,rhsby,rhsbz
logical , intent(in ), dimension(2,3) :: is_bound
character(len=1), intent(in), dimension(0:1,3) :: cbc
character(len=1), intent(in), dimension(3) :: c_or_f
real(rp), intent(inout), dimension(:,:,:) :: p
real(rp), allocatable, dimension(:,:) , save :: lambdaxy
real(rp), allocatable, dimension(:) , save :: a,b,c
real(rp), allocatable, dimension(:,:,:), save :: rhsbx,rhsby,rhsbz
real(rp), allocatable, dimension(:), save :: bb
real(rp) :: alphai
!
logical, save :: is_first = .true.
!
! initialization
!
if(is_first) then ! leverage save attribute to allocate these arrays on the device only once
is_first = .false.
if(present(lambdaxyi)) allocate(lambdaxy,mold=lambdaxyi)
allocate(a,mold=ai)
allocate(b,mold=bi)
allocate(c,mold=ci)
if(present(rhsbxi)) allocate(rhsbx(n(2),n(3),0:1)) ! allocate(rhsbx,mold=rhsbxi) ! gfortran 11.4.0 bug
if(present(rhsbyi)) allocate(rhsby(n(1),n(3),0:1)) ! allocate(rhsby,mold=rhsbyi) ! gfortran 11.4.0 bug
if(present(rhsbzi)) allocate(rhsbz(n(1),n(2),0:1)) ! allocate(rhsbz,mold=rhsbzi) ! gfortran 11.4.0 bug
!$acc enter data create(lambdaxy,a,b,c,rhsbx,rhsby,rhsbz) async(1)
allocate(bb,mold=b)
!$acc enter data create(bb) async(1)
end if
!
if(.not.is_impdiff_1d) then
!$acc kernels default(present) async(1)
!$OMP PARALLEL WORKSHARE
rhsbx(:,:,0:1) = rhsbxi(:,:,0:1)*alpha
rhsby(:,:,0:1) = rhsbyi(:,:,0:1)*alpha
rhsbz(:,:,0:1) = rhsbzi(:,:,0:1)*alpha
!$OMP END PARALLEL WORKSHARE
!$acc end kernels
else
!$acc kernels default(present) async(1)
!$OMP PARALLEL WORKSHARE
rhsbz(:,:,0:1) = rhsbzi(:,:,0:1)*alpha
!$OMP END PARALLEL WORKSHARE
!$acc end kernels
end if
call updt_rhs_b(c_or_f,cbc,n,is_bound,rhsbx,rhsby,rhsbz,p)
alphai = alpha**(-1)
!$acc kernels default(present) async(1)
!$OMP PARALLEL WORKSHARE
a(:) = ai(:)*alpha
b(:) = bi(:)*alpha + 1.
c(:) = ci(:)*alpha
bb(:) = b(:) + alphai
!$OMP END PARALLEL WORKSHARE
!$acc end kernels
if(.not.is_impdiff_1d) then
!$acc kernels default(present) async(1)
!$OMP PARALLEL WORKSHARE
lambdaxy(:,:) = lambdaxyi(:,:)*alpha
!$OMP END PARALLEL WORKSHARE
!$acc end kernels
end if
if(.not.is_impdiff_1d) then
call solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,cbc,c_or_f,p)
call solver(n,ng,arrplan,normfft*alphai,lambdaxy,a,bb,c,cbc,c_or_f,p)
else
call solver_gaussel_z(n,ng,hi,a,b,c,cbc(:,3),c_or_f,p)
call solver_gaussel_z(n,ng,hi,a,bb,c,cbc(:,3),c_or_f,alphai,p)
end if
end subroutine solve_helmholtz
end module mod_solve_helmholtz
59 changes: 33 additions & 26 deletions src/solver.f90
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ subroutine solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_updat
logical :: is_periodic_z
integer, dimension(3) :: n_z,hi_z
logical :: is_ptdma_update_
real(rp) :: norm
!
norm = normfft
!
is_ptdma_update_ = .true.
if(present(is_ptdma_update)) is_ptdma_update_ = is_ptdma_update
Expand Down Expand Up @@ -76,11 +79,11 @@ subroutine solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_updat
if(.not.is_poisson_pcr_tdma) then
call transpose_y_to_z(py,pz)
!
call gaussel(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,pz,lambdaxy)
call gaussel(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,norm,pz,lambdaxy)
!
call transpose_z_to_y(pz,py)
else
call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,py,lambdaxy,is_ptdma_update_,aa_z,cc_z)
call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,norm,py,lambdaxy,is_ptdma_update_,aa_z,cc_z)
if(present(is_ptdma_update)) is_ptdma_update = is_ptdma_update_
end if
call fft(arrplan(2,2),py) ! bwd transform in y
Expand All @@ -91,29 +94,30 @@ subroutine solver(n,ng,arrplan,normfft,lambdaxy,a,b,c,bc,c_or_f,p,is_ptdma_updat
select case(ipencil_axis)
case(1)
!$OMP PARALLEL WORKSHARE
p(1:n(1),1:n(2),1:n(3)) = px(:,:,:)*normfft
p(1:n(1),1:n(2),1:n(3)) = px(:,:,:)
!$OMP END PARALLEL WORKSHARE
case(2)
call transpose_x_to_y(px,py)
!$OMP PARALLEL WORKSHARE
p(1:n(1),1:n(2),1:n(3)) = py(:,:,:)*normfft
p(1:n(1),1:n(2),1:n(3)) = py(:,:,:)
!$OMP END PARALLEL WORKSHARE
case(3)
!call transpose_x_to_z(px,pz)
call transpose_x_to_y(px,py)
call transpose_y_to_z(py,pz)
!$OMP PARALLEL WORKSHARE
p(1:n(1),1:n(2),1:n(3)) = pz(:,:,:)*normfft
p(1:n(1),1:n(2),1:n(3)) = pz(:,:,:)
!$OMP END PARALLEL WORKSHARE
end select
end subroutine solver
!
subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy)
subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,norm,p,lambdaxy)
use mod_param, only: eps
implicit none
integer , intent(in) :: nx,ny,n,nh
real(rp), intent(in), dimension(:) :: a,b,c
logical , intent(in) :: is_periodic
real(rp), intent(in) :: norm
real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p
real(rp), intent(in), dimension(nx,ny), optional :: lambdaxy
real(rp), dimension(n) :: bb,p2
Expand All @@ -129,7 +133,7 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy)
do j=1,ny
do i=1,nx
bb(:) = b(:) + lambdaxy(i,j)
call dgtsv_homebrewed(nn,a,bb,c,p(i,j,1:nn))
call dgtsv_homebrewed(nn,a,bb,c,norm,p(i,j,1:nn))
end do
end do
!$OMP END PARALLEL
Expand All @@ -138,7 +142,7 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy)
!$OMP DO COLLAPSE(2)
do j=1,ny
do i=1,nx
call dgtsv_homebrewed(nn,a,b,c,p(i,j,1:nn))
call dgtsv_homebrewed(nn,a,b,c,norm,p(i,j,1:nn))
end do
end do
!$OMP END PARALLEL
Expand All @@ -153,9 +157,9 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy)
p2(1 ) = -a(1 )
p2(nn) = p2(nn) - c(nn)
bb(:) = b(:) + lambdaxy(i,j)
call dgtsv_homebrewed(nn,a,bb,c,p2(1:nn))
p(i,j,nn+1) = (p(i,j,nn+1) - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / &
(bb( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps)
call dgtsv_homebrewed(nn,a,bb,c,1._rp,p2(1:nn))
p(i,j,nn+1) = (p(i,j,nn+1)*norm - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / &
(bb( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps)
p(i,j,1:nn) = p(i,j,1:nn) + p2(1:nn)*p(i,j,nn+1)
end do
end do
Expand All @@ -168,9 +172,9 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy)
p2(:) = 0.
p2(1 ) = -a(1 )
p2(nn) = p2(nn) - c(nn)
call dgtsv_homebrewed(nn,a,b,c,p2(1:nn))
p(i,j,nn+1) = (p(i,j,nn+1) - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / &
(b( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps)
call dgtsv_homebrewed(nn,a,b,c,1._rp,p2(1:nn))
p(i,j,nn+1) = (p(i,j,nn+1)*norm - c(nn+1)*p(i,j,1) - a(nn+1)*p(i,j,nn)) / &
(b( nn+1) + c(nn+1)*p2( 1) + a(nn+1)*p2( nn)+eps)
p(i,j,1:nn) = p(i,j,1:nn) + p2(1:nn)*p(i,j,nn+1)
end do
end do
Expand All @@ -179,7 +183,7 @@ subroutine gaussel(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy)
end if
end subroutine gaussel
!
subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_save,cc_z_save)
subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,norm,p,lambdaxy,is_update,aa_z_save,cc_z_save)
!
! distributed TDMA solver
!
Expand All @@ -190,6 +194,7 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_
integer , intent(in) :: nx,ny,n,nh
real(rp), intent(in), dimension(:) :: a,b,c
logical , intent(in) :: is_periodic
real(rp), intent(in) :: norm
real(rp), intent(inout), dimension(1-nh:,1-nh:,1-nh:) :: p
real(rp), intent(in), dimension(:,:), optional :: lambdaxy
logical , intent(inout), optional :: is_update
Expand Down Expand Up @@ -227,13 +232,13 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_
zz(:) = 1./(bb(1:2)+eps)
aa(i,j,1:2) = a(1:2)*zz(:)
cc(i,j,1:2) = c(1:2)*zz(:)
p( i,j,1:2) = p(i,j,1:2)*zz(:)
p( i,j,1:2) = p(i,j,1:2)*norm*zz(:)
!
! elimination of lower diagonals
!
do k=3,n
z = 1./(bb(k)-a(k)*cc(i,j,k-1)+eps)
p(i,j,k) = (p(i,j,k)-a(k)*p(i,j,k-1))*z
p(i,j,k) = (p(i,j,k)*norm-a(k)*p(i,j,k-1))*z
aa(i,j,k) = -a(k)*aa(i,j,k-1)*z
cc(i,j,k) = c(k)*z
end do
Expand Down Expand Up @@ -266,13 +271,13 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_
zz(:) = 1./(b(1:2)+eps)
aa(i,j,1:2) = a(1:2)*zz(:)
cc(i,j,1:2) = c(1:2)*zz(:)
p( i,j,1:2) = p(i,j,1:2)*zz(:)
p( i,j,1:2) = p(i,j,1:2)*norm*zz(:)
!
! elimination of lower diagonals
!
do k=3,n
z = 1./(b(k)-a(k)*cc(i,j,k-1)+eps)
p(i,j,k) = (p(i,j,k)-a(k)*p(i,j,k-1))*z
p(i,j,k) = (p(i,j,k)*norm-a(k)*p(i,j,k-1))*z
aa(i,j,k) = -a(k)*aa(i,j,k-1)*z
cc(i,j,k) = c(k)*z
end do
Expand Down Expand Up @@ -389,11 +394,12 @@ subroutine gaussel_ptdma(nx,ny,n,nh,a,b,c,is_periodic,p,lambdaxy,is_update,aa_z_
!$OMP END PARALLEL
end subroutine gaussel_ptdma
!
subroutine dgtsv_homebrewed(n,a,b,c,p)
subroutine dgtsv_homebrewed(n,a,b,c,norm,p)
use mod_param, only: eps
implicit none
integer , intent(in) :: n
real(rp), intent(in ), dimension(:) :: a,b,c
real(rp), intent(in ) :: norm
real(rp), intent(inout), dimension(:) :: p
real(rp), dimension(n) :: d
real(rp) :: z
Expand All @@ -403,11 +409,11 @@ subroutine dgtsv_homebrewed(n,a,b,c,p)
!
z = 1./(b(1)+eps)
d(1) = c(1)*z
p(1) = p(1)*z
p(1) = p(1)*norm*z
do l=2,n
z = 1./(b(l)-a(l)*d(l-1)+eps)
d(l) = c(l)*z
p(l) = (p(l)-a(l)*p(l-1))*z
p(l) = (p(l)*norm-a(l)*p(l-1))*z
end do
!
! backward substitution
Expand All @@ -417,12 +423,13 @@ subroutine dgtsv_homebrewed(n,a,b,c,p)
end do
end subroutine dgtsv_homebrewed
!
subroutine solver_gaussel_z(n,ng,hi,a,b,c,bcz,c_or_f,p)
subroutine solver_gaussel_z(n,ng,hi,a,b,c,bcz,c_or_f,norm,p)
implicit none
integer , intent(in), dimension(3) :: n,ng,hi
real(rp), intent(in), dimension(:) :: a,b,c
character(len=1), dimension(0:1), intent(in) :: bcz
character(len=1), intent(in), dimension(3) :: c_or_f
real(rp), intent(in) :: norm
real(rp), intent(inout), dimension(0:,0:,0:) :: p
real(rp), allocatable, dimension(:,:,:) :: px,py,pz
integer :: q
Expand Down Expand Up @@ -461,12 +468,12 @@ subroutine solver_gaussel_z(n,ng,hi,a,b,c,bcz,c_or_f,p)
is_periodic_z = bcz(0)//bcz(1) == 'PP'
if(.not.is_no_decomp_z) then
if(.not.is_poisson_pcr_tdma) then
call gaussel( n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,pz)
call gaussel( n_z(1),n_z(2),n_z(3)-q,0,a,b,c,is_periodic_z,norm,pz)
else
call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,1,a,b,c,is_periodic_z,p)
call gaussel_ptdma(n_z(1),n_z(2),n_z(3)-q,1,a,b,c,is_periodic_z,norm,p)
end if
else
call gaussel(n(1),n(2),n(3)-q,1,a,b,c,is_periodic_z,p)
call gaussel(n(1),n(2),n(3)-q,1,a,b,c,is_periodic_z,norm,p)
end if
!
if(.not.is_poisson_pcr_tdma .and. .not.is_no_decomp_z) then
Expand Down
Loading

0 comments on commit eaef752

Please sign in to comment.