Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove seed argument #431

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions R/draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,6 @@ extract_data_nmar_as_na <- function(longdata) {
#' @export
draws.bayes <- function(data, data_ice = NULL, vars, method, ncores = 1, quiet = FALSE) {

if (!is.na(method$seed)) {
set.seed(method$seed)
}

longdata <- longDataConstructor$new(data, vars)
longdata$set_strategies(data_ice)

Expand Down
9 changes: 1 addition & 8 deletions R/mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ fit_mcmc <- function(

n_imputations <- method$n_samples
burn_in <- method$burn_in
seed <- method$seed
burn_between <- method$burn_between
same_cov <- method$same_cov

Expand Down Expand Up @@ -114,13 +113,7 @@ fit_mcmc <- function(
)
)

assert_that(
!is.na(seed),
!is.null(seed),
is.numeric(seed),
msg = "mcmc seed is invalid"
)
sampling_args$seed <- seed
sampling_args$seed <- sample.int(.Machine$integer.max, 1)

stan_fit <- record({
do.call(sampling, sampling_args)
Expand Down
17 changes: 10 additions & 7 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@
#' @param type a character string that specifies the resampling method used to perform inference
#' when a conditional mean imputation approach (set via `method_condmean()`) is used. Must be one of `"bootstrap"` or `"jackknife"`.
#'
#' @param seed a numeric that specifies the seed to be used in the call to Stan. This
#' argument is passed onto the `seed` argument of [rstan::sampling()]. Note that
#' this is only required for `method_bayes()`, for all other methods you can achieve
#' reproducible results by setting the seed via `set.seed()`. See details.
#' @param seed deprecated. Please use `set.seed()` instead.
#'
#' @details
#'
Expand Down Expand Up @@ -93,14 +90,20 @@ method_bayes <- function(
burn_between = 50,
same_cov = TRUE,
n_samples = 20,
seed = sample.int(.Machine$integer.max, 1)
seed = NULL
) {
if (!is.null(seed)) {
warning(paste0(
"The `seed` argument to `method_bayes()` has been deprecated",
" please use `set.seed()` instead"
))
gowerc marked this conversation as resolved.
Show resolved Hide resolved
}

x <- list(
burn_in = burn_in,
burn_between = burn_between,
same_cov = same_cov,
n_samples = n_samples,
seed = seed
n_samples = n_samples
)
return(as_class(x, c("method", "bayes")))
}
Expand Down
4 changes: 2 additions & 2 deletions data-raw/create_print_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ set.seed(413)
dobj <- get_data(40)

suppressWarnings({
set.seet(859)
gowerc marked this conversation as resolved.
Show resolved Hide resolved
drawobj_b <- draws(
data = dobj$dat,
data_ice = dobj$dat_ice,
vars = dobj$vars,
method = method_bayes(
n_samples = 50,
burn_between = 1,
seed = 859
burn_between = 1
)
)
})
Expand Down
7 changes: 2 additions & 5 deletions man/method.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion tests/testthat/_snaps/print.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@
burn_between: 1
same_cov: TRUE
n_samples: 50
seed: 859


---
Expand Down
46 changes: 13 additions & 33 deletions tests/testthat/test-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,7 @@ test_that("fit_mcmc can recover known values with same_cov = FALSE", {
n_samples = 250,
burn_in = 100,
burn_between = 3,
same_cov = FALSE,
seed = 8931
same_cov = FALSE
)

### No missingness
Expand Down Expand Up @@ -604,36 +603,17 @@ test_that("fit_mcmc can recover known values with same_cov = FALSE", {
})


test_that("invalid seed throws an error", {

set.seed(301)
sigma <- as_vcov(c(6, 4, 4), c(0.5, 0.2, 0.3))
dat <- get_sim_data(50, sigma)

dat_ice <- dat %>%
group_by(id) %>%
arrange(desc(visit)) %>%
slice(1) %>%
ungroup() %>%
mutate(strategy = "MAR")

vars <- set_vars(
visit = "visit",
subjid = "id",
group = "group",
covariates = "sex",
strategy = "strategy",
outcome = "outcome"
)

expect_error(
draws(
dat,
dat_ice,
vars,
method_bayes(n_samples = 2, seed = NA),
quiet = TRUE
),
regexp = "mcmc seed is invalid"
test_that("seed argument to method_bayes is depreciated", {
gowerc marked this conversation as resolved.
Show resolved Hide resolved
expect_warning(
{
method <- method_bayes(
n_samples = 250,
burn_in = 100,
burn_between = 3,
same_cov = FALSE,
seed = 1234
)
},
regexp = "seed.*deprecated"
)
})
3 changes: 1 addition & 2 deletions tests/testthat/test-print.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ test_that("print - bayesian", {
vars = dobj$vars,
method = method_bayes(
n_samples = 50,
burn_between = 1,
seed = 859
burn_between = 1
),
quiet = TRUE
)
Expand Down
9 changes: 4 additions & 5 deletions tests/testthat/test-reproducibility.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ test_that("Results are Reproducible", {



test_that("bayes - seed argument works without set.seed", {
test_that("bayes - set.seed produces identical results", {

sigma <- as_vcov(c(2, 1, 0.7), c(0.5, 0.3, 0.2))
dat <- get_sim_data(200, sigma, trt = 8) %>%
Expand All @@ -111,17 +111,16 @@ test_that("bayes - seed argument works without set.seed", {
)

meth <- method_bayes(
seed = 1482,
burn_between = 5,
burn_in = 200,
n_samples = 2
n_samples = 6
)

set.seed(49812)
set.seed(1234)
x <- suppressWarnings({
draws(dat, dat_ice, vars, meth, quiet = TRUE)
})
set.seed(2414)
set.seed(1234)
y <- suppressWarnings({
draws(dat, dat_ice, vars, meth, quiet = TRUE)
})
Expand Down
2 changes: 1 addition & 1 deletion vignettes/advanced.html
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ <h1><span class="header-section-number">6</span> Custom imputation strategies</h
<span id="cb6-17"><a href="#cb6-17" tabindex="-1"></a><span class="co">#&gt; pars &lt;- list(mu = mu, sigma = sigma)</span></span>
<span id="cb6-18"><a href="#cb6-18" tabindex="-1"></a><span class="co">#&gt; return(pars)</span></span>
<span id="cb6-19"><a href="#cb6-19" tabindex="-1"></a><span class="co">#&gt; }</span></span>
<span id="cb6-20"><a href="#cb6-20" tabindex="-1"></a><span class="co">#&gt; &lt;bytecode: 0x7ff37e6af218&gt;</span></span>
<span id="cb6-20"><a href="#cb6-20" tabindex="-1"></a><span class="co">#&gt; &lt;bytecode: 0x7f86686ebac0&gt;</span></span>
<span id="cb6-21"><a href="#cb6-21" tabindex="-1"></a><span class="co">#&gt; &lt;environment: namespace:rbmi&gt;</span></span></code></pre></div>
<p>To further illustrate this for a simple example, assume that a new strategy is to be implemented as follows:
- The marginal mean of the imputation distribution is equal to the marginal mean trajectory for the subject according to their assigned group and covariates up to the ICE.
Expand Down
6 changes: 2 additions & 4 deletions vignettes/quickstart.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ vars <- set_vars(
method <- method_bayes(
burn_in = 200,
burn_between = 5,
n_samples = 150,
seed = 675442751
n_samples = 150
)

# Create samples for the imputation parameters by running the draws() function
Expand Down Expand Up @@ -347,8 +346,7 @@ vars <- set_vars(
method <- method_bayes(
burn_in = 200,
burn_between = 5,
n_samples = 150,
seed = 675442751
n_samples = 150
)


Expand Down
Loading
Loading