From b54e4199625e4ecd012c80340a07f75067aba204 Mon Sep 17 00:00:00 2001 From: jchristopherson Date: Fri, 19 Jan 2024 06:04:17 -0600 Subject: [PATCH] Add banded matrix-vector multiplication --- src/blas.f90 | 18 +++ src/linalg.f90 | 179 ++++++++++++++++++++++++ src/linalg_basic.f90 | 308 ++++++++++++++++++++++++++++++++++++++++++ tests/linalg_test.f90 | 6 + tests/test_misc.f90 | 104 ++++++++++++++ 5 files changed, 615 insertions(+) diff --git a/src/blas.f90 b/src/blas.f90 index bb8f1e9b..17ac0e33 100644 --- a/src/blas.f90 +++ b/src/blas.f90 @@ -89,5 +89,23 @@ subroutine ZDSCAL(n, da, zx, incx) real(real64), intent(in) :: da complex(real64), intent(inout) :: zx(*) end subroutine + + subroutine DGBMV(trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, & + incy) + use iso_fortran_env, only : int32, real64 + character, intent(in) :: trans + integer(int32), intent(in) :: m, n, kl, ku, lda, incx, incy + real(real64), intent(in) :: alpha, beta, a(lda,*), x(*) + real(real64), intent(inout) :: y(*) + end subroutine + + subroutine ZGBMV(trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, & + incy) + use iso_fortran_env, only : int32, real64 + character, intent(in) :: trans + integer(int32), intent(in) :: m, n, kl, ku, lda, incx, incy + complex(real64), intent(in) :: alpha, beta, a(lda,*), x(*) + complex(real64), intent(inout) :: y(*) + end subroutine end interface end module \ No newline at end of file diff --git a/src/linalg.f90 b/src/linalg.f90 index f72a3b19..e0ea0df6 100644 --- a/src/linalg.f90 +++ b/src/linalg.f90 @@ -184,6 +184,8 @@ module linalg public :: form_lq public :: mult_lq public :: solve_lq + public :: band_mtx_mult + public :: band_mtx_to_full_mtx public :: LA_NO_OPERATION public :: LA_TRANSPOSE public :: LA_HERMITIAN_TRANSPOSE @@ -3796,6 +3798,148 @@ module linalg module procedure :: solve_lq_vec_cmplx end interface +! ------------------------------------------------------------------------------ +!> @brief Multiplies a banded matrix, A, by a vector x such that +!! alpha * op(A) * x + beta * y = y. +!! +!! @par Syntax 1 +!! @code{.f90} +!! subroutine band_mtx_mult( & +!! logical trans, & +!! integer(int32) kl, & +!! integer(int32) ku, & +!! real(real64) a(:,:), & +!! real(real64) x(:), & +!! real(real64) beta, & +!! real(real64) y(:), & +!! optional class(errors) err & +!! ) +!! @endcode +!! +!! @param[in] trans Set to true for op(A) == A**T; else, false for op(A) == A. +!! @param[in] kl The number of subdiagonals. Must be at least 0. +!! @param[in] ku The number of superdiagonals. Must be at least 0. +!! @param[in] alpha A scalar multiplier. +!! @param[in] a The M-by-N matrix A storing the banded matrix in a compressed +!! form supplied column by column. The following code segment transfers +!! between a full matrix to the banded matrix storage scheme. +!! @code{.f90} +!! do j = 1, n +!! k = ku + 1 - j +!! do i = max(1, j - ku), min(m, j + kl) +!! a(k + i, j) = matrix(i, j) +!! end do +!! end do +!! @endcode +!! @param[in] x If @p trans is true, this is an M-element vector; else, if +!! @p trans is false, this is an N-element vector. +!! @param[in] beta A scalar multiplier. +!! @param[in,out] On input, the vector Y. On output, the resulting vector. +!! if @p trans is true, this vector is an N-element vector; else, it is an +!! M-element vector. +!! @param[in,out] err An optional errors-based object that if provided can be +!! used to retrieve information relating to any errors encountered during +!! execution. If not provided, a default implementation of the errors +!! class is used internally to provide error handling. Possible errors and +!! warning messages that may be encountered are as follows. +!! - LA_ARRAY_SIZE_ERROR: Occurs if any of the input arrays are not sized +!! appropriately. +!! - LA_INVALID_INPUT_ERROR: Occurs if either @p ku or @p kl are not zero or +!! greater. +!! +!! @par Syntax 2 +!! @code{.f90} +!! subroutine band_mtx_mult( & +!! integer(int32) trans, & +!! integer(int32) kl, & +!! integer(int32) ku, & +!! complex(real64) a(:,:), & +!! complex(real64) x(:), & +!! complex(real64) beta, & +!! complex(real64) y(:), & +!! optional class(errors) err & +!! ) +!! @endcode +!! +!! @param[in] trans set to LA_TRANSPOSE if \f$ op(A) = A^T \f$, set to +!! LA_HERMITIAN_TRANSPOSE if \f$ op(A) = A^H \f$, otherwise set to +!! LA_NO_OPERATION if \f$ op(A) = A \f$. +!! @param[in] kl The number of subdiagonals. Must be at least 0. +!! @param[in] ku The number of superdiagonals. Must be at least 0. +!! @param[in] alpha A scalar multiplier. +!! @param[in] a The M-by-N matrix A storing the banded matrix in a compressed +!! form supplied column by column. The following code segment transfers +!! between a full matrix to the banded matrix storage scheme. +!! @code{.f90} +!! do j = 1, n +!! k = ku + 1 - j +!! do i = max(1, j - ku), min(m, j + kl) +!! a(k + i, j) = matrix(i, j) +!! end do +!! end do +!! @endcode +!! @param[in] x If @p trans is true, this is an M-element vector; else, if +!! @p trans is false, this is an N-element vector. +!! @param[in] beta A scalar multiplier. +!! @param[in,out] On input, the vector Y. On output, the resulting vector. +!! if @p trans is true, this vector is an N-element vector; else, it is an +!! M-element vector. +!! @param[in,out] err An optional errors-based object that if provided can be +!! used to retrieve information relating to any errors encountered during +!! execution. If not provided, a default implementation of the errors +!! class is used internally to provide error handling. Possible errors and +!! warning messages that may be encountered are as follows. +!! - LA_ARRAY_SIZE_ERROR: Occurs if any of the input arrays are not sized +!! appropriately. +!! - LA_INVALID_INPUT_ERROR: Occurs if either @p ku or @p kl are not zero or +!! greater. +interface band_mtx_mult + module procedure :: band_mtx_vec_mult_dbl + module procedure :: band_mtx_vec_mult_cmplx +end interface + +!> @brief Converts a banded matrix stored in dense form to a full matrix. +!! +!! @par Syntax 1 +!! @code{.f90} +!! subroutine band_mtx_to_full_mtx( & +!! integer(int32) kl, & +!! integer(int32) ku, & +!! real(real64) b(:,:), & +!! real(real64) f(:,:), & +!! optional class(errors) err & +!! ) +!! @endcode +!! +!! @par Syntax 2 +!! @code{.f90} +!! subroutine band_mtx_to_full_mtx( & +!! integer(int32) kl, & +!! integer(int32) ku, & +!! complex(real64) b(:,:), & +!! complex(real64) f(:,:), & +!! optional class(errors) err & +!! ) +!! @endcode +!! +!! @param[in] kl The number of subdiagonals. Must be at least 0. +!! @param[in] ku The number of superdiagonals. Must be at least 0. +!! @param[in] b The banded matrix to convert, stored in dense form. See +!! @ref band_mtx_vec_mult for details on this storage method. +!! @param[out] f The M-by-N element full matrix. +!! @param[in,out] err An optional errors-based object that if provided can be +!! used to retrieve information relating to any errors encountered during +!! execution. If not provided, a default implementation of the errors +!! class is used internally to provide error handling. Possible errors and +!! warning messages that may be encountered are as follows. +!! - LA_ARRAY_SIZE_ERROR: Occurs if @p b and @p f are not compatible in size. +!! - LA_INVALID_INPUT_ERROR: Occurs if either @p ku or @p kl are not zero or +!! greater. +interface band_mtx_to_full_mtx + module procedure :: band_to_full_mtx_dbl + module procedure :: band_to_full_mtx_cmplx +end interface + ! ****************************************************************************** ! LINALG_BASIC.F90 ! ------------------------------------------------------------------------------ @@ -3994,6 +4138,41 @@ module subroutine tri_mtx_mult_cmplx(upper, alpha, a, beta, b, err) class(errors), intent(inout), optional, target :: err end subroutine + module subroutine band_mtx_vec_mult_dbl(trans, kl, ku, alpha, a, x, beta, & + y, err) + logical, intent(in) :: trans + integer(int32), intent(in) :: kl, ku + real(real64), intent(in) :: alpha, beta + real(real64), intent(in), dimension(:,:) :: a + real(real64), intent(in), dimension(:) :: x + real(real64), intent(inout), dimension(:) :: y + class(errors), intent(inout), optional, target :: err + end subroutine + + module subroutine band_mtx_vec_mult_cmplx(trans, kl, ku, alpha, a, x, & + beta, y, err) + integer(int32), intent(in) :: trans + integer(int32), intent(in) :: kl, ku + complex(real64), intent(in) :: alpha, beta + complex(real64), intent(in), dimension(:,:) :: a + complex(real64), intent(in), dimension(:) :: x + complex(real64), intent(inout), dimension(:) :: y + class(errors), intent(inout), optional, target :: err + end subroutine + + module subroutine band_to_full_mtx_dbl(kl, ku, b, f, err) + integer(int32), intent(in) :: kl, ku + real(real64), intent(in), dimension(:,:) :: b + real(real64), intent(out), dimension(:,:) :: f + class(errors), intent(inout), optional, target :: err + end subroutine + + module subroutine band_to_full_mtx_cmplx(kl, ku, b, f, err) + integer(int32), intent(in) :: kl, ku + complex(real64), intent(in), dimension(:,:) :: b + complex(real64), intent(out), dimension(:,:) :: f + class(errors), intent(inout), optional, target :: err + end subroutine end interface ! ****************************************************************************** diff --git a/src/linalg_basic.f90 b/src/linalg_basic.f90 index 5ece43e5..bcb09784 100644 --- a/src/linalg_basic.f90 +++ b/src/linalg_basic.f90 @@ -2223,5 +2223,313 @@ module subroutine tri_mtx_mult_cmplx(upper, alpha, a, beta, b, err) 100 format(A, I0, A, I0, A, I0, A, I0, A, I0, A) end subroutine +! ****************************************************************************** +! BANDED MATRIX MULTIPLICATION ROUTINES +! ------------------------------------------------------------------------------ + module subroutine band_mtx_vec_mult_dbl(trans, kl, ku, alpha, a, x, beta, & + y, err) + ! Arguments + logical, intent(in) :: trans + integer(int32), intent(in) :: kl, ku + real(real64), intent(in) :: alpha, beta + real(real64), intent(in), dimension(:,:) :: a + real(real64), intent(in), dimension(:) :: x + real(real64), intent(inout), dimension(:) :: y + class(errors), intent(inout), optional, target :: err + + ! Local Variables + integer(int32) :: m, n + class(errors), pointer :: errmgr + type(errors), target :: deferr + + ! Initialization + if (present(err)) then + errmgr => err + else + errmgr => deferr + end if + if (trans) then + m = size(x) + n = size(y) + else + m = size(y) + n = size(x) + end if + + ! Input Checking + if (kl < 0) go to 10 + if (ku < 0) go to 20 + if (size(a, 1) /= kl + ku + 1) go to 30 + if (size(a, 2) /= n) go to 30 + + ! Process + if (trans) then + call DGBMV("T", m, n, kl, ku, alpha, a, size(a, 1), x, 1, beta, y, 1) + else + call DGBMV("N", m, n, kl, ku, alpha, a, size(a, 1), x, 1, beta, y, 1) + end if + + ! End + return + + ! KL < 0 +10 continue + call errmgr%report_error("band_mtx_vec_mult_dbl", & + "The number of subdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! KU < 0 +20 continue + call errmgr%report_error("band_mtx_vec_mult_dbl", & + "The number of superdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! A is incorrectly sized +30 continue + call errmgr%report_error("band_mtx_vec_mult_dbl", & + "The size of matrix A is not compatible with the other vectors.", & + LA_ARRAY_SIZE_ERROR) + return + end subroutine + +! ------------------------------------------------------------------------------ + module subroutine band_mtx_vec_mult_cmplx(trans, kl, ku, alpha, a, x, & + beta, y, err) + ! Arguments + integer(int32), intent(in) :: trans + integer(int32), intent(in) :: kl, ku + complex(real64), intent(in) :: alpha, beta + complex(real64), intent(in), dimension(:,:) :: a + complex(real64), intent(in), dimension(:) :: x + complex(real64), intent(inout), dimension(:) :: y + class(errors), intent(inout), optional, target :: err + + ! Local Variables + character :: op + logical :: trns + integer(int32) :: m, n + class(errors), pointer :: errmgr + type(errors), target :: deferr + + ! Initialization + if (present(err)) then + errmgr => err + else + errmgr => deferr + end if + if (trans == LA_TRANSPOSE) then + op = "T" + trns = .true. + else if (trans == LA_HERMITIAN_TRANSPOSE) then + op = "C" + trns = .true. + else + op = "N" + trns = .false. + end if + if (trns) then + m = size(x) + n = size(y) + else + m = size(y) + n = size(x) + end if + + ! Input Checking + if (kl < 0) go to 10 + if (ku < 0) go to 20 + if (size(a, 1) /= kl + ku + 1) go to 30 + if (size(a, 2) /= n) go to 30 + + ! Process + call ZGBMV(op, m, n, kl, ku, alpha, a, size(a, 1), x, 1, beta, y, 1) + + ! End + return + + ! KL < 0 +10 continue + call errmgr%report_error("band_mtx_vec_mult_cmplx", & + "The number of subdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! KU < 0 +20 continue + call errmgr%report_error("band_mtx_vec_mult_cmplx", & + "The number of superdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! A is incorrectly sized +30 continue + call errmgr%report_error("band_mtx_vec_mult_cmplx", & + "The size of matrix A is not compatible with the other vectors.", & + LA_ARRAY_SIZE_ERROR) + return + end subroutine + +! ------------------------------------------------------------------------------ + module subroutine band_to_full_mtx_dbl(kl, ku, b, f, err) + ! Arguments + integer(int32), intent(in) :: kl, ku + real(real64), intent(in), dimension(:,:) :: b + real(real64), intent(out), dimension(:,:) :: f + class(errors), intent(inout), optional, target :: err + + ! Parameters + real(real64), parameter :: zero = 0.0d0 + + ! Local Variables + class(errors), pointer :: errmgr + type(errors), target :: deferr + integer(int32) :: i, j, k, m, n, i1, i2 + + ! Initialization + if (present(err)) then + errmgr => err + else + errmgr => deferr + end if + m = size(f, 1) + n = size(f, 2) + + ! Input Check + if (kl < 0) go to 10 + if (ku < 0) go to 20 + if (size(b, 2) /= n) go to 30 + if (size(b, 1) /= kl + ku + 1) go to 40 + + ! Process + do j = 1, n + k = ku + 1 - j + i1 = max(1, j - ku) + i2 = min(m, j + kl) + do i = 1, i1 - 1 + f(i,j) = zero + end do + do i = i1, i2 + f(i,j) = b(k+i,j) + end do + do i = i2 + 1, m + f(i,j) = zero + end do + end do + + ! End + return + + ! KL < 0 +10 continue + call errmgr%report_error("band_to_full_mtx_dbl", & + "The number of subdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! KU < 0 +20 continue + call errmgr%report_error("band_to_full_mtx_dbl", & + "The number of superdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! A is incorrectly sized +30 continue + call errmgr%report_error("band_to_full_mtx_dbl", & + "The number of columns in the banded matrix does not match " // & + "the number of columns in the full matrix.", & + LA_ARRAY_SIZE_ERROR) + return + +40 continue + call errmgr%report_error("band_to_full_mtx_dbl", & + "The number of rows in the banded matrix does not align with " // & + "the number of sub and super-diagonals specified.", & + LA_ARRAY_SIZE_ERROR) + return + end subroutine + +! ------------------------------------------------------------------------------ + module subroutine band_to_full_mtx_cmplx(kl, ku, b, f, err) + ! Arguments + integer(int32), intent(in) :: kl, ku + complex(real64), intent(in), dimension(:,:) :: b + complex(real64), intent(out), dimension(:,:) :: f + class(errors), intent(inout), optional, target :: err + + ! Parameters + complex(real64), parameter :: zero = (0.0d0, 0.0d0) + + ! Local Variables + class(errors), pointer :: errmgr + type(errors), target :: deferr + integer(int32) :: i, j, k, m, n, i1, i2 + + ! Initialization + if (present(err)) then + errmgr => err + else + errmgr => deferr + end if + m = size(f, 1) + n = size(f, 2) + + ! Input Check + if (kl < 0) go to 10 + if (ku < 0) go to 20 + if (size(b, 2) /= n) go to 30 + if (size(b, 1) /= kl + ku + 1) go to 40 + + ! Process + do j = 1, n + k = ku + 1 - j + i1 = max(1, j - ku) + i2 = min(m, j + kl) + do i = 1, i1 - 1 + f(i,j) = zero + end do + do i = i1, i2 + f(i,j) = b(k+i,j) + end do + do i = i2 + 1, m + f(i,j) = zero + end do + end do + + ! End + return + + ! KL < 0 +10 continue + call errmgr%report_error("band_to_full_mtx_cmplx", & + "The number of subdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! KU < 0 +20 continue + call errmgr%report_error("band_to_full_mtx_cmplx", & + "The number of superdiagonals must be at least 0.", & + LA_INVALID_INPUT_ERROR) + return + + ! A is incorrectly sized +30 continue + call errmgr%report_error("band_to_full_mtx_cmplx", & + "The number of columns in the banded matrix does not match " // & + "the number of columns in the full matrix.", & + LA_ARRAY_SIZE_ERROR) + return + +40 continue + call errmgr%report_error("band_to_full_mtx_cmplx", & + "The number of rows in the banded matrix does not align with " // & + "the number of sub and super-diagonals specified.", & + LA_ARRAY_SIZE_ERROR) + return + end subroutine + ! ------------------------------------------------------------------------------ end submodule diff --git a/tests/linalg_test.f90 b/tests/linalg_test.f90 index 0d26ff8a..2319d717 100644 --- a/tests/linalg_test.f90 +++ b/tests/linalg_test.f90 @@ -288,6 +288,12 @@ program main rst = test_lq_mult_right_cmplx_ud() if (.not.rst) flag = 85 + rst = test_banded_mtx_mult_dbl() + if (.not.rst) flag = 86 + + rst = test_banded_mtx_mult_cmplx() + if (.not.rst) flag = 87 + ! End if (flag /= 0) stop flag end program diff --git a/tests/test_misc.f90 b/tests/test_misc.f90 index dfbc4528..48bb2d50 100644 --- a/tests/test_misc.f90 +++ b/tests/test_misc.f90 @@ -512,5 +512,109 @@ function test_tri_mtx_solve_1_cmplx() result(rst) end if end function +! ------------------------------------------------------------------------------ + function test_banded_mtx_mult_dbl() result(rst) + ! Arguments + logical :: rst + + ! Local Variables + integer(int32), parameter :: kl = 3 + integer(int32), parameter :: ku = 4 + integer(int32), parameter :: mb = kl + ku + 1 + integer(int32), parameter :: m = 52 + integer(int32), parameter :: n = 50 + real(real64) :: alpha, beta, a(mb,n), af(m,n), x1(n), y1(m), ans1(m) + real(real64) :: x2(m), y2(n), ans2(n) + + ! Initialization + rst = .true. + call random_number(alpha) + call random_number(beta) + call random_number(a) + call random_number(x1) + call random_number(y1) + call random_number(x2) + call random_number(y2) + + ! Construct a full matrix from the banded matrix to use for checking + call band_mtx_to_full_mtx(kl, ku, a, af) + + ! Test 1 + ans1 = alpha * matmul(af, x1) + beta * y1 + call band_mtx_mult(.false., kl, ku, alpha, a, x1, beta, y1) + if (.not.assert(ans1, y1, REAL64_TOL)) then + rst = .false. + print "(A)", "Test Failed: test_banded_mtx_mult_dbl -1" + end if + + ! Test 2 + ans2 = alpha * matmul(transpose(af), x2) + beta * y2 + call band_mtx_mult(.true., kl, ku, alpha, a, x2, beta, y2) + if (.not.assert(ans2, y2, REAL64_TOL)) then + rst = .false. + print "(A)", "Test Failed: test_banded_mtx_mult_dbl -2" + end if + end function + +! ------------------------------------------------------------------------------ + function test_banded_mtx_mult_cmplx() result(rst) + ! Arguments + logical :: rst + + ! Local Variables + integer(int32), parameter :: kl = 3 + integer(int32), parameter :: ku = 4 + integer(int32), parameter :: mb = kl + ku + 1 + integer(int32), parameter :: m = 52 + integer(int32), parameter :: n = 50 + complex(real64) :: alpha, beta, a(mb,n), af(m,n), x1(n), y1(m), ans1(m) + complex(real64) :: x2(m), y2(n), ans2(n), ans3(n) + real(real64) :: ar(mb,n), ai(mb,n), mr(m), mi(m), nr(n), ni(n) + + ! Initialization + rst = .true. + call random_number(ar) + call random_number(ai) + call random_number(mr) + call random_number(mi) + call random_number(nr) + call random_number(ni) + alpha = cmplx(ar(1,1), ai(1,1)) + beta = cmplx(ar(1,2), ai(1,2)) + a = cmplx(ar, ai) + x1 = cmplx(nr, ni) + y1 = cmplx(mr, mi) + x2 = cmplx(mr, mi) + y2 = cmplx(nr, ni) + + ! Construct a full matrix from the banded matrix to use for checking + call band_mtx_to_full_mtx(kl, ku, a, af) + + ! Test 1 + ans1 = alpha * matmul(af, x1) + beta * y1 + call band_mtx_mult(LA_NO_OPERATION, kl, ku, alpha, a, x1, beta, y1) + if (.not.assert(ans1, y1, tol = REAL64_TOL)) then + rst = .false. + print "(A)", "Test Failed: test_banded_mtx_mult_cmplx -1" + end if + + ! Test 2 + ans2 = alpha * matmul(transpose(af), x2) + beta * y2 + call band_mtx_mult(LA_TRANSPOSE, kl, ku, alpha, a, x2, beta, y2) + if (.not.assert(ans2, y2, tol = REAL64_TOL)) then + rst = .false. + print "(A)", "Test Failed: test_banded_mtx_mult_cmplx -2" + end if + + ! Test 3 + ans3 = alpha * matmul(conjg(transpose(af)), x2) + beta * y2 + call band_mtx_mult(LA_HERMITIAN_TRANSPOSE, kl, ku, alpha, a, x2, & + beta, y2) + if (.not.assert(ans3, y2, tol = REAL64_TOL)) then + rst = .false. + print "(A)", "Test Failed: test_banded_mtx_mult_cmplx -3" + end if + end function + ! ------------------------------------------------------------------------------ end module