Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

catboost method to embed categorical variables #138

Open
talegari opened this issue Jun 22, 2022 · 11 comments
Open

catboost method to embed categorical variables #138

talegari opened this issue Jun 22, 2022 · 11 comments
Labels
feature a feature request or enhancement

Comments

@talegari
Copy link

talegari commented Jun 22, 2022

Hi Emil,
I am planning to implement a step_catboost (on these lines). IMHO, it should belong here.

Let me know if you are open for PR?

@juliasilge
Copy link
Member

Unfortunately catboost (the R package) is not on CRAN 😔 which is a blocker for us being able to implement catboost methods in our packages. You can see related discussion in catboost/catboost#439.

@talegari
Copy link
Author

hey Julia, step_catboost would not depend on catboost package. The step involves involves permutations and target encoding. Here is the python implementation of the same.

@EmilHvitfeldt EmilHvitfeldt added the feature a feature request or enhancement label Jun 22, 2022
@EmilHvitfeldt
Copy link
Member

Hey @talegari 👋

That sounds great! Feel free to open an issue, and ping me if you need any help or assistance!

@EmilHvitfeldt
Copy link
Member

Hello @talegari 👋 Are you still interested opening a PR for this step? if not, then I will do it

@talegari
Copy link
Author

Hey @EmilHvitfeldt ... it just fell off the radar. I will submit a PR. I am planning on these lines. Let me know if you have a different suggestion.

@EmilHvitfeldt
Copy link
Member

Amazing! That looks like a great place to start!
Do you know when you will have time to work on this? No rush!

@talegari
Copy link
Author

talegari commented Mar 17, 2023 via email

@talegari
Copy link
Author

hey @EmilHvitfeldt , there was an unforseen thing that stopped me working on this. This is to let you know that I am on it and will raise a PR shortly.

@EmilHvitfeldt
Copy link
Member

no problem! It might not make it into the next {embed} release, but that is fine, we can send it in later

@talegari
Copy link
Author

talegari commented Apr 4, 2023

@EmilHvitfeldt , I am one step away from raising a PR. I need your help in resolving a small issue. Here is the context:

I have implemented catboost encoder as a R6 class here:

Category encoder R6 class
# catboost encoder core logic
pacman::p_load("tidyverse")

#' catboost_encoder R6 class
#'
#' An R6 class to encode categorical variables with the CatBoost method.
#'
#' @name catboost_encoder
#' @docType class
#' @importFrom R6 R6Class
#'
#' @slot dataset The dataset to fit the encoder
#' @slot mean The mean of the response variable in the dataset
#' @slot varnames_to_encode The names of the categorical variables to encode
#' @slot response_varname The name of the response variable in the dataset
#' @slot is_fitted A flag indicating whether the encoder has been fitted
#' @slot a A hyperparameter to control the strength of the encoding
#'
#' @section Public methods: \describe{
#'   \item{\code{initialize(dataset)}}{Constructor method for the
#'   catboost_encoder class} \item{\code{fit(varnames_to_encode,
#'   response_varname, a = 1)}}{Fit the encoder to the data}
#'   \item{\code{transform(new_data = NULL)}}{Transform a new dataset using the
#'   fitted encoder} }
#'
#' @section Private methods: \describe{ \item{\code{encode_with_y(df,
#'   varname_to_encode, response_varname)}}{Encode a categorical variable using
#'   the response variable} \item{\code{encode_without_y(df, varname_to_encode,
#'   response_varname)}}{Encode a categorical variable without using the
#'   response variable} }
#'
#' @section Usage
#'
#'   catboost_encoder <- catboost_encoder$new(dataset)
#'   catboost_encoder$fit(varnames_to_encode, response_varname) 
#'   encoded_data <- catboost_encoder$transform(new_data)
#'
#' @export catboost_encoder
catboost_encoder = R6::R6Class(
  "catboost_encoder",
  public = list(
    
    dataset               = NULL,
    mean                  = NULL,
    varnames_to_encode    = NULL,
    response_varname      = NULL,
    is_fitted             = FALSE,
    a                     = NULL,
    encode_novel_levels   = NULL,
    encode_missing_levels = NULL,
    
    initialize = function(dataset){
      checkmate::assert_data_frame(dataset)
      self$dataset = dataset
      return(invisible(NULL))
    },
    
    fit = function(varnames_to_encode,
                   response_varname,
                   a = 1,
                   encode_novel_levels = TRUE,
                   encode_missing_levels = FALSE
                   ){
      
      checkmate::assert_string(response_varname)
      checkmate::assert_subset(response_varname,
                               choices = colnames(self$dataset)
                               )
      checkmate::assert_numeric(self$dataset[[response_varname]],
                                any.missing = FALSE
                                )
      checkmate::assert_character(varnames_to_encode)
      checkmate::assert_subset(varnames_to_encode,
                               choices = colnames(self$dataset)
                               )
      for (avarname in varnames_to_encode){
        checkmate::assert_factor(self$dataset[[avarname]])
      }
      
      checkmate::assert_number(a)
      checkmate::assert_flag(encode_novel_levels)
      checkmate::assert_flag(encode_missing_levels)
      
      self$varnames_to_encode = varnames_to_encode
      self$response_varname = response_varname
      self$mean = mean(self$dataset[[response_varname]], na.rm = TRUE)
      self$a = a
      self$encode_novel_levels = TRUE
      self$encode_missing_levels = FALSE
      
      self$is_fitted = TRUE
      return(invisible(NULL))
    },
    
    transform = function(new_data = NULL){
      new_data_is_null = TRUE
      if (!is.null(new_data)){
        checkmate::assert_data_frame(new_data)
        checkmate::assert_false(self$response_varname %in% colnames(new_data))
        names_sorted = sort(colnames(new_data))
        checkmate::assert_set_equal(colnames(new_data),
                                    setdiff(colnames(self$dataset),
                                            self$response_varname
                                            )
                                    )
        checkmate::assert_set_equal(
          sapply(new_data, class)[names_sorted],
          sapply(dplyr::select(self$dataset, -c(self$response_varname))
                 , class
                 )[names_sorted]
          )
        new_data_is_null = FALSE
      }
      
      if (!self$is_fitted){
        stop("please 'fit' before 'transform'")
      }
      
      if (new_data_is_null){
        message("transforming on the dataset")
        new_data = self$dataset
      }
      
      if (new_data_is_null){
        encoded_cols = map(self$varnames_to_encode,
                           ~ private$encode_with_y(new_data, .x)
                           )
        
      } else {
        encoded_cols = map(self$varnames_to_encode,
                           ~ private$encode_without_y(new_data,.x)
                           )
      }
      
      names(encoded_cols) = self$varnames_to_encode
      
      res = as_tibble(encoded_cols) %>% 
        bind_cols(select(new_data, -c(self$varnames_to_encode))) %>% 
        relocate(colnames(new_data))
      
      # encode novel (in new data case only)
      if (self$encode_novel_levels && !new_data_is_null){
        for (avarname in self$varnames_to_encode){
          new_levels = setdiff(levels(new_data[[avarname]]),
                               levels(self$dataset[[avarname]])
                               )
          if (length(new_levels) > 0){
            res[[avarname]] = ifelse(new_data[[avarname]] %in% new_levels,
                                     self$mean,
                                     res[[avarname]]
                                     )
          }
        }
      }
            
      # encode missing (in new data case only)
      if (self$encode_missing_levels && !new_data_is_null){
        for (avarname in self$varnames_to_encode){
          res[[avarname]][ is.na(new_data[[avarname]]) ] = NA
        }
      }
      
      return(res)
    }
  ),
  private = list(
    
    encode_with_y = function(df, varname_to_encode){
      
      # new levels: not applicable
      # NA: encoded
      
      res = df %>% 
        select(all_of(c(varname_to_encode, self$response_varname))) %>% 
        group_by(.data[[varname_to_encode]]) %>% 
        mutate(cs__ = cumsum(.data[[self$response_varname]]),
               cc__ = row_number() - 1L
               ) %>% 
        ungroup() %>% 
        transmute({{varname_to_encode}} := (cs__ -
                      .data[[self$response_varname]] +
                      mean(.data[[self$response_varname]], na.rm = TRUE) *
                      self$a
                      ) / (cc__ + self$a)
               ) %>% 
        pull()
      
      return(res)
    },
    
    encode_without_y = function(df, varname_to_encode){
      
      # new levels: NA
      # NA: NA
      
      level_means = "level_means__"
      
      agg_frame = self$dataset %>% 
        select(all_of(c(varname_to_encode, self$response_varname))) %>% 
        group_by(.data[[varname_to_encode]]) %>% 
        summarise(sum__ = sum(.data[[self$response_varname]], na.rm = TRUE),
                  count__ = n()
                  ) %>% 
        ungroup() %>% 
        mutate(level_means__ = 
                 ifelse(count__ == 1,
                        self$mean,
                        (sum__ + self$mean * self$a) / (count__ + self$a)
                        )
               ) %>% 
        drop_na(all_of(varname_to_encode)) %>% 
        select(all_of(c(varname_to_encode, level_means)))
      
      res = df %>% 
        select(all_of(c(varname_to_encode))) %>%
        left_join(agg_frame, by = varname_to_encode) %>% 
        pull(level_means)
      
      return(res)
    }
    
  )
)
recipe wrapper as 'step_catboost'
step_catboost = function(recipe,
                         ...,
                         role = NA,
                         trained = FALSE,
                         outcome = NULL,
                         mapping = NULL,
                         skip = FALSE,
                         id = rand_id("catboost")
                         ){
    if (is.null(outcome)) {
      rlang::abort("Please list a variable in `outcome`")
    }
    recipes:::add_step(
      recipe,
      step_catboost_new(
        terms = enquos(...),
        role = role,
        trained = trained,
        outcome = outcome,
        mapping = mapping,
        skip = skip,
        id = id
      )
    )
  }

step_catboost_new = 
  function(terms,
           role,
           trained,
           outcome,
           mapping,
           skip,
           id
           ){
    step(
      subclass = "catboost",
      terms = terms,
      role = role,
      trained = trained,
      outcome = outcome,
      mapping = mapping,
      skip = skip,
      id = id
      )
  }

#' @export
prep.step_catboost = function(x,
                              training,
                              info = NULL,
                              ...
                              ){
  col_names = recipes_eval_select(x$terms, training, info)

  if (length(col_names) > 0) {
    y_name = recipes_eval_select(x$outcome, training, info)
    
    # instantiate R6 class obj
    ce = catboost_encoder$new(training)
    ce$fit(varnames_to_encode = col_names,
           response_varname = y_name
           )
  } else {
    ce = list()
  }
  step_catboost_new(
    terms = x$terms,
    role = x$role,
    trained = TRUE,
    outcome = x$outcome,
    mapping = ce,
    skip = x$skip,
    id = x$id
  )
}

#' @export
bake.step_catboost = function(object, new_data, ...) {
  
  if (!is.null(new_data)){
    y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
    ce = object$mapping
    if (y_name %in% colnames(new_data)){
      new_data[[y_name]] = NULL
    }
    res = ce$transform(new_data)
  } else {
    res = ce$transform()
  }
  
  res = ce$transform(new_data)
  return(res)
}

#' @rdname required_pkgs.embed
#' @export
required_pkgs.step_catboost = function(x, ...) {
  c("embed")
}
Example
pacman::p_load("recipes", "tidyverse")
source("~/personal/catboost_encoding_r6.R")
#> transforming on the dataset
#> transforming on the dataset
source("~/personal/step_catboost.R")

pen1 = palmerpenguins::penguins %>% 
  drop_na(bill_length_mm) %>% 
  slice_sample(prop = 0.7, by = 'species')

pen2 = palmerpenguins::penguins %>% 
  drop_na(bill_length_mm) %>% 
  setdiff(pen1)

# example with R6 class
ce = catboost_encoder$new(pen1)
ce$fit(c('species', 'sex'), response_varname = 'bill_length_mm')

# when input to transofrm is empty, it uses the training dataset 
# (here it is pen1)
ce$transform()
#> transforming on the dataset
#> # A tibble: 238 × 8
#>    species island    bill_length_mm bill_depth_mm flipper_…¹ body_…²   sex  year
#>      <dbl> <fct>              <dbl>         <dbl>      <int>   <int> <dbl> <int>
#>  1    43.8 Torgersen           39.6          17.2        196    3550  43.8  2008
#>  2    41.7 Dream               37.5          18.9        179    2975  43.8  2007
#>  3    40.3 Biscoe              35.5          16.2        195    3350  41.7  2008
#>  4    39.1 Torgersen           40.6          19          199    4000  43.8  2009
#>  5    39.4 Biscoe              40.1          18.9        188    4300  42.2  2008
#>  6    39.5 Dream               39.6          18.8        190    4600  41.5  2007
#>  7    39.5 Dream               32.1          15.5        188    3050  39.6  2009
#>  8    38.6 Dream               39.8          19.1        184    4650  41.0  2007
#>  9    38.7 Torgersen           34.1          18.1        193    3475  40.6  2007
#> 10    38.3 Dream               37            16.9        185    3000  37.7  2007
#> # … with 228 more rows, and abbreviated variable names ¹​flipper_length_mm,
#> #   ²​body_mass_g

# transform on a new dataset
ce$transform(pen2 %>% select(-bill_length_mm))
#> # A tibble: 104 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g   sex  year
#>      <dbl> <fct>             <dbl>             <int>       <int> <dbl> <int>
#>  1    38.7 Torgersen          18                 195        3250  42.2  2007
#>  2    38.7 Torgersen          20.6               190        3650  45.6  2007
#>  3    38.7 Torgersen          17.8               181        3625  42.2  2007
#>  4    38.7 Torgersen          19.6               195        4675  45.6  2007
#>  5    38.7 Torgersen          21.2               191        3800  45.6  2007
#>  6    38.7 Torgersen          17.8               185        3700  42.2  2007
#>  7    38.7 Torgersen          20.7               197        4500  45.6  2007
#>  8    38.7 Torgersen          21.5               194        4200  45.6  2007
#>  9    38.7 Biscoe             18.6               172        3150  42.2  2007
#> 10    38.7 Dream              16.7               178        3250  42.2  2007
#> # … with 94 more rows

# example with step_catboost recipe
ar = recipe(bill_length_mm ~ ., data = pen1) %>% 
  step_catboost(species, outcome = "bill_length_mm") %>% 
  prep(training = pen1)

ar
#> Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>    outcome          1
#>  predictor          7
#> 
#> Training data contained 238 data points and 9 incomplete rows. 
#> 
#> Operations:
#> 
#> $terms
#> <list_of<quosure>>
#> 
#> [[1]]
#> <quosure>
#> expr: ^species
#> env:  0x7fbbb5a65120
#> 
#> 
#> $role
#> [1] NA
#> 
#> $trained
#> [1] TRUE
#> 
#> $outcome
#> [1] "bill_length_mm"
#> 
#> $mapping
#> <catboost_encoder>
#>   Public:
#>     a: 1
#>     clone: function (deep = FALSE) 
#>     dataset: tbl_df, tbl, data.frame
#>     encode_missing_levels: FALSE
#>     encode_novel_levels: TRUE
#>     fit: function (varnames_to_encode, response_varname, a = 1, encode_novel_levels = TRUE, 
#>     initialize: function (dataset) 
#>     is_fitted: TRUE
#>     mean: 43.7655462184874
#>     response_varname: bill_length_mm
#>     transform: function (new_data = NULL) 
#>     varnames_to_encode: species
#>   Private:
#>     encode_with_y: function (df, varname_to_encode) 
#>     encode_without_y: function (df, varname_to_encode) 
#> 
#> $skip
#> [1] FALSE
#> 
#> $id
#> [1] "catboost_LGVzz"
#> 
#> attr(,"class")
#> [1] "step_catboost" "step"

ar %>% 
  juice()
#> # A tibble: 238 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          17.2               196        3550 female  2008
#>  2    38.7 Dream              18.9               179        2975 <NA>    2007
#>  3    38.7 Biscoe             16.2               195        3350 female  2008
#>  4    38.7 Torgersen          19                 199        4000 male    2009
#>  5    38.7 Biscoe             18.9               188        4300 male    2008
#>  6    38.7 Dream              18.8               190        4600 male    2007
#>  7    38.7 Dream              15.5               188        3050 female  2009
#>  8    38.7 Dream              19.1               184        4650 male    2007
#>  9    38.7 Torgersen          18.1               193        3475 <NA>    2007
#> 10    38.7 Dream              16.9               185        3000 female  2007
#> # … with 228 more rows

ar %>% 
  bake(new_data = NULL)
#> # A tibble: 238 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          17.2               196        3550 female  2008
#>  2    38.7 Dream              18.9               179        2975 <NA>    2007
#>  3    38.7 Biscoe             16.2               195        3350 female  2008
#>  4    38.7 Torgersen          19                 199        4000 male    2009
#>  5    38.7 Biscoe             18.9               188        4300 male    2008
#>  6    38.7 Dream              18.8               190        4600 male    2007
#>  7    38.7 Dream              15.5               188        3050 female  2009
#>  8    38.7 Dream              19.1               184        4650 male    2007
#>  9    38.7 Torgersen          18.1               193        3475 <NA>    2007
#> 10    38.7 Dream              16.9               185        3000 female  2007
#> # … with 228 more rows

ar %>% 
  bake(new_data = pen1)
#> # A tibble: 238 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          17.2               196        3550 female  2008
#>  2    38.7 Dream              18.9               179        2975 <NA>    2007
#>  3    38.7 Biscoe             16.2               195        3350 female  2008
#>  4    38.7 Torgersen          19                 199        4000 male    2009
#>  5    38.7 Biscoe             18.9               188        4300 male    2008
#>  6    38.7 Dream              18.8               190        4600 male    2007
#>  7    38.7 Dream              15.5               188        3050 female  2009
#>  8    38.7 Dream              19.1               184        4650 male    2007
#>  9    38.7 Torgersen          18.1               193        3475 <NA>    2007
#> 10    38.7 Dream              16.9               185        3000 female  2007
#> # … with 228 more rows

ar %>% 
  bake(new_data = pen2)
#> # A tibble: 104 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          18                 195        3250 female  2007
#>  2    38.7 Torgersen          20.6               190        3650 male    2007
#>  3    38.7 Torgersen          17.8               181        3625 female  2007
#>  4    38.7 Torgersen          19.6               195        4675 male    2007
#>  5    38.7 Torgersen          21.2               191        3800 male    2007
#>  6    38.7 Torgersen          17.8               185        3700 female  2007
#>  7    38.7 Torgersen          20.7               197        4500 male    2007
#>  8    38.7 Torgersen          21.5               194        4200 male    2007
#>  9    38.7 Biscoe             18.6               172        3150 female  2007
#> 10    38.7 Dream              16.7               178        3250 female  2007
#> # … with 94 more rows

Issue: The ce$transform() and ar %>% bake(new_data = NULL) give different results. How do I resolve this?

@EmilHvitfeldt
Copy link
Member

Hello @talegari Sorry for taking a while to answer.

I'm not terrible familiar with {R6} so I'm not sure how much I can help you. However, I can tell you where something might happen. In bake.step_catboost() you have

  if (!is.null(new_data)){
    y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
    ce = object$mapping
    if (y_name %in% colnames(new_data)){
      new_data[[y_name]] = NULL
    }
    res = ce$transform(new_data)
  } else {
    res = ce$transform()
  }

I'm assuming that you thought this was needed to deal with bake(new_data = NULL). This is actually not the case, the data passed to any bake method will always be a non-NULL tibble. What is happening when you call bake(new_data = NULL) is that it extracts ar$template and does a couple of other things. So it just extracts the data we got when running prep/bake() the first time.

Secondly, I'm sad to say since you put in a lot of effort, but I don't want to include {R6} and {checkmate} as dependencies just to include this step. If you don't want to go through the work on translating away from {R6} and {checkmate} I understand, and If you want I can take over and do the last parts.

Thanks again for all the work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature a feature request or enhancement
Projects
None yet
Development

No branches or pull requests

3 participants