From 8f0e419d5c25432137f4aff57cd20d77b59a2f7b Mon Sep 17 00:00:00 2001 From: William G Underwood <42812654+WGUNDERWOOD@users.noreply.github.com> Date: Fri, 20 Oct 2023 00:37:55 +0100 Subject: [PATCH] diagram working again --- .../construction_diagrams.jl | 185 +++++++++++------- 1 file changed, 109 insertions(+), 76 deletions(-) diff --git a/replication/construction_diagrams/construction_diagrams.jl b/replication/construction_diagrams/construction_diagrams.jl index 5b5d1e2..bf234e8 100644 --- a/replication/construction_diagrams/construction_diagrams.jl +++ b/replication/construction_diagrams/construction_diagrams.jl @@ -8,22 +8,47 @@ rcParams["text.usetex"] = true rcParams["text.latex.preamble"] = "\\usepackage[sfdefault,light]{FiraSans}" plt.ioff() -function plot_mondrian_process(state, tree) - t = state["tree"] - splits = get_splits(t) +function get_splits(tree::MondrianTree{d}) where {d} + if !isnothing(tree.split_axis) + lower = tree.tree_right.cell.lower + upper = tree.tree_left.cell.upper + return [(lower, upper); get_splits(tree.tree_left); get_splits(tree.tree_right)] + else + return Tuple{NTuple{d,Float64},NTuple{d,Float64}}[] + end +end + +function get_cells(tree::MondrianTree{d}) where {d} + if !isnothing(tree.split_axis) + return [get_cells(tree.tree_left); get_cells(tree.tree_right)] + else + return [tree.cell] + end +end + +function get_ids(tree::MondrianTree{d}) where {d} + if !isnothing(tree.split_axis) + return [get_ids(tree.tree_left); get_ids(tree.tree_right)] + else + return [tree.cell.id] + end +end + +function plot_mondrian_process(partition) + tree = partition["tree"] + splits = get_splits(tree) (fig, ax) = plt.subplots(figsize=(2.2, 2.2)) # highlight current cell - if !isnothing(state["current"]) - cell = state["current"] + if !isnothing(partition["current"]) + cell = partition["current"] x1s = [cell.lower[1], cell.lower[1], cell.upper[1], cell.upper[1]] x2s = [cell.lower[2], cell.upper[2], cell.upper[2], cell.lower[2]] fill(x1s, x2s, facecolor="#ecd9ff") end - #= # highlight leaves - for cell in states["leaves"] + for cell in partition["terminals"] x1s = [cell.lower[1], cell.lower[1], cell.upper[1], cell.upper[1]] x2s = [cell.lower[2], cell.upper[2], cell.upper[2], cell.lower[2]] fill(x1s, x2s, facecolor="#b5fdc7") @@ -46,10 +71,10 @@ function plot_mondrian_process(state, tree) end # annotate cells - ids = MondrianForests.get_ids(tree) - leaves = MondrianForests.get_leaves(tree) - centers = MondrianForests.get_center.([t.cell for t in leaves]) - for i in 1:length(leaves) + ids = get_ids(tree) + cells = get_cells(tree) + centers = MondrianForests.get_center.(cells) + for i in 1:length(cells) if ids[i] == "" label = "\$C_{\\emptyset}\$" else @@ -61,9 +86,9 @@ function plot_mondrian_process(state, tree) # add split point subtrees = MondrianForests.get_subtrees(tree) - current = state["current"] + current = partition["current"] if !isnothing(current) - if !(current in leaves) + if !(current in cells) subtree = [c for c in subtrees if c.cell == current][] J = subtree.split_axis S = subtree.split_location @@ -90,21 +115,29 @@ function plot_mondrian_process(state, tree) # color key current_handle = plt.scatter([], [], c="#d9b3ff") terminal_handle = plt.scatter([], [], c="#6feb8e") - if length(leaves) >= 4 || length(state["leaves"]) >= 1 + if length(cells) >= 4 || length(partition["terminals"]) >= 1 ax.legend([current_handle, terminal_handle], ["Current", "Leaf"], ncol=2, handletextpad=0.1, frameon=false, columnspacing=0.8, bbox_to_anchor=(0.47, 1.19), loc="upper center") - elseif length(leaves) >= 2 || !isnothing(state["current"]) + elseif length(cells) >= 2 || !isnothing(partition["current"]) ax.legend([current_handle], ["Current"], ncol=2, handletextpad=0.1, frameon=false, columnspacing=0.8, bbox_to_anchor=(0.277, 1.19), loc="upper center") end - =# return (fig, ax) end +function get_tree_info(tree::MondrianTree) + info = (tree.cell.id, tree.creation_time, tree.cell, tree.cell.id, tree.split_axis) + if !isnothing(tree.split_axis) + return [info; get_tree_info(tree.tree_left); get_tree_info(tree.tree_right)] + else + return [info] + end +end + function get_horizontal_value(id::String) value = 0.0 for i in 1:length(id) @@ -118,20 +151,20 @@ function get_horizontal_value(id::String) return value end -function plot_mondrian_tree(state, tree) +function plot_mondrian_tree(partition) + tree = partition["tree"] (fig, ax) = plt.subplots(figsize=(2.2, 2.2)) - state_tree = state["tree"] - leaves = get_leaves(state_tree) - ids = [t.cell.id for t in leaves] - times = [t.creation_time for t in leaves] + info = get_tree_info(tree) + ids = [i[1] for i in info] + times = [i[2] for i in info] + cells = [i[3] for i in info] n = length(times) - println(state["current"]) # plot split points for i in 1:n - if leaves[i] == state["current"] + if cells[i] == partition["current"] color = "#ecd9ff" - elseif leaves[i] in state["leaves"] + elseif cells[i] in partition["terminals"] color = "#b5fdc7" else color = "white" @@ -148,7 +181,6 @@ function plot_mondrian_tree(state, tree) plt.text(x_locs[ids[i]] + 0.005, times[i] + 0.01, label, ha="center", va="center", fontsize=8, zorder=30) end - #= # plot tree lw = 0.9 @@ -166,14 +198,15 @@ function plot_mondrian_tree(state, tree) end # add split time + cells = get_cells(tree) subtrees = MondrianForests.get_subtrees(tree) - current = state["current"] + current = partition["current"] if !isnothing(current) if !(current in cells) subtree = [c for c in subtrees if c.cell == current][] t = subtree.tree_left.creation_time - x_left = x_locs[subtree.tree_left.id] - x_right = x_locs[subtree.tree_right.id] + x_left = x_locs[subtree.tree_left.cell.id] + x_right = x_locs[subtree.tree_right.cell.id] if x_left > 1 plt.text(x_left - 0.8, t, "\$t + E\$", fontsize=10, ha="center", va="center") @@ -191,7 +224,6 @@ function plot_mondrian_tree(state, tree) # time label plt.text(2.84, 2.47, "\$t\$", fontsize=10) - =# # format ax.invert_yaxis() plt.yticks([0, 1, 2]) @@ -209,68 +241,71 @@ function plot_mondrian_tree(state, tree) return (fig, ax) end -function update_states(states, tree) - state = states[end] - tree_leaves = MondrianForests.get_leaves(tree) - state_leaves = MondrianForests.get_leaves(state["tree"]) - times = [t.creation_time for t in tree_leaves] - current_split = !isnothing(state["current"]) && !(state["current"] in tree_leaves) - current_parent = !isnothing(state["current"]) && !(state["current"] in state_leaves) +function update_partitions(partitions, tree) + p = partitions[end] + info = get_tree_info(p["tree"]) + cells = get_cells(tree) + leaves = get_cells(p["tree"]) + current_split = !isnothing(p["current"]) && !(p["current"] in cells) + current_parent = !(p["current"] in leaves) # update time and tree if current_split && !current_parent - new_time = minimum(t for t in times if t > state["time"]) + new_time = minimum(t for t in times if t > p["time"]) new_tree = MondrianForests.restrict(tree, new_time) else - new_time = state["time"] - new_tree = state["tree"] + new_time = p["time"] + new_tree = p["tree"] end # update current - if isnothing(state["current"]) || current_parent || !current_split - println([t.cell.id for t in state_leaves]) - ts = [t for t in state_leaves if !(t in state["leaves"]) && !(t == state["current"])] - ts = [t for t in ts if t.cell.id == minimum(tt.cell.id for tt in ts)] - if !isempty(ts) - new_current = ts[] + if isnothing(p["current"]) || current_parent || !current_split + ids = [i[4] + for i in info + if i[3] in leaves && + !(i[3] in p["terminals"]) && !(i[3] == p["current"])] + ids = [i for i in ids if length(i) == minimum(length(j) for j in ids)] + if !isempty(ids) + new_current_id = minimum(ids) + new_current = [i[3] for i in info if i[4] == new_current_id][] else new_current = nothing end else - new_current = state["current"] + new_current = p["current"] end - # update leaves - if current_split || isnothing(state["current"]) - new_leaves = state["leaves"] + # update terminals + if current_split || isnothing(p["current"]) + new_terminals = p["terminals"] else - new_leaves = [state["leaves"]; [state["current"]]] + new_terminals = [p["terminals"]; [p["current"]]] end - new_state = Dict("time" => new_time, + new_partition = Dict("time" => new_time, "tree" => new_tree, "current" => new_current, - "leaves" => new_leaves) + "terminals" => new_terminals) - return [states; [new_state]] + return [partitions; [new_partition]] end # construct a good tree d = 2 -lambda = 1.5 +lambda = 2.0 Random.seed!(0) min_vol = 0.0 -n_leaves = 1 -while min_vol < 0.2 || n_leaves != 4 +n_cells = 1 +while min_vol < 0.2 || n_cells != 4 global tree = MondrianTree(d, lambda) - global leaves = MondrianForests.get_leaves(tree) - global min_vol = minimum(MondrianForests.get_volume(t.cell) for t in leaves) - global n_leaves = length(leaves) + global cells = get_cells(tree) + global min_vol = minimum(MondrianForests.get_volume(c) for c in cells) + global n_cells = length(cells) end # get locations of tree nodes for diagram -subtrees = get_subtrees(tree) -ids = [t.cell.id for t in subtrees] +info = get_tree_info(tree) +ids = [i[1] for i in info] xs = get_horizontal_value.(ids) xs = invperm(sortperm(xs)) / 3 x_locs = Dict() @@ -279,36 +314,34 @@ for i in 1:length(ids) end # calculate the current and terminal nodes -states = [Dict("time" => 0.0, - "tree" => MondrianForests.restrict(tree, 0.0), - "current" => nothing, - "leaves" => [])] +partitions = [Dict("time" => 0.0, + "tree" => MondrianForests.restrict(tree, 0.0), + "current" => nothing, + "terminals" => [])] +info = get_tree_info(tree) +times = [i[2] for i in info] for rep in 1:11 - global states = update_states(states, tree) + global partitions = update_partitions(partitions, tree) end -#[display(s) for s in states] - # plot the tree structures println("plotting trees") dpi = 500 -for i in 1:length(states) +for i in 1:length(partitions) println(i) - state = states[i] - global (fig, ax) = plot_mondrian_tree(state, tree) + partition = partitions[i] + global (fig, ax) = plot_mondrian_tree(partition) plt.savefig("replication/construction_diagrams/construction_mondrian_tree_$(i).png", dpi=dpi) plt.close("all") end -#= # plot the generation of the partition println("plotting partitions") -for i in 1:length(states) +for i in 1:length(partitions) println(i) - state = states[i] - global (fig, ax) = plot_mondrian_process(state) + partition = partitions[i] + global (fig, ax) = plot_mondrian_process(partition) plt.savefig("replication/construction_diagrams/construction_mondrian_partition_$(i).png", dpi=dpi) plt.close("all") end -=#