diff --git a/.travis.yml b/.travis.yml index fb562bf..d421fad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,7 @@ language: julia os: - linux julia: - - 0.7 - - 1.0 + - 1 notifications: email: false # uncomment the following lines to override the default test script diff --git a/Project.toml b/Project.toml index 75dd343..6531ff8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "POMCPOW" uuid = "4c53ee00-974c-466f-8fa5-8dd73959bbab" repo = "https://github.com/JuliaPOMDP/POMCPOW.jl" -version = "0.2.2" +version = "0.3.0" [deps] BasicPOMCP = "d721219e-3fc6-5570-a8ef-e5402f47c49e" @@ -18,7 +18,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -POMDPs = "< 0.7.3" +POMDPs = "0.7.3, 0.8" +julia = "^1.1" [extras] BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" diff --git a/src/POMCPOW.jl b/src/POMCPOW.jl index 6034ffa..d0475ba 100644 --- a/src/POMCPOW.jl +++ b/src/POMCPOW.jl @@ -15,7 +15,7 @@ using POMDPPolicies using BasicPOMCP: convert_estimator import Base: insert! -import POMDPs: action, solve, mean, rand, updater +import POMDPs: action, solve, mean, rand, updater, currentobs, history import POMDPModelTools: action_info import MCTS: n_children, next_action, isroot, node_tag, tooltip_tag @@ -40,7 +40,6 @@ export n_children, belief, sr_belief, - current_obs, isroot, POMCPOWVisualizer, diff --git a/src/beliefs.jl b/src/beliefs.jl index ef85b69..1d960d3 100644 --- a/src/beliefs.jl +++ b/src/beliefs.jl @@ -18,21 +18,26 @@ end rand(rng::AbstractRNG, b::POWNodeBelief) = rand(rng, b.dist) state_mean(b::POWNodeBelief) = first_mean(b.dist) +POMDPs.currentobs(b::POWNodeBelief) = b.o +POMDPs.history(b::POWNodeBelief) = tuple((a=b.a, o=b.o)) + struct POWNodeFilter end belief_type(::Type{POWNodeFilter}, ::Type{P}) where {P<:POMDP} = POWNodeBelief{statetype(P), actiontype(P), obstype(P), P} -init_node_sr_belief(::POWNodeFilter, p::POMDP, s, a, sp, o, r) = POWNodeBelief(p, s, a, sp, o, r) +init_node_sr_belief(::POWNodeFilter, p::POMDP, s, a, sp, o, r) = POWNodeBelief(p, s, a, sp, o, r) function push_weighted!(b::POWNodeBelief, ::POWNodeFilter, s, sp, r) w = obs_weight(b.model, s, b.a, sp, b.o) insert!(b.dist, (sp, convert(Float64, r)), w) end -struct StateBelief{SRB} +struct StateBelief{SRB<:POWNodeBelief} sr_belief::SRB end rand(rng::AbstractRNG, b::StateBelief) = first(rand(rng, b.sr_belief)) mean(b::StateBelief) = state_mean(b.sr_belief) +POMDPs.currentobs(b::StateBelief) = currentobs(b.sr_belief) +POMDPs.history(b::StateBelief) = history(b.sr_belief) diff --git a/src/solver2.jl b/src/solver2.jl index 4c0c778..3abdf57 100644 --- a/src/solver2.jl +++ b/src/solver2.jl @@ -26,7 +26,11 @@ function simulate(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O}, s::S, d) end else # run through all the actions if isempty(tree.tried[h]) - action_space_iter = POMDPs.actions(pomcp.problem, h_node) + if h == 1 + action_space_iter = POMDPs.actions(pomcp.problem, tree.root_belief) + else + action_space_iter = POMDPs.actions(pomcp.problem, StateBelief(tree.sr_beliefs[h])) + end anode = length(tree.n) for a in action_space_iter push_anode!(tree, h, a, @@ -44,7 +48,7 @@ function simulate(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O}, s::S, d) new_node = false if tree.n_a_children[best_node] <= sol.k_observation*(tree.n[best_node]^sol.alpha_observation) - sp, o, r = generate_sor(pomcp.problem, s, a, sol.rng) + sp, o, r = gen(DDNOut(:sp, :o, :r), pomcp.problem, s, a, sol.rng) if sol.check_repeat_obs && haskey(tree.a_child_lookup, (best_node,o)) hao = tree.a_child_lookup[(best_node, o)] @@ -66,7 +70,7 @@ function simulate(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O}, s::S, d) push!(tree.generated[best_node], o=>hao) else - sp, r = generate_sr(pomcp.problem, s, a, sol.rng) + sp, r = gen(DDNOut(:sp, :r), pomcp.problem, s, a, sol.rng) end diff --git a/src/tree.jl b/src/tree.jl index 9d5a092..f4724a6 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -59,14 +59,7 @@ struct POWTreeObsNode{B,A,O,RB} <: BeliefNode end isroot(h::POWTreeObsNode) = h.node==1 -function current_obs(h::POWTreeObsNode) - if isroot(h) - error("Tried to access the observation for the root node in a POMCPOW tree") - else - return h.tree.o_labels[h.node] - end -end -function belief(h::POWTreeObsNode) +@inline function belief(h::POWTreeObsNode) if isroot(h) return h.tree.root_belief else diff --git a/src/visualization.jl b/src/visualization.jl index 0e26c15..e05c484 100644 --- a/src/visualization.jl +++ b/src/visualization.jl @@ -1,12 +1,12 @@ function D3Trees.D3Tree(p::POMCPOWPlanner; title="POMCPOW Tree", kwargs...) @warn(""" D3Tree(planner::POMCPOWPlanner) is deprecated and may be removed in the future. Instead, please use - + a, info = action_info(planner, state) D3Tree(info[:tree]) Or, you can get this info from a POMDPSimulators History - + info = first(ainfo_hist(hist)) D3Tree(info[:tree]) """) diff --git a/test/init_node_sr_belief_error.jl b/test/init_node_sr_belief_error.jl index fd0c4e4..8e82d52 100644 --- a/test/init_node_sr_belief_error.jl +++ b/test/init_node_sr_belief_error.jl @@ -35,8 +35,6 @@ end POMDPs.actions(m::SimplePOMDP) = [-1, 1] POMDPs.states(m::SimplePOMDP) = 1:7 POMDPs.actionindex(m::SimplePOMDP, a::Int) = a == 1 ? 1 : 2 -POMDPs.n_states(m::SimplePOMDP) = 7 -POMDPs.n_actions(m::SimplePOMDP) = 2 pomdp = SimplePOMDP(0.7) diff --git a/test/runtests.jl b/test/runtests.jl index 675929b..62f9ba2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,43 +9,78 @@ using D3Trees using BeliefUpdaters using POMDPModelTools -solver = POMCPOWSolver() +@testset "all" begin -pomdp = BabyPOMDP() + @testset "POMDPTesting" begin + solver = POMCPOWSolver() + pomdp = BabyPOMDP() + test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp)) + test_solver(solver, pomdp) -test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp)) -test_solver(solver, pomdp) + solver = POMCPOWSolver(max_time=0.1, tree_queries=typemax(Int)) + test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp)) + end -solver = POMCPOWSolver(max_time=0.1, tree_queries=typemax(Int)) -test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp)) + @testset "type stability" begin + # make sure internal function is type stable + pomdp = BabyPOMDP() + solver = POMCPOWSolver() + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp)) + tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries) + @inferred POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10) -# make sure internal function is type stable -solver = POMCPOWSolver() -planner = solve(solver, pomdp) -b = initialstate_distribution(pomdp) -B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp)) -tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries) -@inferred POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10) -# @code_warntype POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10) + pomdp = LightDark1D() + solver = POMCPOWSolver(default_action=485) + planner = solve(solver, pomdp) -pomdp = LightDark1D() -solver = POMCPOWSolver(default_action=485) -planner = solve(solver, pomdp) + b = ParticleCollection([LightDark1DState(-1, 0)]) + @test @test_logs (:warn,) @inferred(action(planner, b)) == 485 -b = ParticleCollection([LightDark1DState(-1, 0)]) -println("There should be a warning about a default action below") -@test @inferred(action(planner, b)) == 485 + b = initialstate_distribution(pomdp) + @inferred action(planner, b) + end; -b = initialstate_distribution(pomdp) -@inferred action(planner, b) + @testset "currentobs and history" begin + pomdp = BabyPOMDP() + solver = POMCPOWSolver() + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp)) + tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries) -a, info = action_info(planner, b) -# d3t = D3Tree(planner) -@test_throws KeyError d3t = D3Tree(info[:tree]) + n = POMCPOW.POWTreeObsNode(tree, 1) + nb = belief(n) + # we can't call current obs on the root node + @test_throws MethodError currentobs(nb) + # simulate the tree to expand one step + POMCPOW.simulate(planner, n, true, 1) + n = POMCPOW.POWTreeObsNode(tree, 2) + nb = belief(n) + # but at a non-root node, this should work + @test currentobs(nb) isa Bool + @test currentobs(nb) == history(nb)[end].o + @test history(nb)[end].a isa Bool + end; -a, info = action_info(planner, b, tree_in_info=true) -# d3t = D3Tree(planner) -d3t = D3Tree(info[:tree]) -# inchrome(d3t) + @testset "D3tree" begin + # make sure internal function is type stable + pomdp = BabyPOMDP() + solver = POMCPOWSolver() + planner = solve(solver, pomdp) + b = initialstate_distribution(pomdp) + a, info = action_info(planner, b) + # d3t = D3Tree(planner) + @test_throws KeyError d3t = D3Tree(info[:tree]) -include("init_node_sr_belief_error.jl") + a, info = action_info(planner, b, tree_in_info=true) + # d3t = D3Tree(planner) + d3t = D3Tree(info[:tree]) + # inchrome(d3t) + end; + + @testset "init_node_sr_belief_error" begin + include("init_node_sr_belief_error.jl") + end; +end;