Skip to content

Commit

Permalink
[ONNX] Add OnnxToTorch lowering for SpaceToDepth op (#3393)
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Jun 3, 2024
1 parent 285b087 commit 6382dbb
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 18 deletions.
10 changes: 10 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {

std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);

LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input, int64_t dimA,
int64_t dimB, Value &transposed);

LogicalResult createTorchPermuteOp(OpBinder binder,
ConversionPatternRewriter &rewriter,
Location loc, Value input,
SmallVector<int64_t> permuteDims,
Value &permuted);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
// Torch flags, user options, etc).
Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);

LogicalResult getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
17 changes: 0 additions & 17 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t dimA, int64_t dimB,
Value &transposed) {
Type transposedType;
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimB));
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, input, cstDimA, cstDimB);
return success();
}

namespace {
LogicalResult windowFunctionImpl(OpBinder binder,
ConversionPatternRewriter &rewriter,
Expand Down
98 changes: 98 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2952,4 +2952,102 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
patterns.onOp(
"SpaceToDepth", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t blockSize;
std::string mode;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(blockSize, "blocksize") ||
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
binder.tensorResultType(resultType))
return failure();
auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
}
SmallVector<int64_t> inputSizes{inputTy.getSizes()};
if (inputSizes.size() != 4) {
return rewriter.notifyMatchFailure(binder.op,
"Expected input rank to be 4");
}

Value b = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
Value c = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1)));
Value h = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
Value w = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(3)));
Value cstBlockSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize));
Value cstBlockSizeSquare = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize));
Value hDivBlockSize = rewriter.create<Torch::AtenDivIntOp>(
binder.getLoc(), h, cstBlockSize);
Value wDivBlockSize = rewriter.create<Torch::AtenDivIntOp>(
binder.getLoc(), w, cstBlockSize);
hDivBlockSize = rewriter.create<Torch::AtenIntFloatOp>(binder.getLoc(),
hDivBlockSize);
wDivBlockSize = rewriter.create<Torch::AtenIntFloatOp>(binder.getLoc(),
wDivBlockSize);

// The implementation is as follows:
// tmp = np.reshape(
// x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]
// )
// tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
// y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w //
// blocksize])
Value reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(input.getContext())),
llvm::SmallVector<Value>{b, c, hDivBlockSize, cstBlockSize,
wDivBlockSize, cstBlockSize});
int64_t hDivBlockSizeInt = inputSizes[2] == Torch::kUnknownSize
? Torch::kUnknownSize
: inputSizes[2] / blockSize;
int64_t wDivBlockSizeInt = inputSizes[3] == Torch::kUnknownSize
? Torch::kUnknownSize
: inputSizes[3] / blockSize;
SmallVector<int64_t, 6> reshapeSizesInt{inputSizes[0], inputSizes[1],
hDivBlockSizeInt, blockSize,
wDivBlockSizeInt, blockSize};
Value reshapedInput = rewriter.create<Torch::AtenReshapeOp>(
binder.getLoc(),
inputTy.getWithSizesAndDtype(reshapeSizesInt,
inputTy.getOptionalDtype()),
input, reshapeSizesList);

SmallVector<int64_t, 6> permuteDimsInt{0, 3, 5, 1, 2, 4};
Value permutedInput;
if (failed(createTorchPermuteOp(binder, rewriter, binder.getLoc(),
reshapedInput, permuteDimsInt,
permutedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create Torch Permute op");

Value cMulBlockSizeSquare = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), c, cstBlockSizeSquare);
reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(input.getContext())),
llvm::SmallVector<Value>{b, cMulBlockSizeSquare, hDivBlockSize,
wDivBlockSize});
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
binder.op, resultType, permutedInput, reshapeSizesList);
return success();
});
}
30 changes: 30 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,33 @@ mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {

return dtypeIntTorch;
}

LogicalResult mlir::torch::onnx_c::createTorchTransposeOp(
ConversionPatternRewriter &rewriter, Location loc, Value input,
int64_t dimA, int64_t dimB, Value &transposed) {
Type transposedType;
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimB));
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, input, cstDimA, cstDimB);
return success();
}

LogicalResult mlir::torch::onnx_c::createTorchPermuteOp(
OpBinder binder, ConversionPatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> permuteDims, Value &permuted) {
Type permutedType;
if (failed(
Torch::getPermutedType(cast<Torch::BaseTensorType>(input.getType()),
permuteDims, permutedType)))
return failure();
Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims);
permuted = rewriter.create<Torch::AtenPermuteOp>(loc, permutedType, input,
permuteDimsList);
return success();
}
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,24 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
return success();
}

LogicalResult Torch::getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType) {
if (!inType.hasSizes())
return failure();

SmallVector<int64_t> shape(inType.getSizes());
if (shape.size() != permuteDims.size())
return failure();

SmallVector<int64_t> permutedShape;
for (unsigned i = 0; i < shape.size(); i++)
permutedShape.push_back(shape[permuteDims[i]]);
permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape),
inType.getOptionalDtype());
return success();
}

Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
if (inputType.isF16())
return rewriter.getF32Type();
Expand Down
Loading

0 comments on commit 6382dbb

Please sign in to comment.