Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasKook committed Feb 14, 2024
1 parent 7f46cab commit d5ef687
Show file tree
Hide file tree
Showing 20 changed files with 1,111 additions and 674 deletions.
4 changes: 4 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
^.*\.Rproj$
^\.Rproj\.user$
^\.github$
^inst$
1 change: 1 addition & 0 deletions .github/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.html
29 changes: 29 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: R-CMD-check

jobs:
R-CMD-check:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes
steps:
- uses: actions/checkout@v3

- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.Rproj.user
.Rhistory
.RData
.Ruserdata
*.Rproj
*.csv
*.pdf
inst/data/
inst/ignore/
mimic
*.Rout
*.bib
*.rds
*.sh

18 changes: 18 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Package: comet
Type: Package
Title: Covariance Measure Tests for Conditional Independence
Version: 0.0-1
Authors@R: person("Lucas", "Kook", email = "[email protected]",
role = c("aut", "cre"))
Description: Covariance measure tests for conditional independence testing
against conditional covariance and nonlinear conditional mean alternatives.
Contains versions of the generalised covariance measure test (Shah and Peters,
2020, <doi:TODO>) and projected covariance measure test (Lundborg et al., 2023,
<doi:TODO>). Applications can be found in Kook and Lundborg (2024, <doi:>).
Imports: mlt, sandwich, ranger, glmnet
License: GPL-3
Encoding: UTF-8
RoxygenNote: 7.2.3
Suggests:
testthat (>= 3.0.0), tram, ggplot2, tidyr, ggpubr, dplyr
Config/testthat/edition: 3
674 changes: 0 additions & 674 deletions LICENSE

This file was deleted.

13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,gcm)
S3method(plot,pcm)
export(gcm)
export(pcm)
importFrom(glmnet,cv.glmnet)
importFrom(ranger,ranger)
importFrom(stats,model.frame)
importFrom(stats,model.response)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,terms)
166 changes: 166 additions & 0 deletions R/gcm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#' GCM using random forests
#'
#' @param Y Response
#' @param X Covariates
#' @param Z Covariates
#' @param alternative Alternative
#' @param ... Additional arguments to ranger
#'
#' @return Object of class htest
#' @export
#'
#' @examples
#' X <- matrix(rnorm(2e3), ncol = 2)
#' colnames(X) <- c("X1", "X2")
#' Z <- matrix(rnorm(2e3), ncol = 2)
#' colnames(Z) <- c("Z1", "Z2")
#' Y <- rnorm(1e3) # X[, 2] + Z[, 2] + rnorm(1e3)
#' (gcm1 <- gcm(Y, X, Z))
#' plot(gcm1)
#'
gcm <- function(Y, X, Z, alternative = c("two.sided", "less", "greater"), ...) {
alternative <- match.arg(alternative)
args <- if (length(list(...)) > 0) list(...) else NULL
YZ <- do.call("pcm_ranger", c(list(y = Y, x = Z), args))
XZ <- apply(as.data.frame(X), 2, \(tX) {
do.call("pcm_ranger", c(list(y = tX, x = Z), args))
})
rY <- Y - predict(YZ, data = Z)
preds <- lapply(XZ, predict.pcm_ranger, data = Z)
rX <- X - do.call("cbind", preds)
stat <- .gcm(rY, rX)
pval <- .compute_normal_pval(stat, alternative)

structure(list(
statistic = c("Z" = stat), p.value = pval,
hypothesis = c("E[cov(Y, X | Z)]" = "0"),
null.value = c("E[cov(Y, X | Z)]" = "0"), alternative = alternative,
method = paste0("Generalized covariance measure test"),
data.name = deparse(match.call(), width.cutoff = 80),
rY = rY, rX = rX), class = c("gcm", "htest"))

}

# Helpers -----------------------------------------------------------------

.gcm <- function (r, e) {
dR <- NCOL(r)
dE <- NCOL(e)
nn <- NROW(r)
if (dR > 1 || dE > 1) {
R_mat <- matrix(r, nrow = nn, ncol = dE) * e
sigma <- crossprod(R_mat)/nn - tcrossprod(colMeans(R_mat))
eig <- eigen(sigma)
if (min(eig$values) < .Machine$double.eps)
warning("`vcov` of test statistic is not invertible")
siginvhalf <- eig$vectors %*% diag(eig$values^(-1/2)) %*%
t(eig$vectors)
tstat <- siginvhalf %*% colSums(R_mat)/sqrt(nn)
stat <- structure(sum(tstat^2), df = dR * dE)
}
else {
R <- r * e
R.sq <- R^2
meanR <- mean(R)
stat <- sqrt(nn) * meanR/sqrt(mean(R.sq) - meanR^2)
}
stat
}

#' @importFrom stats pnorm
.compute_normal_pval <- function(stat, alternative) {
if (!is.null(df <- attr(stat, "df")))
return(stats::pchisq(stat, df = df, lower.tail = FALSE))
switch(
alternative,
"two.sided" = 2 * stats::pnorm(-abs(stat)),
"greater" = stats::pnorm(-abs(stat)),
"less" = stats::pnorm(abs(stat))
)
}

.rm_int <- function(x) {
if (all(x[, 1] == 1))
return(x[, -1L, drop = FALSE])
x
}

#' @importFrom stats terms
.get_terms <- function(formula) {
if (is.null(formula))
return(NULL)
atms <- stats::terms(formula)
tms <- attr(atms, "term.labels")
resp <- all.vars(formula)[1]
ridx <- grep("|", tms, fixed = TRUE)
tms[ridx] <- paste0("(", tms[ridx], ")")
ie <- grep(":", tms, value = TRUE)
me <- grep(":", tms, value = TRUE, invert = TRUE)
list(all = tms, me = me, ie = ie, response = resp, terms = atms,
fml = formula)
}

# Ranger ------------------------------------------------------------------

#' @importFrom stats model.response model.frame
.ranger <- function(formula, data, ...) {
response <- stats::model.response(stats::model.frame(formula, data))
is_factor <- is.factor(response)
tms <- .get_terms(formula)
resp <- if (is_factor)
.rm_int(stats::model.matrix(~ response, contrasts.arg = list(
"response" = "contr.treatment")))
else response
tmp <- list(data = data, response = resp, is_factor = is_factor)
if (identical(tms$me, character(0))) {
if (is_factor)
return(structure(c(list(mean = base::colMeans(resp)), tmp),
class = "ranger"))
else return(structure(c(list(mean = mean(as.numeric(response))),
tmp), class = "ranger"))
}
ret <- ranger::ranger(formula, data, probability = is_factor, ...)
structure(c(ret, tmp), class = "ranger")
}

#' @importFrom stats predict
residuals.ranger <- function(object, newdata = NULL, newy = NULL, ...) {
if (is.null(newdata))
newdata <- object$data
if (!is.null(newy))
newy <- if (object$is_factor)
.rm_int(stats::model.matrix(~ newy, contrasts.arg = list(
"newy" = "contr.treatment")))
else newy
if (is.null(newy))
newy <- object$response
if (!is.null(object$mean))
return(newy - object$mean)
preds <- stats::predict(object, data = newdata)$predictions
if (object$is_factor)
preds <- preds[, -1]
unname(newy - preds)
}

# Diagnostics -------------------------------------------------------------

#' @exportS3Method plot gcm
plot.gcm <- function(x, ...) {
.data <- NULL
pd <- tidyr::pivot_longer(data.frame(rY = x$rY, rX = unname(x$rX)),
dplyr::starts_with("rX"))
if (requireNamespace("ggplot2")) {
p1 <- ggplot2::ggplot(pd, ggplot2::aes(x = .data[["value"]] ,
y = .data[["rY"]],
color = .data[["name"]])) +
ggplot2::geom_point(alpha = 0.3, show.legend = FALSE) +
ggplot2::geom_smooth(method = "lm", se = FALSE, show.legend = FALSE) +
ggplot2::theme_bw() +
ggplot2::labs(x = "Residuals X | Z", y = "Residuals Y | Z")
print(p1)
}
return(invisible(p1))
}

.mm <- function(preds, data)
.rm_int(stats::model.matrix(stats::reformulate(preds), data = data))
Loading

0 comments on commit d5ef687

Please sign in to comment.