Skip to content

混合精度pass梳理

Xinyu Yang edited this page Nov 3, 2023 · 1 revision

首先给出几点原则:

  1. 同种类型的op的运行精度可以不同; 这是因为某op局部的情况可能会导致原本支持低精度的op只能run在float32下。局部不应影响全局。
  2. 涉及到环的op的运行精度需要保持一致; 这是因为涉及到环的op之间无法正确插入cast op。
  3. 根据op的运行精度,设置其输出var的dtype;
  4. op的运行精度和其输入var的dtype不同时,插入cast op做转换;

为了遵守以上几点原则,算法步骤如下:

  1. 根据原则1,给每一个op赋予一个独一无二的type;(SetOpUniqueType)

    1. 遍历操作节点:代码使用两层嵌套循环,首先遍历了一个名为 all_op_nodes_ 的数据结构,其中包含所有操作节点。外层循环迭代所有不同类型的操作,内层循环迭代每个操作类型中的具体操作节点。
    2. 检查特殊操作:在每次内层循环迭代中,首先检查操作类型是否为 "feed" 或 "fetch",如果是的话,就跳过该操作。这是因为 "feed" 和 "fetch" 操作通常用于数据输入和输出,通常不需要改变其类型。
    3. 生成唯一类型:对于其他类型的操作,代码生成一个唯一的类型 unique_type,这个唯一类型包括原始操作类型 op_type 和一个递增的后缀 suffix。后缀是为了确保每个操作都有不同的唯一类型。
    4. 记录原始类型:将唯一类型和原始类型的对应关系记录在 op_original_type_ 中,这样以后可以根据唯一类型找到原始类型。
    5. 设置操作类型:使用 op_node->Op()->SetType(unique_type) 将操作节点的类型设置为唯一类型,从而改变操作的类型。
    6. 刷新操作:调用 op_node->Op()->Flush() 来刷新操作,确保类型的更改在计算图中得到应用。
  2. 根据kernel库中op的注册情况,初步设定每个op run的精度;(GetOpPrecision)

    1. 遍历操作节点:代码首先使用两层嵌套循环,迭代了所有操作节点,按照不同的操作类型分类遍历这些节点。
    2. 检查特殊操作:对于特殊的操作类型,如 "feed" 和 "fetch",代码会根据 enable_low_precision_io_ 的值来确定是否支持低精度运行。如果 enable_low_precision_io_ 为 true,那么这些特殊操作可以以低精度运行。
    3. 对于 "tensorrt_engine" 类型的操作,代码会检查操作的属性(attributes)来确定是否支持低精度运行。具体来说,它会检查是否启用了 float16(enable_fp16)、是否启用了 int8(enable_int8)以及是否启用了低精度 I/O(enable_low_precision_io)。只有在满足这些条件的情况下,支持低精度运行。
    4. 对于其他类型的操作,代码会调用 OpSupportPrecision 函数来检查操作是否支持低精度运行。同时,还会检查操作的输入和输出类型是否是浮点32位(float32 或 float64),以确保数据类型的一致性。某些操作可能需要特殊处理,例如 "scale" 操作,它会检查 "scale" 和 "bias" 属性的值是否在低精度范围内,以确定是否支持低精度运行。
    5. 检查输入和输出变量类型:代码还会检查操作的输入和输出变量的类型,确保它们都是密集张量(dense tensor),如果不是,那么操作就不支持低精度运行。
    6. 记录支持低精度运行的操作:对于支持低精度运行的操作,代码将其操作类型添加到 op_run_low_precision_ 集合中,并输出相应的日志信息。
  3. 步骤2中的做法是将支持运行低精度的op存入哈希表(op_run_half_)中;

  4. 根据原则2,首先获取每个var的所有input op,即op的输出var name相同的op集合(input_op_nodes);

  5. 遍历每一个op(当前op),遍历这个op的输出var的所有input op(input_op_nodes),当当前op支持低精度时,如果input_op_nodes中有op不支持低精度,将当前op从低精度哈希表中移除,并标记本次遍历有op的精度发生变化;(UpdateOpPrecision)

    1. 创建数据结构和函数:在函数开头,代码定义了几个数据结构和函数,用于跟踪哪些变量不应该以低精度运行,以及哪些操作的输入变量来自哪些操作。这些信息将在后续的判断中使用。
    2. 获取变量的输入操作:GetVarInputOps 函数遍历所有操作节点,获取每个变量的输入操作,并根据一些条件将变量添加到 vars_should_not_low_precision 集合中。条件包括特殊操作类型、操作的输入和输出类型等。此函数的目的是为了记录哪些变量不应以低精度运行。
    3. 循环更新操作精度:接下来,代码进入一个循环,不断尝试更新操作的运行精度,直到不再有操作的运行精度需要更新为止。
    4. 遍历操作节点:在循环中,代码首先再次遍历所有操作节点。如果操作的类型不在支持低精度运行的操作类型集合 op_run_low_precision_ 中,那么跳过该操作。
    5. 遍历输入变量:对于每个操作的输入变量,代码检查变量是否不应以低精度运行,如果是,就从支持低精度运行的操作集合 op_run_low_precision_ 中移除该操作,并将 precision_updated 标记为 true,表示更新了精度。然后在日志中输出相关信息。
    6. 遍历输出变量:接下来,代码遍历每个操作的输出变量。对于每个输出变量,代码检查是否存在输入操作(根据之前记录的信息),如果存在输入操作,就检查输入操作是否支持低精度运行,以及输出变量是否不应以低精度运行。如果不满足这些条件,就从支持低精度运行的操作集合 op_run_low_precision_ 中移除该操作,并将 precision_updated 标记为 true,表示更新了精度。然后在日志中输出相关信息。
    7. 循环结束条件:循环会一直执行直到 precision_updated 不再为 true,这表示不再需要更新操作的精度。
  6. 反复重复步骤5,直到某次遍历后,无op的精度被改变;

  7. 根据op的运行精度,统一设定这些op输入(only权重)、输出var的dtype,将设定为低精度的权重var放入哈希表vars_convert_to_half_中;(SetVarPrecision)

    1. 获取变量存储域(Scope):首先,代码获取了用于操作的变量存储域(scope),确保它不为 null,以便后续能够访问变量数据。
    2. 遍历操作节点:代码遍历所有的操作节点,对于每个操作节点,首先检查该操作是否支持低精度运行,如果不支持,就跳过该操作。
    3. 遍历输入变量:对于支持低精度运行的操作,代码遍历其输入变量。对于每个输入变量,它会执行以下操作:
      • 检查变量的数据类型是否为浮点32位(float32 或 float64)。
      • 检查变量是否有数据类型(data type)信息。
      • 检查是否应该跳过将变量转换为低精度的条件,通过 InputVarsNotConvert 函数来判断。
      • 如果变量是持久变量(persistable),则检查该变量的数据类型是否与实际数据张量的数据类型匹配。如果不匹配,就跳过。
      • 如果变量是持久变量,且满足转换条件,将变量的数据类型设置为低精度(low_precision_)并将变量名添加到 vars_convert_to_low_precision_ 集合中。
    4. 遍历输出变量:类似地,代码遍历操作的输出变量,进行类似的检查和操作,将输出变量的数据类型设置为低精度并将其添加到 vars_convert_to_low_precision_ 集合中。
    5. 处理具有相同名称的变量:代码还包括一段用于处理具有相同名称的变量的逻辑。对于每个子图(subgraph)中的变量,如果该变量的名称在 vars_convert_to_low_precision_ 集合中,那么它的数据类型也会被设置为低精度。
  8. 转换vars_convert_to_half_中权重var的精度(涉及到实际数据的读入、转换、写回);(ConvertWeightsData)

    1. 获取变量存储域(Scope):首先,代码获取了用于操作的变量存储域(scope),确保它不为 null,以便访问和操作权重变量的数据。
    2. 遍历局部变量:代码遍历存储域中的所有局部变量的名称(var_names)。
    3. 检查是否需要转换:对于每个局部变量,代码检查其名称是否在 vars_convert_to_low_precision_ 集合中,如果在集合中,表示该变量需要转换为低精度。
    4. 执行数据类型转换:如果需要转换数据类型,代码执行以下操作:
      • 获取原始的权重张量(origin_tensor)。
      • 创建一个新的低精度张量(low_precision_tensor),其数据类型与配置的低精度 low_precision_ 一致。
      • 重新分配低精度张量的形状和数据类型。
      • 遍历原始权重张量中的元素,将每个元素从高精度转换为低精度,并存储在新的低精度张量中。转换根据低精度的数据类型(float16 或 bfloat16)进行。
      • 清空原始权重张量的数据。
      • 使用 paddle::framework::TensorCopySync 将新的低精度张量的数据拷贝回原始权重张量中,确保数据类型转换生效。
  9. 遍历所有op,在op node和其非权重的输入var node间视情况插入cast op;(InsertCastOp)

    1. 初始化变量:代码中定义了一些变量,包括 suffix 用于生成唯一的 Cast 操作名称,以及 cache 用于缓存已插入的 Cast 操作。
    2. 遍历所有操作节点:代码首先循环遍历所有操作节点,对于每个操作节点,执行以下操作:
      • 检查是否为 Feed 操作,如果是,就跳过该操作。
      • 检查是否为某些特殊类型的操作,特殊操作通常不需要插入 Cast 操作。
    3. 遍历输入变量:对于每个操作节点的输入变量,代码执行以下操作:
      • 检查变量是否为变量节点。
      • 检查变量是否有数据类型信息。
      • 检查变量是否为持久变量(persistable),如果是,就跳过该变量。
      • 获取输入变量的实际变量节点(real_in_var_node)以及其数据类型(in_var_type)。
    4. 插入 Cast 操作:根据变量的数据类型和规则,代码决定是否需要插入 Cast 操作。具体插入 Cast 操作的条件如下:
      • 如果输入变量数据类型为浮点32位或浮点64位(IsFP32AndFP64)并且该操作需要以低精度运行(op_run_low_precision_.count(op_type) 为真),则需要插入 Cast 操作将数据类型转换为低精度。
      • 如果输入变量数据类型为浮点16位或双精度 BFloat16(IsFP16AndBFP16)并且该操作不需要以低精度运行(op_run_low_precision_.count(op_type) 为假),则需要插入 Cast 操作将数据类型转换为浮点32位。
    5. 插入 Cast 操作:如果需要插入 Cast 操作,代码调用 DoInsertCastOp 函数来执行实际的插入操作。这个函数会处理 Cast 操作的插入,设置数据类型,生成唯一的名称,并将 Cast 操作缓存起来,以便后续使用。
    6. 特殊操作:对于一些特殊操作类型(如 "fused_multi_transformer"),代码执行一些额外的逻辑以处理特殊的输入和输出变量的名称。
    7. 日志输出:最后,代码输出插入的 Cast 操作的数量。
  10. 将每个op都的type设置回原始type;(RestoreOpOriginType)

    恢复第1步操作修改的op type