diff --git a/src/cuda/kernels/spectral_processing.f90 b/src/cuda/kernels/spectral_processing.f90 index 27fc8297..6d386825 100644 --- a/src/cuda/kernels/spectral_processing.f90 +++ b/src/cuda/kernels/spectral_processing.f90 @@ -231,4 +231,42 @@ attributes(global) subroutine process_spectral_010( & end subroutine process_spectral_010 + attributes(global) subroutine enforce_periodicity_y(f_out, f_in, ny) + implicit none + + real(dp), device, intent(out), dimension(:, :, :) :: f_out + real(dp), device, intent(in), dimension(:, :, :) :: f_in + integer, value, intent(in) :: ny + + integer :: i, j, k + + i = threadIdx%x + k = blockIdx%x + + do j = 1, ny/2 + f_out(i, j, k) = f_in(i, 2*j - 1, k) + f_out(i, j + ny/2, k) = f_in(i, ny - 2*j + 2, k) + end do + + end subroutine enforce_periodicity_y + + attributes(global) subroutine undo_periodicity_y(f_out, f_in, ny) + implicit none + + real(dp), device, intent(out), dimension(:, :, :) :: f_out + real(dp), device, intent(in), dimension(:, :, :) :: f_in + integer, value, intent(in) :: ny + + integer :: i, j, k + + i = threadIdx%x + k = blockIdx%x + + do j = 1, ny/2 + f_out(i, 2*j - 1, k) = f_in(i, j, k) + f_out(i, 2*j, k) = f_in(i, ny - j + 1, k) + end do + + end subroutine undo_periodicity_y + end module m_cuda_spectral diff --git a/src/cuda/poisson_fft.f90 b/src/cuda/poisson_fft.f90 index f958eab6..630d4cfb 100644 --- a/src/cuda/poisson_fft.f90 +++ b/src/cuda/poisson_fft.f90 @@ -13,7 +13,8 @@ module m_cuda_poisson_fft use m_tdsops, only: dirps_t use m_cuda_allocator, only: cuda_field_t - use m_cuda_spectral, only: process_spectral_000, process_spectral_010 + use m_cuda_spectral, only: process_spectral_000, process_spectral_010, & + enforce_periodicity_y, undo_periodicity_y implicit none @@ -254,6 +255,24 @@ subroutine enforce_periodicity_y_cuda(self, f_out, f_in) class(field_t), intent(inout) :: f_out class(field_t), intent(in) :: f_in + real(dp), device, pointer, dimension(:, :, :) :: f_out_dev, f_in_dev + type(dim3) :: blocks, threads + + select type (f_out) + type is (cuda_field_t) + f_out_dev => f_out%data_d + end select + select type (f_in) + type is (cuda_field_t) + f_in_dev => f_in%data_d + end select + + blocks = dim3(self%nz_spec, 1, 1) + threads = dim3(self%nx_spec, 1, 1) + call enforce_periodicity_y<<>>( & !& + f_out_dev, f_in_dev, self%ny_spec & + ) + end subroutine enforce_periodicity_y_cuda subroutine undo_periodicity_y_cuda(self, f_out, f_in) @@ -263,6 +282,24 @@ subroutine undo_periodicity_y_cuda(self, f_out, f_in) class(field_t), intent(inout) :: f_out class(field_t), intent(in) :: f_in + real(dp), device, pointer, dimension(:, :, :) :: f_out_dev, f_in_dev + type(dim3) :: blocks, threads + + select type (f_out) + type is (cuda_field_t) + f_out_dev => f_out%data_d + end select + select type (f_in) + type is (cuda_field_t) + f_in_dev => f_in%data_d + end select + + blocks = dim3(self%nz_spec, 1, 1) + threads = dim3(self%nx_spec, 1, 1) + call undo_periodicity_y<<>>( & !& + f_out_dev, f_in_dev, self%ny_spec & + ) + end subroutine undo_periodicity_y_cuda end module m_cuda_poisson_fft