diff --git a/docs/src/documentation.md b/docs/src/documentation.md index 43f5449..99ac301 100644 --- a/docs/src/documentation.md +++ b/docs/src/documentation.md @@ -1,12 +1,5 @@ # Documentation -## Mondrian cells - -```@autodocs -Modules = [MondrianForests] -Pages = ["cell.jl"] -``` - ## Mondrian trees ```@autodocs diff --git a/src/MondrianForests.jl b/src/MondrianForests.jl index e368660..23e5c80 100644 --- a/src/MondrianForests.jl +++ b/src/MondrianForests.jl @@ -12,7 +12,7 @@ export get_subtrees export get_leaves export get_common_refinement #export get_cell_id -#export are_in_same_cell +export are_in_same_leaf #export count_cells #export restrict diff --git a/src/data.jl b/src/data.jl index fd7a4e7..287c5e2 100644 --- a/src/data.jl +++ b/src/data.jl @@ -1,10 +1,13 @@ using Distributions """ + generate_data(n::Int, X_dist::Distribution, eps_dist::Distribution, + mu::Function, sigma2::Function) + Generate sample data for Mondrian forest estimation. -Draws `n` independent samples from `Y = mu(X) + sigma(X) eps`, -with `X ~ X_dist` and `eps ~ eps_dist` +Draws `n` independent samples from \$Y = \\mu(X) + \\sigma(X) \\varepsilon\$, +with \$X \\sim\$ `X_dist` and \$\\varepsilon \\sim\$ `eps_dist`. """ function generate_data(n::Int, X_dist::Distribution, eps_dist::Distribution, mu::Function, sigma2::Function) @@ -17,10 +20,13 @@ function generate_data(n::Int, X_dist::Distribution, eps_dist::Distribution, end """ + generate_uniform_data_uniform_errors(d::Int, n::Int) + Generate uniform sample data with uniform errors for Mondrian forest estimation. -Draws `n` independent samples from `Y = eps`, -with `X ~ U[0, 1]` and `eps ~ U[-sqrt(3), sqrt(3)]` +Draws `n` independent samples from \$Y = \\varepsilon\$, +with \$X \\sim \\mathcal{U}[0, 1]\$ and +\$\\varepsilon \\sim \\mathcal{U}\\big[-\\sqrt 3, \\sqrt 3\\big]\$. """ function generate_uniform_data_uniform_errors(d::Int, n::Int) X_dist = product_distribution([Uniform(0, 1) for _ in 1:d]) @@ -31,10 +37,12 @@ function generate_uniform_data_uniform_errors(d::Int, n::Int) end """ + generate_uniform_data_normal_errors(d::Int, n::Int) + Generate uniform sample data with normal errors for Mondrian forest estimation. -Draws `n` independent samples from `Y = eps`, -with `X ~ U[0, 1]` and `eps ~ N(0, 1)` +Draws `n` independent samples from \$Y = \\varepsilon\$, +with \$X \\sim \\mathcal{U}[0, 1]\$ and \$\\varepsilon \\sim \\mathcal{N}(0, 1)\$. """ function generate_uniform_data_normal_errors(d::Int, n::Int) X_dist = product_distribution([Uniform(0, 1) for _ in 1:d]) diff --git a/src/tree.jl b/src/tree.jl index d38e6c3..f481874 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -1,22 +1,21 @@ using Random using Distributions -# TODO rewrite functions based on subtrees # TODO rewrite replication files to use new functions -# TODO docs """ A Mondrian tree is determined by: - `id`: a string to identify the tree - `lambda`: the non-negative lifetime parameter -- `lower`: the lower coordinate of the cell -- `upper`: the upper coordinate of the cell -- `creation_time`: the time when the cell was created during sampling -- `is_split`: whether the cell is split -- `split_axis`: the direction in which the cell is split, if any -- `split_location`: the location on `split_axis` at which the cell is split, if any -- `tree_left`: the left child tree of the cell, if any -- `tree_right`: the right child tree of the cell, if any +- `lower`: the lower coordinate of the root cell +- `upper`: the upper coordinate of the root cell +- `creation_time`: the time when the root cell was created during sampling +- `is_split`: whether the root cell is split +- `split_axis`: the direction in which the root cell is split, if any +- `split_location`: the location on `split_axis` + at which the root cell is split, if any +- `tree_left`: the left child tree of the root cell, if any +- `tree_right`: the right child tree of the root cell, if any """ struct MondrianTree{d} id::String @@ -32,7 +31,22 @@ struct MondrianTree{d} end """ -Sample a Mondrian tree with a given lifetime, lower and upper cell coordinates, and creation time. + MondrianTree(id::String, lambda::Float64, lower::NTuple{d,Float64}, + upper::NTuple{d,Float64}, creation_time::Float64) where {d} + +Sample a Mondrian tree with a given id, lifetime, +lower and upper cell coordinates, and creation time. +To be used in internal construction methods. + +# Examples +```julia +id = "" +lambda = 3.0 +lower = ntuple(i -> 0.2, d) +upper = ntuple(i -> 0.7, d) +creation_time = 0.0 +tree = MondrianTree(id, lambda, lower, upper, creation_time) +``` """ function MondrianTree(id::String, lambda::Float64, lower::NTuple{d,Float64}, upper::NTuple{d,Float64}, creation_time::Float64) where {d} @@ -59,7 +73,18 @@ function MondrianTree(id::String, lambda::Float64, lower::NTuple{d,Float64}, return tree end -"""Sample a Mondrian tree `M([0,1]^d, lambda)`.""" +""" + MondrianTree(d::Int, lambda::Float64) + +Sample a Mondrian tree \$\\mathcal{M}([0,1]^d, \\lambda)\$. + +# Examples +```julia +d = 2 +lambda = 3.0 +tree = MondrianTree(d, lambda); +``` +""" function MondrianTree(d::Int, lambda::Float64) if lambda < 0 throw(DomainError(lambda, "lambda must be non-negative")) @@ -73,7 +98,7 @@ end """ is_in(x::NTuple{d,Float64}, tree::MondrianTree{d}) where {d} -Check if a point `x` is contained in a Mondrian tree. +Check if a point `x` is contained in the root cell of a Mondrian tree. # Examples ```jldoctest @@ -92,7 +117,7 @@ end """ get_center(tree::MondrianTree{d}) where {d} -Get the center point of a Mondrian tree. +Get the center point of the root cell of a Mondrian tree. # Examples ```jldoctest @@ -110,7 +135,7 @@ end """ get_volume(tree::MondrianTree{d}) where {d} -Get the d-dimensional volume of a Mondrian tree. +Get the d-dimensional volume of the root cell of a Mondrian tree. # Examples ```jldoctest @@ -125,34 +150,61 @@ function get_volume(tree::MondrianTree{d}) where {d} return prod(tree.upper .- tree.lower) end -# TODO doc -function apply_split(tree::MondrianTree{d}, split_lower::NTuple{d,Float64}, - split_upper::NTuple{d,Float64}, split_time::Float64, - split_axis::Int, split_location::Float64) where {d} +""" + apply_split(tree::MondrianTree{d}, split_tree::MondrianTree{d}) where {d} + +Take the split which occurred in `split_tree` and apply it to `tree`. +Returns a new `MondrianTree`. + +# Examples +```julia +tree = MondrianTree(2, 3.0) +split_tree = MondrianTree(2, 3.0) +new_tree = apply_split(tree, split_tree) +``` +""" +function apply_split(tree::MondrianTree{d}, split_tree::MondrianTree{d}) where {d} if tree.is_split - tree_left = apply_split(tree.tree_left, split_lower, split_upper, - split_time, split_axis, split_location) - tree_right = apply_split(tree.tree_right, split_lower, split_upper, - split_time, split_axis, split_location) + tree_left = apply_split(tree.tree_left, split_tree) + tree_right = apply_split(tree.tree_right, split_tree) return MondrianTree(tree.id, tree.lambda, tree.lower, tree.upper, tree.creation_time, - true, tree.split_axis, tree.split_location, tree_left, tree_right) + true, split_tree.split_axis, split_tree.split_location, + tree_left, tree_right) else - if all(tree.lower .< split_upper) && all(split_lower .< tree.upper) - left_upper = min.(split_upper, tree.upper) - right_lower = max.(split_lower, tree.lower) + if all(tree.lower .< split_tree.tree_left.upper) && + all(split_tree.tree_right.lower .< tree.upper) + left_upper = min.(split_tree.tree_left.upper, tree.upper) + right_lower = max.(split_tree.tree_right.lower, tree.lower) tree_left = MondrianTree(tree.id * "L", tree.lambda, tree.lower, left_upper, - split_time, false, nothing, nothing, nothing, nothing) + split_tree.tree_left.creation_time, false, + nothing, nothing, nothing, nothing) tree_right = MondrianTree(tree.id * "R", tree.lambda, right_lower, tree.upper, - split_time, false, nothing, nothing, nothing, nothing) + split_tree.tree_left.creation_time, false, + nothing, nothing, nothing, nothing) return MondrianTree(tree.id, tree.lambda, tree.lower, tree.upper, tree.creation_time, - true, split_axis, split_location, tree_left, tree_right) + true, split_tree.split_axis, split_tree.split_location, + tree_left, tree_right) else return tree end end end -# TODO doc +""" + get_common_refinement(tree1::MondrianTree{d}, tree2::MondrianTree{d}) where {d} + +Get the common refinement of two Mondrian trees, +whose leaf cells are the intersections of all leaf cells in `tree1` +and `tree2`. Preserves the split times and +returns a new equivalent `MondrianTree`. + +# Examples +```julia +tree1 = MondrianTree(2, 3.0) +tree2 = MondrianTree(2, 3.0) +refined_tree = get_common_refinement(tree1, tree2) +``` +""" function get_common_refinement(tree1::MondrianTree{d}, tree2::MondrianTree{d}) where {d} @assert tree1.id == tree2.id @assert tree1.lambda == tree2.lambda @@ -168,21 +220,26 @@ function get_common_refinement(tree1::MondrianTree{d}, tree2::MondrianTree{d}) w false, nothing, nothing, nothing, nothing) for subtree in subtrees - split_location = subtree.split_location - split_axis = subtree.split_axis - lower = subtree.lower - upper = subtree.upper - split_lower = ntuple(j -> (j == split_axis ? split_location : lower[j]), d) - split_upper = ntuple(j -> (j == split_axis ? split_location : upper[j]), d) - split_time = subtree.tree_left.creation_time - tree = apply_split(tree, split_lower, split_upper, split_time, - split_axis, split_location) + tree = apply_split(tree, subtree) end return tree end -# TODO doc +""" + get_common_refinement(trees::Vector{MondrianTree{d}}) where {d} + +Get the common refinement of several Mondrian trees, +whose leaf cells are the intersections of all leaf cells in `trees`. +Preserves the split times and +returns a new equivalent `MondrianTree`. + +# Examples +```julia +trees = [MondrianTree(2, 3.0) for _ in 1:3] +refined_tree = get_common_refinement(trees) +``` +""" function get_common_refinement(trees::Vector{MondrianTree{d}}) where {d} @assert !isempty(trees) if length(trees) == 1 @@ -197,9 +254,23 @@ function get_common_refinement(trees::Vector{MondrianTree{d}}) where {d} end """ -Check if two points are in the same leaf of a Mondrian tree. + are_in_same_leaf(x1::NTuple{d,Float64}, x2::NTuple{d,Float64}, tree::MondrianTree) + +Check if two points are in the same leaf cell of a Mondrian tree. + +# Examples +```jldoctest +d = 2 +tree = MondrianTree(d, 0.0) +x1 = ntuple(i -> 0.2, d) +x2 = ntuple(i -> 0.7, d) +are_in_same_leaf(x1, x2, tree) + +# output +true +``` """ -function are_in_same_leaf(x1::Vector{Float64}, x2::Vector{Float64}, tree::MondrianTree) +function are_in_same_leaf(x1::NTuple{d,Float64}, x2::NTuple{d,Float64}, tree::MondrianTree) where {d} if tree.is_split if is_in(x1, tree.tree_left) && is_in(x2, tree.tree_right) return are_in_same_leaf(x1, x2, tree.left) @@ -213,7 +284,17 @@ function are_in_same_leaf(x1::Vector{Float64}, x2::Vector{Float64}, tree::Mondri end end -"""Get a list of the subtrees contained in a Mondrian tree.""" +""" + get_subtrees(tree::MondrianTree{d}) where {d} + +Get a list of the subtrees contained in a Mondrian tree. + +# Examples +```julia +tree = MondrianTree(2, 3.0) +get_subtrees(tree) +``` +""" function get_subtrees(tree::MondrianTree{d}) where {d} if tree.is_split subtrees_left = get_subtrees(tree.tree_left) @@ -224,22 +305,64 @@ function get_subtrees(tree::MondrianTree{d}) where {d} end end -"""Get a list of the leaves in a Mondrian tree.""" +""" + get_leaves(tree::MondrianTree{d}) where {d} + +Get a list of the leaves in a Mondrian tree. + +# Examples +```julia +tree = MondrianTree(2, 3.0) +get_leaves(tree) +``` +""" function get_leaves(tree::MondrianTree{d}) where {d} return [t for t in get_subtrees(tree) if !t.is_split] end -"""Get the leaf of a Mondrian tree containing a point `x`.""" +""" + get_leaf_containing(x::NTuple{d,Float64}, tree::MondrianTree{d}) where {d} + +Get the leaf of a Mondrian tree containing a point `x`. + +# Examples +```julia +d = 2 +x = ntuple(i -> 0.2, d) +tree = MondrianTree(d, 3.0) +get_leaf_containing(x, tree) +``` +""" function get_leaf_containing(x::NTuple{d,Float64}, tree::MondrianTree{d}) where {d} return [t for t in get_leaves(tree) if is_in(x, t)][] end -"""Count the leaves of a Mondrian tree.""" +""" + count_leaves(tree::MondrianTree{d}) where {d} + +Count the leaves of a Mondrian tree. + +# Examples +```julia +tree = MondrianTree(2, 3.0) +count_leaves(tree) +``` +""" function count_leaves(tree::MondrianTree{d}) where {d} return length(get_leaves(tree)) end -"""Restrict a Mondrian tree to a stopping time.""" +""" + restrict(tree::MondrianTree{d}, time::Float64) where {d} + +Restrict a Mondrian tree to a stopping time. + +# Examples +```julia +tree = MondrianTree(2, 3.0) +restrict(tree, 2.0) +``` +""" function restrict(tree::MondrianTree{d}, time::Float64) where {d} if tree.is_split && tree.tree_left.creation_time <= time tree_left = restrict(tree.tree_left, time) @@ -252,7 +375,11 @@ function restrict(tree::MondrianTree{d}, time::Float64) where {d} end end -"""Show a Mondrian tree.""" +""" + Base.show(tree::MondrianTree{d}) where {d} + +Show the recursive structure of a Mondrian tree. +""" function Base.show(tree::MondrianTree{d}) where {d} depth = length(tree.id) if depth >= 1