Skip to content

Commit

Permalink
Merge branch 'main' into chore/warning-reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
kodiakhq[bot] authored Aug 20, 2024
2 parents 4b7b772 + 6fa403e commit bd96716
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
12 changes: 7 additions & 5 deletions Core/src/Propagator/SympyStepper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ Result<double> SympyStepper::stepImpl(
double m = particleHypothesis(state).mass();
double p_abs = absoluteMomentum(state);

auto getB = [&](const double* p) -> Vector3 {
auto fieldRes = getField(state, {p[0], p[1], p[2]});
return *fieldRes;
auto getB = [&](const double* p) -> Result<Vector3> {
return getField(state, {p[0], p[1], p[2]});
};

const auto calcStepSizeScaling = [&](const double errorEstimate_) -> double {
Expand Down Expand Up @@ -155,17 +154,20 @@ Result<double> SympyStepper::stepImpl(
nStepTrials++;

// For details about the factor 4 see ATL-SOFT-PUB-2009-001
bool ok =
Result<bool> res =
rk4(pos.data(), dir.data(), t, h, qop, m, p_abs, getB, &errorEstimate,
4 * stepTolerance, state.pars.template segment<3>(eFreePos0).data(),
state.pars.template segment<3>(eFreeDir0).data(),
state.pars.template segment<1>(eFreeTime).data(),
state.derivative.data(),
state.covTransport ? state.jacTransport.data() : nullptr);
if (!res.ok()) {
return res.error();
}
// Protect against division by zero
errorEstimate = std::max(1e-20, errorEstimate);

if (ok) {
if (*res) {
break;
}

Expand Down
33 changes: 24 additions & 9 deletions Core/src/Propagator/codegen/sympy_stepper_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@

#pragma once

#include <Acts/Utilities/Result.hpp>

#include <cmath>

template <typename T, typename GetB>
bool rk4(const T* p, const T* d, const T t, const T h, const T lambda,
const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p,
T* new_d, T* new_time, T* path_derivatives, T* J) {
const auto B1 = getB(p);
Acts::Result<bool> rk4(const T* p, const T* d, const T t, const T h,
const T lambda, const T m, const T p_abs, GetB getB,
T* err, const T errTol, T* new_p, T* new_d, T* new_time,
T* path_derivatives, T* J) {
const auto B1res = getB(p);
if (!B1res.ok()) {
return Acts::Result<bool>::failure(B1res.error());
}
const auto B1 = *B1res;
const auto x5 = std::pow(h, 2);
const auto x0 = B1[1] * d[2];
const auto x1 = B1[0] * d[2];
Expand All @@ -35,7 +42,11 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda,
p2[0] = (1.0 / 2.0) * x4 + x6 * k1[0] + p[0];
p2[1] = x6 * k1[1] + (1.0 / 2.0) * x7 + p[1];
p2[2] = x6 * k1[2] + (1.0 / 2.0) * x8 + p[2];
const auto B2 = getB(p2);
const auto B2res = getB(p2);
if (!B2res.ok()) {
return Acts::Result<bool>::failure(B2res.error());
}
const auto B2 = *B2res;
const auto x9 = (1.0 / 2.0) * h;
const auto x19 = (1.0 / 2.0) * x5;
const auto x11 = lambda * B2[2];
Expand All @@ -62,7 +73,11 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda,
p3[0] = x19 * k3[0] + x20;
p3[1] = x19 * k3[1] + x21;
p3[2] = x19 * k3[2] + x22;
const auto B3 = getB(p3);
const auto B3res = getB(p3);
if (!B3res.ok()) {
return Acts::Result<bool>::failure(B3res.error());
}
const auto B3 = *B3res;
const auto x24 = lambda * B3[2];
const auto x26 = lambda * B3[1];
const auto x28 = lambda * B3[0];
Expand All @@ -80,7 +95,7 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda,
x5 * (std::fabs(-x29 + k2[0] + k3[0]) + std::fabs(-x30 + k2[1] + k3[1]) +
std::fabs(-x31 + k2[2] + k3[2]));
if (*err > errTol) {
return false;
return Acts::Result<bool>::success(false);
}
const auto x32 = (1.0 / 6.0) * x5;
new_p[0] = x20 + x32 * (k1[0] + k2[0] + k3[0]);
Expand All @@ -101,7 +116,7 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda,
const auto dtds = std::sqrt(std::pow(p_abs, 2) + x35) / p_abs;
*new_time = dtds * h + t;
if (J == nullptr) {
return true;
return Acts::Result<bool>::success(true);
}
path_derivatives[0] = new_d[0];
path_derivatives[1] = new_d[1];
Expand Down Expand Up @@ -345,5 +360,5 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda,
J[61] = new_J[61];
J[62] = new_J[62];
J[63] = new_J[63];
return true;
return Acts::Result<bool>::success(true);
}
22 changes: 15 additions & 7 deletions codegen/generate_sympy_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,14 @@ def my_step_function_print(name_exprs, run_cse=True):

lines = []

head = "template <typename T, typename GetB> bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
head = "template <typename T, typename GetB> Acts::Result<bool> rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
lines.append(head)

lines.append(" const auto B1 = getB(p);")
lines.append(" const auto B1res = getB(p);")
lines.append(
" if (!B1res.ok()) {\n return Acts::Result<bool>::failure(B1res.error());\n }"
)
lines.append(" const auto B1 = *B1res;")

def pre_expr_hook(var):
if str(var) == "p2":
Expand All @@ -211,13 +215,15 @@ def pre_expr_hook(var):

def post_expr_hook(var):
if str(var) == "p2":
return "const auto B2 = getB(p2);"
return "const auto B2res = getB(p2);\n if (!B2res.ok()) {\n return Acts::Result<bool>::failure(B2res.error());\n }\n const auto B2 = *B2res;"
if str(var) == "p3":
return "const auto B3 = getB(p3);"
return "const auto B3res = getB(p3);\n if (!B3res.ok()) {\n return Acts::Result<bool>::failure(B3res.error());\n }\n const auto B3 = *B3res;"
if str(var) == "err":
return "if (*err > errTol) {\n return false;\n}"
return (
"if (*err > errTol) {\n return Acts::Result<bool>::success(false);\n}"
)
if str(var) == "new_time":
return "if (J == nullptr) {\n return true;\n}"
return "if (J == nullptr) {\n return Acts::Result<bool>::success(true);\n}"
if str(var) == "new_J":
return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
return None
Expand All @@ -232,7 +238,7 @@ def post_expr_hook(var):
)
lines.extend([f" {l}" for l in code.split("\n")])

lines.append(" return true;")
lines.append(" return Acts::Result<bool>::success(true);")

lines.append("}")

Expand Down Expand Up @@ -279,6 +285,8 @@ def post_expr_hook(var):
#pragma once
#include <Acts/Utilities/Result.hpp>
#include <cmath>
"""
)
Expand Down

0 comments on commit bd96716

Please sign in to comment.