Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasaunai committed Dec 19, 2023
1 parent 83280fb commit 94fac3e
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 63 deletions.
2 changes: 1 addition & 1 deletion pyphare/pyphare/pharein/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def as_paths(rb):

#### adding diagnostics
diag_path = "simulation/diagnostics/"
for _, diag in simulation.diagnostics.items():
for diag in list(simulation.diagnostics.values()):
type_path = diag_path + diag.type + "/"
name_path = type_path + diag.name
add_string(name_path + "/" + "type", diag.type)
Expand Down
4 changes: 2 additions & 2 deletions src/core/data/ions/ion_population/ion_population.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ namespace core
NO_DISCARD VecField const& flux() const { return flux_; }
NO_DISCARD VecField& flux() { return flux_; }

TensorField const& momentumTensor() const { return momentumTensor_; }
TensorField& momentumTensor() { return momentumTensor_; }
NO_DISCARD TensorField const& momentumTensor() const { return momentumTensor_; }
NO_DISCARD TensorField& momentumTensor() { return momentumTensor_; }



Expand Down
3 changes: 1 addition & 2 deletions src/core/data/ions/ions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ namespace core
{
if (isUsable())
return sameMasses_ ? *rho_ : *massDensity_;
else
throw std::runtime_error("Error - cannot access density data");
throw std::runtime_error("Error - cannot access density data");
}


Expand Down
75 changes: 29 additions & 46 deletions src/core/data/tensorfield/tensorfield.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,71 +116,54 @@ class TensorField
NO_DISCARD std::string const& name() const { return name_; }



NO_DISCARD field_type& getComponent(Component component)
template<typename Arg>
NO_DISCARD decltype(auto) static _switcheroo(Component component, Arg& arg)
{
if (isUsable())
switch (component)
{
switch (component)
{
case Component::X: return *components_[0];
case Component::Y: return *components_[1];
case Component::Z: return *components_[2];

case Component::XX: return *components_[0];
case Component::XY: return *components_[1];
case Component::XZ: return *components_[2];
case Component::YY: return *components_[3];
case Component::YZ: return *components_[4];
case Component::ZZ: return *components_[5];
}
case Component::X: return arg[0];
case Component::Y: return arg[1];
case Component::Z: return arg[2];

case Component::XX: return arg[0];
case Component::XY: return arg[1];
case Component::XZ: return arg[2];
case Component::YY: return arg[3];
case Component::YZ: return arg[4];
case Component::ZZ: return arg[5];
}
throw std::runtime_error("Error - TensorField not usable");
}

void _check() const
{
if (!isUsable())
throw std::runtime_error("Error - TensorField not usable");
}

NO_DISCARD field_type& getComponent(Component component)
{
_check();
return *_switcheroo(component, components_);
}




NO_DISCARD field_type const& getComponent(Component component) const
{
if (isUsable())
{
switch (component)
{
case Component::X: return *components_[0];
case Component::Y: return *components_[1];
case Component::Z: return *components_[2];

case Component::XX: return *components_[0];
case Component::XY: return *components_[1];
case Component::XZ: return *components_[2];
case Component::YY: return *components_[3];
case Component::YZ: return *components_[4];
case Component::ZZ: return *components_[5];
}
}
throw std::runtime_error("Error - TensorField not usable");
_check();
return *_switcheroo(component, components_);
}



NO_DISCARD std::string getComponentName(Component component) const
{
switch (component)
{
case Component::X: return componentNames_[0];
case Component::Y: return componentNames_[1];
case Component::Z: return componentNames_[2];
case Component::XX: return componentNames_[0];
case Component::XY: return componentNames_[1];
case Component::XZ: return componentNames_[2];
case Component::YY: return componentNames_[3];
case Component::YZ: return componentNames_[4];
case Component::ZZ: return componentNames_[5];
}
throw std::runtime_error("Error - TensorField not usable");
return _switcheroo(component, componentNames_);
}


template<std::size_t... Index>
NO_DISCARD auto components(std::index_sequence<Index...>) const
{
Expand Down
10 changes: 10 additions & 0 deletions src/core/data/vecfield/vecfield_component.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ namespace core
{"xx", Component::XX}, {"xy", Component::XY}, {"xz", Component::XZ},
{"yy", Component::YY}, {"yz", Component::YZ}, {"zz", Component::ZZ}};

struct VectorComponents
{
auto static map() { return Components::componentMap(); }
};

struct TensorComponents
{
auto static map() { return Components::componentMap<2>(); }
};

} // namespace core
} // namespace PHARE

Expand Down
4 changes: 2 additions & 2 deletions src/diagnostic/detail/types/fluid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ void FluidDiagnosticWriter<H5Writer>::getDataSetInfo(DiagnosticProperties& diagn
};

auto infoVF = [&](auto& vecF, std::string name, auto& attr) {
for (auto const& [id, type] : core::Components::componentMap<1>())
for (auto const& [id, type] : core::VectorComponents::map())
infoDS(vecF.getComponent(type), name + "_" + id, attr);
};

auto infoTF = [&](auto& tensorF, std::string name, auto& attr) {
for (auto const& [id, type] : core::Components::componentMap<2>())
for (auto const& [id, type] : core::TensorComponents::map())
infoDS(tensorF.getComponent(type), name + "_" + id, attr);
};

Expand Down
4 changes: 3 additions & 1 deletion tests/core/data/gridlayout/test_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def test_laplacian_yee2D(path):
Jz = np.tensordot(np.sinh(0.2 * x_primal), np.cosh(0.2 * y_primal), axes=0)

Jx_x[psi_d_X : pei_d_X + 1, :] = (
Jx[psi_d_X + 1 : pei_d_X + 2 :,]
Jx[
psi_d_X + 1 : pei_d_X + 2 :,
]
- 2.0 * Jx[psi_d_X : pei_d_X + 1, :]
+ Jx[psi_d_X - 1 : pei_d_X, :]
) / (tv.meshSize[0] * tv.meshSize[0])
Expand Down
4 changes: 2 additions & 2 deletions tests/core/data/ion_population/test_ion_population.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ struct DummyVecField
{
static constexpr std::size_t dimension = 1;
using field_type = DummyField;
DummyVecField(std::string name, [[maybe_unused]] HybridQuantity::Vector v) { (void)name; }
DummyVecField(std::string name, HybridQuantity::Vector /*v*/) { (void)name; }
bool isUsable() const { return false; }
bool isSettable() const { return true; }
};

struct DummyTensorField
{
static constexpr std::size_t dimension = 1;
DummyTensorField(std::string name, [[maybe_unused]] HybridQuantity::Tensor v) { (void)name; }
DummyTensorField(std::string name, HybridQuantity::Tensor /*v*/) { (void)name; }
bool isUsable() const { return false; }
bool isSettable() const { return true; }
};
Expand Down
2 changes: 1 addition & 1 deletion tests/core/data/ions/test_ions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "core/data/vecfield/vecfield.hpp"
#include "core/hybrid/hybrid_quantities.hpp"

#include "core/data/tensorfield//tensorfield.hpp"
#include "core/data/tensorfield/tensorfield.hpp"
#include "core/data/grid/gridlayout.hpp"
#include "core/data/grid/gridlayout_impl.hpp"
#include "core/data/ions/particle_initializers/maxwellian_particle_initializer.hpp"
Expand Down
6 changes: 0 additions & 6 deletions tests/simulator/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,6 @@ def _test_dump_diags(self, dim, **simInput):
self.simulator = None
ph.global_vars.sim = None

# def test_twice_register(self):
# simulation = ph.Simulation(**simArgs.copy())
# model = setup_model()
# dump_all_diags(model.populations) # first register
# self.assertRaises(RuntimeError, dump_all_diags, model.populations)


if __name__ == "__main__":
unittest.main()

0 comments on commit 94fac3e

Please sign in to comment.