Skip to content

Commit

Permalink
plots
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Mar 22, 2024
1 parent 0d1bb0e commit 4f68595
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 17 deletions.
27 changes: 12 additions & 15 deletions replication/weather/weather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ using DataFrames
using PyPlot
using Random
using Colors
using Plots
using Dates
using MondrianForests

# plot setup
rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
rcParams["text.usetex"] = true
rcParams["font.family"] = "serif"
#rcParams["text.latex.preamble"] = "\\usepackage[sfdefault,light]{FiraSans}"
plt.ioff()

function load_data(; limit=nothing)
Expand Down Expand Up @@ -172,9 +170,9 @@ function plot_debiased_forest(trees1, trees2, ax)
Y = [data[i, :RainTomorrow] for i in 1:nrow(data)]
n = length(X)
all_counts1 = [[sum(MondrianForests.is_in(X[i], cell) for i in 1:n)
for cell in all_cells1[j]] for j in 1:length(trees1)]
for cell in all_cells1[j]] for j in 1:length(trees1)]
all_counts2 = [[sum(MondrianForests.is_in(X[i], cell) for i in 1:n)
for cell in all_cells2[j]] for j in 1:length(trees2)]
for cell in all_cells2[j]] for j in 1:length(trees2)]
all_ones1 = [[sum(MondrianForests.is_in(X[i], cell) * Y[i] for i in 1:n)
for cell in all_cells1[j]] for j in 1:length(trees1)]
all_ones2 = [[sum(MondrianForests.is_in(X[i], cell) * Y[i] for i in 1:n)
Expand All @@ -188,18 +186,22 @@ function plot_debiased_forest(trees1, trees2, ax)
for c in 1:length(refined_cells)
r_cell = refined_cells[c]
x = MondrianForests.get_center(r_cell)
cell_ids1 = [[c for c in 1:length(all_cells1[j]) if
MondrianForests.is_in(x, all_cells1[j][c])][] for j in 1:length(trees1)]
cell_ids2 = [[c for c in 1:length(all_cells2[j]) if
MondrianForests.is_in(x, all_cells2[j][c])][] for j in 1:length(trees2)]
cell_ids1 = [[c
for c in 1:length(all_cells1[j])
if
MondrianForests.is_in(x, all_cells1[j][c])][] for j in 1:length(trees1)]
cell_ids2 = [[c
for c in 1:length(all_cells2[j])
if
MondrianForests.is_in(x, all_cells2[j][c])][] for j in 1:length(trees2)]
counts1 = [all_counts1[j][cell_ids1[j]] for j in 1:length(trees1)]
counts2 = [all_counts2[j][cell_ids2[j]] for j in 1:length(trees2)]
ones1 = [all_ones1[j][cell_ids1[j]] for j in 1:length(trees1)]
ones2 = [all_ones2[j][cell_ids2[j]] for j in 1:length(trees2)]
rs1 = ones1 ./ counts1
rs2 = ones2 ./ counts2
if !all(isnan.(rs1))
# TODO is this right
# TODO is this right, fix in other non-debiased version too
ratios1[c] = sum(r for r in rs1 if !isnan(r)) / sum(1 for r in rs1 if !isnan(r))
else
ratios1[c] = NaN
Expand Down Expand Up @@ -293,7 +295,7 @@ function make_forest_design_plot(data, trees, x_min, x_max,
end

function make_debiased_forest_design_plot(data, trees1, trees2, x_min, x_max,
y_min, y_max, design_points, filename)
y_min, y_max, design_points, filename)
# TODO
(fig, ax) = plt.subplots(figsize=figsize)
plot_debiased_forest(trees1, trees2, ax)
Expand All @@ -309,7 +311,6 @@ function make_debiased_forest_design_plot(data, trees1, trees2, x_min, x_max,
end

# get data and plot params
#limit = 1000
limit = nothing
(data, x_min, x_max, y_min, y_max) = load_data(limit=limit)
dry_color = "#da6200"
Expand Down Expand Up @@ -339,7 +340,6 @@ for i in 1:length(seeds)
push!(trees, tree)
end

#=
# plot data
println("plotting data")
filename = "./replication/weather/weather_data.png"
Expand All @@ -362,7 +362,6 @@ for i in [2]
global filename = "./replication/weather/weather_forest_" * string(i) * ".png"
make_forest_plot(data, trees[1:i], x_min, x_max, y_min, y_max, filename)
end
=#

# plot debiased forest with design points
i = 15
Expand All @@ -382,11 +381,9 @@ global filename = "./replication/weather/weather_debiased_forest_design.png"
make_debiased_forest_design_plot(data, trees[1:i], trees2[1:i], x_min, x_max,
y_min, y_max, design_points, filename)

#=
i = 30
# plot forest with design points
println("plotting forest with ", i, " trees and design points")
global filename = "./replication/weather/weather_forest_design.png"
make_forest_design_plot(data, trees[1:i], x_min, x_max, y_min,
y_max, design_points, filename)
=#
1 change: 0 additions & 1 deletion replication/weather/weather_cv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using MondrianForests
rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
rcParams["text.usetex"] = true
rcParams["font.family"] = "serif"
#rcParams["text.latex.preamble"] = "\\usepackage[sfdefault,light]{FiraSans}"
plt.ioff()

function load_data(; limit=nothing)
Expand Down
2 changes: 1 addition & 1 deletion shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ in pkgs.mkShell {
];
shellHook = ''
# run this to link Julia PyCall package to nixpkgs python3
#julia --color=yes -e 'using Pkg; ENV["PYTHON"]="${pkgs.python3}/bin/python3"; Pkg.build("PyCall")'
julia --color=yes -e 'using Pkg; ENV["PYTHON"]="${pkgs.python3}/bin/python3"; Pkg.build("PyCall")'
'';
}

0 comments on commit 4f68595

Please sign in to comment.