Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Chains objects display only information and not statistical eval #307

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,6 @@ end

function Base.show(io::IO, mime::MIME"text/plain", chains::Chains)
print(io, "Chains ", chains, ":\n\n", header(chains))

# Show summary stats.
summaries = describe(chains)
for summary in summaries
println(io)
show(io, mime, summary)
end
end

Base.keys(c::Chains) = names(c)
Expand All @@ -333,7 +326,7 @@ Base.last(c::Chains) = last(c.value[Axis{:iter}].val)

Base.convert(::Type{Array}, chn::Chains) = convert(Array, chn.value)

# Convenience functions to handle different types of
# Convenience functions to handle different types of
# timestamps.
to_datetime(t::DateTime) = t
to_datetime(t::Float64) = unix2datetime(t)
Expand Down Expand Up @@ -375,15 +368,15 @@ max_stop(c::Chains) = max_datetime(stop_times(c))
"""
start_times(c::Chains)

Retrieve the contents of `c.info.start_time`, or `missing` if no
Retrieve the contents of `c.info.start_time`, or `missing` if no
`start_time` is set.
"""
start_times(c::Chains) = to_datetime_vec(get(c.info, :start_time, missing))

"""
stop_times(c::Chains)

Retrieve the contents of `c.info.stop_time`, or `missing` if no
Retrieve the contents of `c.info.stop_time`, or `missing` if no
`stop_time` is set.
"""
stop_times(c::Chains) = to_datetime_vec(get(c.info, :stop_time, missing))
Expand Down Expand Up @@ -411,14 +404,14 @@ end

Calculate the compute time for all chains in seconds.

The duration is calculated as the sum of `start - stop` in seconds.
The duration is calculated as the sum of `start - stop` in seconds.

`compute_duration` is more useful in cases of parallel sampling, where `wall_duration`
may understate how much computation time was utilitzed.
"""
function compute_duration(
c::Chains;
start=start_times(c),
c::Chains;
start=start_times(c),
stop=stop_times(c)
)
# Calculate total time for each chain, then add it up.
Expand Down Expand Up @@ -776,24 +769,24 @@ function _cat(::Val{3}, c1::Chains, args::Chains...)

# concatenate all chains
data = mapreduce(
c -> c.value.data,
(x, y) -> cat(x, y; dims = 3),
args;
c -> c.value.data,
(x, y) -> cat(x, y; dims = 3),
args;
init = c1.value.data
)
value = AxisArray(data; iter = rng, var = nms, chain = 1:size(data, 3))

# Concatenate times, if available
starts = mapreduce(
c -> get(c.info, :start_time, nothing),
vcat,
args,
c -> get(c.info, :start_time, nothing),
vcat,
args,
init = get(c1.info, :start_time, nothing)
)
stops = mapreduce(
c -> get(c.info, :stop_time, nothing),
vcat,
args,
c -> get(c.info, :stop_time, nothing),
vcat,
args,
init = get(c1.info, :stop_time, nothing)
)
nontime_props = filter(x -> !(x in [:start_time, :stop_time]), [propertynames(c1.info)...])
Expand All @@ -810,7 +803,7 @@ function pool_chain(c::Chains)
end

"""
replacenames(chains::Chains, dict::AbstractDict)
replacenames(chains::Chains, dict::AbstractDict)

Replace parameter names by creating a new `Chains` object that shares the same underlying data.

Expand All @@ -827,7 +820,7 @@ julia> names(chn2)

julia> chn3 = replacenames(chn2, Dict("A" => "one", "two" => "B"));

julia> names(chn3)
julia> names(chn3)
2-element Vector{Symbol}:
:one
:B
Expand Down
9 changes: 9 additions & 0 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ function describe(
return dfs
end

function Base.show(io::IO, mime::MIME"text/plain", cs::Vector{ChainDataFrame})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only captures the abstract Vector{ChainDataFrame} but not any concretely typed Vector{<:ChainDataFrame}. In general, I am a bit worried about changing the display of vectors of ChainDataFrames - it seems wrong to completely opt out of the default display mechanism of vectors in Julia (I also wonder if it causes any problems) just to change the way in which describe(chain) is displayed. Maybe rather describe should not return a vector of ChainDataFrame.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm down with this. My understanding of your suggestion is that describe should be a pure IO function -- we should make display(io::IO, chn::Chain) and the output is all the stuff inside the Base.show definition above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not define display (it calls show(io, MIME("text/plain"), x) which should be implemented), but we should just implement DataAPI.describe(io, chain) if we want to display the summary statistics in a nice way: https://juliastats.org/StatsBase.jl/latest/scalarstats/#Summary-Statistics-1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your help! So, to double check that I'm understanding these suggestions. Instead of returning a Vector{ChainDataFrame}, describe should be an implementation of StatsBase.describe and return something like this?

chn = Chains(rand(100, 2, 2), [:a, :b])
chn_arr = Array(chn)
sections = chn.name_map[:parameters]
for i in 1:length(sections)
    println("Parameter $(sections[i])")
    describe(chn_arr[:,i])
    println()
end

StatsBase _describe

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine if summarystats returns different summary statistics (so this method does not have to be changed) but we should make sure that describe just prints these summary statistics in a pretty way to be consistent with how describe and summarystats are defined in StatsBase. I.e., in particular describe should not return anything but only print to IO and it should not print the quantiles if they are not part of summarystats (here it might actually be better to include them in summarystats as well and return two dataframes, possibly as a named tuple).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify maybe a little bit, there's no need to change to this code you've posted:

chn = Chains(rand(100, 2, 2), [:a, :b])
chn_arr = Array(chn)
sections = chn.name_map[:parameters]
for i in 1:length(sections)
    println("Parameter $(sections[i])")
    describe(chn_arr[:,i]) # <- This is not what we want, we want to print the results of `describe(chain)` here instead
    println()
end

Basically we want to change the stuff that is currently in show to describe, so that describe becomes a pure IO function and not a weird Vector{<:ChainDataFrame} thing that we have now.

Maybe one way is to do something like

function DataAPI.describe(io::IO, chains::Chains)
    print(io, "Chains ", chains, ":\n\n", header(chains))

    summstats = summarystats(chains)
    qs = quantiles(chains)

    println(io)
    show(io, summstats)

    println(io)
    show(io, qs)
end

which won't actually return anything, it just prints stuff out to the screen. There's probably a way more sane way to do this, but it's a rough sketch to get you started.


# Show summary stats.
for c in cs
println(io)
show(io, mime, c)
end
end

function _hpd(x::AbstractVector{<:Real}; alpha::Real=0.05)
n = length(x)
m = max(1, ceil(Int, alpha * n))
Expand Down