diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index 669b1a44f..d73cbff8b 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -205,6 +205,40 @@ def schema_bool_arg(name: str, default: bool = False) -> str: return f"bool {name} = {default}" +def list_arg(ty: str) -> str: + """ + Returns a C++ argument for a list of optimizer arguments the given type. + + Parameters: + ty (str) - type of the list e.g., "int", "float", "tensor" + Returns: + C++ arguemnt for a list of the given type e.g., for a list of int returns "std::vector optim_int", + """ + return { + "tensor": "std::vector> optim_tensor", + "int": "std::vector optim_int", + "float": "std::vector optim_float", + "bool": "c10::List optim_bool", + }[ty] + + +def schema_list_arg(ty: str) -> str: + """ + Returns a C++ schema for a list of optimizer arguments the given type. + + Parameters: + ty (str) - type of the list e.g., "int", "float", "tensor" + Returns: + C++ arguemnt for a list of the given type e.g., for a list of int returns "int[] optim_int", + """ + return { + "tensor": "Tensor?[] optim_tensor", + "int": "int[] optim_int", + "float": "float[] optim_float", + "bool": "bool[] optim_bool", + }[ty] + + def optional_tensor_arg(name: str) -> str: return f"std::optional {name} = std::nullopt" @@ -230,7 +264,6 @@ def schema_optional_tensorlist_arg(name: str) -> str: def make_kernel_arg( - # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType, name: str, default: Union[int, float, None], @@ -505,6 +538,10 @@ class PT2ArgsSet: split_function_schemas: List[str] split_saved_tensorlist: List[str] split_saved_tensorlist_optional: List[str] + split_saved_data: List[dict[str, str]] + split_variables: List[str] + split_unpacked_arg_names: List[str] + split_args_dict: Dict[str, List[str]] @staticmethod # pyre-ignore[3] @@ -525,27 +562,52 @@ def create( Returns: PT2ArgsSet object with the following attributes: split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions - Tensors will be packed and pass as TensorList - e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay']. + Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict. + e.g., ['at::TensorList momentum1', 'at::Dict optim_int']. split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions - e.g., ['momentum1', 'eps', 'weight_decay']. + e.g., ['momentum1', 'optim_int', 'optim_float']. split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in PT2 autograd function. e.g., ['momentum1']. split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional and will be unpacked in PT2 autograd function e.g., ['row_counter']. + split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward + split_unpacked_arg_names: List[str] - List of argument names, unrolled from list + e.g., ['momentum1', 'eps', 'weight_decay', 'iter']. + split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type. + e.g., if an optimizer only has an int argument called iter, the dict will look like: + {'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []} """ split_function_arg_names = [] split_function_args = [] split_function_schemas = [] split_saved_tensorlist = [] split_saved_tensorlist_optional = [] + split_saved_data = [] + split_variables = [] + split_unpacked_arg_names = [] + has_optim_tensor = False # optim tensors here are optional tensor + has_optim_int = False + has_optim_float = False + has_optim_bool = False + split_args_dict = { + "optim_tensor": [], + "optim_int": [], + "optim_float": [], + "optim_bool": [], + } + # list of symint args to be appended after optim_xxx args + # since they have default values + symint_list: List[OptimItem] = [] + for s in arg_spec: if s.name == "learning_rate_tensor": split_function_arg_names.append(s.name) + split_unpacked_arg_names.append(s.name) split_function_args.append(tensor_arg(s.name)) split_function_schemas.append(tensor_arg(s.name)) + split_variables.append(f"ret.push_back(Variable()); // {s.name}") elif s.ty in ( ArgType.TENSOR, ArgType.INT_TENSOR, @@ -553,31 +615,121 @@ def create( ArgType.PLACEHOLDER_TENSOR, ): name = s.name - split_function_arg_names.append(name) + split_unpacked_arg_names.append(name) if s.is_optional: - split_function_args.append(optional_tensorlist_arg(name)) - split_function_schemas.append(schema_optional_tensorlist_arg(name)) split_saved_tensorlist_optional.append(name) + split_args_dict["optim_tensor"].append(s.name) + has_optim_tensor = True else: split_function_args.append( tensor_list_arg_no_default(name, pass_by_ref=False) ) + split_function_arg_names.append(name) split_function_schemas.append( schema_tensor_list_arg_no_default(name) ) split_saved_tensorlist.append(name) + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_dev or host" + ) + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_placements" + ) + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_offsets" + ) + split_variables.append("if (" + name + "_host.numel() == 0) {") + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_uvm" + ) + split_variables.append("}") else: - split_function_arg_names.append(s.name) - split_function_args.append(make_function_arg(s.ty, s.name, s.default)) - split_function_schemas.append( - make_function_schema_arg(s.ty, s.name, s.default) - ) + if s.ty == ArgType.INT: + # iter is passed in aux_int + if s.name != "iter": + split_args_dict["optim_int"].append(s.name) + split_saved_data.append( + ( + s.name, + f'optim_int[{len(split_args_dict["optim_int"]) - 1}]', + make_ivalue_cast(s.ty), + "int64_t", + ) + ) + has_optim_int = True + elif s.ty == ArgType.SYM_INT: + symint_list.append(s) + split_saved_data.append( + ( + s.name, + "", + make_ivalue_cast(s.ty), + "c10::SymInt", + ) + ) + elif s.ty == ArgType.FLOAT: + split_args_dict["optim_float"].append(s.name) + split_saved_data.append( + ( + s.name, + f'optim_float[{len(split_args_dict["optim_float"])- 1}]', + make_ivalue_cast(s.ty), + "double", + ) + ) + has_optim_float = True + elif s.ty == ArgType.BOOL: + split_args_dict["optim_bool"].append(s.name) + split_saved_data.append( + ( + s.name, + f'optim_bool[{len(split_args_dict["optim_bool"])- 1}]', + make_ivalue_cast(s.ty), + "bool", + ) + ) + has_optim_bool = True + else: + raise ValueError(f"Unsupported type {s.ty}") + split_unpacked_arg_names.append(s.name) + if has_optim_tensor: + split_function_args.append(list_arg("tensor")) + split_function_schemas.append(schema_list_arg("tensor")) + split_function_arg_names.append("optim_tensor") + split_variables.append(f"ret.push_back(Variable()); // optim_tensor") + + if has_optim_int: + split_function_args.append(list_arg("int")) + split_function_schemas.append(schema_list_arg("int")) + split_function_arg_names.append("optim_int") + split_variables.append(f"ret.push_back(Variable()); // optim_int") + if has_optim_float: + split_function_args.append(list_arg("float")) + split_function_schemas.append(schema_list_arg("float")) + split_function_arg_names.append("optim_float") + split_variables.append(f"ret.push_back(Variable()); // optim_float") + if has_optim_bool: + split_function_args.append(list_arg("bool")) + split_function_schemas.append(schema_list_arg("bool")) + split_function_arg_names.append("optim_bool") + split_variables.append(f"ret.push_back(Variable()); // optim_bool") + for s in symint_list: + split_function_arg_names.append(s.name) + split_function_args.append(make_function_arg(s.ty, s.name, s.default)) + split_function_schemas.append( + make_function_schema_arg(s.ty, s.name, s.default) + ) + split_variables.append(f"ret.push_back(Variable()); // {s.name}") return PT2ArgsSet( split_function_args=split_function_args, split_function_arg_names=split_function_arg_names, split_function_schemas=split_function_schemas, split_saved_tensorlist=split_saved_tensorlist, split_saved_tensorlist_optional=split_saved_tensorlist_optional, + split_saved_data=split_saved_data, + split_variables=split_variables, + split_unpacked_arg_names=split_unpacked_arg_names, + split_args_dict=split_args_dict, ) @@ -637,12 +789,14 @@ def create( if s.is_optional: has_optional_tensors = True - # Optional tensors are converted to tensor in autograd functions - # Reorganize arguments for wrapper, backend and kernel functions + # Optim arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors + # The optional tensors are converted to Tensor in autograd functions + # Hence, need to reorganize such that the tensors come before non-tensors which have default values values + # This is used in wrapper, backend and kernel functions if has_optional_tensors: - # Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors, + # reordered args for split_arg_spec: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors split_arg_spec = reorder_args(split_arg_spec) - # Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors + # reordered args for kernel_split_arg_spec: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors kernel_split_arg_spec = reorder_args(kernel_split_arg_spec) # Compute placeholder tensor combinations diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index a6ccbd7ed..6a2669c11 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -171,10 +171,10 @@ enum SSDTensor { {%- endif %} BT_block_size, max_segment_length_per_warp, - {%- if optimizer != "none" and not dense %} + {%- if not dense %} + {%- if optimizer != "none" %} stochastic_rounding, {%- endif %} - {%- if not dense %} info_B_num_bits, info_B_mask_int64, {%- endif %} @@ -311,42 +311,41 @@ enum SSDTensor { hash_size_cumsum, total_hash_size_bits, indices, - {%- if not nobag and dense and not vbe %} + {%- if dense and nobag %} + offsets + {%- else %} offsets, + {%- endif %} + {%- if not nobag %} pooling_mode, indice_weights, + {%- if dense and not vbe %} feature_requires_grad - {%- elif not nobag %} - offsets, - pooling_mode, - indice_weights, - feature_requires_grad, - {%- elif nobag and dense and not vbe %} - offsets {%- else %} - offsets, + feature_requires_grad, {%- endif %} + {%- endif %} {# /* if not nobag */ #} {%- if not dense %} lxu_cache_locations, uvm_cache_stats, - {%- endif %} - {%- if optimizer != "none" and not dense %} + {%- if optimizer != "none" %} gradient_clipping, max_gradient, stochastic_rounding, {%- endif %} + {%- endif %} {# /* if not dense */ #} {%- if vbe %} B_offsets, vbe_output_offsets_feature_rank, vbe_B_offsets_rank_per_feature, max_B, max_B_feature_rank, - {%- endif %} - {%- if vbe and not dense %} - vbe_output_size, - {%- elif vbe and dense %} + {%- if dense %} vbe_output_size - {%- endif %} + {%- else %} + vbe_output_size, + {%- endif %} {# /* if dense */ #} + {%- endif %} {# /* if vbe */ #} {%- if not dense %} is_experimental, use_uniq_cache_locations_bwd, @@ -359,12 +358,12 @@ enum SSDTensor { iter, {%- endif %} gwd_lower_bound, - {%- endif %} + {%- endif %} {# /* if is_gwd */ #} {%- if ssd %} ssd_tensors.value(), {%- endif %} {{ args.split_function_arg_names_autograd | join(", ") }} - {%- endif %} + {%- endif %} {# /* if not dense */ #} )[0]; {%- endmacro %} @@ -577,20 +576,19 @@ class {{ autograd_func }} : const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, - {%- if not nobag and dense and not vbe %} + {%- if dense and nobag %} + const Tensor& offsets + {%- else %} const Tensor& offsets, + {%- endif %} + {%- if not nobag %} const int64_t pooling_mode, const std::optional& indice_weights, + {%- if dense and not vbe %} const std::optional& feature_requires_grad - {%- elif not nobag %} - const Tensor& offsets, - const int64_t pooling_mode, - const std::optional& indice_weights, - const std::optional& feature_requires_grad, - {%- elif nobag and dense and not vbe %} - const Tensor& offsets {%- else %} - const Tensor& offsets, + const std::optional& feature_requires_grad, + {%- endif %} {%- endif %} {%- if not dense %} const Tensor& lxu_cache_locations, @@ -619,7 +617,7 @@ class {{ autograd_func }} : const int64_t iter, {%- endif %} const double gwd_lower_bound, - {%- endif %} + {%- endif %} {#-/* if is_gwd */#} {%- if ssd %} const at::TensorList& ssd_tensors, {%- endif %} @@ -633,14 +631,13 @@ class {{ autograd_func }} : const c10::SymInt max_B_feature_rank, const c10::SymInt vbe_output_size {%- endif %} - {%- endif %}) { + {%- endif %} {# /* if not dense */ #}) { const auto T = weights_offsets.sym_numel(); {%- if vbe %} const auto B_offsets_ = B_offsets.value_or(Tensor()); const auto vbe_output_offsets_feature_rank_ = vbe_output_offsets_feature_rank.value_or(Tensor()); const auto vbe_B_offsets_rank_per_feature_ = vbe_B_offsets_rank_per_feature.value_or(Tensor()); - const c10::SymInt max_B_ = max_B; {%- else %} const auto max_B_ = offsets.sym_size(0) / T; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b224c3e70..f4ab3e886 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -92,7 +92,7 @@ enum SSDTensor { const Tensor& /*weights_dev*/, {%- if not dense %} const Tensor& /*weights_uvm*/, - const Tensor& /*lxu_cache_weights*/, + const Tensor& /*weights_lxu_cache*/, const Tensor& /*weights_placements*/, {%- endif %} const Tensor& /*weights_offsets*/, @@ -136,7 +136,7 @@ enum SSDTensor { weights_host, flatten_weights_dev, weights_uvm, - lxu_cache_weights, + weights_lxu_cache, weights_placements, weights_offsets, {%- if nobag %} @@ -155,7 +155,7 @@ enum SSDTensor { {%- endif %} {# /* if not nobag */ #} {%- if not dense %} {{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }}, - uvm_cache_stats_, + uvm_cache_stats, {%- endif %} {%- if not nobag %} {%- if vbe %} @@ -214,6 +214,7 @@ enum SSDTensor { {%- else %} const Tensor& /*D_offsets*/, const c10::SymInt /*max_D*/, + const bool /*mixed_D*/, {%- endif %} const Tensor& /*hash_size_cumsum*/, const int64_t /*total_hash_size_bits*/, @@ -261,17 +262,22 @@ enum SSDTensor { grad_weights_dev = embedding_codegen{{ wdesc }}_backward_op.call( grad_output, + {% if dense %} + dev_weights, + {% else %} weights_host, weights_dev, weights_uvm, - lxu_cache_weights, + weights_lxu_cache, weights_placements, weights_offsets, + {% endif %} {% if nobag %} D, {%- else %} D_offsets, max_D, + mixed_D, {%- endif %} {# /* if nobag */ #} hash_size_cumsum, total_hash_size_bits, @@ -283,16 +289,18 @@ enum SSDTensor { {%- endif %} {# /* if not nobag */ #} {%- if ssd %} ssd_row_addrs, - {%- else %} + {%- elif not dense %} lxu_cache_locations, {%- endif %} BT_block_size, max_segment_length_per_warp, + {%- if not dense %} {%- if optimizer != "none" %} stochastic_rounding, {%- endif %} info_B_num_bits, info_B_mask_int64, + {%- endif %} {# /* if not dense */ #} {%- if vbe %} B_offsets, vbe_row_output_offsets, @@ -311,7 +319,11 @@ enum SSDTensor { {%- endif %} gwd_lower_bound, {%- endif %} {# /* if is_gwd */ #} + {%- if dense %} + /*unused=*/0 + {%- else %} {{ args_pt2.split_function_arg_names | join(", ") }} + {%- endif %} {%- if not nobag %} , output_dtype {%- endif %} @@ -321,73 +333,65 @@ enum SSDTensor { record_trace->record.end(); } - return { - {%- if not dense %} - Tensor(), // placeholder autograd tensor - {%- endif %} - Variable(), // output_dtype - Variable(), // weights_host - grad_weights_dev, // weights_dev - {%- if not dense %} - Variable(), // weights_uvm - Variable(), // lxu_cache_weights - Variable(), // weights_placements - {%- endif %} - Variable(), // weights_offsets - {%- if nobag %} - Variable(), // D - {%- else %} - Variable(), // D_offsets - Variable(), // total_D - Variable(), // max_D - {%- endif %} - Variable(), // hash_size_cumsum - Variable(), //total_hash_size_bits - Variable(), // indices - Variable(), // offsets - {%- if not nobag %} - Variable(), // pooling_mode - grad_indice_weights, // indice_weights - Variable(), // feature_requires_grad - {%- endif %} - {%- if not dense %} - Variable(), // lxu_cache_locations - Variable(), // uvm_cache_stats - {%- endif %} - {%- if optimizer != "none" and not dense %} - Variable(), // gradient_clipping - Variable(), // max_gradient - Variable(), // stochastic_rounding - {%- endif %} - {%- if vbe %} - Variable(), // B_offsets - Variable(), // vbe_output_offsets_feature_rank - Variable(), // vbe_B_offsets_rank_per_feature - Variable(), // max_B - Variable(), // max_B_feature_rank - Variable(), // vbe_output_size - {%- endif %} - {%- if not dense %} - Variable(), // is_experimental - Variable(), // use_uniq_cache_locations_bwd - Variable(), // use_homogeneous_placements - {%- endif %} - {%- if is_gwd %} - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - Variable(), // prev_iter_dev - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - Variable(), // iter - {%- endif %} - Variable(), // gwd_lower_bound - {%- endif %} - {%- if ssd %} - {%- for tensor in ssd_tensors %} - Variable(), // {{ tensor }} - {%- endfor %} - {%- endif %} - {{ args_pt2.split_variables | join(", ") }} - }; + // Number of returned gradients have to match the input to Autograd's forward + // The number of items in the tensorlist differ between devices and is determined at runtime + std::vector ret; + + {%- if not dense %} + ret.push_back(Variable()); // placeholder autograd tensor + {%- endif %} + ret.push_back(Variable()); // output_dtype + {%- if not dense %} + if (weights_host.numel() > 0) { + ret.push_back(Tensor()); // host_weights + } + else { + ret.push_back(grad_weights_dev); // dev_weights + ret.push_back(Variable()); // weights_uvm + ret.push_back(Variable()); // weights_lxu_cache + } + ret.push_back(Variable()); // weights_placement + {%- endif %} + ret.push_back(Variable()); // weights_offsets + {%- if nobag %} + ret.push_back(Variable()); // D + {%- else %} + ret.push_back(Variable()); // D_offsets + ret.push_back(Variable()); // total_D + ret.push_back(Variable()); // max_D + {%- endif %} + ret.push_back(Variable()); // hash_size_cumsum + ret.push_back(Variable()); // total_hash_size_bits + ret.push_back(Variable()); // indices + ret.push_back(Variable()); // offsets + {%- if not nobag %} + ret.push_back(Variable()); // pooling_mode + ret.push_back(grad_indice_weights); // indice_weights + ret.push_back(Variable()); // feature_requires_grad + {%- endif %} + {%- if vbe %} + {%- if dense %} + ret.push_back(Variable()); // B_offsets + ret.push_back(Variable()); // vbe_output_offsets_feature_rank + ret.push_back(Variable()); // vbe_B_offsets_rank_per_feature + {%- endif %} {# /* if dense */ #} + ret.push_back(Variable()); // max_B + ret.push_back(Variable()); // max_B_feature_rank + ret.push_back(Variable()); // vbe_output_size + {%- endif %} {# /* if vbe */ #} + {%- if not dense %} + ret.push_back(Variable()); // aux_tensor + ret.push_back(Variable()); // aux_int + ret.push_back(Variable()); // aux_float + ret.push_back(Variable()); // aux_bool + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + ret.push_back(Variable()); // {{ tensor }} + {%- endfor %} + {%- endif %} + {{ args_pt2.unified_pt2.split_variables | join("\n") }} + return ret; {%- endmacro %} /* This macro generates a code blob that calls corresponding autograd function @@ -407,9 +411,11 @@ enum SSDTensor { placeholder_autograd_tensor, {%- endif %} output_dtype, + {%- if dense %} + dev_weights, + weights_offsets, + {%- else %} weights, - {%- if not dense %} - lxu_cache_weights, {%- endif %} {%- if nobag %} max_D, @@ -421,51 +427,35 @@ enum SSDTensor { hash_size_cumsum, total_hash_size_bits, indices, - {%- if not nobag and dense and not vbe %} + {%- if dense and nobag %} + offsets + {%- else %} offsets, + {%- endif %} + {%- if not nobag %} pooling_mode, indice_weights, + {%- if dense and not vbe %} feature_requires_grad - {%- elif not nobag %} - offsets, - pooling_mode, - indice_weights, - feature_requires_grad, - {%- elif nobag and dense and not vbe %} - offsets {%- else %} - offsets, - {%- endif %} - {%- if not dense %} - lxu_cache_locations, - uvm_cache_stats, + feature_requires_grad, {%- endif %} - {%- if optimizer != "none" and not dense %} - gradient_clipping, - max_gradient, - stochastic_rounding, {%- endif %} {%- if vbe %} + {%- if dense %} B_offsets, vbe_output_offsets_feature_rank, vbe_B_offsets_rank_per_feature, + {%- endif %} {# /* if dense */ #} max_B, max_B_feature_rank, vbe_output_size, - {%- endif %} + {%- endif %} {# /* if vbe */ #} {%- if not dense %} - is_experimental, - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {%- if is_gwd %} - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - prev_iter_dev, - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - iter, - {%- endif %} - gwd_lower_bound, - {%- endif %} + aux_tensor, + aux_int, + aux_float, + aux_bool, {%- if ssd %} ssd_tensors.value(), {%- endif %} @@ -474,39 +464,57 @@ enum SSDTensor { )[0]; {%- endmacro %} -/* This macro generates a code blob for unpacking the tensor list +/* This macro generates a code blob for unpacking TensorList */ {%- macro unpack_tensorlist(name) %} - const Tensor {{ name }}_host = {{ name }}[0]; - const Tensor {{ name }}_dev = {{ name }}[1]; - const Tensor {{ name }}_uvm = {{ name }}[2]; - const Tensor {{ name }}_placements = {{ name }}[3]; - const Tensor {{ name }}_offsets = {{ name }}[4]; -{%- endmacro %} - -{%- macro unpack_tensorlist_optional(name) %} Tensor {{ name }}_host; Tensor {{ name }}_dev; Tensor {{ name }}_uvm; Tensor {{ name }}_placements; Tensor {{ name }}_offsets; - if ({{ name }}.has_value()) { - at::TensorList _{{ name }} = {{ name }}.value(); - {{ name }}_host = _{{ name }}[0]; - {{ name }}_dev = _{{ name }}[1]; - {{ name }}_uvm = _{{ name }}[2]; - {{ name }}_placements = _{{ name }}[3]; - {{ name }}_offsets = _{{ name }}[4]; + {%- if name == "weights" %} + Tensor {{ name }}_lxu_cache; + {%- endif %} + + if ({{ name }}.size() == 3) { + {{ name }}_host = {{ name }}[0]; + {{ name }}_placements = {{ name }}[1]; + {{ name }}_offsets = {{ name }}[2]; + } + else if ({{ name }}.size() == {{ 5 if name == "weights" else 4 }}) { + {{ name }}_dev = {{ name }}[0]; + {{ name }}_uvm = {{ name }}[1]; + {{ name }}_placements = {{ name }}[2]; + {{ name }}_offsets = {{ name }}[3]; + {%- if name == "weights" %} + {{ name }}_lxu_cache = {{ name }}[4]; + {%- endif %} } - else{ - {{ name }}_host = at::empty({0}, weights_host.options()); - {{ name }}_dev = at::empty({0}, weights_dev.options()); - {{ name }}_uvm = at::empty({0}, weights_uvm.options()); - {{ name }}_placements = at::empty({0}, weights_placements.options()); - {{ name }}_offsets = at::empty({0}, weights_offsets.options()); + else { + TORCH_CHECK(false, "Invalid size of {{ name }}, expected 3 for CPU or {{ 5 if name == "weights" else 4 }} for CUDA but got ", {{ name }}.size()); } {%- endmacro %} +/* This macro generates a code blob for unpacking a list of optional tensors + We cannot do list of optional tensorlist. We need to pack optimizer optional tensors in a flatten manner. + For readability and programmability, we pass all unified args (i.e., 5 items), as opposed to passing per device (like above) + which needs to be determined at runtime. +*/ +{%- macro unpack_tensorlist_optional(name, arg_index) %} + + {%- set idx = arg_index * 5 %} + at::TensorOptions options = weights_host.numel() > 0 ? weights_host.options() : weights_dev.options(); + Tensor {{ name }}_host = optim_tensor[{{ idx }}].has_value() ? optim_tensor[{{ idx }}].value() : at::empty({0}, options); + {%- set idx = arg_index * 5 + 1 %} + Tensor {{ name }}_dev = optim_tensor[{{ idx }}].has_value() ? optim_tensor[{{ idx }}].value() : at::empty({0}, options); + {%- set idx = arg_index * 5 + 2 %} + Tensor {{ name }}_uvm = optim_tensor[{{ idx }}].has_value() ? optim_tensor[{{ idx }}].value() : at::empty({0}, options); + {%- set idx = arg_index * 5 + 3 %} + Tensor {{ name }}_placements = optim_tensor[{{ idx }}].has_value() ? optim_tensor[{{ idx }}].value() : at::empty({0}, weights_placements.options()); + {%- set idx = arg_index * 5 + 4 %} + Tensor {{ name }}_offsets = optim_tensor[{{ idx }}].has_value() ? optim_tensor[{{ idx }}].value() : at::empty({0}, weights_offsets.options()); +{%- endmacro %} + //////////////////////////////////////////////////////////////////////////////// // Autograd Function Declarations @@ -552,10 +560,16 @@ class {{ autograd_func }} : static constexpr bool is_traceable = true; static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, + {%- if not dense %} const Tensor& placeholder_autograd_tensor, + {%- endif %} const int64_t output_dtype, + {%- if dense %} + const Tensor& dev_weights, + const Tensor& weights_offsets, + {%- else %} const at::TensorList weights, - const Tensor& lxu_cache_weights, + {%- endif %} {%- if not nobag %} const Tensor& D_offsets, const c10::SymInt total_D, @@ -572,32 +586,21 @@ class {{ autograd_func }} : const std::optional& indice_weights, const std::optional& feature_requires_grad, {%- endif %} - const Tensor& lxu_cache_locations, - std::optional uvm_cache_stats, - {%- if optimizer != "none" %} - const bool gradient_clipping, - const double max_gradient, - const bool stochastic_rounding, - {%- endif %} {%- if vbe %} + {%- if dense %} const std::optional& B_offsets, const std::optional& vbe_output_offsets_feature_rank, const std::optional& vbe_B_offsets_rank_per_feature, + {%- endif %} {# /* if dense */ #} const c10::SymInt max_B, const c10::SymInt max_B_feature_rank, const c10::SymInt vbe_output_size, - {%- endif %} - const bool is_experimental, - const bool use_uniq_cache_locations_bwd, - const bool use_homogeneous_placements, - {%- if is_gwd %} - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - const std::optional& prev_iter_dev, - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, - {%- endif %} - const double gwd_lower_bound, + {%- endif %} {# /* if vbe */ #} + {%- if not dense %} + std::vector> aux_tensor, + std::vector aux_int, + std::vector aux_float, + c10::List aux_bool, {%- endif %} {%- if ssd %} const at::TensorList& ssd_tensors, @@ -609,16 +612,19 @@ class {{ autograd_func }} : {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist %} {{ unpack_tensorlist(arg_name) }} {%- endfor %} + {%- if "optim_tensor" in args_pt2.unified_pt2.split_function_arg_names %} + TORCH_CHECK(optim_tensor.size() % 5 == 0); + {%- endif %} {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist_optional %} - {{ unpack_tensorlist_optional(arg_name) }} + {{ unpack_tensorlist_optional(arg_name, loop.index0) }} {%- endfor %} const auto T = weights_offsets.sym_numel(); - {%- if vbe %} - const auto B_offsets_ = B_offsets.value_or(Tensor()); - const auto vbe_output_offsets_feature_rank_ = vbe_output_offsets_feature_rank.value_or(Tensor()); - const auto vbe_B_offsets_rank_per_feature_ = vbe_B_offsets_rank_per_feature.value_or(Tensor()); + {%- if vbe %} + const auto B_offsets_ = aux_tensor[0].has_value() ? aux_tensor[0].value() : Tensor(); + const auto vbe_output_offsets_feature_rank_ = aux_tensor[1].has_value() ? aux_tensor[1].value() : Tensor(); // .at("vbe_output_offsets_feature_rank") + const auto vbe_B_offsets_rank_per_feature_ = aux_tensor[2].has_value() ? aux_tensor[2].value() : Tensor(); // .at("vbe_B_offsets_rank_per_feature") const c10::SymInt max_B_ = max_B; {%- else %} const auto max_B_ = offsets.sym_size(0) / T; @@ -627,7 +633,7 @@ class {{ autograd_func }} : // Annotate Kineto trace const static bool is_annotate_trace_enabled = config::is_feature_enabled( config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE); - std::string op_annotation = ""; + at::string op_annotation = ""; c10::intrusive_ptr record_trace; if (is_annotate_trace_enabled) { std::stringstream ss; @@ -648,11 +654,15 @@ class {{ autograd_func }} : "{{ fwd_mdesc }}_tbe_fwd" + op_annotation); ctx->saved_data["op_annotation"] = op_annotation; } - + {%- if not dense %} // NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t // TODO: Hook up with frontend code - const auto uvm_cache_stats_ = uvm_cache_stats - .value_or(at::empty({0}, weights_uvm.options().dtype(at::kInt))); + at::TensorOptions uvm_options = weights_host.numel() > 0 ? weights_host.options() : weights_dev.options(); + const auto uvm_cache_stats = aux_tensor[4].has_value() ? aux_tensor[4].value() : at::empty({0}, uvm_options.dtype(at::kInt)); // .at("uvm_cache_stats") + TORCH_CHECK(aux_tensor[3].has_value(), "lxu_cache_locations should have value."); + const auto lxu_cache_locations = aux_tensor[3].value(); // .at("lxu_cache_locations") + const auto is_experimental = aux_bool[0]; // .at("is_experimental_tbe") + {%- endif %} // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -709,20 +719,29 @@ class {{ autograd_func }} : {%- endif %} // vbe {%- if is_gwd %} - const auto prev_iter_dev_ = prev_iter_dev.value_or(Tensor()); + {%- if "prev_iter" in args_pt2.unified_pt2.split_function_arg_names %} + const auto prev_iter_dev_ = prev_iter_dev.has_value() ? prev_iter_dev.value() : Tensor(); + {%- else %} + const auto prev_iter_dev_ = aux_tensor[5].has_value() ? aux_tensor[5].value() : Tensor(); // .at("prev_iter_dev") + {%- endif %} {%- endif %} {%- if not nobag %} - const auto indice_weights_value = indice_weights.value_or(Tensor()); + const auto indice_weights_value = indice_weights.has_value() ? indice_weights.value() : Tensor(); {%- endif %} ctx->save_for_backward({ + {%- if dense %} + dev_weights, + weights_offsets, + {%- else %} weights_host, weights_dev, weights_uvm, - lxu_cache_weights, + weights_lxu_cache, weights_placements, weights_offsets, + {%- endif %} {%- if not nobag %} D_offsets, {%- endif %} @@ -733,7 +752,9 @@ class {{ autograd_func }} : indice_weights_value, feature_requires_grad.value_or(Tensor()), {%- endif %} + {%- if not dense %} lxu_cache_locations, + {%- endif %} {%- if vbe %} B_offsets_, vbe_row_output_offsets, @@ -752,35 +773,50 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; + {%- if optimizer == "none" %} + ctx->saved_data["mixed_D"] = (bool) aux_bool[4]; + {%- else %} + ctx->saved_data["mixed_D"] = (bool) aux_bool[6]; + {%- endif %} ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; {%- endif %} ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; - {%- if optimizer != "none" %} - ctx->saved_data["gradient_clipping"] = gradient_clipping; - ctx->saved_data["max_gradient"] = max_gradient; - ctx->saved_data["stochastic_rounding"] = stochastic_rounding; + {%- if optimizer != "none" and not dense %} + ctx->saved_data["gradient_clipping"] = (bool)aux_bool[4]; // .at("gradient_clipping"); + ctx->saved_data["max_gradient"] = aux_float[1]; // .at("max_gradient"); + ctx->saved_data["stochastic_rounding"] = (bool)aux_bool[5]; // .at("stochastic_rounding"); {%- endif %} {#-/* if optimizer != "none" */#} ctx->saved_data["info_B_num_bits"] = info_B_num_bits; const auto info_B_mask_int64 = static_cast(info_B_mask); ctx->saved_data["info_B_mask"] = info_B_mask_int64; - ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd; - ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements; - {%- if is_gwd %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - ctx->saved_data["iter"] = iter; + {%- if not dense %} + ctx->saved_data["use_uniq_cache_locations_bwd"] = (bool)aux_bool[1]; + ctx->saved_data["use_homogeneous_placements"] = (bool)aux_bool[2]; // .at("use_homogeneous_placements") {%- endif %} + const auto iter = aux_int[0]; // .at("iter") + ctx->saved_data["iter"] = iter; + {%- if is_gwd %} + const auto gwd_lower_bound = aux_float[0]; //.at("gwd_lower_bound") ctx->saved_data["gwd_lower_bound"] = gwd_lower_bound; {%- endif %} {%- if not nobag %} ctx->saved_data["output_dtype"] = output_dtype; {%- endif %} - {%- for (var, _) in args_pt2.saved_data %} + {%- if not dense %} + // unpack optim args + {%- for (var, dict_val, _, type) in args_pt2.unified_pt2.split_saved_data %} + {%- if type == "bool" %} + bool {{ var }} = {{ dict_val }}; + {%- elif type != "c10::SymInt" %} + auto {{ var }} = {{ dict_val }}; + {%- endif %} ctx->saved_data["{{ var }}"] = {{ var }}; {%- endfor %} + {%- endif %} {%- if optimizer == "none" %} // Flatten @@ -827,12 +863,17 @@ static torch::autograd::variable_list backward( torch::autograd::variable_list grad_outputs) { const auto saved = ctx->get_saved_variables(); auto savedItr = std::begin(saved); + {%- if dense %} + auto dev_weights = *savedItr++; + auto weights_offsets = *savedItr++; + {%- else %} auto weights_host = *savedItr++; auto weights_dev = *savedItr++; auto weights_uvm = *savedItr++; - auto lxu_cache_weights = *savedItr++; + auto weights_lxu_cache = *savedItr++; auto weights_placements = *savedItr++; auto weights_offsets = *savedItr++; + {%- endif %} {%- if not nobag %} auto D_offsets = *savedItr++; {%- endif %} @@ -843,7 +884,9 @@ static torch::autograd::variable_list backward( auto indice_weights = *savedItr++; auto feature_requires_grad = *savedItr++; {%- endif %} + {%- if not dense %} auto lxu_cache_locations = *savedItr++; + {%- endif %} {%- if vbe %} auto B_offsets = *savedItr++; auto vbe_row_output_offsets = *savedItr++; @@ -864,35 +907,39 @@ static torch::autograd::variable_list backward( {%- if not nobag %} auto max_D = ctx->saved_data["max_D"].toInt(); + const auto mixed_D = ctx->saved_data["mixed_D"].toBool(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); {%- else %} auto D = ctx->saved_data["D"].toInt(); {%- endif %} auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool(); auto max_gradient = ctx->saved_data["max_gradient"].toDouble(); auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool(); {%- endif %} {#-/* if optimizer != "none" */#} [[maybe_unused]] const int32_t info_B_num_bits = ctx->saved_data["info_B_num_bits"].toInt(); [[maybe_unused]] const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt(); + {%- if not dense %} const auto use_uniq_cache_locations_bwd = ctx->saved_data["use_uniq_cache_locations_bwd"].toBool(); const auto use_homogeneous_placements = ctx->saved_data["use_homogeneous_placements"].toBool(); - {%- if is_gwd %} - {%- if "iter" not in args_pt2.split_function_arg_names %} + {%- endif %} + {%- if is_gwd or "iter" in args_pt2.unified_pt2.split_unpacked_arg_names %} const auto iter = ctx->saved_data["iter"].toInt(); {%- endif %} + {%- if is_gwd %} const auto gwd_lower_bound = ctx->saved_data["gwd_lower_bound"].toDouble(); {%- endif %} {%- if not nobag %} auto output_dtype = ctx->saved_data["output_dtype"].toInt(); {%- endif %} - - {%- for (var, ivalue_cast) in args_pt2.saved_data %} + {%- if not dense %} + {%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %} auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}(); {%- endfor %} + {%- endif %} const static bool is_annotate_trace_enabled = config::is_feature_enabled( config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE); @@ -914,7 +961,7 @@ static torch::autograd::variable_list backward( #endif using torch::autograd::Variable; - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; {%- else %} auto& grad_output = grad_outputs[0]; @@ -948,19 +995,24 @@ static torch::autograd::variable_list backward( .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed& indice_weights, const std::optional& feature_requires_grad, - const Tensor& lxu_cache_locations, - {%- if optimizer != "none" %} - const bool gradient_clipping, - const double max_gradient, - const bool stochastic_rounding, + const int64_t output_dtype, + {%- if not dense %} + std::vector> aux_tensor, + std::vector aux_int, + std::vector aux_float, + c10::List aux_bool, {%- endif %} {{ args_pt2.unified_pt2.split_function_args | join(", ") }}, - const int64_t output_dtype = static_cast(SparseType::FP32), - const std::optional& B_offsets = std::nullopt, - const std::optional& vbe_output_offsets_feature_rank = std::nullopt, - const std::optional& vbe_B_offsets_rank_per_feature = std::nullopt, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, - const c10::SymInt vbe_output_size = -1, - const bool is_experimental_tbe = false, // formerly named is_experimental - const bool use_uniq_cache_locations_bwd = false, - const bool use_homogeneous_placements = false, - const std::optional& uvm_cache_stats = std::nullopt, - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - const std::optional& prev_iter_dev = std::nullopt, - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter = 0, - {%- endif %} - const bool apply_global_weight_decay = false, {%- if ssd %} - const std::optional& ssd_tensors = std::nullopt, + const c10::SymInt vbe_output_size = -1, + const std::optional& ssd_tensors = std::nullopt + {%- else %} + const c10::SymInt vbe_output_size = -1 {%- endif %} - const double gwd_lower_bound = 0 ) { + {%- if has_gpu_support or has_cpu_support %} {%- if not dense %} @@ -1097,7 +1141,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2); // Set to experimental if either the feature is enabled in JK, or the user specifies to use TBEv2 - const auto is_experimental = is_tbev2_enabled || is_experimental_tbe; + aux_bool[0] = is_tbev2_enabled || aux_bool[0]; // .at("is_experimental_tbe") {%- endif %} {%- if ssd %} @@ -1108,10 +1152,12 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value()) { + if (aux_tensor[0].has_value()) { // .at("B_offsets") {%- if has_global_weight_decay_support and not ssd %} // vbe and has gwd support - if (apply_global_weight_decay && weight_decay > 0) { + // if weight_decay arg is not passed or < 0 even though apply_global_weight_decay is True, we don't do gwd + // TODO: add check to ensure weight decay exists + if (aux_bool[3] && optim_float[{{args_pt2.unified_pt2.split_args_dict["optim_float"].index("weight_decay")}}] > 0) { // .at("apply_global_weight_decay") && .at("weight_decay") {{ call_autograd(nobag=False, vbe=True, is_gwd=True) }} } {%- endif %} {#-/* if has_global_weight_decay_support */ #} @@ -1122,7 +1168,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if has_global_weight_decay_support and not ssd %} // has gwd support - if (apply_global_weight_decay && weight_decay > 0) { + if (aux_bool[3] && optim_float[{{args_pt2.unified_pt2.split_args_dict["optim_float"].index("weight_decay")}}] > 0) { // .at("apply_global_weight_decay") && .at("weight_decay") // not vbe and gwd {{ call_autograd(nobag=False, vbe=False, is_gwd=True) }} } @@ -1146,11 +1192,15 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - {%- set op_name = "{}_embedding_codegen_lookup_{}_function_pt2".format(bwd_mdesc, optimizer) %} + {%- set op_name = "{}_embedding_codegen_lookup_{}_function_pt2".format(fwd_mdesc, optimizer) %} m.def("{{ op_name }}(" + {%- if dense %} + " Tensor dev_weights, " + " Tensor weights_offsets, " + {%- else %} " Tensor placeholder_autograd_tensor, " " Tensor[] weights, " - " Tensor lxu_cache_weights, " + {%- endif %} " Tensor D_offsets, " " SymInt total_D, " " SymInt max_D, " @@ -1161,35 +1211,22 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int pooling_mode, " " Tensor? indice_weights, " " Tensor? feature_requires_grad, " - " Tensor lxu_cache_locations, " - {%- if optimizer != "none" %} - " bool gradient_clipping, " - " float max_gradient, " - " bool stochastic_rounding, " - {%- endif %} + " int output_dtype, " + {%- if not dense %} + " Tensor?[] aux_tensor, " + " int[] aux_int, " + " float[] aux_float, " + " bool[] aux_bool, " " {{ args_pt2.unified_pt2.split_function_schemas | join(", ") }}, " - " int output_dtype=0, " - " Tensor? B_offsets=None, " - " Tensor? vbe_output_offsets_feature_rank=None, " - " Tensor? vbe_B_offsets_rank_per_feature=None, " " SymInt max_B=-1, " " SymInt max_B_feature_rank=-1, " + {%- if ssd %} " SymInt vbe_output_size=-1, " - " bool is_experimental_tbe=False, " - " bool use_uniq_cache_locations_bwd=False, " - " bool use_homogeneous_placements=False, " - " Tensor? uvm_cache_stats=None," - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - " Tensor? prev_iter_dev=None, " - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - " int iter=0, " + " Tensor[]? ssd_tensors=None" + {%- else %} + " SymInt vbe_output_size=-1 " {%- endif %} - " bool apply_global_weight_decay=False, " - {%- if ssd %} - " Tensor[]? ssd_tensors=None," {%- endif %} - " float gwd_lower_bound=0 " ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index 7c405d4f9..807fc94d5 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -175,6 +175,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p const Tensor& weights_offsets, const Tensor& D_offsets, const int64_t max_D, + const bool mixed_D, const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -207,19 +208,19 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") .typedH sync point since it will be used as float in the kernel # this fails fake_tensor test as the test expects all tensors to be on the same device "test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [ @@ -273,18 +294,6 @@ def execute_forward_( # noqa C901 use_experimental_tbe=use_experimental_tbe, ) - if not use_cpu and torch.cuda.is_available(): - # NOTE: Test TorchScript-compatible! - try: - # Occasionally, we run into the following error when running - # against PyTorch nightly: - # - # RuntimeError: Can't redefine method: - # forward on class: __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.___torch_mangle_0.SplitTableBatchedEmbeddingBagsCodegen (of Python compilation unit at: 0x5e74890) - cc = torch.jit.script(cc) - except Exception as e: - print(f"Torch JIT compilation failed: {e}") - for t in range(T): cc.split_embedding_weights()[t].data.copy_( bs[t].weight