Skip to content

Commit

Permalink
v1.88 trankplot edition
Browse files Browse the repository at this point in the history
- ranked histogram trace plots, trankplot(), added. See ?trankplot
- small adjustment to rethinking color palette
-
  • Loading branch information
Richard McElreath committed Apr 9, 2019
1 parent 2ce24c8 commit f699ad3
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 12 deletions.
4 changes: 4 additions & 0 deletions R/colors.r
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ col.dist <- function( x , mu=0 , sd=1 , col="slateblue" ) {

rangi2 <- col.desat( "blue" , 0.5 )
grau <- function( alpha=0.5 ) col.alpha( "black" , alpha )

#rethink_palette <- c("#5BBCD6","#F98400","#F2AD00","#00A08A","#FF0000")
rethink_palette <- c("#8080FF","#F98400","#00A08A","#E2AD00","#FF0000")
rethink_cmyk <- c(col.alpha("black",0.25),"cyan")
4 changes: 1 addition & 3 deletions R/map2stan-class.r
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,7 @@ setMethod("pairs" , "map2stan" , function(x, n=500 , alpha=0.7 , cex=0.7 , pch=1
})

# my trace plot function
#rethink_palette <- c("#5BBCD6","#F98400","#F2AD00","#00A08A","#FF0000")
rethink_palette <- c("#8080FF","#F98400","#00A08A","#E2AD00","#FF0000")
rethink_cmyk <- c(col.alpha("black",0.25),"cyan")

tracerplot <- function( object , pars , col=rethink_palette , alpha=1 , bg=col.alpha("black",0.15) , ask=TRUE , window , n_cols=3 , max_rows=5 , lwd=0.5 , ... ) {

if ( !(class(object) %in% c("map2stan","stanfit")) ) stop( "requires map2stan or stanfit fit object" )
Expand Down
8 changes: 6 additions & 2 deletions R/trankplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

# convert matrix to a matrix of ranks (over entire matrix)
rank_mat <- function( x ) {
if ( class(x)=="numeric" ) x <- array( x , dim=c(length(x),1) )
matrix( rank(x) , ncol=ncol(x) )
}

trankplot <- function( object , bins=30 , pars , chains , col=rethink_palette , alpha=1 , bg=col.alpha("black",0.15) , ask=TRUE , window , n_cols=3 , max_rows=5 , lwd=1.5 , lp=FALSE , axes=FALSE , ... ) {
trankplot <- function( object , bins=30 , pars , chains , col=rethink_palette , alpha=1 , bg=col.alpha("black",0.15) , ask=TRUE , window , n_cols=3 , max_rows=5 , lwd=1.5 , lp=FALSE , axes=FALSE , off=0 , ... ) {

if ( !(class(object) %in% c("map2stan","ulam","stanfit")) ) stop( "requires map2stan, ulam or stanfit object" )

Expand All @@ -24,6 +25,9 @@ trankplot <- function( object , bins=30 , pars , chains , col=rethink_palette ,
# names
dimnames <- attr(post,"dimnames")
n_chains <- length(dimnames$chains)

if ( n_chains==1 ) stop( "trankplot requires more than one chain." )

if ( missing(chains) ) chains <- 1:n_chains
n_chains <- length(chains)
pars <- dimnames$parameters
Expand Down Expand Up @@ -92,7 +96,7 @@ trankplot <- function( object , bins=30 , pars , chains , col=rethink_palette ,
for ( i in chains ) {
x <- c( breaks[1] , rep( breaks[2:(nb-1)] , each=2 ) , breaks[nb] )
y <- rep( r[ 1:(nb-1) ,i] , each=2 )
lines( x , y , col=col.alpha(chain.cols[i],alpha) , lwd=lwd )
lines( x + (i-1)*off , y , col=col.alpha(chain.cols[i],alpha) , lwd=lwd )
}#i
}

Expand Down
3 changes: 0 additions & 3 deletions R/ulam-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,6 @@ setMethod("pairs" , "ulam" , function(x, n=200 , alpha=0.7 , cex=0.7 , pch=16 ,
})

# my trace plot function
#rethink_palette <- c("#5BBCD6","#F98400","#F2AD00","#00A08A","#FF0000")
rethink_palette <- c("#8080FF","#F98400","#F2AD00","#00A08A","#FF0000")
rethink_cmyk <- c(col.alpha("black",0.25),"cyan")
traceplot_ulam <- function( object , pars , chains , col=rethink_palette , alpha=1 , bg=col.alpha("black",0.15) , ask=TRUE , window , trim=100 , n_cols=3 , max_rows=5 , lwd=0.5 , lp=FALSE , ... ) {

if ( !(class(object) %in% c("map2stan","ulam","stanfit")) ) stop( "requires map2stan or stanfit fit object" )
Expand Down
8 changes: 4 additions & 4 deletions man/rethinking.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
\docType{package}
\title{Statistical Rethinking package}
\description{
This package accompanies a book and course on Bayesian data analysis, featured MAP estimation through \code{\link{map}} and Hamiltonian Monte Carlo through \code{\link{map2stan}}.
This package accompanies a book and course on Bayesian data analysis, featured MAP estimation through \code{\link{quap}} and Hamiltonian Monte Carlo through \code{\link{ulam}}.
}

\details{
\tabular{ll}{
Package: \tab rethinking\cr
Type: \tab Package\cr
Version: \tab 1.70\cr
Date: \tab 21 Aug 2017\cr
Version: \tab 1.88\cr
Date: \tab 02 April 2019\cr
License: \tab GPL-3 \cr
}

Expand All @@ -27,7 +27,7 @@
\examples{}

\seealso{
\code{\link{map}}, \code{\link{map2stan}}
\code{\link{quap}}, \code{\link{ulam}}
}
\keyword{rethinking}
\keyword{package}
85 changes: 85 additions & 0 deletions man/trankplot.Rd
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
\name{trankplot}
\alias{trankplot}
\alias{traceplot}
%- Also NEED an '\alias' for EACH other topic documented here.
\title{Diagnostic trace and rank histogram plots for MCMC output}
\description{
The functions \code{trankplot} and \code{traceplot} display MCMC chain diagnostic plots. \code{trankplot} displays ranked histograms and \code{traceplot} shows the more traditional trace of the samples.
}
\usage{
trankplot( object , bins=30 , pars , chains , col=rethink_palette , alpha=1 ,
bg=col.alpha("black",0.15) , ask=TRUE , window , n_cols=3 , max_rows=5 ,
lwd=1.5 , lp=FALSE , axes=FALSE , ... )
traceplot( object , pars , chains , col=rethink_palette , alpha=1 ,
bg=col.alpha("black",0.15) , ask=TRUE , window , trim=100 , n_cols=3 ,
max_rows=5 , lwd=0.5 , lp=FALSE , ... )
}
%- maybe also 'usage' for other objects documented here.
\arguments{
\item{object}{A \code{stanfit}, \code{ulam} or \code{map2stan} object}
\item{bins}{For \code{trankplot}, the number of histogram bins to use}
\item{pars}{Optional character vector of parameters to display}
\item{chains}{Optional integer vector of chains to display}
\item{col}{Vector of colors to use for chains}
\item{alpha}{Transparency}
\item{bg}{Background color for warmup samples}
\item{ask}{Interactive paging when \code{TRUE}}
\item{window}{Optional range of samples to show}
\item{n_cols}{Number of columns in display}
\item{max_rows}{Maximum number of rows on each page}
\item{lwd}{Line width}
\item{lp}{Whether to include log_prob in display}
\item{axes}{Whether to show axes on plots}
\item{trim}{For \code{traceplot}, number of samples to trim for start. Helps with display, since early warmup samples typically very far from typical samples.}
\item{...}{Additional arguments to pass to \code{\link{plot}}}
}
\details{
\code{trankplot} produces rank histograms of each chain, as described in Vehtari et al 2019 (see reference below). For each parameter, the samples from all chains are first ranked, using \code{rank_mat}. This returns a matrix of ranks, with the chains preserved. Then a histogram is built for each chain, using the same break points. These historgrams are then overlain in the plot.

For healthy well-mixing chains, the histrograms should be uniform. When there are spikes for some chains, especially in the low or high ranks, this suggests problems in exploring the posterior.

\code{traceplot} shows the sequential samples for each parameter and chain. This is the same information as the \code{trankplot}, but often much harder to see, given the volume of samples.
}
\references{Vehtari, Gelman, Simpson, Carpenter, Bürkner. 2019. Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC. https://arxiv.org/abs/1903.08008}
\author{Richard McElreath}
\seealso{}
\examples{
\dontrun{
library(rethinking)
data(chimpanzees)

d <- list(
pulled_left = chimpanzees$pulled_left ,
prosoc_left = chimpanzees$prosoc_left ,
condition = as.integer( 2 - chimpanzees$condition ) ,
actor = as.integer( chimpanzees$actor ) ,
blockid = as.integer( chimpanzees$block )
)

m <- ulam(
alist(
# likeliood
pulled_left ~ bernoulli(theta),

# linear models
logit(theta) <- A + BP*prosoc_left,
A <- a + v[actor,1],
BP <- bp + v[actor,condition+1],

# adaptive prior
vector[3]: v[actor] ~ multi_normal( 0 , Rho_actor , sigma_actor ),

# fixed priors
c(a,bp) ~ normal(0,1),
sigma_actor ~ exponential(1),
Rho_actor ~ lkjcorr(4)
) , data=d , chains=3 , cores=1 , sample=TRUE )

trankplot(m)

}
}
% Add one or more standard keywords, see file 'KEYWORDS' in the
% R documentation directory.
\keyword{ }

0 comments on commit f699ad3

Please sign in to comment.