Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRAPH: Analyze SDXL Turbo model graph and list graph-level issues #418

Closed
powderluv opened this issue Feb 10, 2024 · 3 comments
Closed

GRAPH: Analyze SDXL Turbo model graph and list graph-level issues #418

powderluv opened this issue Feb 10, 2024 · 3 comments
Assignees

Comments

@powderluv
Copy link
Contributor

No description provided.

@antiagainst
Copy link

antiagainst commented Feb 15, 2024

From @qedawkins: Ok some initial observations. Looking at IR immediately before iree-flow-form-dispatch-regions there are a few oddities/places to look at higher up the stack.

  1. GRAPH: Lower PyTorch sdpa to linalg ext attention #433 We should be lowering to torch.sdpa -> linalg_ext.attention for all of the attention operators. I remember hearing that decompositions for attention were being employed at the torch FX level, but don't know if they are kicking in for this model. If not, we can write a raising pattern for the following but it would be best not to rely on raisings.
  %148 = linalg.fill ins(%cst_1 : f16) outs(%147 : tensor<20x4096x4096xf16>) -> tensor<20x4096x4096xf16>
  %149 = linalg.batch_matmul ins(%collapsed_76, %collapsed_77 : tensor<20x4096x64xf16>, tensor<20x64x4096xf16>) outs(%148 : tensor<20x4096x4096xf16>) -> tensor<20x4096x4096xf16>
  %expanded_78 = tensor.expand_shape %149 [[0, 1], [2], [3]] : tensor<20x4096x4096xf16> into tensor<2x10x4096x4096xf16>
  %150 = tensor.empty() : tensor<2x10x4096x4096xf16>
  %151 = linalg.softmax dimension(3) ins(%expanded_78 : tensor<2x10x4096x4096xf16>) outs(%150 : tensor<2x10x4096x4096xf16>) -> tensor<2x10x4096x4096xf16>
  %collapsed_79 = tensor.collapse_shape %151 [[0, 1], [2], [3]] : tensor<2x10x4096x4096xf16> into tensor<20x4096x4096xf16>
  %collapsed_80 = tensor.collapse_shape %143 [[0, 1], [2], [3]] : tensor<2x10x4096x64xf16> into tensor<20x4096x64xf16>
  %152 = tensor.empty() : tensor<20x4096x64xf16>
  %153 = linalg.fill ins(%cst_1 : f16) outs(%152 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
  %154 = linalg.batch_matmul ins(%collapsed_79, %collapsed_80 : tensor<20x4096x4096xf16>, tensor<20x4096x64xf16>) outs(%153 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
  1. There is a decent amount of strange looking extension/truncation arithmetic to temporarily do some computations in f32. It would be good to understand whether that is baked in to the model or being introduced by lowerings.
  %163 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%159#0, %161, %162, %_params.unet.down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight, %_params.unet.down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias : tensor<8192x640xf16>, tensor<8192xf32>, tensor<8192xf32>, tensor<640xf16>, tensor<640xf16>) outs(%127 : tensor<8192x640xf16>) {
  ^bb0(%in: f16, %in_2126: f32, %in_2127: f32, %in_2128: f16, %in_2129: f16, %out: f16):
    %3783 = arith.divf %in_2127, %cst_10 : f32
    %3784 = arith.addf %3783, %cst_6 : f32
    %3785 = math.rsqrt %3784 : f32
    %3786 = arith.extf %in : f16 to f32
    %3787 = arith.subf %3786, %in_2126 : f32
    %3788 = arith.mulf %3787, %3785 : f32
    %3789 = arith.extf %in_2128 : f16 to f32
    %3790 = arith.mulf %3788, %3789 : f32
    %3791 = arith.extf %in_2129 : f16 to f32
    %3792 = arith.addf %3790, %3791 : f32
    %3793 = arith.truncf %3792 : f32 to f16
    linalg.yield %3793 : f16
  } -> tensor<8192x640xf16>
  1. GRAPH: Add pass for NCHW -> NHWC convolution transposing #448 NCHW convolutions will be problematic for performance. All loads of the (large in the case of unet) filter will be scalarized when trying to target mma intrinsics, unless we play some tricks with img2col.
%252 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_130, %_params.unet.down_blocks.1.resnets.1.conv1.weight : tensor<2x640x66x66xf16>, tensor<640x640x3x3xf16>) outs(%97 : tensor<2x640x64x64xf16>) -> tensor<2x640x64x64xf16>
  1. Pads are becoming fill + insert slice?
  %115 = linalg.fill ins(%cst_1 : f16) outs(%114 : tensor<2x640x66x66xf16>) -> tensor<2x640x66x66xf16>
  %inserted_slice_65 = tensor.insert_slice %113 into %115[0, 0, 1, 1] [2, 640, 64, 64] [1, 1, 1, 1] : tensor<2x640x64x64xf16> into tensor<2x640x66x66xf16>
  %116 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_65, %_params.unet.down_blocks.1.resnets.0.conv2.weight : tensor<2x640x66x66xf16>, tensor<640x640x3x3xf16>) outs(%97 : tensor<2x640x64x64xf16>) -> tensor<2x640x64x64xf16>

(this pattern might just be kicking in before forming dispatches if pad fusion isn't enabled).

  1. GRAPH: Investigate missing elementwise (extf/dequant-like) fusions #447 Some elementwise dispatches that failed to fuse
        %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<2x77x640xf16>) outs(%3 : tensor<2x640x77xf16>) {
        ^bb0(%in: f16, %out: f16):
          %5 = arith.truncf %cst : f32 to f16
          %6 = arith.mulf %in, %5 : f16
          linalg.yield %6 : f16
        } -> tensor<2x640x77xf16>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6 : tensor<3276800xf16>) outs(%7 : tensor<3276800xf32>) {
        ^bb0(%in: f16, %out: f32):
          %9 = arith.extf %in : f16 to f32
          linalg.yield %9 : f32
        } -> tensor<3276800xf32>
        %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %7 : tensor<2x4096x2560xf16>, tensor<2x4096x2560xf16>) outs(%6 : tensor<2x4096x2560xf32>) {
        ^bb0(%in: f16, %in_2: f16, %out: f32):
          %10 = math.sqrt %cst_1 : f16
          %11 = arith.divf %in_2, %10 : f16
          %12 = math.erf %11 : f16
          %13 = arith.addf %12, %cst_0 : f16
          %14 = arith.mulf %13, %cst : f16
          %15 = arith.mulf %in_2, %14 : f16
          %16 = arith.mulf %in, %15 : f16
          %17 = arith.extf %16 : f16 to f32
          linalg.yield %17 : f32
        } -> tensor<2x4096x2560xf32>
  1. GRAPH: enable horizontal fusion in SDXL #495 Some potentially beneficial horizontal fusions here:
  %278 = linalg.matmul_transpose_b ins(%277, %_params.unet.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight : tensor<8192x640xf16>, tensor<640x640xf16>) outs(%128 : tensor<8192x640xf16>) -> tensor<8192x640xf16>
  %279 = linalg.matmul_transpose_b ins(%277, %_params.unet.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight : tensor<8192x640xf16>, tensor<640x640xf16>) outs(%128 : tensor<8192x640xf16>) -> tensor<8192x640xf16>
  %280 = linalg.matmul_transpose_b ins(%277, %_params.unet.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight : tensor<8192x640xf16>, tensor<640x640xf16>) outs(%128 : tensor<8192x640xf16>) -> tensor<8192x640xf16>

The LHS is shared and larger than the right hand sides of these matmuls, meaning there could be a chance to reduce the number of times we load the left. Probably only worth doing as an experiment if we have time.

@monorimet
Copy link
Contributor

  1. GRAPH: Lower PyTorch sdpa to linalg ext attention #433 We should be lowering to torch.sdpa -> linalg_ext.attention for all of the attention operators. I remember hearing that decompositions for attention were being employed at the torch FX level, but don't know if they are kicking in for this model. If not, we can write a raising pattern for the following but it would be best not to rely on raisings.

FWIW There won't be any decomposition done for sdpa or its variants at torch.fx level going forward, as of #271

@antiagainst
Copy link

Closing this as the analysis is done. We have issues to track implementing them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

No branches or pull requests

4 participants