From dc68aeeda6bd2cda3a3e08a99cf95ed423086e46 Mon Sep 17 00:00:00 2001 From: Morten Schou Date: Fri, 7 May 2021 17:33:46 +0200 Subject: [PATCH] Refactor TraceBack in Solver --- src/pdaaal/Solver.h | 315 ++++++++++++++++---------------------------- 1 file changed, 114 insertions(+), 201 deletions(-) diff --git a/src/pdaaal/Solver.h b/src/pdaaal/Solver.h index d1265af..e9a5ceb 100644 --- a/src/pdaaal/Solver.h +++ b/src/pdaaal/Solver.h @@ -75,9 +75,9 @@ namespace pdaaal { using early_termination_fn = std::function)>; template - class pre_star_saturation { + class PreStarSaturation { public: - explicit pre_star_saturation(PAutomaton &automaton, const early_termination_fn& early_termination = [](size_t f, uint32_t l, size_t t, trace_ptr trace) -> bool { return false; }) + explicit PreStarSaturation(PAutomaton &automaton, const early_termination_fn& early_termination = [](size_t f, uint32_t l, size_t t, trace_ptr trace) -> bool { return false; }) : _automaton(automaton), _early_termination(early_termination), _pda_states(_automaton.pda().states()), _n_pda_states(_pda_states.size()), _n_automaton_states(_automaton.states().size()), _n_pda_labels(_automaton.number_of_labels()), _rel(_n_automaton_states), _delta_prime(_n_automaton_states) { @@ -212,9 +212,9 @@ namespace pdaaal { }; template - class post_star_saturation { + class PostStarSaturation { public: - explicit post_star_saturation(PAutomaton &automaton, const early_termination_fn& early_termination = [](size_t f, uint32_t l, size_t t, trace_ptr trace) -> bool { return false; }) + explicit PostStarSaturation(PAutomaton &automaton, const early_termination_fn& early_termination = [](size_t f, uint32_t l, size_t t, trace_ptr trace) -> bool { return false; }) : _automaton(automaton), _early_termination(early_termination), _pda_states(_automaton.pda().states()), _n_pda_states(_pda_states.size()), _n_Q(_automaton.states().size()) { initialize(); @@ -353,7 +353,7 @@ namespace pdaaal { }; template> - class post_star_shortest_saturation { + class PostStarShortestSaturation { static_assert(is_weighted); struct weight_edge_trace { @@ -388,7 +388,7 @@ namespace pdaaal { }; public: - post_star_shortest_saturation(PAutomaton &automaton, const early_termination_fn& early_termination) + PostStarShortestSaturation(PAutomaton &automaton, const early_termination_fn& early_termination) : _automaton(automaton), _early_termination(early_termination), _pda_states(_automaton.pda().states()), _n_pda_states(_pda_states.size()), _n_Q(_automaton.states().size()) { initialize(); @@ -607,6 +607,76 @@ namespace pdaaal { return _found; } }; + + template + class TraceBack { + using rule_t = user_rule_t; + public: + TraceBack(const PAutomaton& automaton, std::deque>&& edges) + : _automaton(automaton), _edges(std::move(edges)) { }; + private: + const PAutomaton& _automaton; + std::deque> _edges; + bool _post = false; + public: + [[nodiscard]] bool post() const { return _post; } + [[nodiscard]] const std::deque>& edges() const { return _edges; } + std::optional next() { + while(true) { // In case of post_epsilon_trace, keep going until a rule is found or we are done. + auto[from, label, to] = _edges.back(); + const trace_t* trace_label = _automaton.get_trace_label(from, label, to); + if (trace_label == nullptr) return std::nullopt; // Done + _edges.pop_back(); + + if (trace_label->is_pre_trace()) { + // pre* trace + const auto &[rule, labels] = _automaton.pda().states()[from]._rules[trace_label->_rule_id]; + switch (rule._operation) { + case POP: + break; + case SWAP: + _edges.emplace_back(rule._to, rule._op_label, to); + break; + case NOOP: + _edges.emplace_back(rule._to, label, to); + break; + case PUSH: + _edges.emplace_back(trace_label->_state, label, to); + _edges.emplace_back(rule._to, rule._op_label, trace_label->_state); + break; + } + return rule_t(from, label, rule); + } else if (trace_label->is_post_epsilon_trace()) { + // Intermediate post* trace + // Current edge is the result of merging with an epsilon edge. + // Reconstruct epsilon edge and the other edge. + _edges.emplace_back(trace_label->_state, label, to); + _edges.emplace_back(from, std::numeric_limits::max(), trace_label->_state); + + } else { // post* trace + _post = true; + const auto &[rule, labels] = _automaton.pda().states()[trace_label->_state]._rules[trace_label->_rule_id]; + switch (rule._operation) { + case POP: + case SWAP: + case NOOP: + _edges.emplace_back(trace_label->_state, trace_label->_label, to); + break; + case PUSH: + auto[from2, label2, to2] = _edges.back(); + _edges.pop_back(); + trace_label = _automaton.get_trace_label(from2, label2, to2); + assert(trace_label != nullptr); + _edges.emplace_back(trace_label->_state, trace_label->_label, to2); + break; + } + assert(from == rule._to); + return rule_t(trace_label->_state, trace_label->_label, rule); + } + } + } + }; + } class Solver { @@ -629,8 +699,8 @@ namespace pdaaal { static bool dual_search(PAutomaton &pre_star_automaton, PAutomaton &post_star_automaton, const details::early_termination_fn& pre_star_early_termination, const details::early_termination_fn& post_star_early_termination) { - details::pre_star_saturation pre_star(pre_star_automaton, pre_star_early_termination); - details::post_star_saturation post_star(post_star_automaton, post_star_early_termination); + details::PreStarSaturation pre_star(pre_star_automaton, pre_star_early_termination); + details::PostStarSaturation post_star(post_star_automaton, post_star_early_termination); if constexpr (ET) { if (pre_star.found() || post_star.found()) return true; } @@ -671,7 +741,7 @@ namespace pdaaal { template static bool pre_star(PAutomaton &automaton, const details::early_termination_fn& early_termination = [](size_t f, uint32_t l, size_t t, trace_ptr trace) -> bool { return false; }) { - details::pre_star_saturation saturation(automaton, early_termination); + details::PreStarSaturation saturation(automaton, early_termination); while(!saturation.workset_empty()) { if constexpr (ET) { if (saturation.found()) return true; @@ -792,28 +862,28 @@ namespace pdaaal { } private: - template> - static bool post_star_shortest(PAutomaton &automaton, const details::early_termination_fn& early_termination) { - details::post_star_shortest_saturation saturation(automaton, early_termination); + template + static bool post_star_any(PAutomaton &automaton, const details::early_termination_fn& early_termination) { + details::PostStarSaturation saturation(automaton, early_termination); while(!saturation.workset_empty()) { if constexpr (ET) { - if (saturation.found()) break; + if (saturation.found()) return true; } saturation.step(); } - saturation.finalize(); return saturation.found(); } - template - static bool post_star_any(PAutomaton &automaton, const details::early_termination_fn& early_termination) { - details::post_star_saturation saturation(automaton, early_termination); + template> + static bool post_star_shortest(PAutomaton &automaton, const details::early_termination_fn& early_termination) { + details::PostStarShortestSaturation saturation(automaton, early_termination); while(!saturation.workset_empty()) { if constexpr (ET) { - if (saturation.found()) return true; + if (saturation.found()) break; } saturation.step(); } + saturation.finalize(); return saturation.found(); } @@ -841,63 +911,13 @@ namespace pdaaal { return result; }; - bool post = false; - std::vector trace; trace.push_back(decode_edges(edges)); - while (true) { - auto [from, label, to] = edges.back(); - edges.pop_back(); - const trace_t *trace_label = automaton.get_trace_label(from, label, to); - if (trace_label == nullptr) break; - - if (trace_label->is_pre_trace()) { - // pre* trace - const auto &[rule,labels] = automaton.pda().states()[from]._rules[trace_label->_rule_id]; - switch (rule._operation) { - case POP: - break; - case SWAP: - edges.emplace_back(rule._to, rule._op_label, to); - break; - case NOOP: - edges.emplace_back(rule._to, label, to); - break; - case PUSH: - edges.emplace_back(trace_label->_state, label, to); - edges.emplace_back(rule._to, rule._op_label, trace_label->_state); - break; - } - trace.push_back(decode_edges(edges)); - - } else if (trace_label->is_post_epsilon_trace()) { - // Intermediate post* trace - // Current edge is the result of merging with an epsilon edge. - // Reconstruct epsilon edge and the other edge. - edges.emplace_back(trace_label->_state, label, to); - edges.emplace_back(from, std::numeric_limits::max(), trace_label->_state); - - } else { - // post* trace - const auto &[rule, labels] = automaton.pda().states()[trace_label->_state]._rules[trace_label->_rule_id]; - switch (rule._operation) { - case POP: - case SWAP: - case NOOP: - edges.emplace_back(trace_label->_state, trace_label->_label, to); - break; - case PUSH: - auto [from2, label2, to2] = edges.back(); - edges.pop_back(); - auto trace_label2 = automaton.get_trace_label(from2, label2, to2); - edges.emplace_back(trace_label2->_state, trace_label2->_label, to2); - break; - } - trace.push_back(decode_edges(edges)); - post = true; - } + details::TraceBack tb(automaton, std::move(edges)); + while (tb.next()) { + trace.push_back(decode_edges(tb.edges())); } - if (post) { + if (tb.post()) { std::reverse(trace.begin(), trace.end()); } return trace; @@ -933,74 +953,24 @@ namespace pdaaal { edges.emplace_back(paths[i - 1].first, stack[i - 1], paths[i].first); } - bool post = false; - + details::TraceBack tb(automaton, std::move(edges)); std::vector trace; - while (true) { - auto [from, label, to] = edges.back(); - const trace_t *trace_label = automaton.get_trace_label(from, label, to); - if (trace_label == nullptr) break; - edges.pop_back(); - - if (trace_label->is_pre_trace()) { - // pre* trace - const auto &[rule,labels] = automaton.pda().states()[from]._rules[trace_label->_rule_id]; - switch (rule._operation) { - case POP: - break; - case SWAP: - edges.emplace_back(rule._to, rule._op_label, to); - break; - case NOOP: - edges.emplace_back(rule._to, label, to); - break; - case PUSH: - edges.emplace_back(trace_label->_state, label, to); - edges.emplace_back(rule._to, rule._op_label, trace_label->_state); - break; - } - trace.emplace_back(from, label, rule); - } else if (trace_label->is_post_epsilon_trace()) { - // Intermediate post* trace - // Current edge is the result of merging with an epsilon edge. - // Reconstruct epsilon edge and the other edge. - edges.emplace_back(trace_label->_state, label, to); - edges.emplace_back(from, std::numeric_limits::max(), trace_label->_state); - - } else { // post* trace - post = true; - const auto &[rule, labels] = automaton.pda().states()[trace_label->_state]._rules[trace_label->_rule_id]; - switch (rule._operation) { - case POP: - case SWAP: - case NOOP: - edges.emplace_back(trace_label->_state, trace_label->_label, to); - break; - case PUSH: - auto [from2, label2, to2] = edges.back(); - edges.pop_back(); - trace_label = automaton.get_trace_label(from2, label2, to2); - assert(trace_label != nullptr); - edges.emplace_back(trace_label->_state, trace_label->_label, to2); - break; - } - assert(from == rule._to); - trace.emplace_back(trace_label->_state, trace_label->_label, rule); - } + while (auto rule = tb.next()) { + trace.emplace_back(rule.value()); } // Get accepting path of initial stack (and the initial stack itself - for post*) std::vector start_stack; - start_stack.reserve(edges.size()); + start_stack.reserve(tb.edges().size()); std::vector start_path; - start_path.reserve(edges.size() + 1); - start_path.push_back(std::get<0>(edges.back())); - for (auto it = edges.crbegin(); it != edges.crend(); ++it) { + start_path.reserve(tb.edges().size() + 1); + start_path.push_back(std::get<0>(tb.edges().back())); + for (auto it = tb.edges().crbegin(); it != tb.edges().crend(); ++it) { start_path.push_back(std::get<2>(*it)); start_stack.push_back(std::get<1>(*it)); } - if (post) { // post* was used + if (tb.post()) { // post* was used std::reverse(trace.begin(), trace.end()); return std::make_tuple(trace[0].from(), trace, start_stack, stack, start_path, goal_path); } else { // pre* was used @@ -1008,9 +978,6 @@ namespace pdaaal { } } - /** - * For dual search. Refactor some more later... TODO - */ template static std::tuple< size_t, // Initial state. (State is size_t::max if no trace exists.) @@ -1048,81 +1015,27 @@ namespace pdaaal { template static std::tuple>, std::vector, std::vector> _get_trace_stack_path(const PAutomaton& automaton, std::deque>&& edges) { - auto trace = _follow_trace_labels(automaton, edges); + std::vector> trace; + details::TraceBack tb(automaton, std::move(edges)); + while(auto rule = tb.next()) { + trace.emplace_back(rule.value()); + } + if (tb.post()) { + std::reverse(trace.begin(), trace.end()); + } + // Get accepting path of initial stack (and the initial stack itself - for post*) - std::vector stack; stack.reserve(edges.size()); - std::vector path; path.reserve(edges.size() + 1); - path.push_back(std::get<0>(edges.back())); - for (auto it = edges.crbegin(); it != edges.crend(); ++it) { + std::vector stack; stack.reserve(tb.edges().size()); + std::vector path; path.reserve(tb.edges().size() + 1); + path.push_back(std::get<0>(tb.edges().back())); + for (auto it = tb.edges().crbegin(); it != tb.edges().crend(); ++it) { path.push_back(std::get<2>(*it)); stack.push_back(std::get<1>(*it)); } return {trace, stack, path}; } - - template - static std::vector> _follow_trace_labels(const PAutomaton& automaton, std::deque>& edges) { - bool post = false; - std::vector> trace; - while (true) { - auto[from, label, to] = edges.back(); - const trace_t* trace_label = automaton.get_trace_label(from, label, to); - if (trace_label == nullptr) break; // Done - edges.pop_back(); - - if (trace_label->is_pre_trace()) { - // pre* trace - const auto &[rule, labels] = automaton.pda().states()[from]._rules[trace_label->_rule_id]; - switch (rule._operation) { - case POP: - break; - case SWAP: - edges.emplace_back(rule._to, rule._op_label, to); - break; - case NOOP: - edges.emplace_back(rule._to, label, to); - break; - case PUSH: - edges.emplace_back(trace_label->_state, label, to); - edges.emplace_back(rule._to, rule._op_label, trace_label->_state); - break; - } - trace.emplace_back(from, label, rule); - } else if (trace_label->is_post_epsilon_trace()) { - // Intermediate post* trace - // Current edge is the result of merging with an epsilon edge. - // Reconstruct epsilon edge and the other edge. - edges.emplace_back(trace_label->_state, label, to); - edges.emplace_back(from, std::numeric_limits::max(), trace_label->_state); - - } else { // post* trace - post = true; - const auto &[rule, labels] = automaton.pda().states()[trace_label->_state]._rules[trace_label->_rule_id]; - switch (rule._operation) { - case POP: - case SWAP: - case NOOP: - edges.emplace_back(trace_label->_state, trace_label->_label, to); - break; - case PUSH: - auto[from2, label2, to2] = edges.back(); - edges.pop_back(); - trace_label = automaton.get_trace_label(from2, label2, to2); - assert(trace_label != nullptr); - edges.emplace_back(trace_label->_state, trace_label->_label, to2); - break; - } - assert(from == rule._to); - trace.emplace_back(trace_label->_state, trace_label->_label, rule); - } - } - if (post) { - std::reverse(trace.begin(), trace.end()); - } - return trace; - } - }; + } #endif //PDAAAL_SOLVER_H