Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust base cube classification to extract time series using block strategy #1166

Merged
merged 12 commits into from
Jul 3, 2024
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ S3method(sits_apply,raster_cube)
S3method(sits_apply,sits)
S3method(sits_as_sf,raster_cube)
S3method(sits_as_sf,sits)
S3method(sits_bands,base_raster_cube)
S3method(sits_bands,default)
S3method(sits_bands,patterns)
S3method(sits_bands,raster_cube)
Expand Down
2 changes: 1 addition & 1 deletion R/api_check.R
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,7 @@
n_bands <- length(.samples_bands.sits(samples))
n_times <- length(.samples_timeline(samples))
if(inherits(samples, "sits_base"))
n_bands_base <- length(.samples_bands_base(samples))
n_bands_base <- length(.samples_base_bands(samples))
else
n_bands_base <- 0
.check_that(ncol(pred) == 2 + n_bands * n_times + n_bands_base)
Expand Down
42 changes: 31 additions & 11 deletions R/api_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#' in the classified images for each corresponding year.
#'
#' @param tile Single tile of a data cube.
#' @param band Band to be produced.
#' @param out_band Band to be produced.
#' @param bands Bands to extract time series
#' @param base_bands Base bands to extract values
#' @param ml_model Model trained by \code{\link[sits]{sits_train}}.
#' @param block Optimized block to be read into memory.
#' @param roi Region of interest.
Expand All @@ -29,7 +31,9 @@
#' @param progress Show progress bar?
#' @return List of the classified raster layers.
.classify_tile <- function(tile,
band,
out_band,
bands,
base_bands,
ml_model,
block,
roi,
Expand All @@ -42,7 +46,7 @@
# Output file
out_file <- .file_derived_name(
tile = tile,
band = band,
band = out_band,
version = version,
output_dir = output_dir
)
Expand All @@ -53,7 +57,7 @@
}
probs_tile <- .tile_derived_from_file(
file = out_file,
band = band,
band = out_band,
base_tile = tile,
labels = .ml_labels_code(ml_model),
derived_class = "probs_cube",
Expand Down Expand Up @@ -105,7 +109,8 @@
values <- .classify_data_read(
tile = tile,
block = block,
bands = .ml_bands(ml_model),
bands = bands,
base_bands = base_bands,
ml_model = ml_model,
impute_fn = impute_fn,
filter_fn = filter_fn
Expand Down Expand Up @@ -138,7 +143,7 @@
# Prepare probability to be saved
band_conf <- .conf_derived_band(
derived_class = "probs_cube",
band = band
band = out_band
)
offset <- .offset(band_conf)
if (.has(offset) && offset != 0) {
Expand Down Expand Up @@ -181,7 +186,7 @@
# Merge blocks into a new probs_cube tile
probs_tile <- .tile_derived_merge_blocks(
file = out_file,
band = band,
band = out_band,
labels = .ml_labels_code(ml_model),
base_tile = tile,
block_files = block_files,
Expand Down Expand Up @@ -374,11 +379,12 @@
#' @param tile Input tile to read data.
#' @param block Bounding box in (col, row, ncols, nrows).
#' @param bands Bands to extract time series
#' @param base_bands Base bands to extract values
#' @param ml_model Model trained by \code{\link[sits]{sits_train}}.
#' @param impute_fn Imputation function
#' @param filter_fn Smoothing filter function to be applied to the data.
#' @return A matrix with values for classification.
.classify_data_read <- function(tile, block, bands,
.classify_data_read <- function(tile, block, bands, base_bands,
ml_model, impute_fn, filter_fn) {
# For cubes that have a time limit to expire (MPC cubes only)
tile <- .cube_token_generator(tile)
Expand All @@ -388,7 +394,7 @@
tile = tile,
block = block
)
# Read and preprocess values of each band
# Read and preprocess values of each eo band
values <- purrr::map(bands, function(band) {
# Get band values (stops if band not found)
values <- .tile_read_block(
Expand Down Expand Up @@ -436,9 +442,23 @@
# Return values
return(as.data.frame(values))
})
# Read and preprocess values of each base band
values_base <- purrr::map(base_bands, function(band) {
# Read and preprocess values of each base band
values_base <- .tile_read_block(
tile = .tile_base_info(tile),
band = band,
block = block
)
# Return values
return(as.data.frame(values_base))
})
# Combine two lists
values <- c(values, values_base)
# collapse list to get data.frame
values <- suppressMessages(purrr::list_cbind(values,
name_repair = "universal"))
values <- suppressMessages(
purrr::list_cbind(values, name_repair = "universal")
)
# Compose final values
values <- as.matrix(values)
# Set values features name
Expand Down
35 changes: 32 additions & 3 deletions R/api_cube.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,29 @@ NULL
}
.cube_set_class(cube)
}
#' @title Identity function for data cubes
#' @keywords internal
#' @noRd
#' @name .cube
#' @param x cube
#'
#' @return data cube object.
.cube <- function(x) {
# return the cube
x
}
#' @title Return areas of classes of a class_cue
#' @title Get base info from a data cube
#' @keywords internal
#' @noRd
#' @name .cube
#' @param x cube
#'
#' @return data cube from base_info
.cube_base_info <- function(x) {
# return base info data cube
dplyr::bind_rows(x[["base_info"]])
}
#' @title Return areas of classes of a class_cube
#' @keywords internal
#' @noRd
#' @name .cube_class_areas
Expand Down Expand Up @@ -192,7 +210,7 @@ NULL
class(cube) <- c("raster_cube", class(cube))
bands <- .cube_bands(cube)
} else {
stop(.conf("messages", "cube_bands"))
stop(.conf("messages", ".cube_bands"))
}
return(bands)
}
Expand All @@ -203,7 +221,7 @@ NULL
cube <- tibble::as_tibble(cube)
bands <- .cube_bands(cube, add_cloud, dissolve)
} else {
stop(.conf("messages", "cube_bands"))
stop(.conf("messages", ".cube_bands"))
}
return(bands)
}
Expand Down Expand Up @@ -544,6 +562,17 @@ NULL
}
return(is_regular)
}

#' @title Check that cube is a base cube
#' @name .cube_is_base
#' @keywords internal
#' @noRd
#' @param cube datacube
#' @return Called for side effects.
.cube_is_base <- function(cube) {
inherits(cube, "base_raster_cube")
}

#' @title Find out how many images are in cube during a period
#' @noRd
#' @param cube A data cube.
Expand Down
9 changes: 6 additions & 3 deletions R/api_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
} else {
cld_band <- NULL
}
if (.cube_is_base(cube)) {
bands <- setdiff(bands, .cube_bands(.cube_base_info(cube)))
}

# define parallelization strategy
# find block size
rast <- .raster_open_rast(.tile_path(cube))
Expand Down Expand Up @@ -81,9 +85,7 @@
}
if (.has(cube[["base_info"]])) {
# get base info
cube_base <- cube[["base_info"]]
# bind all base info
cube_base <- dplyr::bind_rows(cube_base)
cube_base <- .cube_base_info(cube)
# get bands
bands_base <- .cube_bands(cube_base)
# extract data
Expand All @@ -97,6 +99,7 @@
)
# save base data
ts_tbl[["base_data"]] <- base_tbl[["time_series"]]
# add base class
class(ts_tbl) <- c("sits_base", class(ts_tbl))
}
return(ts_tbl)
Expand Down
2 changes: 1 addition & 1 deletion R/api_plot_time_series.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
# how many time series are to be plotted?
number <- nrow(data2)
# what are the band names?
bands <- .samples_bands(data2)
bands <- .samples_bands(data2, include_base = FALSE)
# what are the reference dates?
ref_dates <- .samples_timeline(data2)
# align all time series to the same dates
Expand Down
2 changes: 1 addition & 1 deletion R/api_plot_vector.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# verifies if stars package is installed
.check_require_packages("stars")
# verifies if tmap package is installed
.check_require_packages("plot")
.check_require_packages("tmap")
# precondition - check color palette
.check_palette(palette)
# revert the palette
Expand Down
42 changes: 14 additions & 28 deletions R/api_predictors.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
# Get samples time series
pred <- .ts(samples)
# By default get bands as the same of first sample
bands <- .samples_bands(samples)
bands <- .samples_bands(samples, include_base = FALSE)
# Preprocess time series
if (.has(ml_model)) {
# If a model is informed, get predictors from model bands
bands <- .ml_bands(ml_model)
bands <- intersect(.ml_bands(ml_model), bands)

# Normalize values for old version model classifiers that
# do not normalize values itself
# Models trained after version 1.2 do this automatically before
Expand All @@ -49,7 +50,7 @@
})
}
}
# Create predictors...
# Create predictors
pred <- pred[c(.pred_cols, bands)]
# Add sequence 'index' column grouped by 'sample_id'
pred <- pred |>
Expand All @@ -68,36 +69,22 @@
}
#' @export
.predictors.sits_base <- function(samples, ml_model = NULL) {
# Get predictors for time series
# Prune samples time series
samples <- .samples_prune(samples)
# Get samples time series
pred <- .ts(samples)
# By default get bands as the same of first sample
bands <- .samples_bands.sits(samples)
# Create predictors...
pred <- pred[c(.pred_cols, bands)]
# Add sequence 'index' column grouped by 'sample_id'
pred <- pred |>
dplyr::select("sample_id", "label", dplyr::all_of(bands)) |>
dplyr::group_by(.data[["sample_id"]]) |>
dplyr::mutate(index = seq_len(dplyr::n())) |>
dplyr::ungroup()
# Rearrange data to create predictors
pred <- tidyr::pivot_wider(
data = pred, names_from = "index", values_from = dplyr::all_of(bands),
names_prefix = if (length(bands) == 1) bands else "",
names_sep = ""
)
# get predictors for base data
base <- dplyr::bind_rows(samples$base_data)
base <- base[,-1]
# join time series predictors with base data predictors
pred <- dplyr::bind_cols(pred, base)
pred <- .predictors.sits(samples, ml_model)
# Get predictors for base data
pred_base <- samples |>
dplyr::rename(
"_" = "time_series", "time_series" = "base_data"
) |>
.predictors.sits() |>
dplyr::select(-.data[["label"]])
# Merge predictors
pred <- dplyr::inner_join(pred, pred_base, by = "sample_id")
# Return predictors
pred
}

#' @title Get predictors names with timeline
#' @keywords internal
#' @noRd
Expand All @@ -113,7 +100,6 @@
USE.NAMES = FALSE
))
}

#' @title Get features from predictors
#' @keywords internal
#' @noRd
Expand Down
Loading
Loading