Skip to content

Commit

Permalink
fix #109 by updating documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jun 12, 2024
1 parent e6a4944 commit 450dd60
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 153 deletions.
138 changes: 40 additions & 98 deletions src/dpw_types.jl
Original file line number Diff line number Diff line change
@@ -1,104 +1,46 @@
"""
MCTS solver with DPW
Fields:
depth::Int64
Maximum rollout horizon and tree depth.
default: 10
exploration_constant::Float64
Specified how much the solver should explore.
In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
default: 1.0
n_iterations::Int64
Number of iterations during each action() call.
default: 100
max_time::Float64
Maximum amount of CPU time spent iterating through simulations.
default: Inf
k_action::Float64
alpha_action::Float64
k_state::Float64
alpha_state::Float64
These constants control the double progressive widening. A new state
or action will be added if the number of children is less than or equal to kN^alpha.
defaults: k:10, alpha:0.5
keep_tree::Bool
If true, store the tree in the planner for reuse at the next timestep (and every time it is used in the future). There is a computational cost for maintaining the state dictionary necessary for this.
default: false
enable_action_pw::Bool
If true, enable progressive widening on the action space; if false just use the whole action space.
default: true
enable_state_pw::Bool
If true, enable progressive widening on the state space; if false just use the single next state (for deterministic problems).
default: true
check_repeat_state::Bool
check_repeat_action::Bool
When constructing the tree, check whether a state or action has been seen before (there is a computational cost to maintaining the dictionaries necessary for this)
default: true
tree_in_info::Bool
If true, return the tree in the info dict when action_info is called. False by default because it can use a lot of memory if histories are being saved.
default: false
rng::AbstractRNG
Random number generator
estimate_value::Any (rollout policy)
Function, object, or number used to estimate the value at the leaf nodes.
If this is a function `f`, `f(mdp, s, depth)` will be called to estimate the value (depth can be ignored).
If this is an object `o`, `estimate_value(o, mdp, s, depth)` will be called.
If this is a number, the value will be set to that number.
default: RolloutEstimator(RandomSolver(rng))
init_Q::Any
Function, object, or number used to set the initial Q(s,a) value at a new node.
If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
If this is an object `o`, `init_Q(o, mdp, s, a)` will be called.
If this is a number, Q will always be set to that number.
default: 0.0
init_N::Any
Function, object, or number used to set the initial N(s,a) value at a new node.
If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
If this is an object `o`, `init_N(o, mdp, s, a)` will be called.
If this is a number, N will always be set to that number.
default: 0
next_action::Any
Function or object used to choose the next action to be considered for progressive widening.
The next action is determined based on the MDP, the state, `s`, and the current `DPWStateNode`, `snode`.
If this is a function `f`, `f(mdp, s, snode)` will be called to set the value.
If this is an object `o`, `next_action(o, mdp, s, snode)` will be called.
default: RandomActionGenerator(rng)
default_action::Any
Function, action, or Policy used to determine the action if POMCP fails with exception `ex`.
If this is a Function `f`, `f(pomdp, belief, ex)` will be called.
If this is a Policy `p`, `action(p, belief)` will be called.
If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and if this method is not implemented, `a` will be returned directly.
default: `ExceptionRethrow()`
reset_callback::Function
Function used to reset/reinitialize the MDP to a given state `s`.
Useful when the simulator state is not truly separate from the MDP state.
`f(mdp, s)` will be called.
default: `(mdp, s)->false` (optimized out)
show_progress::Bool
Show progress bar during simulation.
default: false
timer::Function:
Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`.
Fields
======
- `depth::Int64=10`: Maximum tree search depth. Rollout depth is controlled via the `estimate_value` field.
- `exploration_constant::Float64=1.0`: Specified how much the solver should explore. In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
- `n_iterations::Int64=100`: Number of iterations during each action() call.
- `max_time::Float64`: Maximum amount of CPU time in seconds spent iterating through simulations.
- Double progressive widening parameters. These constants control the double progressive widening. A new state
or action will be added if the number of children is less than or equal to kN^alpha.
- `k_action::Float64=10`
- `alpha_action::Float64=0.5`
- `k_state::Float64=10`
- `alpha_state::Float64=0.5`
- `keep_tree::Bool=false`: If true, store the tree in the planner for reuse at the next timestep (and every time it is used in the future). There is a computational cost for maintaining the state dictionary necessary for this.
- `enable_action_pw::Bool=true`: If true, enable progressive widening on the action space; if false just use the whole action space.
- `enable_state_pw::Bool=true`: If true, enable progressive widening on the state space; if false just use the single next state (for deterministic problems).
- `check_repeat_state::Bool=true`, `check_repeat_action::Bool=true`: When constructing the tree, check whether a state or action has been seen before (there is a computational cost to maintaining the dictionaries necessary for this)
- `tree_in_info::Bool=false`: If true, return the tree in the info dict when action_info is called. False by default because it can use a lot of memory if histories are being saved.
- `rng::AbstractRNG=Random.GLOBAL_RNG`: Random number generator
- `estimate_value::Any=RolloutEstimator(RandomSolver(rng))`: (rollout policy) Function, object, or number used to estimate the value at the leaf nodes.
- If this is a function `f`, `f(mdp, s, depth)` will be called to estimate the value (depth can be ignored).
- If this is an object `o`, `estimate_value(o, mdp, s, depth)` will be called.
- If this is a number, the value will be set to that number.
- `init_Q::Any=0.0`: Function, object, or number used to set the initial Q(s,a) value at a new node.
- If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
- If this is an object `o`, `init_Q(o, mdp, s, a)` will be called.
- If this is a number, Q will always be set to that number.
- `init_N::Any=0`: Function, object, or number used to set the initial N(s,a) value at a new node.
- If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
- If this is an object `o`, `init_N(o, mdp, s, a)` will be called.
- If this is a number, N will always be set to that number.
- `next_action::Any=RandomActionGenerator(rng)`: Function or object used to choose the next action to be considered for progressive widening. The next action is determined based on the MDP, the state, `s`, and the current `DPWStateNode`, `snode`.
- If this is a function `f`, `f(mdp, s, snode)` will be called to set the value.
- If this is an object `o`, `next_action(o, mdp, s, snode)` will be called.
- `default_action::Any=ExceptionRethrow()`: Function, action, or Policy used to determine the action if POMCP fails with exception `ex`.
- If this is a Function `f`, `f(pomdp, belief, ex)` will be called.
- If this is a Policy `p`, `action(p, belief)` will be called.
- If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and if this method is not implemented, `a` will be returned directly.
- `reset_callback::Function=(mdp, s)->false`: Function used to reset/reinitialize the MDP to a given state `s`. Useful when the simulator state is not truly separate from the MDP state. `f(mdp, s)` will be called. By default, this will be optimized out by the compiler.
- `show_progress::Bool=false` Show progress bar during simulation.
- `timer::Function=()->1e-9*time_ns()`: Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`.
"""
mutable struct DPWSolver <: AbstractMCTSSolver
depth::Int
Expand Down
77 changes: 22 additions & 55 deletions src/vanilla.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,28 @@
"""
MCTS solver type
Fields:
n_iterations::Int64
Number of iterations during each action() call.
default: 100
max_time::Float64
Maximum amount of CPU time spent iterating through simulations.
default: Inf
depth::Int64:
Maximum rollout horizon and tree depth.
default: 10
exploration_constant::Float64:
Specifies how much the solver should explore.
In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
default: 1.0
rng::AbstractRNG:
Random number generator
estimate_value::Any (rollout policy)
Function, object, or number used to estimate the value at the leaf nodes.
If this is a function `f`, `f(mdp, s, remaining_depth)` will be called to estimate the value (remaining_depth can be ignored).
If this is an object `o`, `estimate_value(o, mdp, s, remaining_depth)` will be called.
If this is a number, the value will be set to that number
default: RolloutEstimator(RandomSolver(rng); max_depth=50, eps=nothing)
init_Q::Any
Function, object, or number used to set the initial Q(s,a) value at a new node.
If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
If this is an object `o`, `init_Q(o, mdp, s, a)` will be called.
If this is a number, Q will be set to that number
default: 0.0
init_N::Any
Function, object, or number used to set the initial N(s,a) value at a new node.
If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
If this is an object `o`, `init_N(o, mdp, s, a)` will be called.
If this is a number, N will be set to that number
default: 0
reuse_tree::Bool:
If this is true, the tree information is re-used for calculating the next plan.
Of course, clear_tree! can always be called to override this.
default: false
enable_tree_vis::Bool:
If this is true, extra information needed for tree visualization will
be recorded. If it is false, the tree cannot be visualized.
default: false
timer::Function:
Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`.
Fields
======
- `n_iterations::Int64=100` Number of iterations during each action() call.
- `max_time::Float64=Inf` Maximum amount of CPU time spent iterating through simulations in seconds.
- `depth::Int64=10`: Maximum tree search depth. Rollout depth is controlled via the `estimate_value` field.
- `exploration_constant::Float64=1.0`: Specifies how much the solver should explore. In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
- `rng::AbstractRNG=Random.GLOBAL_RNG`: Random number generator
- `estimate_value::Any=RolloutEstimator(RandomSolver(rng))`: Function, object, or number used to estimate the value at the leaf nodes (rollout policy).
- If this is a function `f`, `f(mdp, s, remaining_depth)` will be called to estimate the value (remaining_depth can be ignored).
- If this is an object `o`, `estimate_value(o, mdp, s, remaining_depth)` will be called.
- If this is a number, the value will be set to that number
- `init_Q::Any=0.0`: Function, object, or number used to set the initial Q(s,a) value at a new node.
- If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
- If this is an object `o`, `init_Q(o, mdp, s, a)` will be called.
- If this is a number, Q will be set to that number
- `init_N::Any=0`: Function, object, or number used to set the initial N(s,a) value at a new node.
- If this is a function `f`, `f(mdp, s, a)` will be called to set the value.
- If this is an object `o`, `init_N(o, mdp, s, a)` will be called.
- If this is a number, N will be set to that number
- `reuse_tree::Bool=false`: If this is true, the tree information is re-used for calculating the next plan. Of course, clear_tree! can always be called to override this.
- `enable_tree_vis::Bool=false`: If this is true, extra information needed for tree visualization will be recorded. If it is false, the tree cannot be visualized.
- `timer::Function=()->1e-9*time_ns()`: Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`.
"""
mutable struct MCTSSolver <: AbstractMCTSSolver
n_iterations::Int64
Expand Down

0 comments on commit 450dd60

Please sign in to comment.