From 3abac46bc9a9f45079f74c59518e7999d34bd468 Mon Sep 17 00:00:00 2001 From: Sebastian Grimberg Date: Thu, 1 Feb 2024 18:18:49 -0800 Subject: [PATCH] Parallel prolongation and restriction sometimes do not have appropriate AddMult/AddMultTranspose overrides, causing temporary vector allocation So, fix it by handling this at the ParOperator level. --- palace/linalg/rap.cpp | 118 ++++++++++-------------------------------- palace/linalg/rap.hpp | 2 - 2 files changed, 26 insertions(+), 94 deletions(-) diff --git a/palace/linalg/rap.cpp b/palace/linalg/rap.cpp index 457f22511..4773ac9aa 100644 --- a/palace/linalg/rap.cpp +++ b/palace/linalg/rap.cpp @@ -52,18 +52,21 @@ void ParOperator::EliminateRHS(const Vector &x, Vector &b) const } MFEM_VERIFY(A, "No local matrix available for ParOperator::EliminateRHS!"); - auto &tx = trial_fespace.GetTVector(); auto &lx = trial_fespace.GetLVector(); auto &ly = GetTestLVector(); - tx = 0.0; - linalg::SetSubVector(tx, *dbc_tdof_list, x); - trial_fespace.GetProlongationMatrix()->Mult(tx, lx); + { + auto &tx = trial_fespace.GetTVector(); + tx = 0.0; + linalg::SetSubVector(tx, *dbc_tdof_list, x); + trial_fespace.GetProlongationMatrix()->Mult(tx, lx); + } // Apply the unconstrained operator. A->Mult(lx, ly); - ly *= -1.0; - RestrictionMatrixAddMult(ly, b); + auto &ty = test_fespace.GetTVector(); + RestrictionMatrixMult(ly, ty); + b.Add(-1.0, ty); if (diag_policy == DiagonalPolicy::DIAG_ONE) { linalg::SetSubVector(b, *dbc_tdof_list, x); @@ -292,10 +295,10 @@ void ParOperator::AddMult(const Vector &x, Vector &y, const double a) const // Apply the operator on the L-vector. A->Mult(lx, ly); + auto &ty = test_fespace.GetTVector(); + RestrictionMatrixMult(ly, ty); if (dbc_tdof_list) { - auto &ty = test_fespace.GetTVector(); - RestrictionMatrixMult(ly, ty); if (diag_policy == DiagonalPolicy::DIAG_ONE) { linalg::SetSubVector(ty, *dbc_tdof_list, x); @@ -304,16 +307,8 @@ void ParOperator::AddMult(const Vector &x, Vector &y, const double a) const { linalg::SetSubVector(ty, *dbc_tdof_list, 0.0); } - y.Add(a, ty); - } - else - { - if (a != 1.0) - { - ly *= a; - } - RestrictionMatrixAddMult(ly, y); } + y.Add(a, ty); } void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) const @@ -343,10 +338,10 @@ void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) c // Apply the operator on the L-vector. A->MultTranspose(ly, lx); + auto &tx = trial_fespace.GetTVector(); + trial_fespace.GetProlongationMatrix()->MultTranspose(lx, tx); if (dbc_tdof_list) { - auto &tx = trial_fespace.GetTVector(); - trial_fespace.GetProlongationMatrix()->MultTranspose(lx, tx); if (diag_policy == DiagonalPolicy::DIAG_ONE) { linalg::SetSubVector(tx, *dbc_tdof_list, x); @@ -355,16 +350,8 @@ void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) c { linalg::SetSubVector(tx, *dbc_tdof_list, 0.0); } - y.Add(a, tx); - } - else - { - if (a != 1.0) - { - lx *= a; - } - trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx, y); } + y.Add(a, tx); } void ParOperator::RestrictionMatrixMult(const Vector &ly, Vector &ty) const @@ -379,18 +366,6 @@ void ParOperator::RestrictionMatrixMult(const Vector &ly, Vector &ty) const } } -void ParOperator::RestrictionMatrixAddMult(const Vector &ly, Vector &ty) const -{ - if (!use_R) - { - test_fespace.GetProlongationMatrix()->AddMultTranspose(ly, ty); - } - else - { - test_fespace.GetRestrictionMatrix()->AddMult(ly, ty); - } -} - void ParOperator::RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const { if (!use_R) @@ -623,10 +598,10 @@ void ComplexParOperator::AddMult(const ComplexVector &x, ComplexVector &y, // Apply the operator on the L-vector. A->Mult(lx, ly); + auto &ty = test_fespace.GetTVector(); + RestrictionMatrixMult(ly, ty); if (dbc_tdof_list) { - auto &ty = test_fespace.GetTVector(); - RestrictionMatrixMult(ly, ty); if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE) { linalg::SetSubVector(ty, *dbc_tdof_list, x); @@ -635,16 +610,8 @@ void ComplexParOperator::AddMult(const ComplexVector &x, ComplexVector &y, { linalg::SetSubVector(ty, *dbc_tdof_list, 0.0); } - y.AXPY(a, ty); - } - else - { - if (a != 1.0) - { - ly *= a; - } - RestrictionMatrixAddMult(ly, y); } + y.AXPY(a, ty); } void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector &y, @@ -669,11 +636,11 @@ void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector // Apply the operator on the L-vector. A->MultTranspose(ly, lx); + auto &tx = trial_fespace.GetTVector(); + trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real()); + trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag()); if (dbc_tdof_list) { - auto &tx = trial_fespace.GetTVector(); - trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real()); - trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag()); if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE) { linalg::SetSubVector(tx, *dbc_tdof_list, x); @@ -682,17 +649,8 @@ void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector { linalg::SetSubVector(tx, *dbc_tdof_list, 0.0); } - y.AXPY(a, tx); - } - else - { - if (a != 1.0) - { - lx *= a; - } - trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Real(), y.Real()); - trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Imag(), y.Imag()); } + y.AXPY(a, tx); } void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y, @@ -717,11 +675,11 @@ void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, Compl // Apply the operator on the L-vector. A->MultHermitianTranspose(ly, lx); + auto &tx = trial_fespace.GetTVector(); + trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real()); + trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag()); if (dbc_tdof_list) { - auto &tx = trial_fespace.GetTVector(); - trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real()); - trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag()); if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE) { linalg::SetSubVector(tx, *dbc_tdof_list, x); @@ -730,17 +688,8 @@ void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, Compl { linalg::SetSubVector(tx, *dbc_tdof_list, 0.0); } - y.AXPY(a, tx); - } - else - { - if (a != 1.0) - { - lx *= a; - } - trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Real(), y.Real()); - trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Imag(), y.Imag()); } + y.AXPY(a, tx); } void ComplexParOperator::RestrictionMatrixMult(const ComplexVector &ly, @@ -758,21 +707,6 @@ void ComplexParOperator::RestrictionMatrixMult(const ComplexVector &ly, } } -void ComplexParOperator::RestrictionMatrixAddMult(const ComplexVector &ly, - ComplexVector &ty) const -{ - if (!use_R) - { - test_fespace.GetProlongationMatrix()->AddMultTranspose(ly.Real(), ty.Real()); - test_fespace.GetProlongationMatrix()->AddMultTranspose(ly.Imag(), ty.Imag()); - } - else - { - test_fespace.GetRestrictionMatrix()->AddMult(ly.Real(), ty.Real()); - test_fespace.GetRestrictionMatrix()->AddMult(ly.Imag(), ty.Imag()); - } -} - void ComplexParOperator::RestrictionMatrixMultTranspose(const ComplexVector &ty, ComplexVector &ly) const { diff --git a/palace/linalg/rap.hpp b/palace/linalg/rap.hpp index 2ee99d3e0..4c9d36174 100644 --- a/palace/linalg/rap.hpp +++ b/palace/linalg/rap.hpp @@ -43,7 +43,6 @@ class ParOperator : public Operator // Helper methods for operator application. void RestrictionMatrixMult(const Vector &ly, Vector &ty) const; - void RestrictionMatrixAddMult(const Vector &ly, Vector &ty) const; void RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const; Vector &GetTestLVector() const; @@ -130,7 +129,6 @@ class ComplexParOperator : public ComplexOperator // Helper methods for operator application. void RestrictionMatrixMult(const ComplexVector &ly, ComplexVector &ty) const; - void RestrictionMatrixAddMult(const ComplexVector &ly, ComplexVector &ty) const; void RestrictionMatrixMultTranspose(const ComplexVector &ty, ComplexVector &ly) const; ComplexVector &GetTestLVector() const;