Skip to content

Commit

Permalink
updating replication diagrams
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Oct 31, 2023
1 parent f25d790 commit e507a68
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 111 deletions.
1 change: 0 additions & 1 deletion docs/src/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ Private = false
```@autodocs
Modules = [MondrianForests]
Pages = ["data.jl"]
Private = false
```
147 changes: 52 additions & 95 deletions replication/construction_diagrams/construction_diagrams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,56 +8,32 @@ rcParams["text.usetex"] = true
rcParams["text.latex.preamble"] = "\\usepackage[sfdefault,light]{FiraSans}"
plt.ioff()

function get_splits(tree::MondrianTree{d}) where {d}
if !isnothing(tree.split_axis)
lower = tree.tree_right.cell.lower
upper = tree.tree_left.cell.upper
return [(lower, upper); get_splits(tree.tree_left); get_splits(tree.tree_right)]
else
return Tuple{NTuple{d,Float64},NTuple{d,Float64}}[]
end
end

function get_cells(tree::MondrianTree{d}) where {d}
if !isnothing(tree.split_axis)
return [get_cells(tree.tree_left); get_cells(tree.tree_right)]
else
return [tree.cell]
end
end

function get_ids(tree::MondrianTree{d}) where {d}
if !isnothing(tree.split_axis)
return [get_ids(tree.tree_left); get_ids(tree.tree_right)]
else
return [tree.cell.id]
end
end

function plot_mondrian_process(partition)
tree = partition["tree"]
splits = get_splits(tree)
split_trees = [t for t in get_subtrees(tree) if t.is_split]
splits = [(t.tree_right.lower, t.tree_left.upper) for t in split_trees]
(fig, ax) = plt.subplots(figsize=(2.2, 2.2))

# highlight current cell
if !isnothing(partition["current"])
cell = partition["current"]
cell = [t for t in get_subtrees(tree) if t.id == partition["current"]][]
x1s = [cell.lower[1], cell.lower[1], cell.upper[1], cell.upper[1]]
x2s = [cell.lower[2], cell.upper[2], cell.upper[2], cell.lower[2]]
fill(x1s, x2s, facecolor="#ecd9ff")
end

# highlight leaves
for cell in partition["terminals"]
for id in partition["terminals"]
cell = [t for t in get_subtrees(tree) if t.id == id][]
x1s = [cell.lower[1], cell.lower[1], cell.upper[1], cell.upper[1]]
x2s = [cell.lower[2], cell.upper[2], cell.upper[2], cell.lower[2]]
fill(x1s, x2s, facecolor="#b5fdc7")
end

# plot root cell
lw = 0.9
(l1, l2) = tree.cell.lower
(u1, u2) = tree.cell.upper
(l1, l2) = tree.lower
(u1, u2) = tree.upper
plt.plot([l1, l1], [l2, u2], color="k", lw=lw)
plt.plot([u1, u1], [l2, u2], color="k", lw=lw)
plt.plot([l1, u1], [l2, l2], color="k", lw=lw)
Expand All @@ -71,8 +47,8 @@ function plot_mondrian_process(partition)
end

# annotate cells
ids = get_ids(tree)
cells = get_cells(tree)
cells = get_leaves(tree)
ids = [t.id for t in cells]
centers = MondrianForests.get_center.(cells)
for i in 1:length(cells)
if ids[i] == ""
Expand All @@ -88,8 +64,8 @@ function plot_mondrian_process(partition)
subtrees = MondrianForests.get_subtrees(tree)
current = partition["current"]
if !isnothing(current)
if !(current in cells)
subtree = [c for c in subtrees if c.cell == current][]
if !(current in [t.id for t in cells])
subtree = [c for c in subtrees if c.id == current][]
J = subtree.split_axis
S = subtree.split_location
J == 1 ? x = S - 0.005 : x = -0.06
Expand Down Expand Up @@ -129,15 +105,6 @@ function plot_mondrian_process(partition)
return (fig, ax)
end

function get_tree_info(tree::MondrianTree)
info = (tree.cell.id, tree.creation_time, tree.cell, tree.cell.id, tree.split_axis)
if !isnothing(tree.split_axis)
return [info; get_tree_info(tree.tree_left); get_tree_info(tree.tree_right)]
else
return [info]
end
end

function get_horizontal_value(id::String)
value = 0.0
for i in 1:length(id)
Expand All @@ -152,19 +119,18 @@ function get_horizontal_value(id::String)
end

function plot_mondrian_tree(partition)
tree = partition["tree"]
(fig, ax) = plt.subplots(figsize=(2.2, 2.2))
info = get_tree_info(tree)
ids = [i[1] for i in info]
times = [i[2] for i in info]
cells = [i[3] for i in info]
tree = partition["tree"]
cells = get_subtrees(tree)
ids = [t.id for t in cells]
times = [t.creation_time for t in cells]
n = length(times)

# plot split points
for i in 1:n
if cells[i] == partition["current"]
if ids[i] == partition["current"]
color = "#ecd9ff"
elseif cells[i] in partition["terminals"]
elseif ids[i] in partition["terminals"]
color = "#b5fdc7"
else
color = "white"
Expand Down Expand Up @@ -198,15 +164,14 @@ function plot_mondrian_tree(partition)
end

# add split time
cells = get_cells(tree)
subtrees = MondrianForests.get_subtrees(tree)
leaves = [t.id for t in get_leaves(partition["tree"])]
current = partition["current"]
if !isnothing(current)
if !(current in cells)
subtree = [c for c in subtrees if c.cell == current][]
if !(current in leaves)
subtree = [c for c in cells if c.id == current][]
t = subtree.tree_left.creation_time
x_left = x_locs[subtree.tree_left.cell.id]
x_right = x_locs[subtree.tree_right.cell.id]
x_left = x_locs[subtree.tree_left.id]
x_right = x_locs[subtree.tree_right.id]
if x_left > 1
plt.text(x_left - 0.8, t, "\$t + E\$", fontsize=10,
ha="center", va="center")
Expand All @@ -222,7 +187,7 @@ function plot_mondrian_tree(partition)
end

# time label
plt.text(2.84, 2.47, "\$t\$", fontsize=10)
#plt.text(2.84, 2.47, "\$t\$", fontsize=10)

# format
ax.invert_yaxis()
Expand All @@ -242,50 +207,46 @@ function plot_mondrian_tree(partition)
end

function update_partitions(partitions, tree)
subtrees = get_subtrees(tree)
times = sort(unique([t.creation_time for t in subtrees]))
p = partitions[end]
info = get_tree_info(p["tree"])
cells = get_cells(tree)
leaves = get_cells(p["tree"])
cells = [t.id for t in get_leaves(tree)]
leaves = [t.id for t in get_leaves(p["tree"])]
current_split = !isnothing(p["current"]) && !(p["current"] in cells)
current_parent = !(p["current"] in leaves)
current_parent = !isnothing(p["current"]) && !(p["current"] in leaves)

# update time and tree
if current_split && !current_parent
new_time = minimum(t for t in times if t > p["time"])
new_tree = MondrianForests.restrict(tree, new_time)
new_tree = restrict(tree, new_time)
else
new_time = p["time"]
new_tree = p["tree"]
end

# update terminals
if current_split || isnothing(p["current"])
new_terminals = p["terminals"]
else
new_terminals = [p["terminals"]; [p["current"]]]
end

# update current
if isnothing(p["current"]) || current_parent || !current_split
ids = [i[4]
for i in info
if i[3] in leaves &&
!(i[3] in p["terminals"]) && !(i[3] == p["current"])]
ids = [i for i in leaves if !(i in new_terminals)]
ids = [i for i in ids if !(i in p["terminals"])]
ids = [i for i in ids if length(i) == minimum(length(j) for j in ids)]
if !isempty(ids)
new_current_id = minimum(ids)
new_current = [i[3] for i in info if i[4] == new_current_id][]
new_current = minimum(ids)
else
new_current = nothing
end
else
new_current = p["current"]
end

# update terminals
if current_split || isnothing(p["current"])
new_terminals = p["terminals"]
else
new_terminals = [p["terminals"]; [p["current"]]]
end

new_partition = Dict("time" => new_time,
"tree" => new_tree,
"current" => new_current,
"terminals" => new_terminals)
new_partition = Dict("time" => new_time, "tree" => new_tree,
"current" => new_current, "terminals" => new_terminals)

return [partitions; [new_partition]]
end
Expand All @@ -295,17 +256,17 @@ d = 2
lambda = 2.0
Random.seed!(0)
min_vol = 0.0
n_cells = 1
while min_vol < 0.2 || n_cells != 4
n_leaves = 1
while min_vol < 0.2 || n_leaves != 4
global tree = MondrianTree(d, lambda)
global cells = get_cells(tree)
global min_vol = minimum(MondrianForests.get_volume(c) for c in cells)
global n_cells = length(cells)
global leaves = get_leaves(tree)
global min_vol = minimum(MondrianForests.get_volume(l) for l in leaves)
global n_leaves = length(leaves)
end

# get locations of tree nodes for diagram
info = get_tree_info(tree)
ids = [i[1] for i in info]
subtrees = get_subtrees(tree)
ids = unique([t.id for t in subtrees])
xs = get_horizontal_value.(ids)
xs = invperm(sortperm(xs)) / 3
x_locs = Dict()
Expand All @@ -314,13 +275,9 @@ for i in 1:length(ids)
end

# calculate the current and terminal nodes
partitions = [Dict("time" => 0.0,
"tree" => MondrianForests.restrict(tree, 0.0),
"current" => nothing,
"terminals" => [])]
partitions = [Dict("time" => 0.0, "tree" => restrict(tree, 0.0),
"current" => nothing, "terminals" => [])]

info = get_tree_info(tree)
times = [i[2] for i in info]
for rep in 1:11
global partitions = update_partitions(partitions, tree)
end
Expand All @@ -329,7 +286,6 @@ end
println("plotting trees")
dpi = 500
for i in 1:length(partitions)
println(i)
partition = partitions[i]
global (fig, ax) = plot_mondrian_tree(partition)
plt.savefig("replication/construction_diagrams/construction_mondrian_tree_$(i).png", dpi=dpi)
Expand All @@ -342,6 +298,7 @@ for i in 1:length(partitions)
println(i)
partition = partitions[i]
global (fig, ax) = plot_mondrian_process(partition)
plt.savefig("replication/construction_diagrams/construction_mondrian_partition_$(i).png", dpi=dpi)
plt.savefig("replication/construction_diagrams/construction_mondrian_partition_$(i).png",
dpi=dpi)
plt.close("all")
end
18 changes: 9 additions & 9 deletions replication/logo/logo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ lw = 12
x_eps = 0.037
y_eps = 0.020
top = 0.59
xs_outer = [0.5, 1+x_eps, 1+x_eps, -x_eps, -x_eps, 0.5]
ys_outer = [-y_eps, -y_eps, top+y_eps, top+y_eps, -y_eps, -y_eps]
xs_outer = [0.5, 1 + x_eps, 1 + x_eps, -x_eps, -x_eps, 0.5]
ys_outer = [-y_eps, -y_eps, top + y_eps, top + y_eps, -y_eps, -y_eps]
outer_col = col_gray
plt.plot(xs_outer, ys_outer, c=outer_col, lw=2*lw, zorder=0)
plt.plot(xs_outer, ys_outer, c=outer_col, lw=2 * lw, zorder=0)
plt.fill(xs_outer, ys_outer, c=outer_col, lw=0)

# curve
n = 100
xs = range(0, 1, length=n)
ys = (1.2 .* xs .- 0.63).^3 .+ 0.60 - 0.2 * xs.^2 -
0.05 * (0.5 .* xs .+ 0.5).^10 +
0.07 * (1 .- xs).^10
ys = (1.2 .* xs .- 0.63) .^ 3 .+ 0.60 - 0.2 * xs .^ 2 -
0.05 * (0.5 .* xs .+ 0.5) .^ 10 +
0.07 * (1 .- xs) .^ 10
all_xs = [[0.5, 0]; xs; [1, 0.5]]
all_ys = [[0, 0]; ys; [0, 0]]
plt.plot(all_xs, all_ys, c="#111111", lw=lw)
Expand All @@ -45,9 +45,9 @@ x2 = 0.80
y1 = 0.11
y2 = 0.25
plt.plot([x1, x1], [eps, 0.55], c="#111111", lw=lw)
plt.plot([eps, 1-eps], [y2, y2], c="#111111", lw=lw)
plt.plot([x2, x2], [eps, y2-eps], c="#111111", lw=lw)
plt.plot([x2+eps, 1-eps], [y1, y1], c="#111111", lw=lw)
plt.plot([eps, 1 - eps], [y2, y2], c="#111111", lw=lw)
plt.plot([x2, x2], [eps, y2 - eps], c="#111111", lw=lw)
plt.plot([x2 + eps, 1 - eps], [y1, y1], c="#111111", lw=lw)

# piet blocks
red_xs = [xs[i] for i in 1:n if xs[i] >= x1]
Expand Down
13 changes: 7 additions & 6 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function apply_split(tree::MondrianTree{d}, split_tree::MondrianTree{d}) where {
tree_left, tree_right)
else
if all(tree.lower .< split_tree.tree_left.upper) &&
all(split_tree.tree_right.lower .< tree.upper)
all(split_tree.tree_right.lower .< tree.upper)
left_upper = min.(split_tree.tree_left.upper, tree.upper)
right_lower = max.(split_tree.tree_right.lower, tree.lower)
tree_left = MondrianTree(tree.id * "L", tree.lambda, tree.lower, left_upper,
Expand Down Expand Up @@ -241,7 +241,8 @@ are_in_same_leaf(x1, x2, tree)
true
```
"""
function are_in_same_leaf(x1::NTuple{d,Float64}, x2::NTuple{d,Float64}, tree::MondrianTree) where {d}
function are_in_same_leaf(x1::NTuple{d,Float64}, x2::NTuple{d,Float64},
tree::MondrianTree) where {d}
if tree.is_split
if is_in(x1, tree.tree_left) && is_in(x2, tree.tree_right)
return are_in_same_leaf(x1, x2, tree.left)
Expand Down Expand Up @@ -338,11 +339,11 @@ function restrict(tree::MondrianTree{d}, time::Float64) where {d}
if tree.is_split && tree.tree_left.creation_time <= time
tree_left = restrict(tree.tree_left, time)
tree_right = restrict(tree.tree_right, time)
return MondrianTree{d}(tree.lambda, tree.lower, tree.upper, tree.creation_time, true,
tree.split_axis, tree.split_location, tree_left, tree_right)
return MondrianTree{d}(tree.id, tree.lambda, tree.lower, tree.upper, tree.creation_time,
true, tree.split_axis, tree.split_location, tree_left, tree_right)
else
return MondrianTree{d}(tree.lambda, tree.lower, tree.upper, tree.creation_time, false,
nothing, nothing, nothing, nothing)
return MondrianTree{d}(tree.id, tree.lambda, tree.lower, tree.upper, tree.creation_time,
false, nothing, nothing, nothing, nothing)
end
end

Expand Down

0 comments on commit e507a68

Please sign in to comment.