From 450dd60560dc1d3eb93317150460a775e3e6cb41 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 12 Jun 2024 12:02:40 -0600 Subject: [PATCH] fix #109 by updating documentation --- src/dpw_types.jl | 138 ++++++++++++++--------------------------------- src/vanilla.jl | 77 ++++++++------------------ 2 files changed, 62 insertions(+), 153 deletions(-) diff --git a/src/dpw_types.jl b/src/dpw_types.jl index b07c1fc..a6ebe8e 100644 --- a/src/dpw_types.jl +++ b/src/dpw_types.jl @@ -1,104 +1,46 @@ """ MCTS solver with DPW -Fields: - - depth::Int64 - Maximum rollout horizon and tree depth. - default: 10 - - exploration_constant::Float64 - Specified how much the solver should explore. - In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant. - default: 1.0 - - n_iterations::Int64 - Number of iterations during each action() call. - default: 100 - - max_time::Float64 - Maximum amount of CPU time spent iterating through simulations. - default: Inf - - k_action::Float64 - alpha_action::Float64 - k_state::Float64 - alpha_state::Float64 - These constants control the double progressive widening. A new state - or action will be added if the number of children is less than or equal to kN^alpha. - defaults: k:10, alpha:0.5 - - keep_tree::Bool - If true, store the tree in the planner for reuse at the next timestep (and every time it is used in the future). There is a computational cost for maintaining the state dictionary necessary for this. - default: false - - enable_action_pw::Bool - If true, enable progressive widening on the action space; if false just use the whole action space. - default: true - - enable_state_pw::Bool - If true, enable progressive widening on the state space; if false just use the single next state (for deterministic problems). - default: true - - check_repeat_state::Bool - check_repeat_action::Bool - When constructing the tree, check whether a state or action has been seen before (there is a computational cost to maintaining the dictionaries necessary for this) - default: true - - tree_in_info::Bool - If true, return the tree in the info dict when action_info is called. False by default because it can use a lot of memory if histories are being saved. - default: false - - rng::AbstractRNG - Random number generator - - estimate_value::Any (rollout policy) - Function, object, or number used to estimate the value at the leaf nodes. - If this is a function `f`, `f(mdp, s, depth)` will be called to estimate the value (depth can be ignored). - If this is an object `o`, `estimate_value(o, mdp, s, depth)` will be called. - If this is a number, the value will be set to that number. - default: RolloutEstimator(RandomSolver(rng)) - - init_Q::Any - Function, object, or number used to set the initial Q(s,a) value at a new node. - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. - If this is an object `o`, `init_Q(o, mdp, s, a)` will be called. - If this is a number, Q will always be set to that number. - default: 0.0 - - init_N::Any - Function, object, or number used to set the initial N(s,a) value at a new node. - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. - If this is an object `o`, `init_N(o, mdp, s, a)` will be called. - If this is a number, N will always be set to that number. - default: 0 - - next_action::Any - Function or object used to choose the next action to be considered for progressive widening. - The next action is determined based on the MDP, the state, `s`, and the current `DPWStateNode`, `snode`. - If this is a function `f`, `f(mdp, s, snode)` will be called to set the value. - If this is an object `o`, `next_action(o, mdp, s, snode)` will be called. - default: RandomActionGenerator(rng) - - default_action::Any - Function, action, or Policy used to determine the action if POMCP fails with exception `ex`. - If this is a Function `f`, `f(pomdp, belief, ex)` will be called. - If this is a Policy `p`, `action(p, belief)` will be called. - If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and if this method is not implemented, `a` will be returned directly. - default: `ExceptionRethrow()` - - reset_callback::Function - Function used to reset/reinitialize the MDP to a given state `s`. - Useful when the simulator state is not truly separate from the MDP state. - `f(mdp, s)` will be called. - default: `(mdp, s)->false` (optimized out) - - show_progress::Bool - Show progress bar during simulation. - default: false - - timer::Function: - Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`. +Fields +====== +- `depth::Int64=10`: Maximum tree search depth. Rollout depth is controlled via the `estimate_value` field. +- `exploration_constant::Float64=1.0`: Specified how much the solver should explore. In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant. +- `n_iterations::Int64=100`: Number of iterations during each action() call. +- `max_time::Float64`: Maximum amount of CPU time in seconds spent iterating through simulations. +- Double progressive widening parameters. These constants control the double progressive widening. A new state + or action will be added if the number of children is less than or equal to kN^alpha. + - `k_action::Float64=10` + - `alpha_action::Float64=0.5` + - `k_state::Float64=10` + - `alpha_state::Float64=0.5` +- `keep_tree::Bool=false`: If true, store the tree in the planner for reuse at the next timestep (and every time it is used in the future). There is a computational cost for maintaining the state dictionary necessary for this. +- `enable_action_pw::Bool=true`: If true, enable progressive widening on the action space; if false just use the whole action space. +- `enable_state_pw::Bool=true`: If true, enable progressive widening on the state space; if false just use the single next state (for deterministic problems). +- `check_repeat_state::Bool=true`, `check_repeat_action::Bool=true`: When constructing the tree, check whether a state or action has been seen before (there is a computational cost to maintaining the dictionaries necessary for this) +- `tree_in_info::Bool=false`: If true, return the tree in the info dict when action_info is called. False by default because it can use a lot of memory if histories are being saved. +- `rng::AbstractRNG=Random.GLOBAL_RNG`: Random number generator +- `estimate_value::Any=RolloutEstimator(RandomSolver(rng))`: (rollout policy) Function, object, or number used to estimate the value at the leaf nodes. + - If this is a function `f`, `f(mdp, s, depth)` will be called to estimate the value (depth can be ignored). + - If this is an object `o`, `estimate_value(o, mdp, s, depth)` will be called. + - If this is a number, the value will be set to that number. +- `init_Q::Any=0.0`: Function, object, or number used to set the initial Q(s,a) value at a new node. + - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + - If this is an object `o`, `init_Q(o, mdp, s, a)` will be called. + - If this is a number, Q will always be set to that number. +- `init_N::Any=0`: Function, object, or number used to set the initial N(s,a) value at a new node. + - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + - If this is an object `o`, `init_N(o, mdp, s, a)` will be called. + - If this is a number, N will always be set to that number. +- `next_action::Any=RandomActionGenerator(rng)`: Function or object used to choose the next action to be considered for progressive widening. The next action is determined based on the MDP, the state, `s`, and the current `DPWStateNode`, `snode`. + - If this is a function `f`, `f(mdp, s, snode)` will be called to set the value. + - If this is an object `o`, `next_action(o, mdp, s, snode)` will be called. +- `default_action::Any=ExceptionRethrow()`: Function, action, or Policy used to determine the action if POMCP fails with exception `ex`. + - If this is a Function `f`, `f(pomdp, belief, ex)` will be called. + - If this is a Policy `p`, `action(p, belief)` will be called. + - If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and if this method is not implemented, `a` will be returned directly. +- `reset_callback::Function=(mdp, s)->false`: Function used to reset/reinitialize the MDP to a given state `s`. Useful when the simulator state is not truly separate from the MDP state. `f(mdp, s)` will be called. By default, this will be optimized out by the compiler. +- `show_progress::Bool=false` Show progress bar during simulation. +- `timer::Function=()->1e-9*time_ns()`: Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`. """ mutable struct DPWSolver <: AbstractMCTSSolver depth::Int diff --git a/src/vanilla.jl b/src/vanilla.jl index 37296f3..0b07f6a 100644 --- a/src/vanilla.jl +++ b/src/vanilla.jl @@ -1,61 +1,28 @@ """ MCTS solver type -Fields: - - n_iterations::Int64 - Number of iterations during each action() call. - default: 100 - - max_time::Float64 - Maximum amount of CPU time spent iterating through simulations. - default: Inf - - depth::Int64: - Maximum rollout horizon and tree depth. - default: 10 - - exploration_constant::Float64: - Specifies how much the solver should explore. - In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant. - default: 1.0 - - rng::AbstractRNG: - Random number generator - - estimate_value::Any (rollout policy) - Function, object, or number used to estimate the value at the leaf nodes. - If this is a function `f`, `f(mdp, s, remaining_depth)` will be called to estimate the value (remaining_depth can be ignored). - If this is an object `o`, `estimate_value(o, mdp, s, remaining_depth)` will be called. - If this is a number, the value will be set to that number - default: RolloutEstimator(RandomSolver(rng); max_depth=50, eps=nothing) - - init_Q::Any - Function, object, or number used to set the initial Q(s,a) value at a new node. - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. - If this is an object `o`, `init_Q(o, mdp, s, a)` will be called. - If this is a number, Q will be set to that number - default: 0.0 - - init_N::Any - Function, object, or number used to set the initial N(s,a) value at a new node. - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. - If this is an object `o`, `init_N(o, mdp, s, a)` will be called. - If this is a number, N will be set to that number - default: 0 - - reuse_tree::Bool: - If this is true, the tree information is re-used for calculating the next plan. - Of course, clear_tree! can always be called to override this. - default: false - - enable_tree_vis::Bool: - If this is true, extra information needed for tree visualization will - be recorded. If it is false, the tree cannot be visualized. - default: false - - timer::Function: - Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`. +Fields +====== +- `n_iterations::Int64=100` Number of iterations during each action() call. +- `max_time::Float64=Inf` Maximum amount of CPU time spent iterating through simulations in seconds. +- `depth::Int64=10`: Maximum tree search depth. Rollout depth is controlled via the `estimate_value` field. +- `exploration_constant::Float64=1.0`: Specifies how much the solver should explore. In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant. +- `rng::AbstractRNG=Random.GLOBAL_RNG`: Random number generator +- `estimate_value::Any=RolloutEstimator(RandomSolver(rng))`: Function, object, or number used to estimate the value at the leaf nodes (rollout policy). + - If this is a function `f`, `f(mdp, s, remaining_depth)` will be called to estimate the value (remaining_depth can be ignored). + - If this is an object `o`, `estimate_value(o, mdp, s, remaining_depth)` will be called. + - If this is a number, the value will be set to that number +- `init_Q::Any=0.0`: Function, object, or number used to set the initial Q(s,a) value at a new node. + - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + - If this is an object `o`, `init_Q(o, mdp, s, a)` will be called. + - If this is a number, Q will be set to that number +- `init_N::Any=0`: Function, object, or number used to set the initial N(s,a) value at a new node. + - If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + - If this is an object `o`, `init_N(o, mdp, s, a)` will be called. + - If this is a number, N will be set to that number +- `reuse_tree::Bool=false`: If this is true, the tree information is re-used for calculating the next plan. Of course, clear_tree! can always be called to override this. +- `enable_tree_vis::Bool=false`: If this is true, extra information needed for tree visualization will be recorded. If it is false, the tree cannot be visualized. +- `timer::Function=()->1e-9*time_ns()`: Timekeeping method. Search iterations ended when `timer() - start_time ≥ max_time`. """ mutable struct MCTSSolver <: AbstractMCTSSolver n_iterations::Int64