Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRAPH: Analyze SDXL model graph and list graph-level issues #417

Closed
powderluv opened this issue Feb 10, 2024 · 2 comments
Closed

GRAPH: Analyze SDXL model graph and list graph-level issues #417

powderluv opened this issue Feb 10, 2024 · 2 comments
Assignees

Comments

@powderluv
Copy link
Contributor

No description provided.

@powderluv powderluv converted this from a draft issue Feb 10, 2024
@powderluv
Copy link
Contributor Author

From @qedawkins https://github.com/nod-ai/playbook/issues/6#issuecomment-1936845981

Ok some initial observations. Looking at IR immediately before iree-flow-form-dispatch-regions there are a few oddities/places to look at higher up the stack.

  • We should be lowering to torch.sdpa -> linalg_ext.attention for all of the attention operators. I remember hearing that decompositions for attention were being employed at the torch FX level, but don't know if they are kicking in for this model. If not, we can write a raising pattern for the following but it would be best not to rely on raisings.
  %148 = linalg.fill ins(%cst_1 : f16) outs(%147 : tensor<20x4096x4096xf16>) -> tensor<20x4096x4096xf16>
  %149 = linalg.batch_matmul ins(%collapsed_76, %collapsed_77 : tensor<20x4096x64xf16>, tensor<20x64x4096xf16>) outs(%148 : tensor<20x4096x4096xf16>) -> tensor<20x4096x4096xf16>
  %expanded_78 = tensor.expand_shape %149 [[0, 1], [2], [3]] : tensor<20x4096x4096xf16> into tensor<2x10x4096x4096xf16>
  %150 = tensor.empty() : tensor<2x10x4096x4096xf16>
  %151 = linalg.softmax dimension(3) ins(%expanded_78 : tensor<2x10x4096x4096xf16>) outs(%150 : tensor<2x10x4096x4096xf16>) -> tensor<2x10x4096x4096xf16>
  %collapsed_79 = tensor.collapse_shape %151 [[0, 1], [2], [3]] : tensor<2x10x4096x4096xf16> into tensor<20x4096x4096xf16>
  %collapsed_80 = tensor.collapse_shape %143 [[0, 1], [2], [3]] : tensor<2x10x4096x64xf16> into tensor<20x4096x64xf16>
  %152 = tensor.empty() : tensor<20x4096x64xf16>
  %153 = linalg.fill ins(%cst_1 : f16) outs(%152 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
  %154 = linalg.batch_matmul ins(%collapsed_79, %collapsed_80 : tensor<20x4096x4096xf16>, tensor<20x4096x64xf16>) outs(%153 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>

  • There is a decent amount of strange looking extension/truncation arithmetic to temporarily do some computations in f32. It would be good to understand whether that is baked in to the model or being introduced by lowerings.
  %163 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%159#0, %161, %162, %_params.unet.down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight, %_params.unet.down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias : tensor<8192x640xf16>, tensor<8192xf32>, tensor<8192xf32>, tensor<640xf16>, tensor<640xf16>) outs(%127 : tensor<8192x640xf16>) {
  ^bb0(%in: f16, %in_2126: f32, %in_2127: f32, %in_2128: f16, %in_2129: f16, %out: f16):
    %3783 = arith.divf %in_2127, %cst_10 : f32
    %3784 = arith.addf %3783, %cst_6 : f32
    %3785 = math.rsqrt %3784 : f32
    %3786 = arith.extf %in : f16 to f32
    %3787 = arith.subf %3786, %in_2126 : f32
    %3788 = arith.mulf %3787, %3785 : f32
    %3789 = arith.extf %in_2128 : f16 to f32
    %3790 = arith.mulf %3788, %3789 : f32
    %3791 = arith.extf %in_2129 : f16 to f32
    %3792 = arith.addf %3790, %3791 : f32
    %3793 = arith.truncf %3792 : f32 to f16
    linalg.yield %3793 : f16
  } -> tensor<8192x640xf16>

  • NCHW convolutions will be problematic for performance. All loads of the (large in the case of unet) filter will be scalarized when trying to target mma intrinsics, unless we play some tricks with img2col.

%252 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_130, %_params.unet.down_blocks.1.resnets.1.conv1.weight : tensor<2x640x66x66xf16>, tensor<640x640x3x3xf16>) outs(%97 : tensor<2x640x64x64xf16>) -> tensor<2x640x64x64xf16>

  • Pads are becoming fill + insert slice?

  %115 = linalg.fill ins(%cst_1 : f16) outs(%114 : tensor<2x640x66x66xf16>) -> tensor<2x640x66x66xf16>
  %inserted_slice_65 = tensor.insert_slice %113 into %115[0, 0, 1, 1] [2, 640, 64, 64] [1, 1, 1, 1] : tensor<2x640x64x64xf16> into tensor<2x640x66x66xf16>
  %116 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_65, %_params.unet.down_blocks.1.resnets.0.conv2.weight : tensor<2x640x66x66xf16>, tensor<640x640x3x3xf16>) outs(%97 : tensor<2x640x64x64xf16>) -> tensor<2x640x64x64xf16>

(this pattern might just be kicking in before forming dispatches if pad fusion isn't enabled).

  • Some elementwise dispatches that failed to fuse
        %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<2x77x640xf16>) outs(%3 : tensor<2x640x77xf16>) {
        ^bb0(%in: f16, %out: f16):
          %5 = arith.truncf %cst : f32 to f16
          %6 = arith.mulf %in, %5 : f16
          linalg.yield %6 : f16
        } -> tensor<2x640x77xf16>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6 : tensor<3276800xf16>) outs(%7 : tensor<3276800xf32>) {
        ^bb0(%in: f16, %out: f32):
          %9 = arith.extf %in : f16 to f32
          linalg.yield %9 : f32
        } -> tensor<3276800xf32>
        %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %7 : tensor<2x4096x2560xf16>, tensor<2x4096x2560xf16>) outs(%6 : tensor<2x4096x2560xf32>) {
        ^bb0(%in: f16, %in_2: f16, %out: f32):
          %10 = math.sqrt %cst_1 : f16
          %11 = arith.divf %in_2, %10 : f16
          %12 = math.erf %11 : f16
          %13 = arith.addf %12, %cst_0 : f16
          %14 = arith.mulf %13, %cst : f16
          %15 = arith.mulf %in_2, %14 : f16
          %16 = arith.mulf %in, %15 : f16
          %17 = arith.extf %16 : f16 to f32
          linalg.yield %17 : f32
        } -> tensor<2x4096x2560xf32>

  • Some potentially beneficial horizontal fusions here:

  %278 = linalg.matmul_transpose_b ins(%277, %_params.unet.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight : tensor<8192x640xf16>, tensor<640x640xf16>) outs(%128 : tensor<8192x640xf16>) -> tensor<8192x640xf16>
  %279 = linalg.matmul_transpose_b ins(%277, %_params.unet.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight : tensor<8192x640xf16>, tensor<640x640xf16>) outs(%128 : tensor<8192x640xf16>) -> tensor<8192x640xf16>
  %280 = linalg.matmul_transpose_b ins(%277, %_params.unet.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight : tensor<8192x640xf16>, tensor<640x640xf16>) outs(%128 : tensor<8192x640xf16>) -> tensor<8192x640xf16>

The LHS is shared and larger than the right hand sides of these matmuls, meaning there could be a chance to reduce the number of times we load the left. Probably only worth doing as an experiment if we have time.

@antiagainst
Copy link

Closing this one given Turbo has the same architecture so #418 is good for tracking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

No branches or pull requests

3 participants