Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix BubbleDownVectorBitCastForExtract crash on non-static index #116518

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

lialan
Copy link
Contributor

@lialan lialan commented Nov 17, 2024

Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash.

This patch is to make sure it return gracefully instead of crashing.

@llvmbot
Copy link

llvmbot commented Nov 17, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash.

This patch is to make sure it return gracefully instead of crashing.


Full diff: https://github.com/llvm/llvm-project/pull/116518.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+7-3)
  • (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (+13)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f6b2303f86e10..3745bee98f3b85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -596,12 +596,16 @@ struct BubbleDownVectorBitCastForExtract
     unsigned expandRatio =
         castDstType.getNumElements() / castSrcType.getNumElements();
 
-    auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
-      assert(values[0].is<Attribute>() && "Unexpected non-constant index");
+    auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> std::optional<uint64_t> {
+      if (!values[0].is<Attribute>())
+        return std::nullopt;
       return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
     };
 
-    uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
+    std::optional<uint64_t> optIndex = getFirstIntValue(extractOp.getMixedPosition());
+    if (!optIndex)
+      return failure();
+    uint64_t index = *optIndex;
 
     // Get the single scalar (as a vector) in the source value that packs the
     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 89e8ca1d93109c..de12a87253a673 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
   %0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
   return %0 : vector<i32>
 }
+
+// Make sure not crash on dynamic index `vector.extract`:
+func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
+  %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
+  %1 = vector.extract %0[%index] : i16 from vector<8xi16>
+  return %1 : i16
+}
+
+// CHECK-LABEL: func.func @vector_extract_dynamic_index
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
+// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
+// CHECK: return %[[EXTRACT]]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented Nov 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! LGTM, please, allow one or two non-weekend days for other to review.

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp Outdated Show resolved Hide resolved
@lialan
Copy link
Contributor Author

lialan commented Nov 17, 2024

@hanhanW Turns out it is needed in turning on i1 attention support.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, just one nit about the comment!

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp Outdated Show resolved Hide resolved
Previously the patch was not expecting to handle non-static index, when
the index is a non constant value it will crash.

This patch is to make sure it return gracefully instead of crashing.
@hanhanW hanhanW merged commit 6626ed6 into llvm:main Nov 19, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants