Skip to content

Commit

Permalink
Update notebook example
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Sep 28, 2019
1 parent eb63d86 commit 4d4fd72
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BasicPOMCP"
uuid = "d721219e-3fc6-5570-a8ef-e5402f47c49e"
repo = "https://github.com/JuliaPOMDP/BasicPOMCP.jl"
version = "0.3.0"
version = "0.3.1"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
Expand Down
105 changes: 46 additions & 59 deletions notebooks/Minimal_Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"using POMDPs\n",
"using Distributions # for Normal\n",
"using Distributions: Normal\n",
"using Random\n",
"import POMDPs: initialstate_distribution, actions, reward, generate_o, generate_s, discount, isterminal\n",
"import POMDPs: initialstate_distribution, actions, gen, discount, isterminal\n",
"Random.seed!(1);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -59,53 +59,38 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"noise(x) = ceil(Int, abs(x - 5)/sqrt(2) + 1e-2)\n",
"\n",
"function generate_o(p::LightDark1D, s::Float64, a::Int, sp::Float64, rng::AbstractRNG)\n",
"function gen(m::LightDark1D, s::Float64, a::Int, rng::AbstractRNG)\n",
" # generate next state\n",
" sp = iszero(a) ? NaN : s+a\n",
" # generate observation\n",
" if isnan(sp)\n",
" return 0\n",
" o = 0\n",
" else\n",
" n = noise(sp)\n",
" return round(Int, sp) + rand(rng, -n:n)\n",
" o = round(Int, sp) + rand(rng, -n:n)\n",
" end\n",
"end\n",
"\n",
"function generate_s(p::LightDark1D, s::Float64, a::Int, rng::AbstractRNG)\n",
" if a == 0\n",
" return NaN\n",
" else\n",
" return s+a\n",
" end\n",
"end\n",
"\n",
"function reward(p::LightDark1D, s::Float64, a::Int, sp::Float64)\n",
" if a == 0\n",
" if abs(s) < 1\n",
" return p.correct_r\n",
" else\n",
" return p.incorrect_r\n",
" end\n",
" else\n",
" return 0.0\n",
" end \n",
" # generate reward\n",
" r = iszero(a) ? (abs(s) < 1 ? m.correct_r : m.incorrect_r) : 0.0\n",
" \n",
" return (sp=sp, o=o, r=r)\n",
"end;"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"actions(::LightDark1D) = [-1, 0, 1] # Left Stop Right\n",
"\n",
"function initialstate_distribution(pomdp::LightDark1D)\n",
" return Normal(2.0, 3.0)\n",
"end;"
"initialstate_distribution(pomdp::LightDark1D) = Normal(2.0, 3.0);"
]
},
{
Expand All @@ -119,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -129,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -140,7 +125,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 54,
"metadata": {
"scrolled": true
},
Expand Down Expand Up @@ -177,7 +162,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 55,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -215,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -250,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -259,27 +244,22 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 58,
"metadata": {},
"outputs": [
{
"ename": "MethodError",
"evalue": "MethodError: no method matching iterate(::RealInterval)\nClosest candidates are:\n iterate(!Matched::Core.SimpleVector) at essentials.jl:578\n iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:578\n iterate(!Matched::ExponentialBackOff) at error.jl:171\n ...",
"output_type": "error",
"traceback": [
"MethodError: no method matching iterate(::RealInterval)\nClosest candidates are:\n iterate(!Matched::Core.SimpleVector) at essentials.jl:578\n iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:578\n iterate(!Matched::ExponentialBackOff) at error.jl:171\n ...",
"",
"Stacktrace:",
" [1] first(::RealInterval) at ./abstractarray.jl:288",
" [2] macro expansion at /home/zach/.julia/dev/POMDPs/src/requirements_internals.jl:95 [inlined]",
" [3] macro expansion at /home/zach/.julia/dev/POMDPs/src/requirements_interface.jl:17 [inlined]",
" [4] initialize_belief at /home/zach/.julia/packages/ParticleFilters/fp73A/src/pomdps.jl:56 [inlined]",
" [5] simulate(::StepSimulator, ::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}, ::Normal{Float64}, ::Nothing) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:32",
" [6] simulate at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:29 [inlined]",
" [7] simulate(::StepSimulator, ::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:25",
" [8] #stepthrough#21(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}, ::Vararg{Any,N} where N) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:213",
" [9] stepthrough(::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}, ::String) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:202",
" [10] top-level scope at ./In[11]:2"
"name": "stdout",
"output_type": "stream",
"text": [
"(s, a, r, sp, o) = (7.6069078240874, -1, 0.0, 6.6069078240874, 5)\n",
"(s, a, r, sp, o) = (6.6069078240874, -1, 0.0, 5.6069078240874, 7)\n",
"(s, a, r, sp, o) = (5.6069078240874, -1, 0.0, 4.6069078240874, 5)\n",
"(s, a, r, sp, o) = (4.6069078240874, -1, 0.0, 3.6069078240873997, 5)\n",
"(s, a, r, sp, o) = (3.6069078240873997, -1, 0.0, 2.6069078240873997, 3)\n",
"(s, a, r, sp, o) = (2.6069078240873997, -1, 0.0, 1.6069078240873997, 3)\n",
"(s, a, r, sp, o) = (1.6069078240873997, -1, 0.0, 0.6069078240873997, 0)\n",
"(s, a, r, sp, o) = (0.6069078240873997, -1, 0.0, -0.39309217591260026, 3)\n",
"(s, a, r, sp, o) = (-0.39309217591260026, 0, 10.0, NaN, 0)\n"
]
}
],
Expand All @@ -290,6 +270,13 @@
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -301,15 +288,15 @@
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Julia 1.0.0",
"display_name": "Julia 1.2.0",
"language": "julia",
"name": "julia-1.0"
"name": "julia-1.2"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.0.0"
"version": "1.2.0"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 4d4fd72

Please sign in to comment.