Skip to content

Commit

Permalink
patching arrays still
Browse files Browse the repository at this point in the history
- deprecation warning for map2stan
- patched multi_normal template for new Stan array syntax
- some new tests in test folder, more to come
- still need to write some docs blarg
  • Loading branch information
Richard McElreath committed Dec 14, 2022
1 parent 6b3ba17 commit 61765ff
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 7 deletions.
2 changes: 2 additions & 0 deletions R/map2stan.r
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

map2stan <- function( flist , data , start , pars , constraints=list() , types=list() , sample=TRUE , iter=2000 , warmup=floor(iter/2) , chains=1 , debug=FALSE , verbose=FALSE , WAIC=TRUE , cores=1 , rng_seed , rawstanfit=FALSE , control=list(adapt_delta=0.95) , add_unique_tag=TRUE , code , log_lik=FALSE , DIC=FALSE , declare_all_data=TRUE , do_discrete_imputation=FALSE , ... ) {

warning("DEPRECATED: map2stan is no longer supported and may behave unpredictably or stop working altogether. Start using ulam instead.",immediate.=TRUE)

if ( missing(rng_seed) ) rng_seed <- sample( 1:1e5 , 1 )
set.seed(rng_seed)

Expand Down
14 changes: 8 additions & 6 deletions R/ulam_templates.R
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,13 @@ ulam_dists <- list(
}

# do we need a local var for left side?
# patched for new Stan array syntax
# need an array of vectors for multi_normal outcome
if ( length(left) > 1 ) {
out <- concat( out , indent , "vector[" , n_vars , "] YY" )
if ( n_cases > 1 )
out <- concat( out , "[" , n_cases , "];\n" )
out <- concat( out , indent , "array[" , n_cases , "] vector[" , n_vars , "] YY;\n" )
else
out <- concat( out , ";\n" )
out <- concat( out , indent , "vector[" , n_vars , "] YY;\n" )
}

# do we need a local var for means?
Expand All @@ -415,11 +416,12 @@ ulam_dists <- list(
if ( vlen != n_cases ) warning( "multi_normal mean vector has length > 1 but not same length as outcome" )
}
# build text
out <- concat( out , indent , "vector[" , n_vars , "] MU" )
# patch for new Stan array format
if ( vlen > 1 )
out <- concat( out , "[" , vlen , "];\n" )
out <- concat( out , indent , "array[" , vlen , "] vector[" , n_vars , "] MU;\n" )
else
out <- concat( out , ";\n" )
out <- concat( out , indent , "vector[" , n_vars , "] MU;\n" )

# assign it too
vsuf <- ""
if ( vlen==1 )
Expand Down
3 changes: 2 additions & 1 deletion man/ulam.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ ulam( flist , data , pars , pars_omit , start , chains=1 , cores=1 , iter=1000 ,
control=list(adapt_delta=0.95) , distribution_library=ulam_dists ,
macro_library=ulam_macros , custom , declare_all_data=TRUE , log_lik=FALSE ,
sample=TRUE , messages=TRUE , pre_scan_data=TRUE , coerce_int=TRUE ,
sample_prior=FALSE , file=NULL , cmdstan=FALSE , threads=1 ... )
sample_prior=FALSE , file=NULL , cmdstan=FALSE , threads=1 ,
stanc_options=list("O1") , ... )
}
%- maybe also 'usage' for other objects documented here.
\arguments{
Expand Down
71 changes: 71 additions & 0 deletions tests/rethinking_tests/test_stanc_O1.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# test --O1


N <- 1e4
x <- rnorm(N)
m <- 1 + rpois(N,2)
y <- rbinom( N , size=m , prob=inv_logit( -3 + x ) )
dat <- list( y=y , x=x , m=m )

N_id <- round(N/10)
id <- rep(1:N_id,each=10)
a <- rnorm(N_id,0,1.5)
y <- rbinom( N , size=m , prob=inv_logit( -3 + a[id] + x ) )
dat$id <- id
dat$y <- y

system.time(
mO1 <- ulam(
alist(
y ~ binomial_logit( m , logit_p ),
logit_p <- a + z[id]*tau + b*x,
a ~ normal(0,1.5),
b ~ normal(0,0.5),
z[id] ~ normal(0,1),
tau ~ exponential(1)
) , data=dat ,
cmdstan=TRUE , chains=4 , threads=1 , cores=4 , refresh=1000 , stanc_options = list("O1") ) )

system.time(
mO0 <- ulam(
alist(
y ~ binomial_logit( m , logit_p ),
logit_p <- a + z[id]*tau + b*x,
a ~ normal(0,1.5),
b ~ normal(0,0.5),
z[id] ~ normal(0,1),
tau ~ exponential(1)
) , data=dat ,
cmdstan=TRUE , chains=4 , threads=1 , cores=4 , refresh=1000 , stanc_options = list("O0") ) )

# slopes

dat$N_id <- N_id
system.time(
mO1s <- ulam(
alist(
y ~ binomial_logit( m , logit_p ),
logit_p <- a + v[id,1] + (b+v[id,2])*x,
transpars> matrix[N_id,2]:v <- compose_noncentered( tau , LR , z ),
a ~ normal(0,1.5),
b ~ normal(0,0.5),
matrix[2,N_id]:z ~ normal(0,1),
vector[2]:tau ~ dexp(1),
cholesky_factor_corr[2]:LR ~ lkj_corr_cholesky( 2 )
) , data=dat ,
cmdstan=TRUE , chains=4 , threads=1 , cores=4 , refresh=1000 , stanc_options = list("O1") ) )

system.time(
mO0s <- ulam(
alist(
y ~ binomial_logit( m , logit_p ),
logit_p <- a + v[id,1] + (b+v[id,2])*x,
transpars> matrix[N_id,2]:v <- compose_noncentered( tau , LR , z ),
a ~ normal(0,1.5),
b ~ normal(0,0.5),
matrix[2,N_id]:z ~ normal(0,1),
vector[2]:tau ~ dexp(1),
cholesky_factor_corr[2]:LR ~ lkj_corr_cholesky( 2 )
) , data=dat ,
cmdstan=TRUE , chains=4 , threads=1 , cores=4 , refresh=1000 , stanc_options = list("O0") ) )

0 comments on commit 61765ff

Please sign in to comment.