From eb63d868af448911beace7e61a5f64c95c19c57c Mon Sep 17 00:00:00 2001 From: Lasse Peters Date: Fri, 20 Sep 2019 11:36:25 -0700 Subject: [PATCH] Support belief dependent action spaces (#12) * Support belief dependent action spaces * Add tests --- src/BasicPOMCP.jl | 24 +++---- src/requirements_info.jl | 2 +- src/solver.jl | 2 +- test/runtests.jl | 135 +++++++++++++++++++++++++++------------ 4 files changed, 106 insertions(+), 57 deletions(-) diff --git a/src/BasicPOMCP.jl b/src/BasicPOMCP.jl index 044a8e1..b34fb49 100644 --- a/src/BasicPOMCP.jl +++ b/src/BasicPOMCP.jl @@ -43,6 +43,7 @@ export default_action, BeliefNode, + AOHistoryBelief, AbstractPOMCPSolver, PORollout, @@ -127,8 +128,8 @@ struct POMCPTree{A,O} a_labels::Vector{A} # actual action corresponding to this action node end -function POMCPTree(pomdp::POMDP, sz::Int=1000) - acts = collect(actions(pomdp)) +function POMCPTree(pomdp::POMDP, b, sz::Int=1000) + acts = collect(actions(pomdp, b)) A = actiontype(pomdp) O = obstype(pomdp) sz = min(100_000, sz) @@ -144,8 +145,14 @@ function POMCPTree(pomdp::POMDP, sz::Int=1000) ) end +struct AOHistoryBelief{H<:NTuple{<:Any, <:NamedTuple{(:a, :o)}}} + hist::H +end +POMDPs.currentobs(h::AOHistoryBelief) = h.hist[end].o +POMDPs.history(h::AOHistoryBelief) = h.hist + function insert_obs_node!(t::POMCPTree, pomdp::POMDP, ha::Int, o) - acts = actions(pomdp) + acts = actions(pomdp, AOHistoryBelief(tuple((a=t.a_labels[ha], o=o)))) push!(t.total_n, 0) push!(t.children, sizehint!(Int[], length(acts))) push!(t.o_labels, o) @@ -200,17 +207,6 @@ function updater(p::POMCPPlanner) return SIRParticleFilter(p.problem, p.solver.tree_queries, rng=p.rng) end -# TODO (maybe): implement this for history-dependent policies -#= -immutable AOHistory - tree::POMCPTree - tail::Int -end - -length -getindex -=# - include("solver.jl") include("exceptions.jl") diff --git a/src/requirements_info.jl b/src/requirements_info.jl index 6530b30..7fdd2d9 100644 --- a/src/requirements_info.jl +++ b/src/requirements_info.jl @@ -19,7 +19,7 @@ end POMDPs.requirements_info(policy::POMCPPlanner, b) = @show_requirements action(policy, b) @POMDP_require action(p::POMCPPlanner, b) begin - tree = POMCPTree(p.problem, p.solver.tree_queries) + tree = POMCPTree(p.problem, b, p.solver.tree_queries) @subreq search(p, b, tree) end diff --git a/src/solver.jl b/src/solver.jl index 2744ebe..21e07e4 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -2,7 +2,7 @@ function action_info(p::POMCPPlanner, b; tree_in_info=false) local a::actiontype(p.problem) info = Dict{Symbol, Any}() try - tree = POMCPTree(p.problem, p.solver.tree_queries) + tree = POMCPTree(p.problem, b, p.solver.tree_queries) a = search(p, b, tree, info) p._tree = tree if p.solver.tree_in_info || tree_in_info diff --git a/test/runtests.jl b/test/runtests.jl index 48354bb..3eae371 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,50 +9,103 @@ using POMDPSimulators using POMDPModelTools using POMDPTesting -test_solver(POMCPSolver(), BabyPOMDP()) - -pomdp = BabyPOMDP() -solver = POMCPSolver(rng = MersenneTwister(1)) -planner = solve(solver, pomdp) - -tree = BasicPOMCP.POMCPTree(pomdp, solver.tree_queries) -node = BasicPOMCP.POMCPObsNode(tree, 1) - -r = @inferred BasicPOMCP.simulate(planner, initialstate(pomdp, MersenneTwister(1)), node, 20) - -sim = HistoryRecorder(max_steps=10) -simulate(sim, pomdp, planner, updater(pomdp)) - -solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1)) -planner = solve(solver, pomdp) -a, info = action_info(planner, initialstate_distribution(pomdp)) -a, info = action_info(planner, initialstate_distribution(pomdp)) -println("time below should be about 0.1 seconds") -etime = @elapsed a, info = action_info(planner, initialstate_distribution(pomdp)) -@show etime -@test etime < 0.2 -@show info[:search_time_us] - -solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1)) -planner = solve(solver, pomdp) -a, info = action_info(planner, initialstate_distribution(pomdp), tree_in_info=true) - -#d3t = D3Tree(planner, title="test") -d3t = D3Tree(info[:tree], title="test tree") -# inchrome(d3t) -show(stdout, MIME("text/plain"), d3t) +import POMDPs: + transition, + observation, + reward, + discount, + initialstate_distribution, + updater, + states, + actions, + observations + +struct ConstObsPOMDP <: POMDP{Bool, Symbol, Bool} end +updater(problem::ConstObsPOMDP) = DiscreteUpdater(problem) +initialstate_distribution(::ConstObsPOMDP) = BoolDistribution(0.0) +transition(p::ConstObsPOMDP, s::Bool, a::Symbol) = BoolDistribution(0.5) +observation(p::ConstObsPOMDP, a::Symbol, sp::Bool) = BoolDistribution(1.0) +reward(p::ConstObsPOMDP, s::Bool, a::Symbol, sp::Bool) = 1. +discount(p::ConstObsPOMDP) = 0.9 +states(p::ConstObsPOMDP) = (true, false) +actions(p::ConstObsPOMDP) = (:the_only_action,) +observations(p::ConstObsPOMDP) = (true, false) + +@testset "POMDPTesting" begin + pomdp = BabyPOMDP() + test_solver(POMCPSolver(), BabyPOMDP()) +end; + +@testset "type stability" begin + pomdp = BabyPOMDP() + solver = POMCPSolver(rng = MersenneTwister(1)) + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + tree = BasicPOMCP.POMCPTree(pomdp, b, solver.tree_queries) + node = BasicPOMCP.POMCPObsNode(tree, 1) + + r = @inferred BasicPOMCP.simulate(planner, initialstate(pomdp, MersenneTwister(1)), node, 20) +end; + +@testset "belief dependent actions" begin + pomdp = ConstObsPOMDP() + function POMDPs.actions(m::ConstObsPOMDP, b::AOHistoryBelief) + @test currentobs(b) == true + @test history(b)[end].o == true + @test history(b)[end].a == :the_only_action + return actions(m) + end + + solver = POMCPSolver(rng = MersenneTwister(1)) + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + tree = BasicPOMCP.POMCPTree(pomdp, b, solver.tree_queries) + node = BasicPOMCP.POMCPObsNode(tree, 1) + + @inferred BasicPOMCP.simulate(planner, initialstate(pomdp, MersenneTwister(1)), node, 20) +end; + +@testset "simulation" begin + pomdp = BabyPOMDP() + solver = POMCPSolver(rng = MersenneTwister(1)) + planner = solve(solver, pomdp) + solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1)) + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + + a, info = action_info(planner, b) + println("time below should be about 0.1 seconds") + etime = @elapsed a, info = action_info(planner, b) + @show etime + @test etime < 0.2 + @show info[:search_time_us] + + sim = HistoryRecorder(max_steps=10) + simulate(sim, pomdp, planner, updater(pomdp)) +end; + +@testset "d3t" begin + pomdp = BabyPOMDP() + solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1)) + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + a, info = action_info(planner, b, tree_in_info=true) + d3t = D3Tree(info[:tree], title="test tree") + # inchrome(d3t) + show(stdout, MIME("text/plain"), d3t) -solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng=MersenneTwister(1), tree_in_info=true) -planner = solve(solver, pomdp) -a, info = action_info(planner, initialstate_distribution(pomdp)) -d3t = D3Tree(info[:tree], title="test tree (tree_in_info solver option)") + solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng=MersenneTwister(1), tree_in_info=true) + planner = solve(solver, pomdp) + a, info = action_info(planner, b) -@nbinclude(joinpath(dirname(@__FILE__), "..", "notebooks", "Minimal_Example.ipynb")) + d3t = D3Tree(info[:tree], title="test tree (tree_in_info solver option)") +end; -#d3t = D3Tree(planner, title="test") -# inchrome(d3t) +@testset "Minimal_Example" begin + @nbinclude(joinpath(dirname(@__FILE__), "..", "notebooks", "Minimal_Example.ipynb")) +end; @testset "consistency" begin # test consistency when rng is specified @@ -66,7 +119,7 @@ d3t = D3Tree(info[:tree], title="test tree (tree_in_info solver option)") hist2 = simulate(HistoryRecorder(max_steps=1000, rng=MersenneTwister(3)), pomdp, planner) @test discounted_reward(hist1) == discounted_reward(hist2) -end +end; @testset "requires" begin # REQUIREMENTS @@ -77,4 +130,4 @@ end @requirements_info solver println("============== @requirements_info with solver and pomdp:") @requirements_info solver pomdp -end +end;