Skip to content

Commit

Permalink
update plots
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Mar 20, 2024
1 parent 69dce93 commit d0ea4aa
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
16 changes: 8 additions & 8 deletions replication/partition_plots/partition_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ end
function plot_mondrian_tree(tree::MondrianTree)
@assert isa(tree, MondrianTree{2})
splits = get_splits(tree)
(fig, ax) = plt.subplots(figsize=(2.1, 2.1))
(fig, ax) = plt.subplots(figsize=(3, 3))

# plot root cell
lw = 0.3
lw = 0.5
(l1, l2) = tree.lower
(u1, u2) = tree.upper
plot([l1, l1], [l1, u2], color="k", lw=lw)
Expand All @@ -38,10 +38,10 @@ function plot_mondrian_tree(tree::MondrianTree)
end

# format plot
plt.xticks([0, 1])
plt.yticks([0, 1])
plt.xlabel("\$x_1\$")
plt.ylabel("\$x_2\$")
plt.xticks([0, 1], fontsize=11)
plt.yticks([0, 1], fontsize=11)
plt.xlabel("\$x_1\$", fontsize=12)
plt.ylabel("\$x_2\$", fontsize=12)
ax.xaxis.set_label_coords(0.5, -0.04)
ax.yaxis.set_label_coords(-0.04, 0.5)
ax.tick_params(color="w", direction="in", pad=0)
Expand All @@ -61,7 +61,7 @@ for i in 1:length(lambdas)
global lambda = lambdas[i]
global tree = MondrianTree(d, lambda)
global (fig, ax) = plot_mondrian_tree(tree)
savefig("./replication/partition_plots/plot_mondrian_process_$i.pgf", bbox_inches="tight")
savefig("./replication/partition_plots/plot_mondrian_process_$i.pdf", bbox_inches="tight")
plt.tight_layout()
savefig("./replication/partition_plots/plot_mondrian_process_$i.pdf")
plt.close("all")
end
18 changes: 11 additions & 7 deletions replication/weather/weather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,20 @@ function format_plot(ax)
yticks = [990 + i * 10 for i in 0:5]
xticklabels = "\$" .* string.(xticks) .* "\$"
yticklabels = "\$" .* string.(yticks) .* "\$"
plt.xticks((xticks .- x_min) ./ (x_max - x_min), labels=xticklabels)
plt.yticks((yticks .- y_min) ./ (y_max - y_min), labels=yticklabels)
plt.xlabel("Relative humidity at 3pm (\\%)")
plt.ylabel("Pressure at 3pm (mbar)")
plt.xticks((xticks .- x_min) ./ (x_max - x_min), labels=xticklabels, fontsize=11)
plt.yticks((yticks .- y_min) ./ (y_max - y_min), labels=yticklabels, fontsize=11)
plt.xlabel("Relative humidity at 3pm (\\%)", fontsize=12)
plt.ylabel("Pressure at 3pm (mbar)", fontsize=12)
# color key
#handle = plt.scatter([], [], c="white")
dry_handle = plt.scatter([], [], c=dry_color)
wet_handle = plt.scatter([], [], c=wet_color)
ax.legend([dry_handle, wet_handle], ["Dry tomorrow", "Wet tomorrow"],
ax.legend([wet_handle, dry_handle],
["Rain next day", "No rain next day"],
handletextpad=0.1, frameon=false,
bbox_to_anchor=(1.01, 1.16), ncol=2)
bbox_to_anchor=(1.04, 1.13), ncol=3,
fontsize=12,
columnspacing=0.8)
# layout
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
Expand Down Expand Up @@ -221,7 +225,7 @@ limit = nothing
(data, x_min, x_max, y_min, y_max) = load_data(limit=limit)
dry_color = "#da6200"
wet_color = "#0080d0"
figsize = (3.5, 3.7)
figsize = (4, 4.2)
dpi = 500

# make trees
Expand Down
14 changes: 7 additions & 7 deletions replication/weather/weather_cv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ for lambda in lambdas
println("mse: ", mse)
end

(fig, ax) = plt.subplots(figsize=(3.5, 3.7))
(fig, ax) = plt.subplots(figsize=(4, 4.4))
best_lambda = 5.0
i = [i for i in 1:length(lambdas) if isapprox(lambdas[i], best_lambda, rtol=0.01)][]
plt.plot([best_lambda, best_lambda], [0.0, gcvs[i] - 0.0001], c="#666677",
Expand All @@ -110,15 +110,15 @@ plt.plot(lambdas, mses, lw=1.0, c="#aa44dd",
plt.plot(lambdas, gcvs, lw=1.0, c="#009944",
label="Generalized cross-validation")
plt.ylim([0.11 - 0.002, 0.17 + 0.002])
plt.yticks(range(0.11, stop=0.17, step=0.01))
plt.xlabel("Lifetime parameter \$\\lambda\$")
plt.ylabel("Loss function")
plt.legend(frameon=false)
plt.subplots_adjust(left=0.205, right=0.96, top=0.854, bottom=0.165)
plt.yticks(range(0.11, stop=0.17, step=0.01), fontsize=11)
plt.xticks(fontsize=11)
plt.xlabel("Lifetime parameter \$\\lambda\$", fontsize=12)
plt.ylabel("Loss function", fontsize=12)
plt.legend(fontsize=12)
plt.subplots_adjust(left=0.205, right=0.96, top=0.842, bottom=0.140)
plt.savefig("./replication/weather/weather_gcv.png", dpi=500)

# CIs
#limit = 1000
limit = nothing
(data, x_min, x_max, y_min, y_max) = load_data(limit=limit)
n = nrow(data)
Expand Down

0 comments on commit d0ea4aa

Please sign in to comment.