Skip to content

Commit

Permalink
Merge pull request #325 from awslabs/hughcars/new_postpro
Browse files Browse the repository at this point in the history
Suggested revisions for phdum/new_postpro
  • Loading branch information
phdum-a authored Jan 22, 2025
2 parents 20660d6 + 4630bb0 commit d3e0ce8
Show file tree
Hide file tree
Showing 16 changed files with 458 additions and 793 deletions.
2 changes: 1 addition & 1 deletion cmake/ExternalMFEM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug|debug|DEBUG")
endif()

# Replace mfem abort calls with exceptions for testing, default off
set(PALACE_MFEM_USE_EXCEPTIONS CACHE NO "MFEM throw exceptsions instead of abort calls")
set(PALACE_MFEM_USE_EXCEPTIONS NO)

set(MFEM_OPTIONS ${PALACE_SUPERBUILD_DEFAULT_ARGS})
list(APPEND MFEM_OPTIONS
Expand Down
100 changes: 41 additions & 59 deletions palace/drivers/basesolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "linalg/ksp.hpp"
#include "models/domainpostoperator.hpp"
#include "models/postoperator.hpp"
#include "models/spaceoperator.hpp"
#include "models/surfacepostoperator.hpp"
#include "utils/communication.hpp"
#include "utils/dorfler.hpp"
Expand Down Expand Up @@ -287,14 +286,12 @@ void BaseSolver::SaveMetadata(const Timer &timer) const
}
}

BaseSolver::DomainsPostPrinter::DomainsPostPrinter(bool do_measurement, bool root,
const fs::path &post_dir,
BaseSolver::DomainsPostPrinter::DomainsPostPrinter(const fs::path &post_dir,
const PostOperator &post_op,
const std::string &idx_col_name,
int n_expected_rows)
: do_measurement_{do_measurement}, root_{root}
{
if (!do_measurement_ || !root_)
if (!Mpi::Root(post_op.GetComm()))
{
return;
}
Expand Down Expand Up @@ -323,7 +320,7 @@ void BaseSolver::DomainsPostPrinter::AddMeasurement(double idx_value_dimensionfu
const PostOperator &post_op,
const IoData &iodata)
{
if (!do_measurement_ || !root_)
if (!Mpi::Root(post_op.GetComm()))
{
return;
}
Expand Down Expand Up @@ -357,26 +354,17 @@ void BaseSolver::DomainsPostPrinter::AddMeasurement(double idx_value_dimensionfu
domain_E.WriteFullTableTrunc();
}

BaseSolver::SurfacesPostPrinter::SurfacesPostPrinter(bool do_measurement, bool root,
const fs::path &post_dir,
BaseSolver::SurfacesPostPrinter::SurfacesPostPrinter(const fs::path &post_dir,
const PostOperator &post_op,
const std::string &idx_col_name,
int n_expected_rows)
: root_{root},
do_measurement_flux_(do_measurement //
&& post_op.GetSurfacePostOp().flux_surfs.size() > 0 // Has flux
),
do_measurement_eps_(do_measurement //
&& post_op.GetSurfacePostOp().eps_surfs.size() > 0 // Has eps
)
{
if (!root_)
if (!Mpi::Root(post_op.GetComm()))
{
return;
}
using fmt::format;

if (do_measurement_flux_)
if (post_op.GetSurfacePostOp().flux_surfs.size() > 0)
{
surface_F = TableWithCSVFile(post_dir / "surface-F.csv");
surface_F.table.reserve(n_expected_rows,
Expand Down Expand Up @@ -425,7 +413,7 @@ BaseSolver::SurfacesPostPrinter::SurfacesPostPrinter(bool do_measurement, bool r
surface_F.AppendHeader();
}

if (do_measurement_eps_)
if (post_op.GetSurfacePostOp().eps_surfs.size() > 0)
{
surface_Q = TableWithCSVFile(post_dir / "surface-Q.csv");
surface_Q.table.reserve(n_expected_rows,
Expand All @@ -444,15 +432,11 @@ void BaseSolver::SurfacesPostPrinter::AddMeasurementFlux(double idx_value_dimens
const PostOperator &post_op,
const IoData &iodata)
{
if (!do_measurement_flux_ || !root_)
{
return;
}
using VT = IoData::ValueType;
using fmt::format;

const bool has_imaginary = post_op.HasImag();
auto flux_data_vec = post_op.GetSurfaceFluxAll();
auto flux_data_vec = post_op.GetSurfaceFluxes();
auto dimensionlize_flux = [&iodata](auto Phi, SurfaceFluxType flux_type)
{
switch (flux_type)
Expand Down Expand Up @@ -489,15 +473,11 @@ void BaseSolver::SurfacesPostPrinter::AddMeasurementEps(double idx_value_dimensi
const PostOperator &post_op,
const IoData &iodata)
{
if (!do_measurement_eps_ || !root_)
{
return;
}
using VT = IoData::ValueType;
using fmt::format;

// Interface Participation adds energy contriutions E_elec + E_cap
// E_cap returns zero if the solver does not supprot lumped ports.
// E_cap returns zero if the solver does not support lumped ports.
double E_elec = post_op.GetEFieldEnergy() + post_op.GetLumpedCapacitorEnergy();
auto eps_data_vec = post_op.GetInterfaceEFieldEnergyAll();

Expand All @@ -517,35 +497,36 @@ void BaseSolver::SurfacesPostPrinter::AddMeasurement(double idx_value_dimensionf
const PostOperator &post_op,
const IoData &iodata)
{
if (!Mpi::Root(post_op.GetComm()))
{
return;
}
// If surfaces have been specified for postprocessing, compute the corresponding values
// and write out to disk. The passed in E_elec is the sum of the E-field and lumped
// capacitor energies, and E_mag is the same for the B-field and lumped inductors.
AddMeasurementFlux(idx_value_dimensionful, post_op, iodata);
AddMeasurementEps(idx_value_dimensionful, post_op, iodata);
if (post_op.GetSurfacePostOp().flux_surfs.size() > 0)
{
AddMeasurementFlux(idx_value_dimensionful, post_op, iodata);
}
if (post_op.GetSurfacePostOp().eps_surfs.size() > 0)
{
AddMeasurementEps(idx_value_dimensionful, post_op, iodata);
}
}

BaseSolver::ProbePostPrinter::ProbePostPrinter(bool do_measurement, bool root,
const fs::path &post_dir,
BaseSolver::ProbePostPrinter::ProbePostPrinter(const fs::path &post_dir,
const PostOperator &post_op,
const std::string &idx_col_name,
int n_expected_rows)
: root_{root}, do_measurement_E_{do_measurement}, do_measurement_B_{do_measurement},
has_imag{post_op.HasImag()}, v_dim{post_op.GetInterpolationOpVDim()}
{
#if defined(MFEM_USE_GSLIB)
do_measurement_E_ = do_measurement_E_ //
&& (post_op.GetProbes().size() > 0) // Has probes defined
&& post_op.HasE(); // Has E fields

do_measurement_B_ = do_measurement_B_ //
&& (post_op.GetProbes().size() > 0) // Has probes defined
&& post_op.HasB(); // Has B fields

if (!root_ || (!do_measurement_E_ && !do_measurement_B_))
if (post_op.GetProbes().size() == 0 || !Mpi::Root(post_op.GetComm()))
{
return;
}
using fmt::format;
const int v_dim = post_op.GetInterpolationOpVDim();
const bool has_imag = post_op.HasImag();
int scale_col = (has_imag ? 2 : 1) * v_dim;
auto dim_labeler = [](int i) -> std::string
{
Expand All @@ -563,7 +544,7 @@ BaseSolver::ProbePostPrinter::ProbePostPrinter(bool do_measurement, bool root,
}
};

if (do_measurement_E_)
if (post_op.HasE())
{
probe_E = TableWithCSVFile(post_dir / "probe-E.csv");
probe_E.table.reserve(n_expected_rows, scale_col * post_op.GetProbes().size());
Expand Down Expand Up @@ -592,7 +573,7 @@ BaseSolver::ProbePostPrinter::ProbePostPrinter(bool do_measurement, bool root,
probe_E.AppendHeader();
}

if (do_measurement_B_)
if (post_op.HasB())
{
probe_B = TableWithCSVFile(post_dir / "probe-B.csv");
probe_B.table.reserve(n_expected_rows, scale_col * post_op.GetProbes().size());
Expand Down Expand Up @@ -627,16 +608,18 @@ void BaseSolver::ProbePostPrinter::AddMeasurementE(double idx_value_dimensionful
const PostOperator &post_op,
const IoData &iodata)
{
if (!do_measurement_E_ || !root_)
if (!post_op.HasE())
{
return;
}
using VT = IoData::ValueType;
using fmt::format;

auto probe_field = post_op.ProbeEField();
const int v_dim = post_op.GetInterpolationOpVDim();
const bool has_imag = post_op.HasImag();
MFEM_VERIFY(probe_field.size() == v_dim * post_op.GetProbes().size(),
format("Size mismatch: expect vector field to ahve size {} * {} = {}; got {}",
format("Size mismatch: expect vector field to have size {} * {} = {}; got {}",
v_dim, post_op.GetProbes().size(), v_dim * post_op.GetProbes().size(),
probe_field.size()))

Expand All @@ -662,16 +645,18 @@ void BaseSolver::ProbePostPrinter::AddMeasurementB(double idx_value_dimensionful
const PostOperator &post_op,
const IoData &iodata)
{
if (!do_measurement_B_ || !root_)
if (!post_op.HasB())
{
return;
}
using VT = IoData::ValueType;
using fmt::format;

auto probe_field = post_op.ProbeBField();
const int v_dim = post_op.GetInterpolationOpVDim();
const bool has_imag = post_op.HasImag();
MFEM_VERIFY(probe_field.size() == v_dim * post_op.GetProbes().size(),
format("Size mismatch: expect vector field to ahve size {} * {} = {}; got {}",
format("Size mismatch: expect vector field to have size {} * {} = {}; got {}",
v_dim, post_op.GetProbes().size(), v_dim * post_op.GetProbes().size(),
probe_field.size()))

Expand All @@ -698,20 +683,17 @@ void BaseSolver::ProbePostPrinter::AddMeasurement(double idx_value_dimensionful,
const IoData &iodata)
{
#if defined(MFEM_USE_GSLIB)
if (!Mpi::Root(post_op.GetComm()) || post_op.GetProbes().size() == 0)
{
return;
}
AddMeasurementE(idx_value_dimensionful, post_op, iodata);
AddMeasurementB(idx_value_dimensionful, post_op, iodata);
#endif
}

BaseSolver::ErrorIndicatorPostPrinter::ErrorIndicatorPostPrinter(bool do_measurement,
bool root,
const fs::path &post_dir)
: root_{root}, do_measurement_{do_measurement}
BaseSolver::ErrorIndicatorPostPrinter::ErrorIndicatorPostPrinter(const fs::path &post_dir)
{
if (!do_measurement_ || !root_)
{
return;
}
error_indicator = TableWithCSVFile(post_dir / "error-indicators.csv");
error_indicator.table.reserve(1, 4);

Expand All @@ -724,7 +706,7 @@ BaseSolver::ErrorIndicatorPostPrinter::ErrorIndicatorPostPrinter(bool do_measure
void BaseSolver::ErrorIndicatorPostPrinter::PrintIndicatorStatistics(
const PostOperator &post_op, const ErrorIndicator::SummaryStatistics &indicator_stats)
{
if (!do_measurement_ || !root_)
if (!Mpi::Root(post_op.GetComm()))
{
return;
}
Expand Down
47 changes: 14 additions & 33 deletions palace/drivers/basesolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,63 +38,47 @@ class BaseSolver
// Common domain postprocessing for all simulation types.
class DomainsPostPrinter
{
bool root_ = false;
bool do_measurement_ = false;
TableWithCSVFile domain_E;

public:
DomainsPostPrinter() = default;
DomainsPostPrinter(bool do_measurement, bool root, const fs::path &post_dir,
const PostOperator &post_op, const std::string &idx_col_name,
int n_expected_rows);
DomainsPostPrinter(const fs::path &post_dir, const PostOperator &post_op,
const std::string &idx_col_name, int n_expected_rows);
void AddMeasurement(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);
};

// Common surface postprocessing for all simulation types.
class SurfacesPostPrinter
{
bool root_ = false;
bool do_measurement_flux_ = false;
bool do_measurement_eps_ = false;
TableWithCSVFile surface_F;
TableWithCSVFile surface_Q;

public:
SurfacesPostPrinter() = default;
SurfacesPostPrinter(bool do_measurement, bool root, const fs::path &post_dir,
const PostOperator &post_op, const std::string &idx_col_name,
int n_expected_rows);
void AddMeasurement(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);
void AddMeasurementFlux(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);
void AddMeasurementEps(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);

public:
SurfacesPostPrinter(const fs::path &post_dir, const PostOperator &post_op,
const std::string &idx_col_name, int n_expected_rows);
void AddMeasurement(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);
};

// Common probe postprocessing for all simulation types.
class ProbePostPrinter
{
bool root_ = false;
bool do_measurement_E_ = false;
bool do_measurement_B_ = false;
TableWithCSVFile probe_E;
TableWithCSVFile probe_B;

int v_dim = 0;
bool has_imag = false;

public:
ProbePostPrinter() = default;
ProbePostPrinter(bool do_measurement, bool root, const fs::path &post_dir,
const PostOperator &post_op, const std::string &idx_col_name,
int n_expected_rows);

void AddMeasurementE(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);
void AddMeasurementB(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);

public:
ProbePostPrinter(const fs::path &post_dir, const PostOperator &post_op,
const std::string &idx_col_name, int n_expected_rows);

void AddMeasurement(double idx_value_dimensionful, const PostOperator &post_op,
const IoData &iodata);
};
Expand All @@ -104,13 +88,10 @@ class BaseSolver
// step (time / frequency / eigenvector).
class ErrorIndicatorPostPrinter
{
bool root_ = false;
bool do_measurement_ = false;
TableWithCSVFile error_indicator;

public:
ErrorIndicatorPostPrinter() = default;
ErrorIndicatorPostPrinter(bool do_measurement, bool root, const fs::path &post_dir);
ErrorIndicatorPostPrinter(const fs::path &post_dir);

void PrintIndicatorStatistics(const PostOperator &post_op,
const ErrorIndicator::SummaryStatistics &indicator_stats);
Expand Down
Loading

0 comments on commit d3e0ce8

Please sign in to comment.