Skip to content

Commit

Permalink
1.57 - fixes and features / 1.6 candidate
Browse files Browse the repository at this point in the history
- added Vehtari’s Pareto smoothed LOO comparison metric, available with
LOO(), and can be used with compare(…,func=LOO)
- map2stan gracefully accepts previous fits now and resamples without
recompiling. resample() deprecated.
- better approach to sampling inits for multiple chains
- fixed map2stan so compatible with rstan 2.8 (refresh bug)
- bug fix for link() and sim() with dordlogit models
- added WAIC method for raw stanfit objects that include log_lik
matrix. See example in ?WAIC.
  • Loading branch information
Richard McElreath committed Sep 25, 2015
1 parent a6631ad commit a9acd53
Show file tree
Hide file tree
Showing 16 changed files with 526 additions and 152 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
Package: rethinking
Type: Package
Title: Statistical Rethinking book package
Version: 1.56
Date: 2015-07-20
Version: 1.57
Date: 2015-09-25
Author: Richard McElreath
Maintainer: Richard McElreath <[email protected]>
Imports: coda, MASS, mvtnorm
Imports: coda, MASS, mvtnorm, loo
Depends: rstan, parallel, methods
Description: Utilities for fitting and comparing models
License: GPL (>= 3)
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ importFrom(parallel, makeCluster)
importFrom(parallel, clusterExport)
importFrom(mvtnorm, dmvnorm)
importFrom(mvtnorm, rmvnorm)
importFrom(loo, loo)
export(dmvnorm)
export(rmvnorm)
exportClasses(map2stan)
Expand Down
76 changes: 57 additions & 19 deletions R/compare.r
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ setMethod( "show" , "compareIC" , function(object) {
} )

# new compare function, defaulting to WAIC
compare <- function( ... , n=1e3 , sort="WAIC" , WAIC=TRUE , refresh=0 ) {
compare <- function( ... , n=1e3 , sort="WAIC" , func=WAIC , WAIC=TRUE , refresh=0 ) {
# retrieve list of models
L <- list(...)
if ( is.list(L[[1]]) && length(L)==1 )
Expand All @@ -23,6 +23,9 @@ compare <- function( ... , n=1e3 , sort="WAIC" , WAIC=TRUE , refresh=0 ) {
# retrieve model names from function call
mnames <- match.call()
mnames <- as.character(mnames)[2:(length(L)+1)]

# use substitute to deparse the func argument
the_func <- deparse(substitute(func))

# check class of fit models and warn when more than one class represented
classes <- as.character(sapply( L , class ))
Expand All @@ -32,7 +35,7 @@ compare <- function( ... , n=1e3 , sort="WAIC" , WAIC=TRUE , refresh=0 ) {

# check nobs for all models
# if different, warn
nobs_list <- sapply( L , nobs )
nobs_list <- try( sapply( L , nobs ) )
if ( any(nobs_list != nobs_list[1]) ) {
nobs_out <- paste( mnames , nobs_list , "\n" )
nobs_out <- concat(nobs_out)
Expand All @@ -41,16 +44,20 @@ compare <- function( ... , n=1e3 , sort="WAIC" , WAIC=TRUE , refresh=0 ) {
}

dSE.matrix <- matrix( NA , nrow=length(L) , ncol=length(L) )
if ( WAIC==FALSE ) {
DIC.list <- lapply( L , function(z) DIC( z , n=n ) )
pD.list <- sapply( DIC.list , function(x) attr(x,"pD") )
} else {
# use WAIC instead of DIC
# deprecate WAIC==TRUE/FALSE flag
# catch it and convert to func intent
if ( WAIC==FALSE ) func <- DIC # assume old code that wants DIC
if ( the_func=="DIC" ) {
IC.list <- lapply( L , function(z) DIC( z , n=n ) )
p.list <- sapply( IC.list , function(x) attr(x,"pD") )
}
if ( the_func=="WAIC" ) {
# use WAIC
# WAIC is processed pointwise, so can compute SE of differences, and summed later
WAIC.list <- lapply( L , function(z) WAIC( z , n=n , refresh=refresh , pointwise=TRUE ) )
pD.list <- sapply( WAIC.list , function(x) sum(attr(x,"pWAIC")) )
p.list <- sapply( WAIC.list , function(x) sum(attr(x,"pWAIC")) )
se.list <- sapply( WAIC.list , function(x) attr(x,"se") )
DIC.list <- sapply( WAIC.list , sum )
IC.list <- sapply( WAIC.list , sum )
# compute SE of differences between adjacent models from top to bottom in ranking
colnames(dSE.matrix) <- mnames
rownames(dSE.matrix) <- mnames
Expand All @@ -63,27 +70,58 @@ compare <- function( ... , n=1e3 , sort="WAIC" , WAIC=TRUE , refresh=0 ) {
}#j
}#i
}
if ( the_func=="LOO" ) {
# use LOO
LOO.list <- lapply( L , function(z) LOO( z , n=n , refresh=refresh , pointwise=TRUE ) )
p.list <- sapply( LOO.list , function(x) sum(attr(x,"pLOO")) )
se.list <- sapply( LOO.list , function(x) attr(x,"se") )
IC.list <- sapply( LOO.list , sum )
# compute SE of differences between adjacent models from top to bottom in ranking
colnames(dSE.matrix) <- mnames
rownames(dSE.matrix) <- mnames
for ( i in 1:(length(L)-1) ) {
for ( j in (i+1):length(L) ) {
loo_ptw1 <- LOO.list[[i]]
loo_ptw2 <- LOO.list[[j]]
dSE.matrix[i,j] <- as.numeric( sqrt( length(loo_ptw1)*var( loo_ptw1 - loo_ptw2 ) ) )
dSE.matrix[j,i] <- dSE.matrix[i,j]
}#j
}#i
}
if ( !(the_func %in% c("DIC","WAIC","LOO")) ) {
# unrecognized IC function; just wing it
IC.list <- lapply( L , function(z) func( z ) )
}

DIC.list <- unlist(DIC.list)
IC.list <- unlist(IC.list)

dDIC <- DIC.list - min( DIC.list )
w.DIC <- ICweights( DIC.list )
dIC <- IC.list - min( IC.list )
w.IC <- ICweights( IC.list )

if ( WAIC==FALSE )
result <- data.frame( DIC=DIC.list , pD=pD.list , dDIC=dDIC , weight=w.DIC )
else {
if ( the_func=="DIC" )
result <- data.frame( DIC=IC.list , pD=p.list , dDIC=dIC , weight=w.IC )
if ( the_func=="WAIC" ) {
# find out which model has dWAIC==0
topm <- which( dDIC==0 )
topm <- which( dIC==0 )
dSEcol <- dSE.matrix[,topm]
result <- data.frame( WAIC=DIC.list , pWAIC=pD.list , dWAIC=dDIC ,
weight=w.DIC , SE=se.list , dSE=dSEcol )
result <- data.frame( WAIC=IC.list , pWAIC=p.list , dWAIC=dIC ,
weight=w.IC , SE=se.list , dSE=dSEcol )
}
if ( the_func=="LOO" ) {
topm <- which( dIC==0 )
dSEcol <- dSE.matrix[,topm]
result <- data.frame( LOO=IC.list , pLOO=p.list , dLOO=dIC ,
weight=w.IC , SE=se.list , dSE=dSEcol )
}
if ( !(the_func %in% c("DIC","WAIC","LOO")) ) {
result <- data.frame( IC=IC.list , dIC=dIC , weight=w.IC )
}

rownames(result) <- mnames

if ( !is.null(sort) ) {
if ( sort!=FALSE ) {
if ( WAIC==FALSE & sort=="WAIC" ) sort <- "DIC"
if ( sort=="WAIC" ) sort <- the_func
result <- result[ order( result[[sort]] ) , ]
}
}
Expand Down
20 changes: 14 additions & 6 deletions R/ensemble.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
# build ensemble of samples using DIC/WAIC weights
ensemble <- function( ... , data , n=1e3 , WAIC=TRUE , refresh=0 , replace=list() , do_link=TRUE , do_sim=TRUE ) {
ensemble <- function( ... , data , n=1e3 , func=WAIC , weights , refresh=0 , replace=list() , do_link=TRUE , do_sim=TRUE ) {
# retrieve list of models
L <- list(...)
if ( is.list(L[[1]]) && length(L)==1 )
L <- L[[1]]
# retrieve model names from function call
mnames <- match.call()
mnames <- as.character(mnames)[2:(length(L)+1)]
if ( length(L)>1 ) {
ictab <- compare( ... , WAIC=WAIC , refresh=refresh , n=n , sort=FALSE )
rownames(ictab@output) <- mnames
weights <- ictab@output$weight

if ( missing(weights) ) {
if ( length(L)>1 ) {
use_func <- func
ictab <- compare( ... , func=use_func , refresh=refresh , n=n , sort=FALSE )
rownames(ictab@output) <- mnames
weights <- ictab@output$weight
} else {
ictab <- NA
weights <- 1
}
} else {
# explicit custom weights
ictab <- NA
weights <- 1
weights <- weights/sum(weights) # ensure sum to one
}

# compute number of predictions per model
Expand Down
3 changes: 2 additions & 1 deletion R/map2stan-class.r
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ setMethod("pairs" , "map2stan" , function(x, n=500 , alpha=0.7 , cex=0.7 , pch=1
# my trace plot function
#rethink_palette <- c("#5BBCD6","#F98400","#F2AD00","#00A08A","#FF0000")
rethink_palette <- c("#8080FF","#F98400","#F2AD00","#00A08A","#FF0000")
tracerplot <- function( object , pars , col=rethink_palette , alpha=1 , bg=gray(0.7,0.5) , ask=TRUE , window , n_cols=3 , max_rows=5 , ... ) {
rethink_cmyk <- c(col.alpha("black",0.25),"cyan")
tracerplot <- function( object , pars , col=rethink_palette , alpha=1 , bg=col.alpha("black",0.15) , ask=TRUE , window , n_cols=3 , max_rows=5 , ... ) {
chain.cols <- col

if ( class(object)!="map2stan" ) stop( "requires map2stan fit" )
Expand Down
19 changes: 15 additions & 4 deletions R/map2stan-templates.r
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ map2stan.templates <- list(
for ( i in 1:length(kout) ) {
# clear from start list, so not in parameters block
apar <- as.character(kout[[i]])
start[[apar]] <- startp[[apar]] <- NULL
start[[apar]] <- NULL
for ( ch in 1:length(startp) ) startp[[ch]][[apar]] <- NULL
# pattern:
# apar <- col(v,i)
trans_txt <- c(
Expand Down Expand Up @@ -356,10 +357,13 @@ map2stan.templates <- list(
assign( "start" , start , envir=parent.frame() )
# have to check start_prior too
startp <- get( "start_prior" , envir=parent.frame() )
startp[[Rho_name]] <- NULL
for ( ch in 1:length(startp) ) startp[[ch]][[Rho_name]] <- NULL
assign( "start_prior" , startp , envir=parent.frame() )

# add Rho to transformed parameters
# need to change this from transpars to GQ
# intent: add Rho to GENERATED QUANTITITIES
# do not want it in transformed pars, bc less efficient
# just need one eval per sample
# Rho <- L_Rho * L_Rho'
transpars <- get( "transpars" , envir=parent.frame() )
transpars[[Rho_name]] <- c(
Expand Down Expand Up @@ -473,7 +477,14 @@ map2stan.templates <- list(
par_bounds = c("","<lower=0>"),
par_types = c("real","real"),
out_type = "real",
par_map = function(k,...) {
par_map = function(k,e,...) {
# get constraints and add <lower=0> for scale
constr_list <- get( "constraints" , envir=e )
scale_name <- as.character( k[[2]] )
if ( is.null(constr_list[[scale_name]]) ) {
constr_list[[scale_name]] <- "lower=0"
assign( "constraints" , constr_list , envir=e )
}
return(k);
},
vectorized = TRUE
Expand Down
Loading

0 comments on commit a9acd53

Please sign in to comment.