Skip to content

Commit

Permalink
Merge branch 'get-trace-weight' into reduction-options
Browse files Browse the repository at this point in the history
  • Loading branch information
MortenSchou committed May 11, 2020
2 parents f9049ff + e205e51 commit ed1e7fb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
20 changes: 10 additions & 10 deletions src/pdaaal/Solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,27 +241,27 @@ namespace pdaaal {
}

template <Trace_Type trace_type = Trace_Type::Any, typename T, typename W, typename C, typename A>
static std::vector<typename TypedPDA<T>::tracestate_t> get_trace(const TypedPDA<T,W,C>& pda, const PAutomaton<W,C,A>& automaton, size_t state, const std::vector<T>& stack) {
static auto get_trace(const TypedPDA<T,W,C>& pda, const PAutomaton<W,C,A>& automaton, size_t state, const std::vector<T>& stack) {
static_assert(trace_type != Trace_Type::None, "If you want a trace, don't ask for none.");
std::vector<size_t> path;
auto stack_native = pda.encode_pre(stack);
if constexpr (trace_type == Trace_Type::Shortest) {
path = automaton.template accept_path<trace_type>(state, stack_native).first;
auto [path, weight] = automaton.template accept_path<trace_type>(state, stack_native);
return std::make_pair(_get_trace(pda, automaton, path, stack_native), weight);
} else {
path = automaton.template accept_path<trace_type>(state, stack_native);
auto path = automaton.template accept_path<trace_type>(state, stack_native);
return _get_trace(pda, automaton, path, stack_native);
}
return _get_trace(pda, automaton, path, stack_native);
}
template <Trace_Type trace_type = Trace_Type::Any, typename T, typename W, typename C, typename A, typename = std::enable_if_t<!std::is_same_v<T,uint32_t>>>
static std::vector<typename TypedPDA<T>::tracestate_t> get_trace(const TypedPDA<T,W,C>& pda, const PAutomaton<W,C,A>& automaton, size_t state, const std::vector<uint32_t>& stack_native) {
static auto get_trace(const TypedPDA<T,W,C>& pda, const PAutomaton<W,C,A>& automaton, size_t state, const std::vector<uint32_t>& stack_native) {
static_assert(trace_type != Trace_Type::None, "If you want a trace, don't ask for none.");
std::vector<size_t> path;
if constexpr (trace_type == Trace_Type::Shortest) {
path = automaton.template accept_path<trace_type>(state, stack_native).first;
auto [path, weight] = automaton.template accept_path<trace_type>(state, stack_native);
return std::make_pair(_get_trace(pda, automaton, path, stack_native),weight);
} else {
path = automaton.template accept_path<trace_type>(state, stack_native);
auto path = automaton.template accept_path<trace_type>(state, stack_native);
return _get_trace(pda, automaton, path, stack_native);
}
return _get_trace(pda, automaton, path, stack_native);
}

private:
Expand Down
8 changes: 6 additions & 2 deletions src/pdaaal/SolverAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ namespace pdaaal {
}

template <Trace_Type trace_type = Trace_Type::Any, typename T, typename W, typename C, typename A>
[[nodiscard]] std::vector<typename TypedPDA<T>::tracestate_t> get_trace(const PDAAdapter<T,W,C>& pda, trace_info<W,C,A> info) const {
[[nodiscard]] auto get_trace(const PDAAdapter<T,W,C>& pda, trace_info<W,C,A> info) const {
auto trace = Solver::get_trace<trace_type>(pda, *info.first, info.second, pda.initial_stack());
trace.pop_back(); // Removes terminal state from trace.
if constexpr (trace_type == Trace_Type::Shortest) {
trace.first.pop_back();
} else { // Removes terminal state from trace.
trace.pop_back();
}
return trace;
}
};
Expand Down
6 changes: 3 additions & 3 deletions test/Solver_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ BOOST_AUTO_TEST_CASE(SolverTest1)
Solver::post_star<Trace_Type::Shortest>(automaton);

std::vector<char> test_stack_reachable{'B', 'A', 'A', 'A'};
auto trace = Solver::get_trace<Trace_Type::Shortest>(pda, automaton, 1, test_stack_reachable);
auto [trace,weight] = Solver::get_trace<Trace_Type::Shortest>(pda, automaton, 1, test_stack_reachable);
BOOST_CHECK_EQUAL(trace.size(), 7);
}

Expand All @@ -73,7 +73,7 @@ BOOST_AUTO_TEST_CASE(SolverTest2)
Solver::post_star<Trace_Type::Shortest>(automaton);

std::vector<uint32_t> test_stack_reachable{3, 2, 2, 2};
auto trace = Solver::get_trace<Trace_Type::Shortest>(pda, automaton, 1, test_stack_reachable);
auto [trace,weight] = Solver::get_trace<Trace_Type::Shortest>(pda, automaton, 1, test_stack_reachable);
BOOST_CHECK_EQUAL(trace.size(), 7);
}

Expand All @@ -97,7 +97,7 @@ BOOST_AUTO_TEST_CASE(SolverTest3)

std::vector<char> test_stack_reachable{'B', 'A', 'A', 'A'};
auto stack_native = pda.encode_pre(test_stack_reachable);
auto trace = Solver::get_trace<Trace_Type::Shortest>(pda, automaton, 1, stack_native);
auto [trace,weight] = Solver::get_trace<Trace_Type::Shortest>(pda, automaton, 1, stack_native);
BOOST_CHECK_EQUAL(trace.size(), 7);
}

Expand Down

0 comments on commit ed1e7fb

Please sign in to comment.