-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Fix BubbleDownVectorBitCastForExtract
crash on non-static index
#116518
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: lialan (lialan) ChangesPreviously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash. This patch is to make sure it return gracefully instead of crashing. Full diff: https://github.com/llvm/llvm-project/pull/116518.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f6b2303f86e10..3745bee98f3b85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -596,12 +596,16 @@ struct BubbleDownVectorBitCastForExtract
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
- auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
- assert(values[0].is<Attribute>() && "Unexpected non-constant index");
+ auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> std::optional<uint64_t> {
+ if (!values[0].is<Attribute>())
+ return std::nullopt;
return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
};
- uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
+ std::optional<uint64_t> optIndex = getFirstIntValue(extractOp.getMixedPosition());
+ if (!optIndex)
+ return failure();
+ uint64_t index = *optIndex;
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 89e8ca1d93109c..de12a87253a673 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
%0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
return %0 : vector<i32>
}
+
+// Make sure not crash on dynamic index `vector.extract`:
+func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
+ %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
+ %1 = vector.extract %0[%index] : i16 from vector<8xi16>
+ return %1 : i16
+}
+
+// CHECK-LABEL: func.func @vector_extract_dynamic_index
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
+// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
+// CHECK: return %[[EXTRACT]]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
3d940cc
to
91342cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! LGTM, please, allow one or two non-weekend days for other to review.
91342cc
to
9b10ccf
Compare
@hanhanW Turns out it is needed in turning on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix, just one nit about the comment!
Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash. This patch is to make sure it return gracefully instead of crashing.
9b10ccf
to
001d08f
Compare
Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash.
This patch is to make sure it return gracefully instead of crashing.