-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GRAPH: Analyze SDXL Turbo model graph and list graph-level issues #418
Comments
From @qedawkins: Ok some initial observations. Looking at IR immediately before
%148 = linalg.fill ins(%cst_1 : f16) outs(%147 : tensor<20x4096x4096xf16>) -> tensor<20x4096x4096xf16>
%149 = linalg.batch_matmul ins(%collapsed_76, %collapsed_77 : tensor<20x4096x64xf16>, tensor<20x64x4096xf16>) outs(%148 : tensor<20x4096x4096xf16>) -> tensor<20x4096x4096xf16>
%expanded_78 = tensor.expand_shape %149 [[0, 1], [2], [3]] : tensor<20x4096x4096xf16> into tensor<2x10x4096x4096xf16>
%150 = tensor.empty() : tensor<2x10x4096x4096xf16>
%151 = linalg.softmax dimension(3) ins(%expanded_78 : tensor<2x10x4096x4096xf16>) outs(%150 : tensor<2x10x4096x4096xf16>) -> tensor<2x10x4096x4096xf16>
%collapsed_79 = tensor.collapse_shape %151 [[0, 1], [2], [3]] : tensor<2x10x4096x4096xf16> into tensor<20x4096x4096xf16>
%collapsed_80 = tensor.collapse_shape %143 [[0, 1], [2], [3]] : tensor<2x10x4096x64xf16> into tensor<20x4096x64xf16>
%152 = tensor.empty() : tensor<20x4096x64xf16>
%153 = linalg.fill ins(%cst_1 : f16) outs(%152 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%154 = linalg.batch_matmul ins(%collapsed_79, %collapsed_80 : tensor<20x4096x4096xf16>, tensor<20x4096x64xf16>) outs(%153 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%163 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%159#0, %161, %162, %_params.unet.down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight, %_params.unet.down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias : tensor<8192x640xf16>, tensor<8192xf32>, tensor<8192xf32>, tensor<640xf16>, tensor<640xf16>) outs(%127 : tensor<8192x640xf16>) {
^bb0(%in: f16, %in_2126: f32, %in_2127: f32, %in_2128: f16, %in_2129: f16, %out: f16):
%3783 = arith.divf %in_2127, %cst_10 : f32
%3784 = arith.addf %3783, %cst_6 : f32
%3785 = math.rsqrt %3784 : f32
%3786 = arith.extf %in : f16 to f32
%3787 = arith.subf %3786, %in_2126 : f32
%3788 = arith.mulf %3787, %3785 : f32
%3789 = arith.extf %in_2128 : f16 to f32
%3790 = arith.mulf %3788, %3789 : f32
%3791 = arith.extf %in_2129 : f16 to f32
%3792 = arith.addf %3790, %3791 : f32
%3793 = arith.truncf %3792 : f32 to f16
linalg.yield %3793 : f16
} -> tensor<8192x640xf16>
%115 = linalg.fill ins(%cst_1 : f16) outs(%114 : tensor<2x640x66x66xf16>) -> tensor<2x640x66x66xf16>
%inserted_slice_65 = tensor.insert_slice %113 into %115[0, 0, 1, 1] [2, 640, 64, 64] [1, 1, 1, 1] : tensor<2x640x64x64xf16> into tensor<2x640x66x66xf16>
%116 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_65, %_params.unet.down_blocks.1.resnets.0.conv2.weight : tensor<2x640x66x66xf16>, tensor<640x640x3x3xf16>) outs(%97 : tensor<2x640x64x64xf16>) -> tensor<2x640x64x64xf16> (this pattern might just be kicking in before forming dispatches if pad fusion isn't enabled).
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<2x77x640xf16>) outs(%3 : tensor<2x640x77xf16>) {
^bb0(%in: f16, %out: f16):
%5 = arith.truncf %cst : f32 to f16
%6 = arith.mulf %in, %5 : f16
linalg.yield %6 : f16
} -> tensor<2x640x77xf16> %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6 : tensor<3276800xf16>) outs(%7 : tensor<3276800xf32>) {
^bb0(%in: f16, %out: f32):
%9 = arith.extf %in : f16 to f32
linalg.yield %9 : f32
} -> tensor<3276800xf32> %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %7 : tensor<2x4096x2560xf16>, tensor<2x4096x2560xf16>) outs(%6 : tensor<2x4096x2560xf32>) {
^bb0(%in: f16, %in_2: f16, %out: f32):
%10 = math.sqrt %cst_1 : f16
%11 = arith.divf %in_2, %10 : f16
%12 = math.erf %11 : f16
%13 = arith.addf %12, %cst_0 : f16
%14 = arith.mulf %13, %cst : f16
%15 = arith.mulf %in_2, %14 : f16
%16 = arith.mulf %in, %15 : f16
%17 = arith.extf %16 : f16 to f32
linalg.yield %17 : f32
} -> tensor<2x4096x2560xf32>
The LHS is shared and larger than the right hand sides of these matmuls, meaning there could be a chance to reduce the number of times we load the left. Probably only worth doing as an experiment if we have time. |
FWIW There won't be any decomposition done for sdpa or its variants at torch.fx level going forward, as of #271 |
Closing this as the analysis is done. We have issues to track implementing them. |
No description provided.
The text was updated successfully, but these errors were encountered: