Skip to content

Commit

Permalink
v1.83 - inching forward
Browse files Browse the repository at this point in the history
- data sets: NWOGrants, PrussianHorses
- ulam's traceplot omits log_lik now
- ulam now accepts previous ulam fit and runs without recompiling
  • Loading branch information
Richard McElreath committed Jan 31, 2019
1 parent 285166c commit 03997bb
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 15 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: rethinking
Type: Package
Title: Statistical Rethinking book package
Version: 1.82
Date: 2019-01-20
Version: 1.83
Date: 2019-01-27
Author: Richard McElreath
Maintainer: Richard McElreath <[email protected]>
Imports: coda, MASS, mvtnorm, loo
Expand Down
6 changes: 3 additions & 3 deletions R/map2stan-divergent.r
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# extracts n_divergent from stan fit
divergent <- function( fit , warmup=FALSE ) {
if ( class(fit)=="map2stan" ) fit <- fit@stanfit
if ( class(fit) %in% c("map2stan","ulam") ) fit <- fit@stanfit
x <- rstan::get_sampler_params(fit)
if ( warmup==FALSE ) {
nwarmup <- fit@stan_args[[1]]$warmup
niter <- fit@stan_args[[1]]$iter
n <- sapply( x , function(ch) sum(ch[(nwarmup+1):niter,5]) )
n <- sapply( x , function(ch) sum(ch[(nwarmup+1):niter,"divergent__"]) )
} else {
n <- sapply( x , function(ch) sum(ch[,5]) )
n <- sapply( x , function(ch) sum(ch[,"divergent__"]) )
}
sum(n)
}
Expand Down
4 changes: 2 additions & 2 deletions R/precis.r
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ precis_format <- function( result , depth , sort , decreasing ) {
hits_idx <- which( hits > -1 )
if ( length(hits_idx)>0 ) {
result <- result[-hits_idx,]
message( paste( length(hits_idx) , "vector or matrix parameters omitted in display. Use depth=2 to show them." ) )
message( paste( length(hits_idx) , "vector or matrix parameters hidden. Use depth=2 to show them." ) )
}
}
if ( depth==2 ) {
hits <- regexpr(",",rownames(result),fixed=TRUE)
hits_idx <- which( hits > -1 )
if ( length(hits_idx)>0 ) {
result <- result[-hits_idx,]
message( paste( length(hits_idx) , "matrix parameters omitted in display. Use depth=3 to show them." ) )
message( paste( length(hits_idx) , "matrix parameters hidden. Use depth=3 to show them." ) )
}
}

Expand Down
7 changes: 4 additions & 3 deletions R/ulam-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ setMethod("pairs" , "ulam" , function(x, n=200 , alpha=0.7 , cex=0.7 , pch=16 ,
pairs( posterior , cex=cex , pch=pch , upper.panel=panel.2d , lower.panel=panel.cor , diag.panel=panel.dens , ... )
})

setMethod("traceplot", "ulam" , function(object,...) traceplot_ulam(object,...) )

# my trace plot function
#rethink_palette <- c("#5BBCD6","#F98400","#F2AD00","#00A08A","#FF0000")
rethink_palette <- c("#8080FF","#F98400","#F2AD00","#00A08A","#FF0000")
Expand All @@ -214,11 +212,13 @@ traceplot_ulam <- function( object , pars , chains , col=rethink_palette , alpha
if ( missing(chains) ) chains <- 1:n_chains
pars <- dimnames$parameters
chain.cols <- rep_len(col,n_chains)
# cut out "dev" and "lp__"
# cut out "dev" and "lp__" and "log_lik"
wdev <- which(pars=="dev")
if ( length(wdev)>0 ) pars <- pars[-wdev]
wlp <- which(pars=="lp__")
if ( length(wlp)>0 & lp==FALSE ) pars <- pars[-wlp]
wlp <- grep( "log_lik" , pars , fixed=TRUE )
if ( length(wlp)>0 ) pars <- pars[-wlp]

# figure out grid and paging
n_pars <- length( pars )
Expand Down Expand Up @@ -290,6 +290,7 @@ traceplot_ulam <- function( object , pars , chains , col=rethink_palette , alpha
}#k

}
setMethod("traceplot", "ulam" , function(object,...) traceplot_ulam(object,...) )

setMethod( "plot" , "ulam" , function(x,y,...) precis_plot(precis(x,depth=y),...) )

Expand Down
28 changes: 23 additions & 5 deletions R/ulam-function.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@

ulam <- function( flist , data , pars , pars_omit , start , chains=1 , cores=1 , iter=1000 , control=list(adapt_delta=0.95) , distribution_library=ulam_dists , macro_library=ulam_macros , custom , declare_all_data=TRUE , log_lik=FALSE , sample=TRUE , messages=TRUE , pre_scan_data=TRUE , ... ) {

data <- as.list(data)
if ( !missing(data) )
data <- as.list(data)

# check for previous fit passed instead of formula
prev_stanfit <- FALSE
if ( class(flist)=="ulam" ) {
prev_stanfit <- TRUE
prev_stanfit_object <- flist@stanfit
if ( missing(data) ) data <- flist@data
flist <- flist@formula
}

if ( pre_scan_data==TRUE ) {
# pre-scan for index variables (integer) that are numeric by accident
Expand Down Expand Up @@ -1084,14 +1094,22 @@ ulam <- function( flist , data , pars , pars_omit , start , chains=1 , cores=1 ,

# fire lasers
if ( sample==TRUE ) {
if ( length(start)==0 )
stanfit <- stan( model_code = model_code , data = data , pars=use_pars ,
if ( length(start)==0 ) {
if ( prev_stanfit==FALSE )
stanfit <- stan( model_code = model_code , data = data , pars=use_pars ,
chains=chains , cores=cores , iter=iter , control=control , ... )
else
stanfit <- stan( fit = prev_stanfit_object , data = data , pars=use_pars ,
chains=chains , cores=cores , iter=iter , control=control , ... )
else {
} else {
f_init <- "random"
if ( class(start)=="list" ) f_init <- function() return(start)
if ( class(start)=="function" ) f_init <- start
stanfit <- stan( model_code = model_code , data = data , pars=use_pars ,
if ( prev_stanfit==FALSE )
stanfit <- stan( model_code = model_code , data = data , pars=use_pars ,
chains=chains , cores=cores , iter=iter , control=control , init=f_init , ... )
else
stanfit <- stan( fit = prev_stanfit_object , data = data , pars=use_pars ,
chains=chains , cores=cores , iter=iter , control=control , init=f_init , ... )
}
}
Expand Down
19 changes: 19 additions & 0 deletions data/NWOGrants.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
discipline;gender;applications;awards
Chemical sciences;m;83;22
Chemical sciences;f;39;10
Physical sciences;m;135;26
Physical sciences;f;39;9
Physics;m;67;18
Physics;f;9;2
Humanities;m;230;33
Humanities;f;166;32
Technical sciences;m;189;30
Technical sciences;f;62;13
Interdisciplinary;m;105;12
Interdisciplinary;f;78;17
Earth/life sciences;m;156;38
Earth/life sciences;f;126;18
Social sciences;m;425;65
Social sciences;f;409;47
Medical sciences;m;245;46
Medical sciences;f;260;29
Loading

0 comments on commit 03997bb

Please sign in to comment.