Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX][TORCH] Add Onnx->Linalg lowering for RotaryEmbedding Op #4002

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Collaborator

This commit adds the Onnx->Linalg lowering for Onnx's RotaryEmbedding op (ref: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftrotaryembedding) by registering a customized torch op named OnnxVariantAtenRotaryEmbeddingOp. This is done so that the Onnx's RotaryEmbedding op can be lowered to this op and this op can be lowered from Torch->Linalg.

The lowering has been adopted from the OnnxRuntime. Files for references:
1.) https://github.com/microsoft/onnxruntime/blob/e1e3f623f61816008e79dddc91a51ffe7f0ff5cf/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc#L47-L93
2.) https://github.com/microsoft/onnxruntime/blob/94c69f55d480cb4a8dcbc161d29ef3acca9392a7/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h

Signed-off-by: Vivek Khandelwal [email protected]

@AmosLewis
Copy link
Collaborator

We need a test in https://github.com/nod-ai/SHARK-TestSuite/tree/main/alt_e2eshark/onnx_tests/operators to verify the numeric before merge

@vivekkhandelwal1 vivekkhandelwal1 force-pushed the rotary-embedding branch 2 times, most recently from 16f397a to caa9622 Compare February 10, 2025 10:10
@vivekkhandelwal1
Copy link
Collaborator Author

We need a test in https://github.com/nod-ai/SHARK-TestSuite/tree/main/alt_e2eshark/onnx_tests/operators to verify the numeric before merge

Actually, the test in SHARK-Testsuite is not working since the op comes from "com.microsoft" domain. Alhtough, I have verified the e2e correctness of lowering by manually generating the IR and then compiling and executing it with the IREE.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small comments. I haven't double checked that the implementation is correct.

include/torch-mlir/Dialect/Torch/IR/TorchOps.td Outdated Show resolved Hide resolved
include/torch-mlir/Dialect/Torch/IR/TorchOps.td Outdated Show resolved Hide resolved
test/Conversion/TorchToLinalg/basic.mlir Show resolved Hide resolved
Add custom parser and printer for the op
Move the op lowering to a seperate code file for com.microsoft domain ops
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants