Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory use in transeq for both backends #130

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions src/cuda/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,17 @@ subroutine transeq_halo_exchange(self, u_dev, v_dev, w_dev, dir)

end subroutine transeq_halo_exchange

subroutine transeq_dist_component(self, rhs_dev, u_dev, conv_dev, &
subroutine transeq_dist_component(self, rhs_du_dev, u_dev, conv_dev, &
u_recv_s_dev, u_recv_e_dev, &
conv_recv_s_dev, conv_recv_e_dev, &
tdsops_du, tdsops_dud, tdsops_d2u, &
dir, blocks, threads)
!! Computes RHS_x^u following:
!!
!! rhs_x^u = -0.5*(conv*du/dx + d(u*conv)/dx) + nu*d2u/dx2
!! Computes RHS_x^u following:
!!
!! rhs_x^u = -0.5*(conv*du/dx + d(u*conv)/dx) + nu*d2u/dx2
class(cuda_backend_t) :: self
real(dp), device, dimension(:, :, :), intent(inout) :: rhs_dev
!> The result field, it is also used as temporary storage
real(dp), device, dimension(:, :, :), intent(out) :: rhs_du_dev
real(dp), device, dimension(:, :, :), intent(in) :: u_dev, conv_dev
real(dp), device, dimension(:, :, :), intent(in) :: &
u_recv_s_dev, u_recv_e_dev, &
Expand All @@ -306,25 +307,22 @@ subroutine transeq_dist_component(self, rhs_dev, u_dev, conv_dev, &
integer, intent(in) :: dir
type(dim3), intent(in) :: blocks, threads

class(field_t), pointer :: du, dud, d2u
class(field_t), pointer :: dud, d2u

real(dp), device, pointer, dimension(:, :, :) :: &
du_dev, dud_dev, d2u_dev
real(dp), device, pointer, dimension(:, :, :) :: dud_dev, d2u_dev

! Get some fields for storing the intermediate results
du => self%allocator%get_block(dir, VERT)
dud => self%allocator%get_block(dir, VERT)
d2u => self%allocator%get_block(dir, VERT)

call resolve_field_t(du_dev, du)
call resolve_field_t(dud_dev, dud)
call resolve_field_t(d2u_dev, d2u)

call exec_dist_transeq_3fused( &
rhs_dev, &
rhs_du_dev, &
u_dev, u_recv_s_dev, u_recv_e_dev, &
conv_dev, conv_recv_s_dev, conv_recv_e_dev, &
du_dev, dud_dev, d2u_dev, &
dud_dev, d2u_dev, &
self%du_send_s_dev, self%du_send_e_dev, &
self%du_recv_s_dev, self%du_recv_e_dev, &
self%dud_send_s_dev, self%dud_send_e_dev, &
Expand All @@ -337,7 +335,6 @@ subroutine transeq_dist_component(self, rhs_dev, u_dev, conv_dev, &
)

! Release temporary blocks
call self%allocator%release_block(du)
call self%allocator%release_block(dud)
call self%allocator%release_block(d2u)

Expand Down
15 changes: 8 additions & 7 deletions src/cuda/exec_dist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,25 @@ subroutine exec_dist_tds_compact( &
end subroutine exec_dist_tds_compact

subroutine exec_dist_transeq_3fused( &
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining that r_du is where the output is

r_u, u, u_recv_s, u_recv_e, v, v_recv_s, v_recv_e, &
du, dud, d2u, &
r_du, u, u_recv_s, u_recv_e, v, v_recv_s, v_recv_e, &
dud, d2u, &
du_send_s, du_send_e, du_recv_s, du_recv_e, &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e, &
der1st, der2nd, nu, nproc, pprev, pnext, blocks, threads &
)
implicit none

! r_u = -1/2*(v*d1(u) + d1(u*v)) + nu*d2(u)
real(dp), device, dimension(:, :, :), intent(out) :: r_u
! r_du = -1/2*(v*d1(u) + d1(u*v)) + nu*d2(u)
pbartholomew08 marked this conversation as resolved.
Show resolved Hide resolved
!> The result array, it is also used as temporary storage
real(dp), device, dimension(:, :, :), intent(out) :: r_du
real(dp), device, dimension(:, :, :), intent(in) :: u, u_recv_s, u_recv_e
real(dp), device, dimension(:, :, :), intent(in) :: v, v_recv_s, v_recv_e

! The ones below are intent(out) just so that we can write data in them,
! not because we actually need the data they store later where this
! subroutine is called. We absolutely don't care the data they pass back
real(dp), device, dimension(:, :, :), intent(out) :: du, dud, d2u
real(dp), device, dimension(:, :, :), intent(out) :: dud, d2u
real(dp), device, dimension(:, :, :), intent(out) :: &
du_send_s, du_send_e, du_recv_s, du_recv_e, &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
Expand All @@ -89,7 +90,7 @@ subroutine exec_dist_transeq_3fused( &
n_data = SZ*1*blocks%x

call transeq_3fused_dist<<<blocks, threads>>>( & !&
du, dud, d2u, &
r_du, dud, d2u, &
du_send_s, du_send_e, &
dud_send_s, dud_send_e, &
d2u_send_s, d2u_send_e, &
Expand All @@ -111,7 +112,7 @@ subroutine exec_dist_transeq_3fused( &
)

call transeq_3fused_subs<<<blocks, threads>>>( & !&
r_u, v, du, dud, d2u, &
r_du, v, dud, d2u, &
du_recv_s, du_recv_e, &
dud_recv_s, dud_recv_e, &
d2u_recv_s, d2u_recv_e, &
Expand Down
20 changes: 10 additions & 10 deletions src/cuda/kernels/distributed.f90
Original file line number Diff line number Diff line change
Expand Up @@ -571,15 +571,16 @@ attributes(global) subroutine transeq_3fused_dist( &
end subroutine transeq_3fused_dist

attributes(global) subroutine transeq_3fused_subs( &
r_u, conv, du, dud, d2u, &
r_du, conv, dud, d2u, &
recv_du_s, recv_du_e, recv_dud_s, recv_dud_e, recv_d2u_s, recv_d2u_e, &
d1_sa, d1_sc, d2_sa, d2_sc, n, nu &
)
implicit none

! Arguments
real(dp), device, intent(out), dimension(:, :, :) :: r_u
real(dp), device, intent(in), dimension(:, :, :) :: conv, du, dud, d2u
!> The result array, it stores 'du' first then its overwritten
real(dp), device, intent(inout), dimension(:, :, :) :: r_du
real(dp), device, intent(in), dimension(:, :, :) :: conv, dud, d2u
real(dp), device, intent(in), dimension(:, :, :) :: &
recv_du_s, recv_du_e, recv_dud_s, recv_dud_e, recv_d2u_s, recv_d2u_e
real(dp), device, intent(in), dimension(:) :: d1_sa, d1_sc, d2_sa, d2_sc
Expand Down Expand Up @@ -610,7 +611,7 @@ attributes(global) subroutine transeq_3fused_subs( &
ur = d1_sa(1)
recp = 1._dp/(1._dp - ur*bl)

du_s = recp*(du(i, 1, b) - bl*recv_du_s(i, 1, b))
du_s = recp*(r_du(i, 1, b) - bl*recv_du_s(i, 1, b))
dud_s = recp*(dud(i, 1, b) - bl*recv_dud_s(i, 1, b))

! second derivative
Expand All @@ -627,7 +628,7 @@ attributes(global) subroutine transeq_3fused_subs( &
ur = d1_sc(n)
recp = 1._dp/(1._dp - ur*bl)

du_e = recp*(du(i, n, b) - ur*recv_du_e(i, 1, b))
du_e = recp*(r_du(i, n, b) - ur*recv_du_e(i, 1, b))
dud_e = recp*(dud(i, n, b) - ur*recv_dud_e(i, 1, b))

! second derivative
Expand All @@ -638,15 +639,14 @@ attributes(global) subroutine transeq_3fused_subs( &
d2u_e = recp*(d2u(i, n, b) - ur*recv_d2u_e(i, 1, b))

! final substitution
r_u(i, 1, b) = -0.5_dp*(conv(i, 1, b)*du_s + dud_s) + nu*d2u_s
r_du(i, 1, b) = -0.5_dp*(conv(i, 1, b)*du_s + dud_s) + nu*d2u_s
do j = 2, n - 1
du_temp = (du(i, j, b) - d1_sa(j)*du_s - d1_sc(j)*du_e)
du_temp = (r_du(i, j, b) - d1_sa(j)*du_s - d1_sc(j)*du_e)
dud_temp = (dud(i, j, b) - d1_sa(j)*dud_s - d1_sc(j)*dud_e)
d2u_temp = (d2u(i, j, b) - d2_sa(j)*d2u_s - d2_sc(j)*d2u_e)
r_u(i, j, b) = -0.5_dp*(conv(i, j, b)*du_temp + dud_temp) &
+ nu*d2u_temp
r_du(i, j, b) = -0.5_dp*(conv(i, j, b)*du_temp + dud_temp) + nu*d2u_temp
end do
r_u(i, n, b) = -0.5_dp*(conv(i, n, b)*du_e + dud_e) + nu*d2u_e
r_du(i, n, b) = -0.5_dp*(conv(i, n, b)*du_e + dud_e) + nu*d2u_e

end subroutine transeq_3fused_subs

Expand Down
17 changes: 8 additions & 9 deletions src/omp/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -233,30 +233,30 @@ subroutine transeq_halo_exchange(self, u, v, w, dir)

end subroutine transeq_halo_exchange

subroutine transeq_dist_component(self, rhs, u, conv, &
subroutine transeq_dist_component(self, rhs_du, u, conv, &
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, add a comment explaining that rhs_du is where the output is stored.

u_recv_s, u_recv_e, &
conv_recv_s, conv_recv_e, &
tdsops_du, tdsops_dud, tdsops_d2u, dir)
!! Computes RHS_x^u following:
!!
!! rhs_x^u = -0.5*(conv*du/dx + d(u*conv)/dx) + nu*d2u/dx2
!! Computes RHS_x^u following:
!!
!! rhs_x^u = -0.5*(conv*du/dx + d(u*conv)/dx) + nu*d2u/dx2
class(omp_backend_t) :: self
class(field_t), intent(inout) :: rhs
!> The result field, it is also used as temporary storage
class(field_t), intent(inout) :: rhs_du
class(field_t), intent(in) :: u, conv
real(dp), dimension(:, :, :), intent(in) :: u_recv_s, u_recv_e, &
conv_recv_s, conv_recv_e
class(tdsops_t), intent(in) :: tdsops_du
class(tdsops_t), intent(in) :: tdsops_dud
class(tdsops_t), intent(in) :: tdsops_d2u
integer, intent(in) :: dir
class(field_t), pointer :: du, d2u, dud
class(field_t), pointer :: d2u, dud

du => self%allocator%get_block(dir, VERT)
dud => self%allocator%get_block(dir, VERT)
d2u => self%allocator%get_block(dir, VERT)

call exec_dist_transeq_compact( &
rhs%data, du%data, dud%data, d2u%data, &
rhs_du%data, dud%data, d2u%data, &
self%du_send_s, self%du_send_e, self%du_recv_s, self%du_recv_e, &
self%dud_send_s, self%dud_send_e, self%dud_recv_s, self%dud_recv_e, &
self%d2u_send_s, self%d2u_send_e, self%d2u_recv_s, self%d2u_recv_e, &
Expand All @@ -266,7 +266,6 @@ subroutine transeq_dist_component(self, rhs, u, conv, &
self%mesh%par%nproc_dir(dir), self%mesh%par%pprev(dir), &
self%mesh%par%pnext(dir), self%mesh%get_n_groups(dir))

call self%allocator%release_block(du)
call self%allocator%release_block(dud)
call self%allocator%release_block(d2u)

Expand Down
17 changes: 10 additions & 7 deletions src/omp/exec_dist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ subroutine exec_dist_tds_compact( &
end subroutine exec_dist_tds_compact

subroutine exec_dist_transeq_compact( &
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment explaining that rhs_du is where the results are also stored.

rhs, du, dud, d2u, &
rhs_du, dud, d2u, &
du_send_s, du_send_e, du_recv_s, du_recv_e, &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e, &
Expand All @@ -71,8 +71,10 @@ subroutine exec_dist_transeq_compact( &

implicit none

! du = d(u)
real(dp), dimension(:, :, :), intent(out) :: rhs, du, dud, d2u
!> The result array, it is also used as temporary storage
real(dp), dimension(:, :, :), intent(out) :: rhs_du
!> Temporary storage arrays
real(dp), dimension(:, :, :), intent(out) :: dud, d2u

! The ones below are intent(out) just so that we can write data in them,
! not because we actually need the data they store later where this
Expand Down Expand Up @@ -109,7 +111,7 @@ subroutine exec_dist_transeq_compact( &
!$omp parallel do private(ud, ud_recv_e, ud_recv_s)
do k = 1, n_groups
call der_univ_dist( &
du(:, :, k), du_send_s(:, :, k), du_send_e(:, :, k), u(:, :, k), &
rhs_du(:, :, k), du_send_s(:, :, k), du_send_e(:, :, k), u(:, :, k), &
u_recv_s(:, :, k), u_recv_e(:, :, k), &
tdsops_du%coeffs_s, tdsops_du%coeffs_e, tdsops_du%coeffs, &
n, tdsops_du%dist_fw, tdsops_du%dist_bw, tdsops_du%dist_af &
Expand Down Expand Up @@ -162,7 +164,7 @@ subroutine exec_dist_transeq_compact( &

!$omp parallel do
do k = 1, n_groups
call der_univ_subs(du(:, :, k), &
call der_univ_subs(rhs_du(:, :, k), &
du_recv_s(:, :, k), du_recv_e(:, :, k), &
n, tdsops_du%dist_sa, tdsops_du%dist_sc)

Expand All @@ -177,8 +179,9 @@ subroutine exec_dist_transeq_compact( &
do j = 1, n
!$omp simd
do i = 1, SZ
rhs(i, j, k) = -0.5_dp*(v(i, j, k)*du(i, j, k) + dud(i, j, k)) &
+ nu*d2u(i, j, k)
rhs_du(i, j, k) = -0.5_dp*(v(i, j, k)*rhs_du(i, j, k) &
+ dud(i, j, k)) &
+ nu*d2u(i, j, k)
end do
!$omp end simd
end do
Expand Down
5 changes: 2 additions & 3 deletions tests/cuda/test_cuda_transeq.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ program test_cuda_tridiag
real(dp), allocatable, dimension(:, :, :) :: u, v, r_u
real(dp), device, allocatable, dimension(:, :, :) :: &
u_dev, v_dev, r_u_dev, & ! main fields u, v and result r_u
du_dev, dud_dev, d2u_dev ! intermediate solution arrays
dud_dev, d2u_dev ! intermediate solution arrays
real(dp), device, allocatable, dimension(:, :, :) :: &
du_recv_s_dev, du_recv_e_dev, du_send_s_dev, du_send_e_dev, &
dud_recv_s_dev, dud_recv_e_dev, dud_send_s_dev, dud_send_e_dev, &
Expand Down Expand Up @@ -65,7 +65,6 @@ program test_cuda_tridiag
! field for storing the result
allocate (r_u_dev(SZ, n, n_block))
! intermediate solution fields
allocate (du_dev(SZ, n, n_block))
allocate (dud_dev(SZ, n, n_block))
allocate (d2u_dev(SZ, n, n_block))

Expand Down Expand Up @@ -133,7 +132,7 @@ program test_cuda_tridiag
r_u_dev, &
u_dev, u_recv_s_dev, u_recv_e_dev, &
v_dev, v_recv_s_dev, v_recv_e_dev, &
du_dev, dud_dev, d2u_dev, &
dud_dev, d2u_dev, &
du_send_s_dev, du_send_e_dev, du_recv_s_dev, du_recv_e_dev, &
dud_send_s_dev, dud_send_e_dev, dud_recv_s_dev, dud_recv_e_dev, &
d2u_send_s_dev, d2u_send_e_dev, d2u_recv_s_dev, d2u_recv_e_dev, &
Expand Down
6 changes: 2 additions & 4 deletions tests/omp/test_omp_dist_transeq.f90
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ program test_transeq

logical :: allpass = .true.
real(dp), allocatable, dimension(:, :, :) :: u, v, r_u
real(dp), allocatable, dimension(:, :, :) :: du, dud, d2u ! intermediate solution arrays
real(dp), allocatable, dimension(:, :, :) :: dud, d2u ! intermediate solution arrays
real(dp), allocatable, dimension(:, :, :) :: &
du_recv_s, du_recv_e, du_send_s, du_send_e, &
dud_recv_s, dud_recv_e, dud_send_s, dud_send_e, &
Expand Down Expand Up @@ -52,7 +52,6 @@ program test_transeq
! main input fields
! field for storing the result
! intermediate solution fields
allocate (du(SZ, n, n_block))
allocate (dud(SZ, n, n_block))
allocate (d2u(SZ, n, n_block))

Expand Down Expand Up @@ -108,8 +107,7 @@ program test_transeq
SZ*4*n_block, nproc, pprev, pnext)

call exec_dist_transeq_compact( &
r_u, &
du, dud, d2u, &
r_u, dud, d2u, &
du_send_s, du_send_e, du_recv_s, du_recv_e, &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e, &
Expand Down
Loading