Skip to content

Commit

Permalink
V0.8 compat (#11)
Browse files Browse the repository at this point in the history
* V0.8 compat

* Remove support for Julia 1.0

- fixes #12

* Deprecate current_obs and implement currentobs and history on POWTreeObsNode

* Never hand POWTreeObsNode to actions(m,b)

* Bump version
  • Loading branch information
lassepe authored and zsunberg committed Sep 20, 2019
1 parent f399aad commit 142c355
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 53 deletions.
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ language: julia
os:
- linux
julia:
- 0.7
- 1.0
- 1
notifications:
email: false
# uncomment the following lines to override the default test script
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "POMCPOW"
uuid = "4c53ee00-974c-466f-8fa5-8dd73959bbab"
repo = "https://github.com/JuliaPOMDP/POMCPOW.jl"
version = "0.2.2"
version = "0.3.0"

[deps]
BasicPOMCP = "d721219e-3fc6-5570-a8ef-e5402f47c49e"
Expand All @@ -18,7 +18,8 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
POMDPs = "< 0.7.3"
POMDPs = "0.7.3, 0.8"
julia = "^1.1"

[extras]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
Expand Down
3 changes: 1 addition & 2 deletions src/POMCPOW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using POMDPPolicies
using BasicPOMCP: convert_estimator

import Base: insert!
import POMDPs: action, solve, mean, rand, updater
import POMDPs: action, solve, mean, rand, updater, currentobs, history
import POMDPModelTools: action_info

import MCTS: n_children, next_action, isroot, node_tag, tooltip_tag
Expand All @@ -40,7 +40,6 @@ export
n_children,
belief,
sr_belief,
current_obs,
isroot,

POMCPOWVisualizer,
Expand Down
9 changes: 7 additions & 2 deletions src/beliefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@ end

rand(rng::AbstractRNG, b::POWNodeBelief) = rand(rng, b.dist)
state_mean(b::POWNodeBelief) = first_mean(b.dist)
POMDPs.currentobs(b::POWNodeBelief) = b.o
POMDPs.history(b::POWNodeBelief) = tuple((a=b.a, o=b.o))


struct POWNodeFilter end

belief_type(::Type{POWNodeFilter}, ::Type{P}) where {P<:POMDP} = POWNodeBelief{statetype(P), actiontype(P), obstype(P), P}

init_node_sr_belief(::POWNodeFilter, p::POMDP, s, a, sp, o, r) = POWNodeBelief(p, s, a, sp, o, r)
init_node_sr_belief(::POWNodeFilter, p::POMDP, s, a, sp, o, r) = POWNodeBelief(p, s, a, sp, o, r)

function push_weighted!(b::POWNodeBelief, ::POWNodeFilter, s, sp, r)
w = obs_weight(b.model, s, b.a, sp, b.o)
insert!(b.dist, (sp, convert(Float64, r)), w)
end

struct StateBelief{SRB}
struct StateBelief{SRB<:POWNodeBelief}
sr_belief::SRB
end

rand(rng::AbstractRNG, b::StateBelief) = first(rand(rng, b.sr_belief))
mean(b::StateBelief) = state_mean(b.sr_belief)
POMDPs.currentobs(b::StateBelief) = currentobs(b.sr_belief)
POMDPs.history(b::StateBelief) = history(b.sr_belief)
10 changes: 7 additions & 3 deletions src/solver2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ function simulate(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O}, s::S, d)
end
else # run through all the actions
if isempty(tree.tried[h])
action_space_iter = POMDPs.actions(pomcp.problem, h_node)
if h == 1
action_space_iter = POMDPs.actions(pomcp.problem, tree.root_belief)
else
action_space_iter = POMDPs.actions(pomcp.problem, StateBelief(tree.sr_beliefs[h]))
end
anode = length(tree.n)
for a in action_space_iter
push_anode!(tree, h, a,
Expand All @@ -44,7 +48,7 @@ function simulate(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O}, s::S, d)
new_node = false
if tree.n_a_children[best_node] <= sol.k_observation*(tree.n[best_node]^sol.alpha_observation)

sp, o, r = generate_sor(pomcp.problem, s, a, sol.rng)
sp, o, r = gen(DDNOut(:sp, :o, :r), pomcp.problem, s, a, sol.rng)

if sol.check_repeat_obs && haskey(tree.a_child_lookup, (best_node,o))
hao = tree.a_child_lookup[(best_node, o)]
Expand All @@ -66,7 +70,7 @@ function simulate(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O}, s::S, d)
push!(tree.generated[best_node], o=>hao)
else

sp, r = generate_sr(pomcp.problem, s, a, sol.rng)
sp, r = gen(DDNOut(:sp, :r), pomcp.problem, s, a, sol.rng)

end

Expand Down
9 changes: 1 addition & 8 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,7 @@ struct POWTreeObsNode{B,A,O,RB} <: BeliefNode
end

isroot(h::POWTreeObsNode) = h.node==1
function current_obs(h::POWTreeObsNode)
if isroot(h)
error("Tried to access the observation for the root node in a POMCPOW tree")
else
return h.tree.o_labels[h.node]
end
end
function belief(h::POWTreeObsNode)
@inline function belief(h::POWTreeObsNode)
if isroot(h)
return h.tree.root_belief
else
Expand Down
4 changes: 2 additions & 2 deletions src/visualization.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function D3Trees.D3Tree(p::POMCPOWPlanner; title="POMCPOW Tree", kwargs...)
@warn("""
D3Tree(planner::POMCPOWPlanner) is deprecated and may be removed in the future. Instead, please use
a, info = action_info(planner, state)
D3Tree(info[:tree])
Or, you can get this info from a POMDPSimulators History
info = first(ainfo_hist(hist))
D3Tree(info[:tree])
""")
Expand Down
2 changes: 0 additions & 2 deletions test/init_node_sr_belief_error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ end
POMDPs.actions(m::SimplePOMDP) = [-1, 1]
POMDPs.states(m::SimplePOMDP) = 1:7
POMDPs.actionindex(m::SimplePOMDP, a::Int) = a == 1 ? 1 : 2
POMDPs.n_states(m::SimplePOMDP) = 7
POMDPs.n_actions(m::SimplePOMDP) = 2

pomdp = SimplePOMDP(0.7)

Expand Down
95 changes: 65 additions & 30 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,78 @@ using D3Trees
using BeliefUpdaters
using POMDPModelTools

solver = POMCPOWSolver()
@testset "all" begin

pomdp = BabyPOMDP()
@testset "POMDPTesting" begin
solver = POMCPOWSolver()
pomdp = BabyPOMDP()
test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp))
test_solver(solver, pomdp)

test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp))
test_solver(solver, pomdp)
solver = POMCPOWSolver(max_time=0.1, tree_queries=typemax(Int))
test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp))
end

solver = POMCPOWSolver(max_time=0.1, tree_queries=typemax(Int))
test_solver(solver, pomdp, updater=DiscreteUpdater(pomdp))
@testset "type stability" begin
# make sure internal function is type stable
pomdp = BabyPOMDP()
solver = POMCPOWSolver()
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp))
tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries)
@inferred POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10)

# make sure internal function is type stable
solver = POMCPOWSolver()
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp))
tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries)
@inferred POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10)
# @code_warntype POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10)
pomdp = LightDark1D()
solver = POMCPOWSolver(default_action=485)
planner = solve(solver, pomdp)

pomdp = LightDark1D()
solver = POMCPOWSolver(default_action=485)
planner = solve(solver, pomdp)
b = ParticleCollection([LightDark1DState(-1, 0)])
@test @test_logs (:warn,) @inferred(action(planner, b)) == 485

b = ParticleCollection([LightDark1DState(-1, 0)])
println("There should be a warning about a default action below")
@test @inferred(action(planner, b)) == 485
b = initialstate_distribution(pomdp)
@inferred action(planner, b)
end;

b = initialstate_distribution(pomdp)
@inferred action(planner, b)
@testset "currentobs and history" begin
pomdp = BabyPOMDP()
solver = POMCPOWSolver()
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp))
tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries)

a, info = action_info(planner, b)
# d3t = D3Tree(planner)
@test_throws KeyError d3t = D3Tree(info[:tree])
n = POMCPOW.POWTreeObsNode(tree, 1)
nb = belief(n)
# we can't call current obs on the root node
@test_throws MethodError currentobs(nb)
# simulate the tree to expand one step
POMCPOW.simulate(planner, n, true, 1)
n = POMCPOW.POWTreeObsNode(tree, 2)
nb = belief(n)
# but at a non-root node, this should work
@test currentobs(nb) isa Bool
@test currentobs(nb) == history(nb)[end].o
@test history(nb)[end].a isa Bool
end;

a, info = action_info(planner, b, tree_in_info=true)
# d3t = D3Tree(planner)
d3t = D3Tree(info[:tree])
# inchrome(d3t)
@testset "D3tree" begin
# make sure internal function is type stable
pomdp = BabyPOMDP()
solver = POMCPOWSolver()
planner = solve(solver, pomdp)
b = initialstate_distribution(pomdp)
a, info = action_info(planner, b)
# d3t = D3Tree(planner)
@test_throws KeyError d3t = D3Tree(info[:tree])

include("init_node_sr_belief_error.jl")
a, info = action_info(planner, b, tree_in_info=true)
# d3t = D3Tree(planner)
d3t = D3Tree(info[:tree])
# inchrome(d3t)
end;

@testset "init_node_sr_belief_error" begin
include("init_node_sr_belief_error.jl")
end;
end;

0 comments on commit 142c355

Please sign in to comment.