Skip to content

Commit

Permalink
Support belief dependent action spaces (#12)
Browse files Browse the repository at this point in the history
* Support belief dependent action spaces

* Add tests
  • Loading branch information
lassepe authored and zsunberg committed Sep 20, 2019
1 parent e17df32 commit eb63d86
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 57 deletions.
24 changes: 10 additions & 14 deletions src/BasicPOMCP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export
default_action,

BeliefNode,
AOHistoryBelief,
AbstractPOMCPSolver,

PORollout,
Expand Down Expand Up @@ -127,8 +128,8 @@ struct POMCPTree{A,O}
a_labels::Vector{A} # actual action corresponding to this action node
end

function POMCPTree(pomdp::POMDP, sz::Int=1000)
acts = collect(actions(pomdp))
function POMCPTree(pomdp::POMDP, b, sz::Int=1000)
acts = collect(actions(pomdp, b))
A = actiontype(pomdp)
O = obstype(pomdp)
sz = min(100_000, sz)
Expand All @@ -144,8 +145,14 @@ function POMCPTree(pomdp::POMDP, sz::Int=1000)
)
end

struct AOHistoryBelief{H<:NTuple{<:Any, <:NamedTuple{(:a, :o)}}}
hist::H
end
POMDPs.currentobs(h::AOHistoryBelief) = h.hist[end].o
POMDPs.history(h::AOHistoryBelief) = h.hist

function insert_obs_node!(t::POMCPTree, pomdp::POMDP, ha::Int, o)
acts = actions(pomdp)
acts = actions(pomdp, AOHistoryBelief(tuple((a=t.a_labels[ha], o=o))))
push!(t.total_n, 0)
push!(t.children, sizehint!(Int[], length(acts)))
push!(t.o_labels, o)
Expand Down Expand Up @@ -200,17 +207,6 @@ function updater(p::POMCPPlanner)
return SIRParticleFilter(p.problem, p.solver.tree_queries, rng=p.rng)
end

# TODO (maybe): implement this for history-dependent policies
#=
immutable AOHistory
tree::POMCPTree
tail::Int
end
length
getindex
=#

include("solver.jl")

include("exceptions.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/requirements_info.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end
POMDPs.requirements_info(policy::POMCPPlanner, b) = @show_requirements action(policy, b)

@POMDP_require action(p::POMCPPlanner, b) begin
tree = POMCPTree(p.problem, p.solver.tree_queries)
tree = POMCPTree(p.problem, b, p.solver.tree_queries)
@subreq search(p, b, tree)
end

Expand Down
2 changes: 1 addition & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function action_info(p::POMCPPlanner, b; tree_in_info=false)
local a::actiontype(p.problem)
info = Dict{Symbol, Any}()
try
tree = POMCPTree(p.problem, p.solver.tree_queries)
tree = POMCPTree(p.problem, b, p.solver.tree_queries)
a = search(p, b, tree, info)
p._tree = tree
if p.solver.tree_in_info || tree_in_info
Expand Down
135 changes: 94 additions & 41 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,103 @@ using POMDPSimulators
using POMDPModelTools
using POMDPTesting

test_solver(POMCPSolver(), BabyPOMDP())

pomdp = BabyPOMDP()
solver = POMCPSolver(rng = MersenneTwister(1))
planner = solve(solver, pomdp)

tree = BasicPOMCP.POMCPTree(pomdp, solver.tree_queries)
node = BasicPOMCP.POMCPObsNode(tree, 1)

r = @inferred BasicPOMCP.simulate(planner, initialstate(pomdp, MersenneTwister(1)), node, 20)

sim = HistoryRecorder(max_steps=10)
simulate(sim, pomdp, planner, updater(pomdp))

solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1))
planner = solve(solver, pomdp)
a, info = action_info(planner, initialstate_distribution(pomdp))
a, info = action_info(planner, initialstate_distribution(pomdp))
println("time below should be about 0.1 seconds")
etime = @elapsed a, info = action_info(planner, initialstate_distribution(pomdp))
@show etime
@test etime < 0.2
@show info[:search_time_us]

solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1))
planner = solve(solver, pomdp)
a, info = action_info(planner, initialstate_distribution(pomdp), tree_in_info=true)

#d3t = D3Tree(planner, title="test")
d3t = D3Tree(info[:tree], title="test tree")
# inchrome(d3t)
show(stdout, MIME("text/plain"), d3t)
import POMDPs:
transition,
observation,
reward,
discount,
initialstate_distribution,
updater,
states,
actions,
observations

struct ConstObsPOMDP <: POMDP{Bool, Symbol, Bool} end
updater(problem::ConstObsPOMDP) = DiscreteUpdater(problem)
initialstate_distribution(::ConstObsPOMDP) = BoolDistribution(0.0)
transition(p::ConstObsPOMDP, s::Bool, a::Symbol) = BoolDistribution(0.5)
observation(p::ConstObsPOMDP, a::Symbol, sp::Bool) = BoolDistribution(1.0)
reward(p::ConstObsPOMDP, s::Bool, a::Symbol, sp::Bool) = 1.
discount(p::ConstObsPOMDP) = 0.9
states(p::ConstObsPOMDP) = (true, false)
actions(p::ConstObsPOMDP) = (:the_only_action,)
observations(p::ConstObsPOMDP) = (true, false)

@testset "POMDPTesting" begin
pomdp = BabyPOMDP()
test_solver(POMCPSolver(), BabyPOMDP())
end;

@testset "type stability" begin
pomdp = BabyPOMDP()
solver = POMCPSolver(rng = MersenneTwister(1))
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
tree = BasicPOMCP.POMCPTree(pomdp, b, solver.tree_queries)
node = BasicPOMCP.POMCPObsNode(tree, 1)

r = @inferred BasicPOMCP.simulate(planner, initialstate(pomdp, MersenneTwister(1)), node, 20)
end;

@testset "belief dependent actions" begin
pomdp = ConstObsPOMDP()
function POMDPs.actions(m::ConstObsPOMDP, b::AOHistoryBelief)
@test currentobs(b) == true
@test history(b)[end].o == true
@test history(b)[end].a == :the_only_action
return actions(m)
end

solver = POMCPSolver(rng = MersenneTwister(1))
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
tree = BasicPOMCP.POMCPTree(pomdp, b, solver.tree_queries)
node = BasicPOMCP.POMCPObsNode(tree, 1)

@inferred BasicPOMCP.simulate(planner, initialstate(pomdp, MersenneTwister(1)), node, 20)
end;

@testset "simulation" begin
pomdp = BabyPOMDP()
solver = POMCPSolver(rng = MersenneTwister(1))
planner = solve(solver, pomdp)
solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1))
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)

a, info = action_info(planner, b)
println("time below should be about 0.1 seconds")
etime = @elapsed a, info = action_info(planner, b)
@show etime
@test etime < 0.2
@show info[:search_time_us]

sim = HistoryRecorder(max_steps=10)
simulate(sim, pomdp, planner, updater(pomdp))
end;

@testset "d3t" begin
pomdp = BabyPOMDP()
solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1))
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
a, info = action_info(planner, b, tree_in_info=true)

d3t = D3Tree(info[:tree], title="test tree")
# inchrome(d3t)
show(stdout, MIME("text/plain"), d3t)

solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng=MersenneTwister(1), tree_in_info=true)
planner = solve(solver, pomdp)
a, info = action_info(planner, initialstate_distribution(pomdp))

d3t = D3Tree(info[:tree], title="test tree (tree_in_info solver option)")
solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng=MersenneTwister(1), tree_in_info=true)
planner = solve(solver, pomdp)
a, info = action_info(planner, b)

@nbinclude(joinpath(dirname(@__FILE__), "..", "notebooks", "Minimal_Example.ipynb"))
d3t = D3Tree(info[:tree], title="test tree (tree_in_info solver option)")
end;

#d3t = D3Tree(planner, title="test")
# inchrome(d3t)
@testset "Minimal_Example" begin
@nbinclude(joinpath(dirname(@__FILE__), "..", "notebooks", "Minimal_Example.ipynb"))
end;

@testset "consistency" begin
# test consistency when rng is specified
Expand All @@ -66,7 +119,7 @@ d3t = D3Tree(info[:tree], title="test tree (tree_in_info solver option)")
hist2 = simulate(HistoryRecorder(max_steps=1000, rng=MersenneTwister(3)), pomdp, planner)

@test discounted_reward(hist1) == discounted_reward(hist2)
end
end;

@testset "requires" begin
# REQUIREMENTS
Expand All @@ -77,4 +130,4 @@ end
@requirements_info solver
println("============== @requirements_info with solver and pomdp:")
@requirements_info solver pomdp
end
end;

0 comments on commit eb63d86

Please sign in to comment.