Skip to content

Commit

Permalink
Sample only for feature extraction (#1130)
Browse files Browse the repository at this point in the history
* New branch for testing patch

* Fixes for sampled usage and docs - changed api

* removed parameter
  • Loading branch information
azimov authored Aug 27, 2024
1 parent 9eabd7d commit e98c495
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 85 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Imports:
SqlRender (>= 1.9.0),
stringr,
tidyr (>= 1.2.0),
CohortGenerator (>= 0.8.0),
CohortGenerator (>= 0.10.0),
remotes,
scales
Suggests:
Expand All @@ -61,7 +61,7 @@ License: Apache License
VignetteBuilder: knitr
URL: https://ohdsi.github.io/CohortDiagnostics, https://github.com/OHDSI/CohortDiagnostics
BugReports: https://github.com/OHDSI/CohortDiagnostics/issues
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
Encoding: UTF-8
Language: en-US
StagedInstall: no
Expand Down
19 changes: 11 additions & 8 deletions R/CohortLevelDiagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ getCohortCounts <- function(connectionDetails = NULL,
)
counts <-
DatabaseConnector::querySql(connection, sql, snakeCaseToCamelCase = TRUE) %>%
tidyr::tibble()
tidyr::tibble()

if (length(cohortIds) > 0) {
cohortIdDf <- tidyr::tibble(cohortId = as.numeric(cohortIds))
Expand Down Expand Up @@ -97,7 +97,8 @@ computeCohortCounts <- function(connection,
cohorts,
exportFolder,
minCellCount,
databaseId) {
databaseId,
writeResult = TRUE) {
ParallelLogger::logInfo("Counting cohort records and subjects")
cohortCounts <- getCohortCounts(
connection = connection,
Expand All @@ -117,11 +118,13 @@ computeCohortCounts <- function(connection,
databaseId = databaseId
)

writeToCsv(
data = cohortCounts,
fileName = file.path(exportFolder, "cohort_count.csv"),
incremental = FALSE,
cohortId = cohorts$cohortId
)
if (writeResult) {
writeToCsv(
data = cohortCounts,
fileName = file.path(exportFolder, "cohort_count.csv"),
incremental = FALSE,
cohortId = cohorts$cohortId
)
}
return(cohortCounts)
}
3 changes: 2 additions & 1 deletion R/Incremental.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ writeToCsv <- function(data, fileName, incremental = FALSE, ...) {
UseMethod("writeToCsv", data)
}


#' @noRd
writeToCsv.default <- function(data, fileName, incremental = FALSE, ...) {
colnames(data) <- SqlRender::camelCaseToSnakeCase(colnames(data))
if (incremental) {
Expand Down Expand Up @@ -186,6 +186,7 @@ writeToCsv.default <- function(data, fileName, incremental = FALSE, ...) {
}
}

#'@noRd
writeToCsv.tbl_Andromeda <-
function(data, fileName, incremental = FALSE, ...) {
if (incremental && file.exists(fileName)) {
Expand Down
3 changes: 2 additions & 1 deletion R/Private.R
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ getPrefixedTableNames <- function(tablePrefix) {
return(resultList)
}

#' @noRd

#' Internal utility function for logging execution of variables
#' @noRd
timeExecution <- function(exportFolder,
taskName,
cohortIds = NULL,
Expand Down
132 changes: 75 additions & 57 deletions R/RunDiagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,18 @@ getDefaultCovariateSettings <- function() {
#' @param incremental Create only cohort diagnostics that haven't been created before?
#' @param incrementalFolder If \code{incremental = TRUE}, specify a folder where records are kept
#' of which cohort diagnostics has been executed.
#' @param runOnSample Logical. If TRUE, the function will operate on a sample of the data.
#' @param runFeatureExtractionOnSample Logical. If TRUE, the function will operate on a sample of the data.
#' Default is FALSE, meaning the function will operate on the full data set.
#'
#' @param sampleN Integer. The number of records to include in the sample if runOnSample is TRUE.
#' Default is 1000. Ignored if runOnSample is FALSE.
#' @param sampleN Integer. The number of records to include in the sample if runFeatureExtractionOnSample is TRUE.
#' Default is 1000. Ignored if runFeatureExtractionOnSample is FALSE.
#'
#' @param seed Integer. The seed for the random number generator used to create the sample.
#' This ensures that the same sample can be drawn again in future runs. Default is 64374.
#'
#' @param seedArgs List. Additional arguments to pass to the sampling function.
#' This can be used to control aspects of the sampling process beyond the seed and sample size.
#'
#' @param sampleIdentifierExpression Character. An expression that generates unique identifiers for each sample.
#' This expression can use the variables 'cohortId' and 'seed'.
#' Default is "cohortId * 1000 + seed", which ensures unique identifiers
#' as long as there are fewer than 1000 cohorts.

#' @examples
#' \dontrun{
#' # Load cohorts (assumes that they have already been instantiated)
Expand Down Expand Up @@ -228,11 +223,10 @@ executeDiagnostics <- function(cohortDefinitionSet,
irWashoutPeriod = 0,
incremental = FALSE,
incrementalFolder = file.path(exportFolder, "incremental"),
runOnSample = FALSE,
runFeatureExtractionOnSample = FALSE,
sampleN = 1000,
seed = 64374,
seedArgs = NULL,
sampleIdentifierExpression = "cohortId * 1000 + seed") {
seedArgs = NULL) {
# collect arguments that were passed to cohort diagnostics at initiation
callingArgs <- formals(executeDiagnostics)
callingArgsJson <-
Expand All @@ -250,7 +244,7 @@ executeDiagnostics <- function(cohortDefinitionSet,
incremental = callingArgs$incremental,
temporalCovariateSettings = callingArgs$temporalCovariateSettings
) %>%
RJSONIO::toJSON(digits = 23, pretty = TRUE)
RJSONIO::toJSON(digits = 23, pretty = TRUE)

exportFolder <- normalizePath(exportFolder, mustWork = FALSE)
incrementalFolder <- normalizePath(incrementalFolder, mustWork = FALSE)
Expand Down Expand Up @@ -279,25 +273,25 @@ executeDiagnostics <- function(cohortDefinitionSet,
errorMessage <- checkmate::makeAssertCollection()
checkmate::assertList(cohortTableNames, null.ok = FALSE, types = "character", add = errorMessage, names = "named")
checkmate::assertNames(names(cohortTableNames),
must.include = c(
"cohortTable",
"cohortInclusionTable",
"cohortInclusionResultTable",
"cohortInclusionStatsTable",
"cohortSummaryStatsTable",
"cohortCensorStatsTable"
),
add = errorMessage
must.include = c(
"cohortTable",
"cohortInclusionTable",
"cohortInclusionResultTable",
"cohortInclusionStatsTable",
"cohortSummaryStatsTable",
"cohortCensorStatsTable"
),
add = errorMessage
)
checkmate::assertDataFrame(cohortDefinitionSet, add = errorMessage)
checkmate::assertNames(names(cohortDefinitionSet),
must.include = c(
"json",
"cohortId",
"cohortName",
"sql"
),
add = errorMessage
must.include = c(
"json",
"cohortId",
"cohortName",
"sql"
),
add = errorMessage
)

cohortTable <- cohortTableNames$cohortTable
Expand Down Expand Up @@ -474,17 +468,17 @@ executeDiagnostics <- function(cohortDefinitionSet,
sort()
cohortTableColumnNamesExpected <-
getResultsDataModelSpecifications() %>%
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()
cohortTableColumnNamesRequired <-
getResultsDataModelSpecifications() %>%
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::filter(.data$isRequired == "Yes") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::filter(.data$isRequired == "Yes") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()

expectedButNotObsevered <-
setdiff(x = cohortTableColumnNamesExpected, y = cohortTableColumnNamesObserved)
Expand Down Expand Up @@ -549,23 +543,6 @@ executeDiagnostics <- function(cohortDefinitionSet,
}
}

if (runOnSample & !isTRUE(attr(cohortDefinitionSet, "isSampledCohortDefinition"))) {
cohortDefinitionSet <-
CohortGenerator::sampleCohortDefinitionSet(
connection = connection,
cohortDefinitionSet = cohortDefinitionSet,
tempEmulationSchema = tempEmulationSchema,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTableNames = cohortTableNames,
n = sampleN,
seed = seed,
seedArgs = seedArgs,
identifierExpression = sampleIdentifierExpression,
incremental = incremental,
incrementalFolder = incrementalFolder
)
}

## CDM source information----
timeExecution(
exportFolder,
Expand Down Expand Up @@ -871,18 +848,59 @@ executeDiagnostics <- function(cohortDefinitionSet,
cohortIds,
parent = "executeDiagnostics",
expr = {

feCohortDefinitionSet <- cohortDefinitionSet
feCohortTable <- cohortTable
feCohortCounts <- cohortCounts

if (runFeatureExtractionOnSample & !isTRUE(attr(cohortDefinitionSet, "isSampledCohortDefinition"))) {
cohortTableNames$cohortSampleTable <- paste0(cohortTableNames$cohortTable, "_cd_sample")
CohortGenerator::createCohortTables(connection = connection,
cohortTableNames = cohortTableNames,
cohortDatabaseSchema = cohortDatabaseSchema,
incremental = TRUE)

feCohortTable <- cohortTableNames$cohortSampleTable
feCohortDefinitionSet <-
CohortGenerator::sampleCohortDefinitionSet(
connection = connection,
cohortDefinitionSet = cohortDefinitionSet,
tempEmulationSchema = tempEmulationSchema,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTableNames = cohortTableNames,
n = sampleN,
seed = seed,
seedArgs = seedArgs,
identifierExpression = "cohortId",
incremental = incremental,
incrementalFolder = incrementalFolder
)

feCohortCounts <- computeCohortCounts(
connection = connection,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTable = cohortTableNames$cohortSampleTable,
cohorts = feCohortDefinitionSet,
exportFolder = exportFolder,
minCellCount = minCellCount,
databaseId = databaseId,
writeResult = FALSE
)
}


executeCohortCharacterization(
connection = connection,
databaseId = databaseId,
exportFolder = exportFolder,
cdmDatabaseSchema = cdmDatabaseSchema,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTable = cohortTable,
cohortTable = feCohortTable,
covariateSettings = temporalCovariateSettings,
tempEmulationSchema = tempEmulationSchema,
cdmVersion = cdmVersion,
cohorts = cohortDefinitionSet,
cohortCounts = cohortCounts,
cohorts = feCohortDefinitionSet,
cohortCounts = feCohortCounts,
minCellCount = minCellCount,
instantiatedCohorts = instantiatedCohorts,
incremental = incremental,
Expand Down
16 changes: 5 additions & 11 deletions man/executeDiagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-1-ResultsDataModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ VALUES ('Synthea','Synthea','OHDSI Community','SyntheaTM is a Synthetic Patient
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
},
"CDM Source table has more than one record while only one is expected."
Expand All @@ -149,7 +149,7 @@ VALUES ('Synthea','Synthea','OHDSI Community','SyntheaTM is a Synthetic Patient
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
}

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-2-againstCdm.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test_that("Cohort diagnostics in incremental mode", {
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
)

Expand Down Expand Up @@ -76,7 +76,7 @@ test_that("Cohort diagnostics in incremental mode", {
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
)
# generate sqlite file
Expand Down Expand Up @@ -123,7 +123,7 @@ test_that("Cohort diagnostics in incremental mode", {
incremental = FALSE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
})

Expand Down

0 comments on commit e98c495

Please sign in to comment.