diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 069f66b02a..82e9c9f064 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -67,6 +67,12 @@ class BoundsCheckMode(enum.IntEnum): IGNORE = 2 # No bounds checks. NONE = 3 + # IGNORE with V2 enabled + V2_IGNORE = 4 + # WARNING with V2 enabled + V2_WARNING = 5 + # FATAL with V2 enabled + V2_FATAL = 6 class EmbeddingSpecInfo(enum.IntEnum): diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d8667abe07..85ebd69f2e 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -638,6 +638,20 @@ def __init__( # noqa C901 self.pooling_mode = pooling_mode self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE # If environment variable is set, it overwrites the default bounds check mode. + self.bounds_check_version: int = 1 + if bounds_check_mode.name.startswith("V2_"): + self.bounds_check_version = 2 + if bounds_check_mode == BoundsCheckMode.V2_IGNORE: + bounds_check_mode = BoundsCheckMode.IGNORE + elif bounds_check_mode == BoundsCheckMode.V2_WARNING: + bounds_check_mode = BoundsCheckMode.WARNING + elif bounds_check_mode == BoundsCheckMode.V2_FATAL: + bounds_check_mode = BoundsCheckMode.FATAL + else: + raise NotImplementedError( + f"Did not recognize V2 bounds check mode: {bounds_check_mode}" + ) + self.bounds_check_mode_int: int = int( os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) ) @@ -3352,6 +3366,7 @@ def prepare_inputs( b_t_map=b_t_map, info_B_num_bits=info_B_num_bits, info_B_mask=info_B_mask, + bounds_check_version=self.bounds_check_version, ) return indices, offsets, per_sample_weights, vbe_metadata