Skip to content

Commit

Permalink
diagram working again
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Oct 19, 2023
1 parent 4e2f2d0 commit 8f0e419
Showing 1 changed file with 109 additions and 76 deletions.
185 changes: 109 additions & 76 deletions replication/construction_diagrams/construction_diagrams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,47 @@ rcParams["text.usetex"] = true
rcParams["text.latex.preamble"] = "\\usepackage[sfdefault,light]{FiraSans}"
plt.ioff()

function plot_mondrian_process(state, tree)
t = state["tree"]
splits = get_splits(t)
function get_splits(tree::MondrianTree{d}) where {d}
if !isnothing(tree.split_axis)
lower = tree.tree_right.cell.lower
upper = tree.tree_left.cell.upper
return [(lower, upper); get_splits(tree.tree_left); get_splits(tree.tree_right)]
else
return Tuple{NTuple{d,Float64},NTuple{d,Float64}}[]
end
end

function get_cells(tree::MondrianTree{d}) where {d}
if !isnothing(tree.split_axis)
return [get_cells(tree.tree_left); get_cells(tree.tree_right)]
else
return [tree.cell]
end
end

function get_ids(tree::MondrianTree{d}) where {d}
if !isnothing(tree.split_axis)
return [get_ids(tree.tree_left); get_ids(tree.tree_right)]
else
return [tree.cell.id]
end
end

function plot_mondrian_process(partition)
tree = partition["tree"]
splits = get_splits(tree)
(fig, ax) = plt.subplots(figsize=(2.2, 2.2))

# highlight current cell
if !isnothing(state["current"])
cell = state["current"]
if !isnothing(partition["current"])
cell = partition["current"]
x1s = [cell.lower[1], cell.lower[1], cell.upper[1], cell.upper[1]]
x2s = [cell.lower[2], cell.upper[2], cell.upper[2], cell.lower[2]]
fill(x1s, x2s, facecolor="#ecd9ff")
end

#=
# highlight leaves
for cell in states["leaves"]
for cell in partition["terminals"]
x1s = [cell.lower[1], cell.lower[1], cell.upper[1], cell.upper[1]]
x2s = [cell.lower[2], cell.upper[2], cell.upper[2], cell.lower[2]]
fill(x1s, x2s, facecolor="#b5fdc7")
Expand All @@ -46,10 +71,10 @@ function plot_mondrian_process(state, tree)
end

# annotate cells
ids = MondrianForests.get_ids(tree)
leaves = MondrianForests.get_leaves(tree)
centers = MondrianForests.get_center.([t.cell for t in leaves])
for i in 1:length(leaves)
ids = get_ids(tree)
cells = get_cells(tree)
centers = MondrianForests.get_center.(cells)
for i in 1:length(cells)
if ids[i] == ""
label = "\$C_{\\emptyset}\$"
else
Expand All @@ -61,9 +86,9 @@ function plot_mondrian_process(state, tree)

# add split point
subtrees = MondrianForests.get_subtrees(tree)
current = state["current"]
current = partition["current"]
if !isnothing(current)
if !(current in leaves)
if !(current in cells)
subtree = [c for c in subtrees if c.cell == current][]
J = subtree.split_axis
S = subtree.split_location
Expand All @@ -90,21 +115,29 @@ function plot_mondrian_process(state, tree)
# color key
current_handle = plt.scatter([], [], c="#d9b3ff")
terminal_handle = plt.scatter([], [], c="#6feb8e")
if length(leaves) >= 4 || length(state["leaves"]) >= 1
if length(cells) >= 4 || length(partition["terminals"]) >= 1
ax.legend([current_handle, terminal_handle],
["Current", "Leaf"], ncol=2,
handletextpad=0.1, frameon=false, columnspacing=0.8,
bbox_to_anchor=(0.47, 1.19), loc="upper center")
elseif length(leaves) >= 2 || !isnothing(state["current"])
elseif length(cells) >= 2 || !isnothing(partition["current"])
ax.legend([current_handle], ["Current"], ncol=2,
handletextpad=0.1, frameon=false, columnspacing=0.8,
bbox_to_anchor=(0.277, 1.19), loc="upper center")
end
=#

return (fig, ax)
end

function get_tree_info(tree::MondrianTree)
info = (tree.cell.id, tree.creation_time, tree.cell, tree.cell.id, tree.split_axis)
if !isnothing(tree.split_axis)
return [info; get_tree_info(tree.tree_left); get_tree_info(tree.tree_right)]
else
return [info]
end
end

function get_horizontal_value(id::String)
value = 0.0
for i in 1:length(id)
Expand All @@ -118,20 +151,20 @@ function get_horizontal_value(id::String)
return value
end

function plot_mondrian_tree(state, tree)
function plot_mondrian_tree(partition)
tree = partition["tree"]
(fig, ax) = plt.subplots(figsize=(2.2, 2.2))
state_tree = state["tree"]
leaves = get_leaves(state_tree)
ids = [t.cell.id for t in leaves]
times = [t.creation_time for t in leaves]
info = get_tree_info(tree)
ids = [i[1] for i in info]
times = [i[2] for i in info]
cells = [i[3] for i in info]
n = length(times)
println(state["current"])

# plot split points
for i in 1:n
if leaves[i] == state["current"]
if cells[i] == partition["current"]
color = "#ecd9ff"
elseif leaves[i] in state["leaves"]
elseif cells[i] in partition["terminals"]
color = "#b5fdc7"
else
color = "white"
Expand All @@ -148,7 +181,6 @@ function plot_mondrian_tree(state, tree)
plt.text(x_locs[ids[i]] + 0.005, times[i] + 0.01, label, ha="center",
va="center", fontsize=8, zorder=30)
end
#=

# plot tree
lw = 0.9
Expand All @@ -166,14 +198,15 @@ function plot_mondrian_tree(state, tree)
end

# add split time
cells = get_cells(tree)
subtrees = MondrianForests.get_subtrees(tree)
current = state["current"]
current = partition["current"]
if !isnothing(current)
if !(current in cells)
subtree = [c for c in subtrees if c.cell == current][]
t = subtree.tree_left.creation_time
x_left = x_locs[subtree.tree_left.id]
x_right = x_locs[subtree.tree_right.id]
x_left = x_locs[subtree.tree_left.cell.id]
x_right = x_locs[subtree.tree_right.cell.id]
if x_left > 1
plt.text(x_left - 0.8, t, "\$t + E\$", fontsize=10,
ha="center", va="center")
Expand All @@ -191,7 +224,6 @@ function plot_mondrian_tree(state, tree)
# time label
plt.text(2.84, 2.47, "\$t\$", fontsize=10)

=#
# format
ax.invert_yaxis()
plt.yticks([0, 1, 2])
Expand All @@ -209,68 +241,71 @@ function plot_mondrian_tree(state, tree)
return (fig, ax)
end

function update_states(states, tree)
state = states[end]
tree_leaves = MondrianForests.get_leaves(tree)
state_leaves = MondrianForests.get_leaves(state["tree"])
times = [t.creation_time for t in tree_leaves]
current_split = !isnothing(state["current"]) && !(state["current"] in tree_leaves)
current_parent = !isnothing(state["current"]) && !(state["current"] in state_leaves)
function update_partitions(partitions, tree)
p = partitions[end]
info = get_tree_info(p["tree"])
cells = get_cells(tree)
leaves = get_cells(p["tree"])
current_split = !isnothing(p["current"]) && !(p["current"] in cells)
current_parent = !(p["current"] in leaves)

# update time and tree
if current_split && !current_parent
new_time = minimum(t for t in times if t > state["time"])
new_time = minimum(t for t in times if t > p["time"])
new_tree = MondrianForests.restrict(tree, new_time)
else
new_time = state["time"]
new_tree = state["tree"]
new_time = p["time"]
new_tree = p["tree"]
end

# update current
if isnothing(state["current"]) || current_parent || !current_split
println([t.cell.id for t in state_leaves])
ts = [t for t in state_leaves if !(t in state["leaves"]) && !(t == state["current"])]
ts = [t for t in ts if t.cell.id == minimum(tt.cell.id for tt in ts)]
if !isempty(ts)
new_current = ts[]
if isnothing(p["current"]) || current_parent || !current_split
ids = [i[4]
for i in info
if i[3] in leaves &&
!(i[3] in p["terminals"]) && !(i[3] == p["current"])]
ids = [i for i in ids if length(i) == minimum(length(j) for j in ids)]
if !isempty(ids)
new_current_id = minimum(ids)
new_current = [i[3] for i in info if i[4] == new_current_id][]
else
new_current = nothing
end
else
new_current = state["current"]
new_current = p["current"]
end

# update leaves
if current_split || isnothing(state["current"])
new_leaves = state["leaves"]
# update terminals
if current_split || isnothing(p["current"])
new_terminals = p["terminals"]
else
new_leaves = [state["leaves"]; [state["current"]]]
new_terminals = [p["terminals"]; [p["current"]]]
end

new_state = Dict("time" => new_time,
new_partition = Dict("time" => new_time,
"tree" => new_tree,
"current" => new_current,
"leaves" => new_leaves)
"terminals" => new_terminals)

return [states; [new_state]]
return [partitions; [new_partition]]
end

# construct a good tree
d = 2
lambda = 1.5
lambda = 2.0
Random.seed!(0)
min_vol = 0.0
n_leaves = 1
while min_vol < 0.2 || n_leaves != 4
n_cells = 1
while min_vol < 0.2 || n_cells != 4
global tree = MondrianTree(d, lambda)
global leaves = MondrianForests.get_leaves(tree)
global min_vol = minimum(MondrianForests.get_volume(t.cell) for t in leaves)
global n_leaves = length(leaves)
global cells = get_cells(tree)
global min_vol = minimum(MondrianForests.get_volume(c) for c in cells)
global n_cells = length(cells)
end

# get locations of tree nodes for diagram
subtrees = get_subtrees(tree)
ids = [t.cell.id for t in subtrees]
info = get_tree_info(tree)
ids = [i[1] for i in info]
xs = get_horizontal_value.(ids)
xs = invperm(sortperm(xs)) / 3
x_locs = Dict()
Expand All @@ -279,36 +314,34 @@ for i in 1:length(ids)
end

# calculate the current and terminal nodes
states = [Dict("time" => 0.0,
"tree" => MondrianForests.restrict(tree, 0.0),
"current" => nothing,
"leaves" => [])]
partitions = [Dict("time" => 0.0,
"tree" => MondrianForests.restrict(tree, 0.0),
"current" => nothing,
"terminals" => [])]

info = get_tree_info(tree)
times = [i[2] for i in info]
for rep in 1:11
global states = update_states(states, tree)
global partitions = update_partitions(partitions, tree)
end

#[display(s) for s in states]

# plot the tree structures
println("plotting trees")
dpi = 500
for i in 1:length(states)
for i in 1:length(partitions)
println(i)
state = states[i]
global (fig, ax) = plot_mondrian_tree(state, tree)
partition = partitions[i]
global (fig, ax) = plot_mondrian_tree(partition)
plt.savefig("replication/construction_diagrams/construction_mondrian_tree_$(i).png", dpi=dpi)
plt.close("all")
end

#=
# plot the generation of the partition
println("plotting partitions")
for i in 1:length(states)
for i in 1:length(partitions)
println(i)
state = states[i]
global (fig, ax) = plot_mondrian_process(state)
partition = partitions[i]
global (fig, ax) = plot_mondrian_process(partition)
plt.savefig("replication/construction_diagrams/construction_mondrian_partition_$(i).png", dpi=dpi)
plt.close("all")
end
=#

0 comments on commit 8f0e419

Please sign in to comment.