Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX][TORCH] Add Onnx->Linalg lowering for RotaryEmbedding Op #4002

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ class OnnxCustomOpConversionPattern

// Patterns are split into chunks to speed compile time and reduce some
// contention on the same source files.
void populateComMicrosoftDomain(OnnxCustomOpConversionPattern &patterns);
void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns);
void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns);
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns);
Expand Down
32 changes: 32 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1410,4 +1410,36 @@ def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> {
let hasVerifier = 1;
}

// This op is corresponding to the Onnx's RotaryEmbedding operator.
// Ref: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftrotaryembedding
def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "`rotary_embedding op : (Tensor, Tensor, Tensor, Tensor, int, int, int, int, float) -> (Tensor)`";
let description = [{
The `torch.onnx.rotary_embedding` operation is an op which is used
specifically for supporting the Onnx's Rotary Embedding op. The
reason for this is that the Onnx ops can't be directly lowered to
Linalg and we have to map them to a legal Torch Dialect op, hence
this op is used for that purpose.
}];
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$position_ids,
AnyTorchTensorType:$cos_cache,
AnyTorchTensorType:$sin_cache,
Torch_IntType:$interleaved,
Torch_IntType:$is_packed_batching,
Torch_IntType:$num_heads,
Torch_IntType:$rotary_embedding_dim,
Torch_FloatType:$scale
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
}

#endif // TORCH_OPS
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
ComMicrosoftDomain.cpp
DefaultDomainAtoF.cpp
DefaultDomainGtoP.cpp
DefaultDomainQtoZ.cpp
Expand Down
66 changes: 66 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

void mlir::torch::onnx_c::populateComMicrosoftDomain(
OnnxCustomOpConversionPattern &patterns) {
patterns.onOp(
"RotaryEmbedding", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
int64_t interleaved, isPackedBatching, numHeads, rotaryEmbeddingDim;
float scale;
Value input, positionIds, cosCache, sinCache;
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(positionIds, 1) ||
binder.tensorOperandAtIndex(cosCache, 2) ||
binder.tensorOperandAtIndex(sinCache, 3) ||
binder.s64IntegerAttr(interleaved, "interleaved", 0) ||
binder.s64IntegerAttr(isPackedBatching, "is_packed_batching", 0) ||
binder.s64IntegerAttr(numHeads, "num_heads", 0) ||
binder.s64IntegerAttr(rotaryEmbeddingDim, "rotary_embedding_dim",
0) ||
binder.f32FloatAttr(scale, "scale", 1.0)) {
return rewriter.notifyMatchFailure(binder.op,
"Failed to get required inputs");
}

Torch::ValueTensorType resultType;
if (binder.tensorResultType(resultType)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}

Value cstInterleaved = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(interleaved));
Value cstIsPackedBatching = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(isPackedBatching));
Value cstNumHeads = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(numHeads));
Value cstRotaryEmbeddingDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rotaryEmbeddingDim));
Value cstScale = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(scale));

rewriter.replaceOpWithNewOp<Torch::OnnxVariantRotaryEmbeddingOp>(
binder.op, resultType, input, positionIds, cosCache, sinCache,
cstInterleaved, cstIsPackedBatching, cstNumHeads,
cstRotaryEmbeddingDim, cstScale);
return success();
});
}
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ConvertTorchOnnxToTorch
std::make_unique<OnnxCustomOpConversionPattern>(
context, "onnx.",
/*domainVersion=*/defaultOpsetVersion);
populateComMicrosoftDomain(*defaultDomainPatterns);
populateDefaultDomainAtoF(*defaultDomainPatterns);
populateDefaultDomainGtoP(*defaultDomainPatterns);
populateDefaultDomainQtoZ(*defaultDomainPatterns);
Expand Down
Loading
Loading