Skip to content

Commit

Permalink
Merge pull request #549 from beomki-yeo/refactor-simulator
Browse files Browse the repository at this point in the history
Refactor Simulator Template Parameter
  • Loading branch information
beomki-yeo authored Aug 30, 2023
2 parents cd3adc8 + ef1171a commit 57fd086
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 26 deletions.
9 changes: 8 additions & 1 deletion tests/simulation/run_simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,15 @@ int main() {
measurement_smearer<transform3> smearer(100.f * unit<scalar>::um,
100.f * unit<scalar>::um);

using detector_type = decltype(detector);
using generator_type = decltype(generator);
using writer_type = smearing_writer<measurement_smearer<transform3>>;

typename writer_type::config writer_cfg{smearer};

std::size_t n_events = 2u;
auto sim = simulator(n_events, detector, std::move(generator), smearer);
auto sim = simulator<detector_type, generator_type, writer_type>(
n_events, detector, std::move(generator), std::move(writer_cfg));

sim.run();

Expand Down
23 changes: 19 additions & 4 deletions tests/validation/src/simulation_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,16 @@ GTEST_TEST(detray_simulation, toy_geometry_simulation) {
170.f * unit<scalar>::um);

std::size_t n_events{10u};
auto sim = simulator(n_events, detector, std::move(generator), smearer,
test::filenames);

using detector_type = decltype(detector);
using generator_type = decltype(generator);
using writer_type = smearing_writer<measurement_smearer<transform3>>;

typename writer_type::config writer_cfg{smearer};

auto sim = simulator<detector_type, generator_type, writer_type>(
n_events, detector, std::move(generator), std::move(writer_cfg),
test::filenames);

// Lift step size constraints
sim.get_config().step_constraint = std::numeric_limits<scalar>::max();
Expand Down Expand Up @@ -245,8 +253,15 @@ TEST_P(TelescopeDetectorSimulation, telescope_detector_simulation) {

std::size_t n_events{1000u};

auto sim = simulator(n_events, detector, std::move(generator), smearer,
test::filenames);
using detector_type = decltype(detector);
using generator_type = decltype(generator);
using writer_type = smearing_writer<measurement_smearer<transform3>>;

typename writer_type::config writer_cfg{smearer};

auto sim = simulator<detector_type, generator_type, writer_type>(
n_events, detector, std::move(generator), std::move(writer_cfg),
test::filenames);

// Lift step size constraints
sim.get_config().step_constraint = std::numeric_limits<scalar>::max();
Expand Down
1 change: 1 addition & 0 deletions utils/include/detray/simulation/measurement_smearer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace detray {
template <typename transform3_t>
struct measurement_smearer {

using transform3_type = transform3_t;
using matrix_operator = typename transform3_t::matrix_actor;
using scalar_type = typename transform3_t::scalar_type;
using size_type = typename matrix_operator::size_ty;
Expand Down
31 changes: 17 additions & 14 deletions utils/include/detray/simulation/simulator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
#include "detray/propagator/navigator.hpp"
#include "detray/propagator/propagator.hpp"
#include "detray/propagator/rk_stepper.hpp"
#include "detray/simulation/event_writer.hpp"
#include "detray/simulation/random_scatterer.hpp"
#include "detray/simulation/smearing_writer.hpp"

// System include(s).
#include <limits>
#include <memory>

namespace detray {

template <typename detector_t, typename track_generator_t, typename smearer_t>
template <typename detector_t, typename track_generator_t, typename writer_t>
struct simulator {

using scalar_type = typename detector_t::scalar_type;
Expand All @@ -37,9 +37,10 @@ struct simulator {
using transform3 = typename detector_t::transform3;
using bfield_type = typename detector_t::bfield_type;

using actor_chain_type = actor_chain<
dtuple, parameter_transporter<transform3>, random_scatterer<transform3>,
parameter_resetter<transform3>, event_writer<transform3, smearer_t>>;
using actor_chain_type =
actor_chain<dtuple, parameter_transporter<transform3>,
random_scatterer<transform3>,
parameter_resetter<transform3>, writer_t>;

using navigator_type = navigator<detector_t>;
using stepper_type = rk_stepper<typename bfield_type::view_t, transform3,
Expand All @@ -48,33 +49,35 @@ struct simulator {
propagator<stepper_type, navigator_type, actor_chain_type>;

simulator(std::size_t events, const detector_t& det,
track_generator_t&& track_gen, smearer_t& smearer,
track_generator_t&& track_gen,
typename writer_t::config&& writer_cfg,
const std::string directory = "")
: m_events(events),
m_directory(directory),
m_detector(std::make_unique<detector_t>(det)),
m_track_generator(
std::make_unique<track_generator_t>(std::move(track_gen))),
m_smearer(smearer) {}
m_writer_cfg(writer_cfg) {}

config& get_config() { return m_cfg; }

void run() {

for (std::size_t event_id = 0u; event_id < m_events; event_id++) {
typename event_writer<transform3, smearer_t>::state writer(
event_id, m_smearer, m_directory);

typename writer_t::state writer_state(
event_id, std::move(m_writer_cfg), m_directory);

// Set random seed
m_scatterer.set_seed(event_id);
writer.set_seed(event_id);
writer_state.set_seed(event_id);

auto actor_states =
std::tie(m_transporter, m_scatterer, m_resetter, writer);
std::tie(m_transporter, m_scatterer, m_resetter, writer_state);

for (auto track : *m_track_generator.get()) {

writer.write_particle(track);
writer_state.write_particle(track);

typename propagator_type::state propagation(
track, m_detector->get_bfield(), *m_detector);
Expand All @@ -91,7 +94,7 @@ struct simulator {
p.propagate(propagation, actor_states);

// Increase the particle id
writer.particle_id++;
writer_state.particle_id++;
}
}
}
Expand All @@ -102,7 +105,7 @@ struct simulator {
std::string m_directory = "";
std::unique_ptr<detector_t> m_detector;
std::unique_ptr<track_generator_t> m_track_generator;
smearer_t m_smearer;
typename writer_t::config m_writer_cfg;

/// Actor states
typename parameter_transporter<transform3>::state m_transporter{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@

namespace detray {

template <typename transform3_t, typename smearer_t>
struct event_writer : actor {
template <typename smearer_t>
struct smearing_writer : actor {

using scalar_type = typename transform3_t::scalar_type;
using transform3_type = typename smearer_t::transform3_type;
using scalar_type = typename transform3_type::scalar_type;

struct config {
smearer_t smearer;
};

struct state {
state(std::size_t event_id, smearer_t& smearer,
state(std::size_t event_id, config&& writer_cfg,
const std::string directory)
: m_particle_writer(directory + detail::get_event_filename(
event_id, "-particles.csv")),
Expand All @@ -41,7 +46,7 @@ struct event_writer : actor {
m_meas_hit_id_writer(
directory + detail::get_event_filename(
event_id, "-measurement-simhit-map.csv")),
m_meas_smearer(smearer) {}
m_meas_smearer(writer_cfg.smearer) {}

uint64_t particle_id = 0u;
particle_writer m_particle_writer;
Expand All @@ -53,7 +58,8 @@ struct event_writer : actor {

void set_seed(const uint_fast64_t sd) { m_meas_smearer.set_seed(sd); }

void write_particle(const free_track_parameters<transform3_t>& track) {
void write_particle(
const free_track_parameters<transform3_type>& track) {
csv_particle particle;
const auto pos = track.pos();
const auto mom = track.mom();
Expand All @@ -77,7 +83,7 @@ struct event_writer : actor {
template <typename mask_group_t, typename index_t>
inline std::array<scalar_type, 2> operator()(
const mask_group_t& mask_group, const index_t& index,
const bound_track_parameters<transform3_t>& bound_params,
const bound_track_parameters<transform3_type>& bound_params,
smearer_t& smearer) const {

const auto& mask = mask_group[index];
Expand Down

0 comments on commit 57fd086

Please sign in to comment.