Skip to content

Commit

Permalink
v1.91 - sundry
Browse files Browse the repository at this point in the history
- bug fix for ulam() automatic imputation when only 1 NA in variable
- lots of docs
- integrating PSIS-LOO better into compare function
- sim() can simulate multiple observable variables and uses causal structure of model to do so. Works for quap() now. ulam() coming soon.
- lots of testing and fuzzing yet to come
  • Loading branch information
Richard McElreath committed Oct 29, 2019
1 parent 32bf0d8 commit 2e01a81
Show file tree
Hide file tree
Showing 17 changed files with 864 additions and 191 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: rethinking
Type: Package
Title: Statistical Rethinking book package
Version: 1.90
Date: 2019-08-04
Version: 1.91
Date: 2019-10-25
Author: Richard McElreath
Maintainer: Richard McElreath <[email protected]>
Imports: coda, MASS, mvtnorm, loo, shape
Expand Down
52 changes: 52 additions & 0 deletions R/HMC2.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,55 @@ HMC2 <- function (U, grad_U, epsilon, L, current_q , ... ) {
accept = accept ,
dH = H1 - H0 ) )
}

HMC_2D_sample <- function( n=100 , U , U_gradient , step , L , start=c(0,0) , xlim=c(-5,5) , ylim=c(-4,4) , xlab="x" , ylab="y" , draw=TRUE , draw_contour=TRUE , nlvls=15 , adj_lvls=1 , ... ) {

Q <- list()
Q$q <- start
xr <- xlim
yr <- ylim

if ( draw==TRUE ) plot( NULL , xlab=xlab , ylab=ylab , xlim=xlim, ylim=ylim )

if ( draw==TRUE & draw_contour==TRUE ) {
# draw contour
zr <- 1
y_seq <- seq(from=yr[1]-zr,to=yr[2]+zr,length.out=50)
x_seq <- seq(from=xr[1]-zr,to=xr[2]+zr,length.out=50)
z5 <- matrix(NA,length(x_seq),length(y_seq))
for ( i in 1:length(x_seq) )
for ( j in 1:length(y_seq) )
z5[i,j] <- U( c( x_seq[i] , y_seq[j] ) , ... )
lz5 <- log(z5)
lvls <- exp( pretty( range(lz5), nlvls ) )
cl <- contourLines( x_seq , y_seq , z5 , level=lvls*adj_lvls )
for ( i in 1:length(cl) ) lines( cl[[i]]$x , cl[[i]]$y , col=col.alpha("black",0.6) , lwd=0.5 )
}

n_samples <- n
a <- rep(NA,n_samples)
dH <- rep(NA,n_samples) # energy
path_col <- col.alpha("black",0.3)
xpos <- c(1,3,4,2)
points( Q$q[1] , Q$q[2] , pch=4 , col="black" )
post <- matrix(NA,nrow=n,ncol=2)

for ( i in 1:n_samples ) {
Q <- HMC2( U , U_gradient , epsilon=step , L=L , current_q=Q$q )
if ( n_samples < 100 ) {
# draw paths
lines( Q$traj[,1] , Q$traj[,2] , col=path_col , lwd=2 )
}
r <- min(abs(Q$dH),1)
ptcol <- rgb( r , 0 , 0 )
ptcol <- "black"
points( Q$traj[L+1,1] , Q$traj[L+1,2] , pch=ifelse( Q$accept==1 , 16 , 1 ) , col=ptcol , cex=ifelse( Q$accept==1 , 0.7 , 1 ) , lwd=1.5 )
dH[i] <- Q$dH
a[i] <- Q$accept
if ( a[i]==1 ) {
post[i,] <- Q$q
}
}

return( invisible( post ) )
}
15 changes: 8 additions & 7 deletions R/compare.r
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ compare <- function( ... , n=1e3 , sort="WAIC" , func=WAIC , WAIC=TRUE , refresh
}#j
}#i
}
if ( the_func=="LOO" ) {
if ( the_func=="LOO" | the_func=="PSIS" ) {
# use LOO
LOO.list <- lapply( L , function(z) LOO( z , n=n , refresh=refresh , pointwise=TRUE ) )
p.list <- sapply( LOO.list , function(x) sum(attr(x,"pLOO")) )
LOO.list <- lapply( L , function(z) PSIS( z , n=n , refresh=refresh , pointwise=TRUE ) )
p.list <- sapply( LOO.list , function(x) sum(attr(x,"pPSIS")) )
se.list <- sapply( LOO.list , function(x) attr(x,"se") )
IC.list <- sapply( LOO.list , sum )
# compute SE of differences between adjacent models from top to bottom in ranking
Expand All @@ -91,7 +91,7 @@ compare <- function( ... , n=1e3 , sort="WAIC" , func=WAIC , WAIC=TRUE , refresh
}#j
}#i
}
if ( !(the_func %in% c("DIC","WAIC","LOO")) ) {
if ( !(the_func %in% c("DIC","WAIC","LOO","PSIS")) ) {
# unrecognized IC function; just wing it
IC.list <- lapply( L , function(z) func( z ) )
}
Expand All @@ -110,20 +110,21 @@ compare <- function( ... , n=1e3 , sort="WAIC" , func=WAIC , WAIC=TRUE , refresh
result <- data.frame( WAIC=IC.list , pWAIC=p.list , dWAIC=dIC ,
weight=w.IC , SE=se.list , dSE=dSEcol )
}
if ( the_func=="LOO" ) {
if ( the_func=="LOO" | the_func=="PSIS" ) {
topm <- which( dIC==0 )
dSEcol <- dSE.matrix[,topm]
result <- data.frame( LOO=IC.list , pLOO=p.list , dLOO=dIC ,
result <- data.frame( PSIS=IC.list , pPSIS=p.list , dPSIS=dIC ,
weight=w.IC , SE=se.list , dSE=dSEcol )
}
if ( !(the_func %in% c("DIC","WAIC","LOO")) ) {
if ( !(the_func %in% c("DIC","WAIC","LOO","PSIS")) ) {
result <- data.frame( IC=IC.list , dIC=dIC , weight=w.IC )
}

rownames(result) <- mnames

if ( !is.null(sort) ) {
if ( sort!=FALSE ) {
if ( the_func=="LOO" ) the_func <- "PSIS" # must match header
if ( sort=="WAIC" ) sort <- the_func
result <- result[ order( result[[sort]] ) , ]
}
Expand Down
12 changes: 11 additions & 1 deletion R/drawdag.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

drawdag <- function( x , col_arrow="black" , col_segment="black" , col_labels="black" , cex=1 , lwd=1.5 , goodarrow=TRUE , xlim , ylim , shapes , col_shapes , radius=3.5 , add=FALSE , xkcd=FALSE , latent_mark="c" , ... ){
require(dagitty)

# check for list of DAGs
if ( class(x)=="list" ) {
n <- length(x)
y <- make.grid(n)
par(mfrow=y)
for ( i in 1:n ) drawdag( x[[i]] , ... )
return(invisible(NULL))
}

x <- as.dagitty( x )
dagitty:::.supportsTypes(x,c("dag","mag","pdag"))
coords <- coordinates( x )
Expand Down Expand Up @@ -63,7 +73,7 @@ drawdag <- function( x , col_arrow="black" , col_segment="black" , col_labels="b
# edges
asp <- par("pin")[1]/diff(par("usr")[1:2]) /
(par("pin")[2]/diff(par("usr")[3:4]))
ex <- edges(x)
ex <- dagitty::edges(x)
ax1 <- rep(0,nrow(ex))
ax2 <- rep(0,nrow(ex))
ay1 <- rep(0,nrow(ex))
Expand Down
Loading

0 comments on commit 2e01a81

Please sign in to comment.