Skip to content

Commit

Permalink
Added more tests and support for setting consecutive seeds (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdt committed Nov 4, 2022
1 parent 9135ee8 commit 3401fd2
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
17 changes: 17 additions & 0 deletions R/meta-method-converged.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ lcMetaConverged = function(method, maxRep = Inf) {
#' @rdname lcMetaMethod-interface
setMethod('fit', 'lcMetaConverged', function(method, data, envir, verbose) {
attempt = 1L

repeat {
enter(verbose, level = verboseLevels$fine, suffix = '')
model = fit(getLcMethod(method), data = data, envir = envir, verbose = verbose)
Expand All @@ -44,6 +45,12 @@ setMethod('fit', 'lcMetaConverged', function(method, data, envir, verbose) {
return (model)
} else {
attempt = attempt + 1L
seed = sample.int(.Machine$integer.max, 1L)
set.seed(seed)
if (has_lcMethod_args(getLcMethod(method), 'seed')) {
# update fit method with new seed
method@arguments$method = update(getLcMethod(method), seed = seed, .eval = TRUE)
}

if (is.infinite(method$maxRep)) {
cat(verbose, sprintf('Method failed to converge. Retrying... attempt %d', attempt))
Expand All @@ -53,3 +60,13 @@ setMethod('fit', 'lcMetaConverged', function(method, data, envir, verbose) {
}
}
})

#' @rdname lcMetaMethod-interface
setMethod('validate', 'lcMetaConverged', function(method, data, envir = NULL, ...) {
callNextMethod()

validate_that(
has_lcMethod_args(method, 'maxRep'),
is.count(method$maxRep)
)
})
3 changes: 3 additions & 0 deletions man/lcMetaMethod-interface.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions tests/testthat/test-meta-methods.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
method = lcMethodLMKM(Value ~ Assessment, id = 'Traj', time = 'Assessment', nClusters = 2)

setClass('lcMethodConv', contains = 'lcMethod')

lcMethodConv = function(
response = 'Value',
time = 'Assessment',
id = 'Traj',
nClusters = 1,
nAttempts = 1,
...
) {
mc = match.call.all()
mc$Class = 'lcMethodConv'
do.call(new, as.list(mc))
}

setMethod('preFit', 'lcMethodConv', function(method, data, envir, verbose) {
convAttempts <<- 0
callNextMethod()
})

setMethod('fit', 'lcMethodConv', function(method, data, envir, verbose) {
convAttempts <<- convAttempts + 1
lcModelPartition(
data = data,
response = method$response,
trajectoryAssignments = rep(1, uniqueN(data[[method$id]])),
converged = convAttempts >= method$nAttempts
)
})


test_that('specify converged', {
metaMethod = lcMetaConverged(method)
expect_s4_class(metaMethod, 'lcMetaConverged')
Expand Down Expand Up @@ -40,3 +71,37 @@ test_that('meta converged fit', {
model = latrend(metaMethod, testLongData)
})
})

test_that('meta converged fit until converged', {
metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 2), maxRep = 3)

# workaround because testthat::expect_message() is failing to capture the output...
out = capture.output({
model = latrend(metaMethod, testLongData, verbose = TRUE)
}, type = 'message')
expect_match(paste0(out, collapse = '\n'), regexp = 'attempt 2')
expect_true(converged(model))
})

test_that('meta converged fit always fails', {
metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 3), maxRep = 2)
expect_warning({
model = latrend(metaMethod, testLongData)
}, regexp = 'Failed to obtain converged')

expect_false(converged(model))
})

test_that('meta converged fit with seed on first attempt', {
metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 1, seed = 13))
model = latrend(metaMethod, testLongData)

expect_equal(getLcMethod(model)$method$seed, 13)
})

test_that('meta converged fit different seed on second attempt', {
metaMethod = lcMetaConverged(lcMethodConv(nAttempts = 2, seed = 13))
model = latrend(metaMethod, testLongData)

expect_true(getLcMethod(model)$method$seed != 13)
})

0 comments on commit 3401fd2

Please sign in to comment.