From 4d4fd728a35486c9a29007f8b75f0c33990da15b Mon Sep 17 00:00:00 2001 From: lassepe Date: Sat, 28 Sep 2019 15:04:37 -0700 Subject: [PATCH] Update notebook example --- Project.toml | 2 +- notebooks/Minimal_Example.ipynb | 105 ++++++++++++++------------------ 2 files changed, 47 insertions(+), 60 deletions(-) diff --git a/Project.toml b/Project.toml index c567ee2..87bb538 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BasicPOMCP" uuid = "d721219e-3fc6-5570-a8ef-e5402f47c49e" repo = "https://github.com/JuliaPOMDP/BasicPOMCP.jl" -version = "0.3.0" +version = "0.3.1" [deps] BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" diff --git a/notebooks/Minimal_Example.ipynb b/notebooks/Minimal_Example.ipynb index 02628d0..d7cc82a 100644 --- a/notebooks/Minimal_Example.ipynb +++ b/notebooks/Minimal_Example.ipynb @@ -28,20 +28,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "using POMDPs\n", - "using Distributions # for Normal\n", + "using Distributions: Normal\n", "using Random\n", - "import POMDPs: initialstate_distribution, actions, reward, generate_o, generate_s, discount, isterminal\n", + "import POMDPs: initialstate_distribution, actions, gen, discount, isterminal\n", "Random.seed!(1);" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -59,53 +59,38 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "noise(x) = ceil(Int, abs(x - 5)/sqrt(2) + 1e-2)\n", "\n", - "function generate_o(p::LightDark1D, s::Float64, a::Int, sp::Float64, rng::AbstractRNG)\n", + "function gen(m::LightDark1D, s::Float64, a::Int, rng::AbstractRNG)\n", + " # generate next state\n", + " sp = iszero(a) ? NaN : s+a\n", + " # generate observation\n", " if isnan(sp)\n", - " return 0\n", + " o = 0\n", " else\n", " n = noise(sp)\n", - " return round(Int, sp) + rand(rng, -n:n)\n", + " o = round(Int, sp) + rand(rng, -n:n)\n", " end\n", - "end\n", - "\n", - "function generate_s(p::LightDark1D, s::Float64, a::Int, rng::AbstractRNG)\n", - " if a == 0\n", - " return NaN\n", - " else\n", - " return s+a\n", - " end\n", - "end\n", - "\n", - "function reward(p::LightDark1D, s::Float64, a::Int, sp::Float64)\n", - " if a == 0\n", - " if abs(s) < 1\n", - " return p.correct_r\n", - " else\n", - " return p.incorrect_r\n", - " end\n", - " else\n", - " return 0.0\n", - " end \n", + " # generate reward\n", + " r = iszero(a) ? (abs(s) < 1 ? m.correct_r : m.incorrect_r) : 0.0\n", + " \n", + " return (sp=sp, o=o, r=r)\n", "end;" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "actions(::LightDark1D) = [-1, 0, 1] # Left Stop Right\n", "\n", - "function initialstate_distribution(pomdp::LightDark1D)\n", - " return Normal(2.0, 3.0)\n", - "end;" + "initialstate_distribution(pomdp::LightDark1D) = Normal(2.0, 3.0);" ] }, { @@ -119,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -140,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 54, "metadata": { "scrolled": true }, @@ -177,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 55, "metadata": {}, "outputs": [ { @@ -215,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -250,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -259,27 +244,22 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 58, "metadata": {}, "outputs": [ { - "ename": "MethodError", - "evalue": "MethodError: no method matching iterate(::RealInterval)\nClosest candidates are:\n iterate(!Matched::Core.SimpleVector) at essentials.jl:578\n iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:578\n iterate(!Matched::ExponentialBackOff) at error.jl:171\n ...", - "output_type": "error", - "traceback": [ - "MethodError: no method matching iterate(::RealInterval)\nClosest candidates are:\n iterate(!Matched::Core.SimpleVector) at essentials.jl:578\n iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:578\n iterate(!Matched::ExponentialBackOff) at error.jl:171\n ...", - "", - "Stacktrace:", - " [1] first(::RealInterval) at ./abstractarray.jl:288", - " [2] macro expansion at /home/zach/.julia/dev/POMDPs/src/requirements_internals.jl:95 [inlined]", - " [3] macro expansion at /home/zach/.julia/dev/POMDPs/src/requirements_interface.jl:17 [inlined]", - " [4] initialize_belief at /home/zach/.julia/packages/ParticleFilters/fp73A/src/pomdps.jl:56 [inlined]", - " [5] simulate(::StepSimulator, ::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}, ::Normal{Float64}, ::Nothing) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:32", - " [6] simulate at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:29 [inlined]", - " [7] simulate(::StepSimulator, ::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:25", - " [8] #stepthrough#21(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}, ::Vararg{Any,N} where N) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:213", - " [9] stepthrough(::LightDark1D, ::POMCPPlanner{LightDark1D,BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{MersenneTwister,LightDark1D,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,MersenneTwister},MersenneTwister}, ::BasicParticleFilter{LightDark1D,LightDark1D,LowVarianceResampler,MersenneTwister,Array{Float64,1}}, ::String) at /home/zach/.julia/dev/POMDPSimulators/src/stepthrough.jl:202", - " [10] top-level scope at ./In[11]:2" + "name": "stdout", + "output_type": "stream", + "text": [ + "(s, a, r, sp, o) = (7.6069078240874, -1, 0.0, 6.6069078240874, 5)\n", + "(s, a, r, sp, o) = (6.6069078240874, -1, 0.0, 5.6069078240874, 7)\n", + "(s, a, r, sp, o) = (5.6069078240874, -1, 0.0, 4.6069078240874, 5)\n", + "(s, a, r, sp, o) = (4.6069078240874, -1, 0.0, 3.6069078240873997, 5)\n", + "(s, a, r, sp, o) = (3.6069078240873997, -1, 0.0, 2.6069078240873997, 3)\n", + "(s, a, r, sp, o) = (2.6069078240873997, -1, 0.0, 1.6069078240873997, 3)\n", + "(s, a, r, sp, o) = (1.6069078240873997, -1, 0.0, 0.6069078240873997, 0)\n", + "(s, a, r, sp, o) = (0.6069078240873997, -1, 0.0, -0.39309217591260026, 3)\n", + "(s, a, r, sp, o) = (-0.39309217591260026, 0, 10.0, NaN, 0)\n" ] } ], @@ -290,6 +270,13 @@ "end" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -301,15 +288,15 @@ "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Julia 1.0.0", + "display_name": "Julia 1.2.0", "language": "julia", - "name": "julia-1.0" + "name": "julia-1.2" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", - "version": "1.0.0" + "version": "1.2.0" } }, "nbformat": 4,