diff --git a/R/sits_classify.R b/R/sits_classify.R index c18ea38f..8aee78aa 100644 --- a/R/sits_classify.R +++ b/R/sits_classify.R @@ -293,7 +293,9 @@ sits_classify.raster_cube <- function(data, proc_bloat = proc_bloat ) # Update multicores parameter - if ("xgb_model" %in% .ml_class(ml_model) || .is_torch_model(ml_model)) + if ("xgb_model" %in% .ml_class(ml_model)) + multicores <- 1 + else if (.torch_mps_enabled(ml_model) || .torch_cuda_enabled(ml_model)) multicores <- 1 else # Update multicores parameter