Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying TBE API using List (Backend) #3563

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 170 additions & 16 deletions fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,40 @@ def schema_bool_arg(name: str, default: bool = False) -> str:
return f"bool {name} = {default}"


def list_arg(ty: str) -> str:
"""
Returns a C++ argument for a list of optimizer arguments the given type.

Parameters:
ty (str) - type of the list e.g., "int", "float", "tensor"
Returns:
C++ arguemnt for a list of the given type e.g., for a list of int returns "std::vector<int> optim_int",
"""
return {
"tensor": "std::vector<std::optional<at::Tensor>> optim_tensor",
"int": "std::vector<int64_t> optim_int",
"float": "std::vector<double> optim_float",
"bool": "c10::List<bool> optim_bool",
}[ty]


def schema_list_arg(ty: str) -> str:
"""
Returns a C++ schema for a list of optimizer arguments the given type.

Parameters:
ty (str) - type of the list e.g., "int", "float", "tensor"
Returns:
C++ arguemnt for a list of the given type e.g., for a list of int returns "int[] optim_int",
"""
return {
"tensor": "Tensor?[] optim_tensor",
"int": "int[] optim_int",
"float": "float[] optim_float",
"bool": "bool[] optim_bool",
}[ty]


def optional_tensor_arg(name: str) -> str:
return f"std::optional<Tensor> {name} = std::nullopt"

Expand All @@ -230,7 +264,6 @@ def schema_optional_tensorlist_arg(name: str) -> str:


def make_kernel_arg(
# pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
ty: ArgType,
name: str,
default: Union[int, float, None],
Expand Down Expand Up @@ -505,6 +538,10 @@ class PT2ArgsSet:
split_function_schemas: List[str]
split_saved_tensorlist: List[str]
split_saved_tensorlist_optional: List[str]
split_saved_data: List[dict[str, str]]
split_variables: List[str]
split_unpacked_arg_names: List[str]
split_args_dict: Dict[str, List[str]]

@staticmethod
# pyre-ignore[3]
Expand All @@ -525,59 +562,174 @@ def create(
Returns:
PT2ArgsSet object with the following attributes:
split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions
Tensors will be packed and pass as TensorList
e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay'].
Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict.
e.g., ['at::TensorList momentum1', 'at::Dict<std:string, int> optim_int'].
split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions
e.g., ['momentum1', 'eps', 'weight_decay'].
e.g., ['momentum1', 'optim_int', 'optim_float'].
split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format
e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay'].
split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in
PT2 autograd function. e.g., ['momentum1'].
split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional
and will be unpacked in PT2 autograd function e.g., ['row_counter'].
split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward
split_unpacked_arg_names: List[str] - List of argument names, unrolled from list
e.g., ['momentum1', 'eps', 'weight_decay', 'iter'].
split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type.
e.g., if an optimizer only has an int argument called iter, the dict will look like:
{'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []}
"""
split_function_arg_names = []
split_function_args = []
split_function_schemas = []
split_saved_tensorlist = []
split_saved_tensorlist_optional = []
split_saved_data = []
split_variables = []
split_unpacked_arg_names = []
has_optim_tensor = False # optim tensors here are optional tensor
has_optim_int = False
has_optim_float = False
has_optim_bool = False
split_args_dict = {
"optim_tensor": [],
"optim_int": [],
"optim_float": [],
"optim_bool": [],
}
# list of symint args to be appended after optim_xxx args
# since they have default values
symint_list: List[OptimItem] = []

for s in arg_spec:
if s.name == "learning_rate_tensor":
split_function_arg_names.append(s.name)
split_unpacked_arg_names.append(s.name)
split_function_args.append(tensor_arg(s.name))
split_function_schemas.append(tensor_arg(s.name))
split_variables.append(f"ret.push_back(Variable()); // {s.name}")
elif s.ty in (
ArgType.TENSOR,
ArgType.INT_TENSOR,
ArgType.LONG_TENSOR,
ArgType.PLACEHOLDER_TENSOR,
):
name = s.name
split_function_arg_names.append(name)
split_unpacked_arg_names.append(name)
if s.is_optional:
split_function_args.append(optional_tensorlist_arg(name))
split_function_schemas.append(schema_optional_tensorlist_arg(name))
split_saved_tensorlist_optional.append(name)
split_args_dict["optim_tensor"].append(s.name)
has_optim_tensor = True
else:
split_function_args.append(
tensor_list_arg_no_default(name, pass_by_ref=False)
)
split_function_arg_names.append(name)
split_function_schemas.append(
schema_tensor_list_arg_no_default(name)
)
split_saved_tensorlist.append(name)
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_dev or host"
)
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_placements"
)
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_offsets"
)
split_variables.append("if (" + name + "_host.numel() == 0) {")
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_uvm"
)
split_variables.append("}")
else:
split_function_arg_names.append(s.name)
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
split_function_schemas.append(
make_function_schema_arg(s.ty, s.name, s.default)
)
if s.ty == ArgType.INT:
# iter is passed in aux_int
if s.name != "iter":
split_args_dict["optim_int"].append(s.name)
split_saved_data.append(
(
s.name,
f'optim_int[{len(split_args_dict["optim_int"]) - 1}]',
make_ivalue_cast(s.ty),
"int64_t",
)
)
has_optim_int = True
elif s.ty == ArgType.SYM_INT:
symint_list.append(s)
split_saved_data.append(
(
s.name,
"",
make_ivalue_cast(s.ty),
"c10::SymInt",
)
)
elif s.ty == ArgType.FLOAT:
split_args_dict["optim_float"].append(s.name)
split_saved_data.append(
(
s.name,
f'optim_float[{len(split_args_dict["optim_float"])- 1}]',
make_ivalue_cast(s.ty),
"double",
)
)
has_optim_float = True
elif s.ty == ArgType.BOOL:
split_args_dict["optim_bool"].append(s.name)
split_saved_data.append(
(
s.name,
f'optim_bool[{len(split_args_dict["optim_bool"])- 1}]',
make_ivalue_cast(s.ty),
"bool",
)
)
has_optim_bool = True
else:
raise ValueError(f"Unsupported type {s.ty}")
split_unpacked_arg_names.append(s.name)
if has_optim_tensor:
split_function_args.append(list_arg("tensor"))
split_function_schemas.append(schema_list_arg("tensor"))
split_function_arg_names.append("optim_tensor")
split_variables.append(f"ret.push_back(Variable()); // optim_tensor")

if has_optim_int:
split_function_args.append(list_arg("int"))
split_function_schemas.append(schema_list_arg("int"))
split_function_arg_names.append("optim_int")
split_variables.append(f"ret.push_back(Variable()); // optim_int")
if has_optim_float:
split_function_args.append(list_arg("float"))
split_function_schemas.append(schema_list_arg("float"))
split_function_arg_names.append("optim_float")
split_variables.append(f"ret.push_back(Variable()); // optim_float")
if has_optim_bool:
split_function_args.append(list_arg("bool"))
split_function_schemas.append(schema_list_arg("bool"))
split_function_arg_names.append("optim_bool")
split_variables.append(f"ret.push_back(Variable()); // optim_bool")
for s in symint_list:
split_function_arg_names.append(s.name)
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
split_function_schemas.append(
make_function_schema_arg(s.ty, s.name, s.default)
)
split_variables.append(f"ret.push_back(Variable()); // {s.name}")
return PT2ArgsSet(
split_function_args=split_function_args,
split_function_arg_names=split_function_arg_names,
split_function_schemas=split_function_schemas,
split_saved_tensorlist=split_saved_tensorlist,
split_saved_tensorlist_optional=split_saved_tensorlist_optional,
split_saved_data=split_saved_data,
split_variables=split_variables,
split_unpacked_arg_names=split_unpacked_arg_names,
split_args_dict=split_args_dict,
)


Expand Down Expand Up @@ -637,12 +789,14 @@ def create(
if s.is_optional:
has_optional_tensors = True

# Optional tensors are converted to tensor in autograd functions
# Reorganize arguments for wrapper, backend and kernel functions
# Optim arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors
# The optional tensors are converted to Tensor in autograd functions
# Hence, need to reorganize such that the tensors come before non-tensors which have default values values
# This is used in wrapper, backend and kernel functions
if has_optional_tensors:
# Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors,
# reordered args for split_arg_spec: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors
split_arg_spec = reorder_args(split_arg_spec)
# Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
# reordered args for kernel_split_arg_spec: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
kernel_split_arg_spec = reorder_args(kernel_split_arg_spec)

# Compute placeholder tensor combinations
Expand Down
Loading
Loading