Skip to content

Commit

Permalink
kNNDM with train/test split
Browse files Browse the repository at this point in the history
  • Loading branch information
JanLinnenbrink committed Jun 10, 2024
1 parent ed4987d commit 5af0a17
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 19 deletions.
105 changes: 90 additions & 15 deletions R/knndm.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#' @param predpoints sf or sfc point object, or data.frame if space = "feature". Contains the target prediction points. Optional; alternative to modeldomain (see Details).
#' @param space character. Either "geographical" or "feature".
#' @param k integer. Number of folds desired for CV. Defaults to 10.
#' @param prop_test numeric. The proportion of data that should be assigned to the test set (train/test split).
#' By default NULL, resulting in k-fold Cross-Validation instead of train/test split.
#' @param maxp numeric. Maximum fold size allowed, defaults to 0.5, i.e. a single fold can hold a maximum of half of the training points.
#' @param clustering character. Possible values include "hierarchical" and "kmeans". See details.
#' @param linkf character. Only relevant if clustering = "hierarchical". Link function for agglomerative hierarchical clustering.
Expand All @@ -19,7 +21,6 @@
#' Only required if modeldomain is used instead of predpoints.
#' @param useMD boolean. Only for `space`=feature: shall the Mahalanobis distance be calculated instead of Euclidean?
#' Only works with numerical variables.
#'
#' @return An object of class \emph{knndm} consisting of a list of eight elements:
#' indx_train, indx_test (indices of the observations to use as
#' training/test data in each kNNDM CV iteration), Gij (distances for
Expand Down Expand Up @@ -97,7 +98,7 @@
#' plot(train_points, add = TRUE, col = "red")
#'
#' # Run kNNDM for the whole domain, here the prediction points are known.
#' knndm_folds <- knndm(train_points, predpoints = pred_points, k = 5)
#' knndm_folds <- knndm(train_points, predpoints = pred_points, k = 10)
#' knndm_folds
#' plot(knndm_folds)
#' plot(knndm_folds, type = "simple") # For more accessible legend labels
Expand Down Expand Up @@ -125,15 +126,20 @@
#' plot(train_points, add = TRUE, col = "red")
#'
#' # Run kNNDM for the whole domain, here the prediction points are known.
#' knndm_folds <- knndm(train_points, predpoints = pred_points, k = 5)
#' # Firstly split in train/test
#' knndm_test <- knndm(train_points, predpoints = pred_points, k = 2, prop_test=0.3)
#' # Then, split train data into CV folds
#' train_points_split <- train_points[knndm_test$clusters == 1,]
#' knndm_folds <- knndm(train_points_split, predpoints = pred_points, k = 10)
#' knndm_folds
#' plot(knndm_folds)
#' plot(knndm_folds, type = "simple") # For more accessible legend labels
#' plot(knndm_folds, type = "simple", stat = "density") # To visualize densities rather than ECDFs
#' folds <- as.character(knndm_folds$clusters)
#' ggplot() +
#' geom_sf(data = simarea, alpha = 0) +
#' geom_sf(data = train_points, aes(col = folds))
#' geom_sf(data = train_points_split, aes(col = folds)) +
#' geom_sf(data = train_points[knndm_test$cluster == 2,], colour="black")
#'}
#' ########################################################################
#' # Example 3: Real- world example; using a modeldomain instead of previously
Expand Down Expand Up @@ -198,15 +204,17 @@
#'
#'knndm_folds <- knndm(trainDat[,predictors], modeldomain = predictors_sp, space = "feature",
#' clustering="kmeans", k=4, maxp=0.8)
#'plot(knndm_folds)
#'plot(knndm_folds, type="simple")
#'
#'}
knndm <- function(tpoints, modeldomain = NULL, predpoints = NULL,
space = "geographical",
k = 10, maxp = 0.5,
k = 10, maxp = 0.5, prop_test=NULL,
clustering = "hierarchical", linkf = "ward.D2",
samplesize = 1000, sampling = "regular", useMD=FALSE){



# create sample points from modeldomain
if(is.null(predpoints)&!is.null(modeldomain)){

Expand Down Expand Up @@ -303,7 +311,13 @@ knndm <- function(tpoints, modeldomain = NULL, predpoints = NULL,
}



# if train/test split desired: adjust parameters
if(!is.null(prop_test)) {
k <- 2
if(maxp <= 0.5) {
maxp <- 0.9
}
}


# kNNDM in the geographical / feature space
Expand All @@ -312,14 +326,14 @@ knndm <- function(tpoints, modeldomain = NULL, predpoints = NULL,
# prior checks
check_knndm_geo(tpoints, predpoints, space, k, maxp, clustering, islonglat)
# kNNDM in geographical space
knndm_res <- knndm_geo(tpoints, predpoints, k, maxp, clustering, linkf, islonglat)
knndm_res <- knndm_geo(tpoints, predpoints, k, maxp, clustering, linkf, islonglat, prop_test)

} else if (isTRUE(space == "feature")) {

# prior checks
check_knndm_feature(tpoints, predpoints, space, k, maxp, clustering, islonglat, catVars,useMD)
# kNNDM in feature space
knndm_res <- knndm_feature(tpoints, predpoints, k, maxp, clustering, linkf, catVars, useMD)
knndm_res <- knndm_feature(tpoints, predpoints, k, maxp, clustering, linkf, catVars, useMD, prop_test)

}

Expand Down Expand Up @@ -379,7 +393,7 @@ check_knndm_feature <- function(tpoints, predpoints, space, k, maxp, clustering,


# kNNDM in the geographical space
knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat){
knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat, prop_test){

# Gj and Gij calculation
tcoords <- sf::st_coordinates(tpoints)[,1:2]
Expand All @@ -401,7 +415,13 @@ knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat
testks <- suppressWarnings(stats::ks.test(Gj, Gij, alternative = "great"))
if(testks$p.value >= 0.05){

clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=F)

if(!is.null(prop_test)) {
clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=T,
prob = rep(c(1-prop_test, prop_test), ceiling(nrow(tpoints)/k)))
} else {
clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=F)
}

if(isTRUE(islonglat)){
Gjstar <- distclust_distmat(distmat, clust)
Expand All @@ -426,8 +446,15 @@ knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat
# Build grid of number of clusters to try - we sample low numbers more intensively
clustgrid <- data.frame(nk = as.integer(round(exp(seq(log(k), log(nrow(tpoints)-2),
length.out = 100)))))

clustgrid$W <- NA
clustgrid <- clustgrid[!duplicated(clustgrid$nk),]

if(!is.null(prop_test)) {
n_test <- floor(prop_test*clustgrid$nk)
clustgrid <- clustgrid[n_test>1,]
}

clustgroups <- list()

# Compute 1st PC for ordering clusters
Expand Down Expand Up @@ -467,7 +494,24 @@ knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat

# And we merge the remaining into k groups
clust_i <- setdiff(1:k, unique(tabclust$clust_k))
tabclust$clust_k[is.na(tabclust$clust_k)] <- rep(clust_i, ceiling(nk/length(clust_i)))[1:sum(is.na(tabclust$clust_k))]
if(is.null(prop_test)) {
tabclust$clust_k[is.na(tabclust$clust_k)] <- rep(clust_i, ceiling(nk/length(clust_i)))[1:sum(is.na(tabclust$clust_k))]
} else {
n_train <- ceiling((1-prop_test)*nk)
n_test <- floor(prop_test*nk)

if(n_train > n_test) {
v1 <- rep(1, times=n_train)
v2 <- c(rep(2, times=n_test), rep(NA, n_train - n_test))
} else {
v1 <- rep(1, times=n_test)
v2 <- c(rep(2, times=n_train), rep(NA, n_test - n_train))
}

v3 <- c(rbind(v1, v2))
tabclust$clust_k[is.na(tabclust$clust_k)] <- v3[!is.na(v3)]
}

tabclust2 <- data.frame(ID = 1:length(clust_nk), clust_nk = clust_nk)
tabclust2 <- merge(tabclust2, tabclust, by = "clust_nk")
tabclust2 <- tabclust2[order(tabclust2$ID),]
Expand Down Expand Up @@ -509,7 +553,7 @@ knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat


# kNNDM in the feature space
knndm_feature <- function(tpoints, predpoints, k, maxp, clustering, linkf, catVars, useMD) {
knndm_feature <- function(tpoints, predpoints, k, maxp, clustering, linkf, catVars, useMD, prop_test) {

# rescale data
if(is.null(catVars)) {
Expand Down Expand Up @@ -594,7 +638,12 @@ knndm_feature <- function(tpoints, predpoints, k, maxp, clustering, linkf, catVa
testks <- suppressWarnings(stats::ks.test(Gj, Gij, alternative = "great"))
if(testks$p.value >= 0.05){

clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=F)
if(!is.null(prop_test)) {
clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=T,
prob = rep(c(1-prop_test, prop_test), ceiling(nrow(tpoints)/k)))
} else {
clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=F)
}

if(is.null(catVars)) {
if(isTRUE(useMD)) {
Expand Down Expand Up @@ -645,6 +694,12 @@ knndm_feature <- function(tpoints, predpoints, k, maxp, clustering, linkf, catVa
length.out = 100)))))
clustgrid$W <- NA
clustgrid <- clustgrid[!duplicated(clustgrid$nk),]

if(!is.null(prop_test)) {
n_test <- floor(prop_test*clustgrid$nk)
clustgrid <- clustgrid[n_test>1,]
}

clustgroups <- list()

# Compute 1st PC for ordering clusters
Expand Down Expand Up @@ -713,12 +768,32 @@ knndm_feature <- function(tpoints, predpoints, k, maxp, clustering, linkf, catVa

# And we merge the remaining into k groups
clust_i <- setdiff(1:k, unique(tabclust$clust_k))
tabclust$clust_k[is.na(tabclust$clust_k)] <- rep(clust_i, ceiling(nk/length(clust_i)))[1:sum(is.na(tabclust$clust_k))]

if(is.null(prop_test)) {
tabclust$clust_k[is.na(tabclust$clust_k)] <- rep(clust_i, ceiling(nk/length(clust_i)))[1:sum(is.na(tabclust$clust_k))]
} else {
n_train <- ceiling((1-prop_test)*nk)
n_test <- floor(prop_test*nk)

if(n_train > n_test) {
v1 <- rep(1, times=n_train)
v2 <- c(rep(2, times=n_test), rep(NA, n_train - n_test))
} else {
v1 <- rep(1, times=n_test)
v2 <- c(rep(2, times=n_train), rep(NA, n_test - n_train))
}

v3 <- c(rbind(v1, v2))
tabclust$clust_k[is.na(tabclust$clust_k)] <- v3[!is.na(v3)]
}

tabclust2 <- data.frame(ID = 1:length(clust_nk), clust_nk = clust_nk)
tabclust2 <- merge(tabclust2, tabclust, by = "clust_nk")
tabclust2 <- tabclust2[order(tabclust2$ID),]
clust_k <- tabclust2$clust_k



# Compute W statistic if not exceeding maxp
if(!(any(table(clust_k)/length(clust_k)>maxp))){

Expand Down
17 changes: 13 additions & 4 deletions man/knndm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions vignettes/cast01-CAST-intro.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
## ----setup, echo=FALSE--------------------------------------------------------
knitr::opts_chunk$set(fig.width = 8.83)

## ----message = FALSE, warning=FALSE-------------------------------------------
#install.packages("CAST")
library(CAST)

## ----message = FALSE, warning=FALSE-------------------------------------------
help(CAST)

0 comments on commit 5af0a17

Please sign in to comment.