From 6fa403e16183b3d8fbd4b94884d232bbd470dd56 Mon Sep 17 00:00:00 2001 From: Andreas Stefl Date: Tue, 20 Aug 2024 23:02:14 +0200 Subject: [PATCH] fix: Handle missing field in `SympyStepper` (#3525) A missing B field is not correctly handled in the `SympyStepper` and can result in segmentation faults. Fixes - https://github.com/acts-project/acts/issues/3523 --- Core/src/Propagator/SympyStepper.cpp | 12 ++++--- .../Propagator/codegen/sympy_stepper_math.hpp | 33 ++++++++++++++----- codegen/generate_sympy_stepper.py | 22 +++++++++---- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/Core/src/Propagator/SympyStepper.cpp b/Core/src/Propagator/SympyStepper.cpp index 77b075be31c..669e17ee7cb 100644 --- a/Core/src/Propagator/SympyStepper.cpp +++ b/Core/src/Propagator/SympyStepper.cpp @@ -116,9 +116,8 @@ Result SympyStepper::stepImpl( double m = particleHypothesis(state).mass(); double p_abs = absoluteMomentum(state); - auto getB = [&](const double* p) -> Vector3 { - auto fieldRes = getField(state, {p[0], p[1], p[2]}); - return *fieldRes; + auto getB = [&](const double* p) -> Result { + return getField(state, {p[0], p[1], p[2]}); }; const auto calcStepSizeScaling = [&](const double errorEstimate_) -> double { @@ -155,17 +154,20 @@ Result SympyStepper::stepImpl( nStepTrials++; // For details about the factor 4 see ATL-SOFT-PUB-2009-001 - bool ok = + Result res = rk4(pos.data(), dir.data(), t, h, qop, m, p_abs, getB, &errorEstimate, 4 * stepTolerance, state.pars.template segment<3>(eFreePos0).data(), state.pars.template segment<3>(eFreeDir0).data(), state.pars.template segment<1>(eFreeTime).data(), state.derivative.data(), state.covTransport ? state.jacTransport.data() : nullptr); + if (!res.ok()) { + return res.error(); + } // Protect against division by zero errorEstimate = std::max(1e-20, errorEstimate); - if (ok) { + if (*res) { break; } diff --git a/Core/src/Propagator/codegen/sympy_stepper_math.hpp b/Core/src/Propagator/codegen/sympy_stepper_math.hpp index aa0017f3d1d..e65c911cb51 100644 --- a/Core/src/Propagator/codegen/sympy_stepper_math.hpp +++ b/Core/src/Propagator/codegen/sympy_stepper_math.hpp @@ -11,13 +11,20 @@ #pragma once +#include + #include template -bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, - const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, - T* new_d, T* new_time, T* path_derivatives, T* J) { - const auto B1 = getB(p); +Acts::Result rk4(const T* p, const T* d, const T t, const T h, + const T lambda, const T m, const T p_abs, GetB getB, + T* err, const T errTol, T* new_p, T* new_d, T* new_time, + T* path_derivatives, T* J) { + const auto B1res = getB(p); + if (!B1res.ok()) { + return Acts::Result::failure(B1res.error()); + } + const auto B1 = *B1res; const auto x5 = std::pow(h, 2); const auto x0 = B1[1] * d[2]; const auto x1 = B1[0] * d[2]; @@ -35,7 +42,11 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, p2[0] = (1.0 / 2.0) * x4 + x6 * k1[0] + p[0]; p2[1] = x6 * k1[1] + (1.0 / 2.0) * x7 + p[1]; p2[2] = x6 * k1[2] + (1.0 / 2.0) * x8 + p[2]; - const auto B2 = getB(p2); + const auto B2res = getB(p2); + if (!B2res.ok()) { + return Acts::Result::failure(B2res.error()); + } + const auto B2 = *B2res; const auto x9 = (1.0 / 2.0) * h; const auto x19 = (1.0 / 2.0) * x5; const auto x11 = lambda * B2[2]; @@ -62,7 +73,11 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, p3[0] = x19 * k3[0] + x20; p3[1] = x19 * k3[1] + x21; p3[2] = x19 * k3[2] + x22; - const auto B3 = getB(p3); + const auto B3res = getB(p3); + if (!B3res.ok()) { + return Acts::Result::failure(B3res.error()); + } + const auto B3 = *B3res; const auto x24 = lambda * B3[2]; const auto x26 = lambda * B3[1]; const auto x28 = lambda * B3[0]; @@ -80,7 +95,7 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, x5 * (std::fabs(-x29 + k2[0] + k3[0]) + std::fabs(-x30 + k2[1] + k3[1]) + std::fabs(-x31 + k2[2] + k3[2])); if (*err > errTol) { - return false; + return Acts::Result::success(false); } const auto x32 = (1.0 / 6.0) * x5; new_p[0] = x20 + x32 * (k1[0] + k2[0] + k3[0]); @@ -101,7 +116,7 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, const auto dtds = std::sqrt(std::pow(p_abs, 2) + x35) / p_abs; *new_time = dtds * h + t; if (J == nullptr) { - return true; + return Acts::Result::success(true); } path_derivatives[0] = new_d[0]; path_derivatives[1] = new_d[1]; @@ -345,5 +360,5 @@ bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, J[61] = new_J[61]; J[62] = new_J[62]; J[63] = new_J[63]; - return true; + return Acts::Result::success(true); } diff --git a/codegen/generate_sympy_stepper.py b/codegen/generate_sympy_stepper.py index b28ea2bc758..7602a30c4a8 100644 --- a/codegen/generate_sympy_stepper.py +++ b/codegen/generate_sympy_stepper.py @@ -195,10 +195,14 @@ def my_step_function_print(name_exprs, run_cse=True): lines = [] - head = "template bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {" + head = "template Acts::Result rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {" lines.append(head) - lines.append(" const auto B1 = getB(p);") + lines.append(" const auto B1res = getB(p);") + lines.append( + " if (!B1res.ok()) {\n return Acts::Result::failure(B1res.error());\n }" + ) + lines.append(" const auto B1 = *B1res;") def pre_expr_hook(var): if str(var) == "p2": @@ -211,13 +215,15 @@ def pre_expr_hook(var): def post_expr_hook(var): if str(var) == "p2": - return "const auto B2 = getB(p2);" + return "const auto B2res = getB(p2);\n if (!B2res.ok()) {\n return Acts::Result::failure(B2res.error());\n }\n const auto B2 = *B2res;" if str(var) == "p3": - return "const auto B3 = getB(p3);" + return "const auto B3res = getB(p3);\n if (!B3res.ok()) {\n return Acts::Result::failure(B3res.error());\n }\n const auto B3 = *B3res;" if str(var) == "err": - return "if (*err > errTol) {\n return false;\n}" + return ( + "if (*err > errTol) {\n return Acts::Result::success(false);\n}" + ) if str(var) == "new_time": - return "if (J == nullptr) {\n return true;\n}" + return "if (J == nullptr) {\n return Acts::Result::success(true);\n}" if str(var) == "new_J": return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var)) return None @@ -232,7 +238,7 @@ def post_expr_hook(var): ) lines.extend([f" {l}" for l in code.split("\n")]) - lines.append(" return true;") + lines.append(" return Acts::Result::success(true);") lines.append("}") @@ -279,6 +285,8 @@ def post_expr_hook(var): #pragma once +#include + #include """ )