-
Notifications
You must be signed in to change notification settings - Fork 1
混合精度pass梳理
Xinyu Yang edited this page Nov 3, 2023
·
1 revision
首先给出几点原则:
- 同种类型的op的运行精度可以不同; 这是因为某op局部的情况可能会导致原本支持低精度的op只能run在float32下。局部不应影响全局。
- 涉及到环的op的运行精度需要保持一致; 这是因为涉及到环的op之间无法正确插入cast op。
- 根据op的运行精度,设置其输出var的dtype;
- op的运行精度和其输入var的dtype不同时,插入cast op做转换;
为了遵守以上几点原则,算法步骤如下:
-
根据原则1,给每一个op赋予一个独一无二的type;(SetOpUniqueType)
- 遍历操作节点:代码使用两层嵌套循环,首先遍历了一个名为
all_op_nodes_
的数据结构,其中包含所有操作节点。外层循环迭代所有不同类型的操作,内层循环迭代每个操作类型中的具体操作节点。 - 检查特殊操作:在每次内层循环迭代中,首先检查操作类型是否为 "feed" 或 "fetch",如果是的话,就跳过该操作。这是因为 "feed" 和 "fetch" 操作通常用于数据输入和输出,通常不需要改变其类型。
- 生成唯一类型:对于其他类型的操作,代码生成一个唯一的类型
unique_type
,这个唯一类型包括原始操作类型op_type
和一个递增的后缀suffix
。后缀是为了确保每个操作都有不同的唯一类型。 - 记录原始类型:将唯一类型和原始类型的对应关系记录在
op_original_type_
中,这样以后可以根据唯一类型找到原始类型。 - 设置操作类型:使用
op_node->Op()->SetType(unique_type)
将操作节点的类型设置为唯一类型,从而改变操作的类型。 - 刷新操作:调用
op_node->Op()->Flush()
来刷新操作,确保类型的更改在计算图中得到应用。
- 遍历操作节点:代码使用两层嵌套循环,首先遍历了一个名为
-
根据kernel库中op的注册情况,初步设定每个op run的精度;(GetOpPrecision)
- 遍历操作节点:代码首先使用两层嵌套循环,迭代了所有操作节点,按照不同的操作类型分类遍历这些节点。
- 检查特殊操作:对于特殊的操作类型,如 "feed" 和 "fetch",代码会根据
enable_low_precision_io_
的值来确定是否支持低精度运行。如果enable_low_precision_io_
为 true,那么这些特殊操作可以以低精度运行。 - 对于 "tensorrt_engine" 类型的操作,代码会检查操作的属性(attributes)来确定是否支持低精度运行。具体来说,它会检查是否启用了 float16(
enable_fp16
)、是否启用了 int8(enable_int8
)以及是否启用了低精度 I/O(enable_low_precision_io
)。只有在满足这些条件的情况下,支持低精度运行。 - 对于其他类型的操作,代码会调用
OpSupportPrecision
函数来检查操作是否支持低精度运行。同时,还会检查操作的输入和输出类型是否是浮点32位(float32 或 float64),以确保数据类型的一致性。某些操作可能需要特殊处理,例如 "scale" 操作,它会检查 "scale" 和 "bias" 属性的值是否在低精度范围内,以确定是否支持低精度运行。 - 检查输入和输出变量类型:代码还会检查操作的输入和输出变量的类型,确保它们都是密集张量(dense tensor),如果不是,那么操作就不支持低精度运行。
- 记录支持低精度运行的操作:对于支持低精度运行的操作,代码将其操作类型添加到
op_run_low_precision_
集合中,并输出相应的日志信息。
-
步骤2中的做法是将支持运行低精度的op存入哈希表(op_run_half_)中;
-
根据原则2,首先获取每个var的所有input op,即op的输出var name相同的op集合(input_op_nodes);
-
遍历每一个op(当前op),遍历这个op的输出var的所有input op(input_op_nodes),当当前op支持低精度时,如果input_op_nodes中有op不支持低精度,将当前op从低精度哈希表中移除,并标记本次遍历有op的精度发生变化;(UpdateOpPrecision)
- 创建数据结构和函数:在函数开头,代码定义了几个数据结构和函数,用于跟踪哪些变量不应该以低精度运行,以及哪些操作的输入变量来自哪些操作。这些信息将在后续的判断中使用。
- 获取变量的输入操作:
GetVarInputOps
函数遍历所有操作节点,获取每个变量的输入操作,并根据一些条件将变量添加到vars_should_not_low_precision
集合中。条件包括特殊操作类型、操作的输入和输出类型等。此函数的目的是为了记录哪些变量不应以低精度运行。 - 循环更新操作精度:接下来,代码进入一个循环,不断尝试更新操作的运行精度,直到不再有操作的运行精度需要更新为止。
- 遍历操作节点:在循环中,代码首先再次遍历所有操作节点。如果操作的类型不在支持低精度运行的操作类型集合
op_run_low_precision_
中,那么跳过该操作。 - 遍历输入变量:对于每个操作的输入变量,代码检查变量是否不应以低精度运行,如果是,就从支持低精度运行的操作集合
op_run_low_precision_
中移除该操作,并将precision_updated
标记为true
,表示更新了精度。然后在日志中输出相关信息。 - 遍历输出变量:接下来,代码遍历每个操作的输出变量。对于每个输出变量,代码检查是否存在输入操作(根据之前记录的信息),如果存在输入操作,就检查输入操作是否支持低精度运行,以及输出变量是否不应以低精度运行。如果不满足这些条件,就从支持低精度运行的操作集合
op_run_low_precision_
中移除该操作,并将precision_updated
标记为true
,表示更新了精度。然后在日志中输出相关信息。 - 循环结束条件:循环会一直执行直到
precision_updated
不再为true
,这表示不再需要更新操作的精度。
-
反复重复步骤5,直到某次遍历后,无op的精度被改变;
-
根据op的运行精度,统一设定这些op输入(only权重)、输出var的dtype,将设定为低精度的权重var放入哈希表vars_convert_to_half_中;(SetVarPrecision)
- 获取变量存储域(Scope):首先,代码获取了用于操作的变量存储域(
scope
),确保它不为null
,以便后续能够访问变量数据。 - 遍历操作节点:代码遍历所有的操作节点,对于每个操作节点,首先检查该操作是否支持低精度运行,如果不支持,就跳过该操作。
- 遍历输入变量:对于支持低精度运行的操作,代码遍历其输入变量。对于每个输入变量,它会执行以下操作:
- 检查变量的数据类型是否为浮点32位(float32 或 float64)。
- 检查变量是否有数据类型(data type)信息。
- 检查是否应该跳过将变量转换为低精度的条件,通过
InputVarsNotConvert
函数来判断。 - 如果变量是持久变量(persistable),则检查该变量的数据类型是否与实际数据张量的数据类型匹配。如果不匹配,就跳过。
- 如果变量是持久变量,且满足转换条件,将变量的数据类型设置为低精度(
low_precision_
)并将变量名添加到vars_convert_to_low_precision_
集合中。
- 遍历输出变量:类似地,代码遍历操作的输出变量,进行类似的检查和操作,将输出变量的数据类型设置为低精度并将其添加到
vars_convert_to_low_precision_
集合中。 - 处理具有相同名称的变量:代码还包括一段用于处理具有相同名称的变量的逻辑。对于每个子图(subgraph)中的变量,如果该变量的名称在
vars_convert_to_low_precision_
集合中,那么它的数据类型也会被设置为低精度。
- 获取变量存储域(Scope):首先,代码获取了用于操作的变量存储域(
-
转换vars_convert_to_half_中权重var的精度(涉及到实际数据的读入、转换、写回);(ConvertWeightsData)
- 获取变量存储域(Scope):首先,代码获取了用于操作的变量存储域(
scope
),确保它不为null
,以便访问和操作权重变量的数据。 - 遍历局部变量:代码遍历存储域中的所有局部变量的名称(var_names)。
- 检查是否需要转换:对于每个局部变量,代码检查其名称是否在
vars_convert_to_low_precision_
集合中,如果在集合中,表示该变量需要转换为低精度。 - 执行数据类型转换:如果需要转换数据类型,代码执行以下操作:
- 获取原始的权重张量(origin_tensor)。
- 创建一个新的低精度张量(low_precision_tensor),其数据类型与配置的低精度
low_precision_
一致。 - 重新分配低精度张量的形状和数据类型。
- 遍历原始权重张量中的元素,将每个元素从高精度转换为低精度,并存储在新的低精度张量中。转换根据低精度的数据类型(float16 或 bfloat16)进行。
- 清空原始权重张量的数据。
- 使用
paddle::framework::TensorCopySync
将新的低精度张量的数据拷贝回原始权重张量中,确保数据类型转换生效。
- 获取变量存储域(Scope):首先,代码获取了用于操作的变量存储域(
-
遍历所有op,在op node和其非权重的输入var node间视情况插入cast op;(InsertCastOp)
- 初始化变量:代码中定义了一些变量,包括
suffix
用于生成唯一的 Cast 操作名称,以及cache
用于缓存已插入的 Cast 操作。 - 遍历所有操作节点:代码首先循环遍历所有操作节点,对于每个操作节点,执行以下操作:
- 检查是否为 Feed 操作,如果是,就跳过该操作。
- 检查是否为某些特殊类型的操作,特殊操作通常不需要插入 Cast 操作。
- 遍历输入变量:对于每个操作节点的输入变量,代码执行以下操作:
- 检查变量是否为变量节点。
- 检查变量是否有数据类型信息。
- 检查变量是否为持久变量(persistable),如果是,就跳过该变量。
- 获取输入变量的实际变量节点(
real_in_var_node
)以及其数据类型(in_var_type
)。
- 插入 Cast 操作:根据变量的数据类型和规则,代码决定是否需要插入 Cast 操作。具体插入 Cast 操作的条件如下:
- 如果输入变量数据类型为浮点32位或浮点64位(IsFP32AndFP64)并且该操作需要以低精度运行(
op_run_low_precision_.count(op_type)
为真),则需要插入 Cast 操作将数据类型转换为低精度。 - 如果输入变量数据类型为浮点16位或双精度 BFloat16(IsFP16AndBFP16)并且该操作不需要以低精度运行(
op_run_low_precision_.count(op_type)
为假),则需要插入 Cast 操作将数据类型转换为浮点32位。
- 如果输入变量数据类型为浮点32位或浮点64位(IsFP32AndFP64)并且该操作需要以低精度运行(
- 插入 Cast 操作:如果需要插入 Cast 操作,代码调用
DoInsertCastOp
函数来执行实际的插入操作。这个函数会处理 Cast 操作的插入,设置数据类型,生成唯一的名称,并将 Cast 操作缓存起来,以便后续使用。 - 特殊操作:对于一些特殊操作类型(如 "fused_multi_transformer"),代码执行一些额外的逻辑以处理特殊的输入和输出变量的名称。
- 日志输出:最后,代码输出插入的 Cast 操作的数量。
- 初始化变量:代码中定义了一些变量,包括
-
将每个op都的type设置回原始type;(RestoreOpOriginType)
恢复第1步操作修改的op type