diff --git a/BUILD.bazel b/BUILD.bazel index 949d669543..6eafc665da 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -994,7 +994,7 @@ cc_library( ":linalg_passes", ":reference_api", ":reference_configuration", - ":stablehlo_dialect_capi_objects", + ":stablehlo_dialect_capi", ":stablehlo_ops", ":stablehlo_passes", ":stablehlo_portable_api", diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 58281118ce..4df05fdac6 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "0e779ad4998ef65907502101c5b82ede05ddfa4e" +LLVM_COMMIT = "43d71baae36c8d8b5a9995aa35efebe09cc9c2d6" -LLVM_SHA256 = "d5c2560b2d9ce3ced7951113f2b5d1ea428665678f4dcb1fb8780eb1219ca615" +LLVM_SHA256 = "436af8b4c3403e251ab0b7a471eda7df6063f9da9d22ccbe498f3115cd35225a" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index dca3e3453a..7404929700 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -0e779ad4998ef65907502101c5b82ede05ddfa4e +43d71baae36c8d8b5a9995aa35efebe09cc9c2d6 diff --git a/stablehlo/dialect/ChloOps.cpp b/stablehlo/dialect/ChloOps.cpp index b175006fcf..e1dd8c5eef 100644 --- a/stablehlo/dialect/ChloOps.cpp +++ b/stablehlo/dialect/ChloOps.cpp @@ -16,9 +16,13 @@ limitations under the License. #include "stablehlo/dialect/ChloOps.h" +#include #include #include +#include +#include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -426,12 +430,12 @@ namespace { // Mode 1, where the ragged dimension is an lhs non-contracting dim (m). // lhs : [b, m, k] // rhs : [g, b, k, n] -// group_sizes : [g] +// group_sizes : [b, g] // result : [b, m, n] // Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k). // lhs : [b, m, k] // rhs : [b, k, n] -// group_sizes : [g] +// group_sizes : [b, g] // result : [g, b, m, n] // Mode 3, where the ragged dimension is an lhs/rhs batch dim (b). // lhs : [b, m, k] @@ -440,9 +444,18 @@ namespace { // result : [b, m, n] // As with dot_general, the lhs and rhs can have arbitrary batching, // contracting and non-contracting dimensions. +// The group_sizes arg has the shape [b...,x...,g], where: +// - b... are all the lhs batch dims before (outer-to) the lhs ragged dim, +// - x... are, +// - in mode 1, all the lhs non-contracting dims before the lhs ragged dim, +// - in mode 2, all the lhs contracting dims before the lhs ragged dim, and +// - in mode 3, empty; +// - g is the number of groups in the lhs ragged dim. // Additionally: // - In all modes, the lhs must have exactly one ragged dimension. // - In mode 1, the rhs must have exactly one group dimension. +// - If a group_sizes of shape [g] is passed, it is broadcasted according to +// the rules above. LogicalResult checkRaggedDotConstraints( std::optional location, RankedTensorType rankedLhsType, RankedTensorType rankedRhsType, RankedTensorType rankedGroupSizesType, @@ -452,14 +465,6 @@ LogicalResult checkRaggedDotConstraints( ArrayRef rhsContractingDimensions, ArrayRef lhsRaggedDimensions, ArrayRef rhsGroupDimensions) { - // Check that the group sizes has rank=1. - if (rankedGroupSizesType.getRank() != 1) { - return emitOptionalError( - location, "expected rank of group_sizes of ragged dot to be 1, got ", - rankedGroupSizesType.getRank()); - } - auto numGroups = rankedGroupSizesType.getDimSize(0); - // Check that there is exactly one lhs ragged dimension. if (lhsRaggedDimensions.size() != 1) { return emitOptionalError( @@ -474,6 +479,82 @@ LogicalResult checkRaggedDotConstraints( return failure(); } + enum Mode { + // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]. + kNonContracting, + // Ragged contracting (k): [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]. + kContracting, + // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n]. + kBatch + }; + Mode mode; + if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim)) { + mode = kBatch; + } else if (llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) { + mode = kContracting; + } else { + mode = kNonContracting; + } + + // Validate the shape of group_sizes. + { + // Construct the expected shape [b...,x...,g] of group_sizes. + SmallVector prefixDims; + prefixDims.reserve(rankedLhsType.getRank() - 1); + prefixDims.insert(prefixDims.end(), lhsBatchingDimensions.begin(), + lhsBatchingDimensions.end()); + switch (mode) { + case kBatch: + prefixDims.resize( + std::distance(lhsBatchingDimensions.begin(), + llvm::find(lhsBatchingDimensions, lhsRaggedDim))); + break; + case kContracting: + prefixDims.insert(prefixDims.end(), lhsContractingDimensions.begin(), + llvm::find(lhsContractingDimensions, lhsRaggedDim)); + break; + case kNonContracting: + for (int64_t i = 0; i < lhsRaggedDim; ++i) { + if (!llvm::is_contained(lhsBatchingDimensions, i) && + !llvm::is_contained(lhsContractingDimensions, i)) { + prefixDims.push_back(i); + } + } + break; + } + SmallVector expectedPrefix; + expectedPrefix.reserve(prefixDims.size()); + for (const int64_t dim : prefixDims) { + expectedPrefix.push_back(rankedLhsType.getDimSize(dim)); + } + + // Validate the actual shape, if it was passed as something other than [g]. + if (rankedGroupSizesType.getRank() != 1) { + if (rankedGroupSizesType.getRank() != + static_cast(expectedPrefix.size()) + 1) { + return emitOptionalError(location, "expected group_sizes to have rank ", + expectedPrefix.size() + 1, ", got ", + rankedGroupSizesType.getRank()); + } + auto groupSizesShape = rankedGroupSizesType.getShape(); + if (!std::equal(expectedPrefix.begin(), expectedPrefix.end(), + groupSizesShape.begin())) { + auto nonEmptyShapeStr = [](ArrayRef shape) { + std::string s = ""; + for (size_t i = 0; i < shape.size() - 1; ++i) { + s += std::to_string(shape[i]) + ", "; + } + return s + std::to_string(shape.back()); + }; + return emitOptionalError( + location, "group_sizes is expected to have shape [", + nonEmptyShapeStr(expectedPrefix), ", ", groupSizesShape.back(), + "], got [", nonEmptyShapeStr(groupSizesShape), "]"); + } + } + } + const int64_t numGroups = rankedGroupSizesType.getShape().back(); + // Validate basic properties of the rhs group dimension(s). for (auto rhsGroupDim : rhsGroupDimensions) { if (failed(hlo::checkDimInBounds(location, rhsGroupDim, @@ -491,32 +572,34 @@ LogicalResult checkRaggedDotConstraints( return failure(); } - if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim) || - llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) { - // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n]. - // Ragged contracting (k): [b,m,k], [b,k,n], [g] -> [g,b,m,n]. - if (!rhsGroupDimensions.empty()) { - return emitOptionalError( - location, - "There must be zero group dimensions in the rhs when the " - "ragged dimension is batch or contracting."); - } - } else { - // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [g] -> [b,m,n]. - if (rhsGroupDimensions.size() != 1) { - return emitOptionalError( - location, - "There must be exactly one group dimension in the rhs when the lhs " - "ragged dimension is non-contracting."); - } - // Compare the group dimension size with the number of groups. - const int64_t rhsGroupDim = rhsGroupDimensions[0]; - if (!hlo::verifyCompatibleDims(numGroups, - rankedRhsType.getDimSize(rhsGroupDim))) { - return emitOptionalError( - location, "group_sizes is expected to have shape=[", - rankedRhsType.getDimSize(rhsGroupDim), "], got [", numGroups, "]"); - } + switch (mode) { + case kBatch: + [[fallthrough]]; + case kContracting: + if (!rhsGroupDimensions.empty()) { + return emitOptionalError( + location, + "There must be zero group dimensions in the rhs when the " + "ragged dimension is batch or contracting."); + } + break; + case kNonContracting: + if (rhsGroupDimensions.size() != 1) { + return emitOptionalError( + location, + "There must be exactly one group dimension in the rhs when the lhs " + "ragged dimension is non-contracting."); + } + // Compare the group dimension size with the number of groups. + const int64_t rhsGroupDim = rhsGroupDimensions[0]; + if (!hlo::verifyCompatibleDims(numGroups, + rankedRhsType.getDimSize(rhsGroupDim))) { + return emitOptionalError( + location, + "rhs group dimension is expected to have size=", numGroups, + ", got ", rankedRhsType.getDimSize(rhsGroupDim)); + } + break; } return success(); } @@ -530,10 +613,10 @@ SmallVector inferRaggedDotOutputDimensions( ArrayRef rhsContractingDimensions, ArrayRef lhsRaggedDimensions, ArrayRef rhsGroupDimensions) { - // Must have already checked that group_sizes is 1-D. - const int64_t numGroups = rankedGroupSizesType.getDimSize(0); // Must have already checked that there is exactly one lhs ragged dim. const int64_t lhsRaggedDim = lhsRaggedDimensions[0]; + // Must have already checked the shape of group_sizes. + const int64_t numGroups = rankedGroupSizesType.getShape().back(); SmallVector dimensions; // Add the group dimension to the result shape in case of ragged contracting. diff --git a/stablehlo/dialect/ChloOps.td b/stablehlo/dialect/ChloOps.td index 1cca13b2ac..9ffccc45e4 100644 --- a/stablehlo/dialect/ChloOps.td +++ b/stablehlo/dialect/ChloOps.td @@ -869,12 +869,12 @@ def CHLO_RaggedDotOp : CHLO_Op<"ragged_dot", most one group dimension. The op has three modes, depending on the kind of the lhs ragged dimension. - In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`. + In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [b,g] -> [b,m,n]`. Here the ragged dimension is an lhs non-contracting dimension (`m`). The dimensions `b` and `k` represent batch and contracting dimensions respectively. The rhs is required to have a group dimension (`g`). - In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`. + In mode 2, the shape-signature is `[b,m,k], [b,k,n], [b,g] -> [g,b,m,n]`. Here the ragged dimension is an lhs/rhs contracting dimension (`k`). In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here diff --git a/stablehlo/tests/ops_chlo.mlir b/stablehlo/tests/ops_chlo.mlir index b5a021043b..8f6a74f5b9 100644 --- a/stablehlo/tests/ops_chlo.mlir +++ b/stablehlo/tests/ops_chlo.mlir @@ -146,7 +146,7 @@ func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rh // ----- func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}} + // @expected-error@+1 {{expected group_sizes to have rank 1, got 2}} %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { ragged_dot_dimension_numbers = #chlo.ragged_dot< lhs_batching_dimensions = [], @@ -163,8 +163,79 @@ func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : // ----- -func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> { - // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}} +func.func @ragged_dot_mode1_group_sizes_broadcasted(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<19x17x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [2], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<3xi64>) -> tensor<19x17x11x7xf32> + func.return %0 : tensor<19x17x11x7xf32> +} + +// ----- + +func.func @ragged_dot_mode1_group_sizes_incorrect_shape(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32> { + // @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [2], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32> + func.return %0 : tensor<19x17x11x7xf32> +} + +// ----- + +func.func @ragged_dot_mode2_group_sizes_incorrect_shape(%lhs : tensor<19x11x17x5xf32>, %rhs : tensor<19x17x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32> { + // @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2,3], + rhs_contracting_dimensions = [1,2], + lhs_ragged_dimensions = [3], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<19x11x17x5xf32>, tensor<19x17x5x7xf32>, tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32> + func.return %0 : tensor<3x19x11x7xf32> +} + +// ----- + +func.func @ragged_dot_mode3_group_sizes_incorrect_shape(%lhs : tensor<17x19x11x5xf32>, %rhs : tensor<17x19x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<17x19x11x7xf32> { + // @expected-error@+1 {{group_sizes is expected to have shape [17, 3], got [19, 3]}} + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0,1], + rhs_batching_dimensions = [0,1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<17x19x11x5xf32>, tensor<17x19x5x7xf32>, tensor<19x3xi64>) -> tensor<17x19x11x7xf32> + func.return %0 : tensor<17x19x11x7xf32> +} + +// ----- + +func.func @ragged_dot_incorrect_group_dim_size(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> { + // @expected-error@+1 {{rhs group dimension is expected to have size=2, got 3}} %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { ragged_dot_dimension_numbers = #chlo.ragged_dot< lhs_batching_dimensions = [], diff --git a/stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir b/stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir index 31ed1edf93..84990dd46d 100644 --- a/stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir +++ b/stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir @@ -2675,3 +2675,73 @@ func.func @convolution_with_i8_result_element_type( -> tensor<128x26x26x128x!quant.uniform> return %0 : tensor<128x26x26x128x!quant.uniform> } + +// ----- + +// CHECK-LABEL: func.func public @conv2d_nchw +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x224x224xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x3x7x7xi8>) -> tensor<1x64x112x112xi8> { +func.func public @conv2d_nchw( + %arg0: tensor<1x3x224x224x!quant.uniform>, + %arg1: tensor<64x3x7x7x!quant.uniform> + ) -> tensor<1x64x112x112x!quant.uniform> { +// CHECK: stablehlo.convolution +// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] +// CHECK-SAME: window = { +// CHECK-SAME: stride = [2, 2] +// CHECK-SAME{LITERAL}: pad = [[0, 0], [0, 0]] +// CHECK-SAME: rhs_dilate = [1, 1] +// CHECK-SAME: } +// CHECK-SAME: { +// CHECK-SAME: batch_group_count = 1 : i64 +// CHECK-SAME: feature_group_count = 1 : i64 +// CHECK-SAME: } +// CHECK-SAME : (tensor<1x3x230x230xi8>, tensor<64x3x7x7xi8>) -> tensor<1x64x112x112xi32> + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], + window = { + stride = [2, 2], + pad = [[3, 3], [3, 3]], + rhs_dilate = [1, 1] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : ( + tensor<1x3x224x224x!quant.uniform>, + tensor<64x3x7x7x!quant.uniform> + ) -> tensor<1x64x112x112x!quant.uniform> + return %0 : tensor<1x64x112x112x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func.func @conv3d_ncdhw +func.func @conv3d_ncdhw( + %arg0: tensor<128x1x28x28x28x!quant.uniform>, + %arg1: tensor<128x1x3x3x3x!quant.uniform>) { +// CHECK: stablehlo.convolution +// CHECK-SAME: dim_numbers = [b, f, 0, 1, 2]x[o, i, 0, 1, 2]->[b, f, 0, 1, 2], +// CHECK-SAME: window = { +// CHECK-SAME{LITERAL}: stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], +// CHECK-SAME: lhs_dilate = [1, 1, 1], +// CHECK-SAME: rhs_dilate = [1, 1, 1] +// CHECK-SAME: } +// CHECK-SAME: { +// CHECK-SAME: batch_group_count = 1 : i64, +// CHECK-SAME: feature_group_count = 1 : i64 +// CHECK-SAME: } +// CHECK-SAME: : (tensor<128x1x28x28x28xf32>, tensor<128x1x3x3x3xf32>) -> tensor<128x128x26x26x26xf32> + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0, 1, 2]x[o, i, 0, 1, 2]->[b, f, 0, 1, 2], + window = { + stride = [1, 1, 1], pad = [[0, 0], [0, 0], [0, 0]], + lhs_dilate = [1, 1, 1], + rhs_dilate = [1, 1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x1x28x28x28x!quant.uniform>, tensor<128x1x3x3x3x!quant.uniform>) + -> tensor<128x128x26x26x26x!quant.uniform> + return +} diff --git a/stablehlo/transforms/PassUtils.cpp b/stablehlo/transforms/PassUtils.cpp index c54a4ddd7f..6ce7ad1d36 100644 --- a/stablehlo/transforms/PassUtils.cpp +++ b/stablehlo/transforms/PassUtils.cpp @@ -4,7 +4,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, +distributed under the License is distributed on an "AS IS" BASIS,G WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. @@ -12,6 +12,8 @@ limitations under the License. #include "stablehlo/transforms/PassUtils.h" +#include + #include "llvm/Support/ErrorHandling.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" @@ -65,5 +67,22 @@ bool isAnyQuantizedTypes(TypeRange types) { }); } +Type getQuantizedElementType(Location loc, Type storageType, Type expressedType, + ArrayRef scales, + ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + unsigned flags = + !storageType.isUnsignedInteger() ? quant::QuantizationFlags::Signed : 0; + if (quantizedDimension == -1) { + return quant::UniformQuantizedType::getChecked( + loc, flags, storageType, expressedType, scales[0], zeroPoints[0], + storageTypeMin, storageTypeMax); + } + return quant::UniformQuantizedPerAxisType::getChecked( + loc, flags, storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax); +} + } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/transforms/PassUtils.h b/stablehlo/transforms/PassUtils.h index 9bf6c62d76..b0bc4018e0 100644 --- a/stablehlo/transforms/PassUtils.h +++ b/stablehlo/transforms/PassUtils.h @@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, // Check if any of the given types are mlir::quant::QuantizedType. bool isAnyQuantizedTypes(TypeRange types); +// Creates a quantized element type based on the given parameters. +Type getQuantizedElementType(Location loc, Type storageType, Type expressedType, + ArrayRef scales, + ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax); + } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp b/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp index 35e3c15bd7..01e35f0ebd 100644 --- a/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp +++ b/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp @@ -766,6 +766,7 @@ Value calculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, Value outputDimsValue = nullptr; // Calculate LHS contribution when RHS zp is non-zero. if (rhsZp != 0) { + // Contributions from RHS_ZP * LHS. SmallVector reductionDims = to_vector(llvm::concat( dims.lhsSpatialDims, dims.lhsContractingDims)); Value lhsZpContribution = @@ -778,15 +779,23 @@ Value calculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, } // Calculate RHS contribution when LHS zp is non-zero. if (lhsZp != 0) { + // Contributions from LHS_ZP * RHS. SmallVector reductionDims = to_vector(llvm::concat( dims.rhsSpatialDims, dims.rhsContractingDims)); Value rhsZpContribution = createZeroPointPartialOffset(builder, loc, rhs, lhsZp, reductionDims); + + int64_t nonBatchingStartingIdx = + lhsShape.getRank() - dims.lhsContractingDims.size(); + if (!dims.lhsSpatialDims.empty() && !dims.rhsSpatialDims.empty()) { + // In case of convolution, identified via absence of spatial dims, the + // non-batching starting index is the lhs contracting dim. + nonBatchingStartingIdx = dims.lhsContractingDims[0]; + } // Broadcast rhs ZP contribution to result tensor shape. rhsZpContribution = broadcastZpContribution( builder, loc, rhsZpContribution, reductionDims, dims.rhsBatchingDims, - lhsShape.getRank() - dims.lhsContractingDims.size(), output, - outputTensorType, outputDimsValue); + nonBatchingStartingIdx, output, outputTensorType, outputDimsValue); if (result) { result = builder.create(loc, result, rhsZpContribution); } else { @@ -833,7 +842,8 @@ Value calculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, template Value createDotLikeKernel(OpBuilder &builder, Location loc, DotLikeOp, Type resultType, Value &lhs, Value &rhs, - ArrayRef attrs) { + ArrayRef attrs, + const DotLikeDimensionNumbers &dims) { return builder.create( loc, resultType, ArrayRef{lhs, rhs}, attrs); } @@ -843,7 +853,8 @@ Value createDotLikeKernel(OpBuilder &builder, Location loc, DotLikeOp, template <> Value createDotLikeKernel( OpBuilder &builder, Location loc, stablehlo::ConvolutionOp op, - Type resultType, Value &lhs, Value &rhs, ArrayRef attrs) { + Type resultType, Value &lhs, Value &rhs, ArrayRef attrs, + const DotLikeDimensionNumbers &dims) { // We only handle the case where RHS zp is zero. // Explicitly pad LHS with zp and update LHS value. SmallVector newAttrs(attrs); @@ -865,9 +876,9 @@ Value createDotLikeKernel( int64_t rank = cast(lhs.getType()).getRank(); SmallVector paddingLow(rank, 0), paddingHigh(rank, 0), paddingInterior(rank, 0); - for (int64_t i = 1; i < rank - 1; ++i) { - paddingLow[i] = originalPadding[i * 2 - 2]; - paddingHigh[i] = originalPadding[i * 2 - 1]; + for (size_t i = 0; i < dims.lhsSpatialDims.size(); ++i) { + paddingLow[dims.lhsSpatialDims[i]] = originalPadding[i * 2]; + paddingHigh[dims.lhsSpatialDims[i]] = originalPadding[i * 2 + 1]; } lhs = builder.create(loc, lhs, zp, paddingLow, paddingHigh, paddingInterior); @@ -897,6 +908,8 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, Value rhs = adaptor.getRhs(); auto resInt32TensorType = op.getResult().getType().clone(rewriter.getI32Type()); + auto resFloat32TensorType = + op.getResult().getType().clone(rewriter.getF32Type()); // Dot result // = dot((lhs - zp_l) * scale_l, (rhs - zp_r) * scale_r) / scale_res @@ -908,7 +921,7 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, // combined_zp = res_zp - zp_offset * combined_scale // zp_offset = zp_l*rhs + zp_r*lhs - zp_l*zp_r Value resI32 = createDotLikeKernel(rewriter, op->getLoc(), op, - resInt32TensorType, lhs, rhs, attrs); + resInt32TensorType, lhs, rhs, attrs, dims); auto lhsElementQuantType = getPerTensorType(op.getLhs().getType()); auto rhsElementQuantType = dyn_cast( @@ -945,8 +958,6 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, Value combinedScale = rewriter.create( op->getLoc(), rewriter.getF32FloatAttr(combinedScaleFp)); - auto resFloat32TensorType = - op.getResult().getType().clone(rewriter.getF32Type()); Value resF32 = rewriter.create( op->getLoc(), resFloat32TensorType, resI32); resF32 = rewriter.create( @@ -1089,6 +1100,7 @@ class ConvertUniformQuantizedDotGeneralOp }; bool isConvNhwc(const stablehlo::ConvDimensionNumbersAttr &dims) { + // lhs(b, 0, 1, f) x rhs(0, 1, i, o) -> res(b, 0, 1, f) return dims.getInputBatchDimension() == 0 && dims.getInputFeatureDimension() == 3 && dims.getInputSpatialDimensions().size() == 2 && @@ -1106,7 +1118,27 @@ bool isConvNhwc(const stablehlo::ConvDimensionNumbersAttr &dims) { dims.getOutputSpatialDimensions()[1] == 2; } +bool isConvNchw(const stablehlo::ConvDimensionNumbersAttr &dims) { + // lhs(b, f, 0, 1) x rhs(o, i, 0, 1) -> res(b, f, 0, 1) + return dims.getInputBatchDimension() == 0 && + dims.getInputFeatureDimension() == 1 && + dims.getInputSpatialDimensions().size() == 2 && + dims.getInputSpatialDimensions()[0] == 2 && + dims.getInputSpatialDimensions()[1] == 3 && + dims.getKernelInputFeatureDimension() == 1 && + dims.getKernelOutputFeatureDimension() == 0 && + dims.getKernelSpatialDimensions().size() == 2 && + dims.getKernelSpatialDimensions()[0] == 2 && + dims.getKernelSpatialDimensions()[1] == 3 && + dims.getOutputBatchDimension() == 0 && + dims.getOutputFeatureDimension() == 1 && + dims.getOutputSpatialDimensions().size() == 2 && + dims.getOutputSpatialDimensions()[0] == 2 && + dims.getOutputSpatialDimensions()[1] == 3; +} + bool isConvNDHWC(const stablehlo::ConvDimensionNumbersAttr &dims) { + // lhs(b, 0, 1, 2, f) x rhs(0, 1, 2, i, o) -> res(b, 0, 1, 2, f) return dims.getInputBatchDimension() == 0 && dims.getInputFeatureDimension() == 4 && dims.getInputSpatialDimensions().size() == 3 && @@ -1127,6 +1159,28 @@ bool isConvNDHWC(const stablehlo::ConvDimensionNumbersAttr &dims) { dims.getOutputSpatialDimensions()[2] == 3; } +bool isConvNCDHW(const stablehlo::ConvDimensionNumbersAttr &dims) { + // lhs(b, f, 0, 1, 2) x rhs(o, i, 0, 1, 2) -> res(b, f, 0, 1, 2) + return dims.getInputBatchDimension() == 0 && + dims.getInputFeatureDimension() == 1 && + dims.getInputSpatialDimensions().size() == 3 && + dims.getInputSpatialDimensions()[0] == 2 && + dims.getInputSpatialDimensions()[1] == 3 && + dims.getInputSpatialDimensions()[2] == 4 && + dims.getKernelInputFeatureDimension() == 0 && + dims.getKernelOutputFeatureDimension() == 1 && + dims.getKernelSpatialDimensions().size() == 3 && + dims.getKernelSpatialDimensions()[0] == 2 && + dims.getKernelSpatialDimensions()[1] == 3 && + dims.getKernelSpatialDimensions()[2] == 4 && + dims.getOutputBatchDimension() == 0 && + dims.getOutputFeatureDimension() == 1 && + dims.getOutputSpatialDimensions().size() == 3 && + dims.getOutputSpatialDimensions()[0] == 2 && + dims.getOutputSpatialDimensions()[1] == 3 && + dims.getOutputSpatialDimensions()[2] == 4; +} + FailureOr verifyAndConstructDims( stablehlo::ConvolutionOp op, ConversionPatternRewriter &rewriter) { // RHS (weight) must have zero zp. @@ -1185,7 +1239,8 @@ FailureOr verifyAndConstructDims( // We only support NHWC Conv2D and NDHWC Conv3D. auto dims = op.getDimensionNumbers(); if (isConvNhwc(dims)) { - // 2D Convolution. + // 2D Nhwc Convolution. + // lhs(b, 0, 1, f) x rhs(0, 1, i, o) -> res(b, 0, 1, f) return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, /*lhs_spatial_dims=*/{1, 2}, /*lhs_contracting_dims=*/{3}, @@ -1193,8 +1248,19 @@ FailureOr verifyAndConstructDims( /*rhs_spatial_dims=*/{0, 1}, /*rhs_contracting_dims=*/{2}}; } + if (isConvNchw(dims)) { + // 2D Nchw Convolution. + // lhs(b, f, 0, 1) x rhs(o, i, 0, 1) -> res(b, f, 0, 1) + return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, + /*lhs_spatial_dims=*/{2, 3}, + /*lhs_contracting_dims=*/{1}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{2, 3}, + /*rhs_contracting_dims=*/{1}}; + } if (isConvNDHWC(dims)) { - // 3D Convolution. + // 3D Ndhwc Convolution. + // lhs(b, 0, 1, 2, f) x rhs(0, 1, 2, i, o) -> res(b, 0, 1, 2, f) return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, /*lhs_spatial_dims=*/{1, 2, 3}, /*lhs_contracting_dims=*/{4}, @@ -1202,6 +1268,16 @@ FailureOr verifyAndConstructDims( /*rhs_spatial_dims=*/{0, 1, 2}, /*rhs_contracting_dims=*/{3}}; } + if (isConvNCDHW(dims)) { + // 3D Ncdhw Convolution. + // lhs(b, f, 0, 1, 2) x rhs(o, i, 0, 1, 2) -> res(b, f, 0, 1, 2) + return DotLikeDimensionNumbers{/*lhs_batching_dims=*/{0}, + /*lhs_spatial_dims=*/{2, 3, 4}, + /*lhs_contracting_dims=*/{1}, + /*rhs_batching_dims=*/{}, + /*rhs_spatial_dims=*/{2, 3, 4}, + /*rhs_contracting_dims=*/{1}}; + } return rewriter.notifyMatchFailure(op, "Convolution data format must be NHWC."); }