From bd1f04e4b5185f4a690e5596d3853bb45bb9f975 Mon Sep 17 00:00:00 2001 From: Felipe Carlos Date: Tue, 2 Jul 2024 16:38:49 -0300 Subject: [PATCH] fix base data extraction --- R/sits_classify.R | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/R/sits_classify.R b/R/sits_classify.R index b5334489c..114eab59d 100644 --- a/R/sits_classify.R +++ b/R/sits_classify.R @@ -253,10 +253,14 @@ sits_classify.raster_cube <- function(data, .check_filter_fn(filter_fn) # Retrieve the samples from the model samples <- .ml_samples(ml_model) - # Retrieve bands from the model - base_bands <- intersect( - .ml_bands(ml_model), .cube_bands(.cube_base_info(data)) - ) + # By default, base bands is null. + base_bands <- NULL + if (.cube_is_base(data)) { + # Get base bands + base_bands <- intersect( + .ml_bands(ml_model), .cube_bands(.cube_base_info(data)) + ) + } # get non-base bands bands <- setdiff(.ml_bands(ml_model), base_bands) # Do the samples and tile match their timeline length? @@ -270,8 +274,12 @@ sits_classify.raster_cube <- function(data, job_size = .block_size(block = block, overlap = 0), npaths = ( length(.tile_paths(data, bands)) + - length(.tile_paths(.cube_base_info(data), base_bands)) + - length(.ml_labels(ml_model)) + length(.ml_labels(ml_model)) + + ifelse( + test = .cube_is_base(data), + yes = length(.tile_paths(.cube_base_info(data), base_bands)), + no = 0 + ) ), nbytes = 8, proc_bloat = proc_bloat