Skip to content

Commit

Permalink
fix base data extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
M3nin0 committed Jul 2, 2024
1 parent 6f05e4c commit bd1f04e
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions R/sits_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,14 @@ sits_classify.raster_cube <- function(data,
.check_filter_fn(filter_fn)
# Retrieve the samples from the model
samples <- .ml_samples(ml_model)
# Retrieve bands from the model
base_bands <- intersect(
.ml_bands(ml_model), .cube_bands(.cube_base_info(data))
)
# By default, base bands is null.
base_bands <- NULL
if (.cube_is_base(data)) {
# Get base bands
base_bands <- intersect(
.ml_bands(ml_model), .cube_bands(.cube_base_info(data))
)
}
# get non-base bands
bands <- setdiff(.ml_bands(ml_model), base_bands)
# Do the samples and tile match their timeline length?
Expand All @@ -270,8 +274,12 @@ sits_classify.raster_cube <- function(data,
job_size = .block_size(block = block, overlap = 0),
npaths = (
length(.tile_paths(data, bands)) +
length(.tile_paths(.cube_base_info(data), base_bands)) +
length(.ml_labels(ml_model))
length(.ml_labels(ml_model)) +
ifelse(
test = .cube_is_base(data),
yes = length(.tile_paths(.cube_base_info(data), base_bands)),
no = 0
)
),
nbytes = 8,
proc_bloat = proc_bloat
Expand Down

0 comments on commit bd1f04e

Please sign in to comment.