Skip to content

Commit

Permalink
[torch-mlir] Support lowering of aten constraint ops (#3943)
Browse files Browse the repository at this point in the history
1. aten::sym_constrain_range
2. aten::sym_constrain_range_for_size
3. aten::_assert_scalar
  • Loading branch information
praveen-g-ctt authored Feb 5, 2025
1 parent 25aa0c6 commit fd65a66
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 1 deletion.
71 changes: 71 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17771,6 +17771,77 @@ def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_
}];
}

def Torch_AtenSymConstrainRangeOp : Torch_Op<"aten.sym_constrain_range", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sym_constrain_range : (Scalar, int?, int?) -> ()`";
let arguments = (ins
AnyTorchScalarType:$size,
AnyTorchOptionalIntType:$min,
AnyTorchOptionalIntType:$max
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSymConstrainRangeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 0);
}
void AtenSymConstrainRangeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 0);
}
}];
}

def Torch_AtenSymConstrainRangeForSizeOp : Torch_Op<"aten.sym_constrain_range_for_size", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()`";
let arguments = (ins
AnyTorchScalarType:$size,
AnyTorchOptionalIntType:$min,
AnyTorchOptionalIntType:$max
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSymConstrainRangeForSizeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 0);
}
void AtenSymConstrainRangeForSizeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 0);
}
}];
}

def Torch_Aten_AssertScalarOp : Torch_Op<"aten._assert_scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_assert_scalar : (Scalar, str) -> ()`";
let arguments = (ins
AnyTorchScalarType:$self,
Torch_StringType:$assert_msg
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AssertScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 0);
}
void Aten_AssertScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 0);
}
}];
}

def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
66 changes: 66 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
#include <string>
#include <type_traits>

using namespace mlir;
Expand Down Expand Up @@ -3564,6 +3566,68 @@ class ConvertAtenPolarOp : public OpConversionPattern<AtenPolarOp> {
};
} // namespace

namespace {
class ConvertSymConstrainRangeOp
: public OpConversionPattern<AtenSymConstrainRangeOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

auto loc = op.getLoc();
auto min = op.getMin();
auto max = op.getMax();

int64_t minValue = std::numeric_limits<int64_t>::min();
int64_t maxValue = std::numeric_limits<int64_t>::max();

Type operandType = getTypeConverter()->convertType(op.getSize().getType());

if (!isa<Torch::NoneType>(min.getType()))
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
return rewriter.notifyMatchFailure(
op, "Expected min value to be constant integer");

if (!isa<Torch::NoneType>(max.getType()))
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
return rewriter.notifyMatchFailure(
op, "Expected max value to be constant integer");

if (maxValue < minValue) {
std::string errorMsg =
"Max must be greater than or equal to min, got min = " +
std::to_string(minValue) + ", max = " + std::to_string(maxValue);
return op.emitError(errorMsg);
}

min = getConstant(rewriter, loc, minValue, operandType);
max = getConstant(rewriter, loc, maxValue, operandType);

// Check min <= size <= max

// FIXME:: Skip the below checks if constraint ops are already inserted as
// part of symbol expr evaluation
auto checkMin = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, min, adaptor.getSize());
auto checkMax = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, adaptor.getSize(), max);
auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax);

std::string assertMessage = "Size constraint failed. Expected range: [" +
std::to_string(minValue) + ", " +
std::to_string(maxValue) + "]";
rewriter.create<cf::AssertOp>(loc, compareVal,
rewriter.getStringAttr(assertMessage));

rewriter.eraseOp(op);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -3626,4 +3690,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
target.addIllegalOp<AtenPolarOp>();
patterns.add<ConvertAtenPolarOp>(typeConverter, context);
target.addIllegalOp<AtenSymConstrainRangeOp>();
patterns.add<ConvertSymConstrainRangeOp>(typeConverter, context);
}
78 changes: 78 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11455,6 +11455,80 @@ class DecomposeAtenSpecialExpm1Op
};
} // namespace

namespace {
class DecomposeAtenConstrainRangeForSizeOp
: public OpRewritePattern<AtenSymConstrainRangeForSizeOp> {
public:
using OpRewritePattern<AtenSymConstrainRangeForSizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSymConstrainRangeForSizeOp op,
PatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto min = op.getMin();
auto max = op.getMax();

int64_t minValue, maxValue;

if (isa<Torch::NoneType>(min.getType())) {
// Set min value to 0
min = rewriter.create<Torch::ConstantIntOp>(loc, 0);
} else {
// Check if min value is a constant
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
return rewriter.notifyMatchFailure(
op, "Expected min value to be constant integer");
}

if (!isa<Torch::NoneType>(max.getType())) {
// Verify that max value is greater than 2
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
return rewriter.notifyMatchFailure(
op, "Expected max value to be constant integer");

if (maxValue <= 2) {
std::string errorMsg = "Max value to constrain_range_for_size must be "
"greater than 2, got: " +
std::to_string(maxValue);
return op.emitError(errorMsg);
}
}

rewriter.replaceOpWithNewOp<AtenSymConstrainRangeOp>(op, op.getSize(), min,
max);
return success();
}
};
} // namespace

namespace {
class DecomposeAten_AssertScalarOp
: public OpRewritePattern<Aten_AssertScalarOp> {
public:
using OpRewritePattern<Aten_AssertScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_AssertScalarOp op,
PatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto assertCond = op.getSelf();

if (isa<Torch::IntType>(assertCond.getType()))
assertCond = rewriter.create<AtenBoolIntOp>(loc, assertCond);
else if (isa<Torch::FloatType>(assertCond.getType()))
assertCond = rewriter.create<AtenBoolFloatOp>(loc, assertCond);
assert(isa<Torch::BoolType>(assertCond.getType()) &&
"Unhandled type encountered in aten._assert_scalar op");

std::string assertMessage;
if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage)))
return rewriter.notifyMatchFailure(
op, "Assert message must be a constant string");

rewriter.replaceOpWithNewOp<RuntimeAssertOp>(op, assertCond, assertMessage);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -11753,6 +11827,10 @@ class DecomposeComplexOpsPass
// Torchvision ops
addPatternIfTargetOpIsIllegal<DecomposeTorchvisionNmsOp>(patterns);

addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
12 changes: 11 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"Aten_TrilinearModuleZerodDimBug_basic",
# missing lowering from aten.pow.Tensor_Tensor for integer result
"PowIntIntModule_basic",
# Unknown builtin op: aten::_check_is_size in TorchScript
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

if torch_version_for_comparison() < version.parse("2.5.0.dev"):
Expand Down Expand Up @@ -623,7 +627,6 @@
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"AtenNonzero1DDynamicModule_basic",
"AtenRealView128Module_basic",
"AtenRealView64Module_basic",
"AtenTopKModule_basic",
Expand Down Expand Up @@ -941,6 +944,9 @@
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand All @@ -964,6 +970,7 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"AtenNonzero1DDynamicModule_basic", # error: Mismatched ranks of types2 vs 1
}

STABLEHLO_PASS_SET = {
Expand Down Expand Up @@ -3254,6 +3261,9 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"ScaledDotProductAttentionGQAModule_basic",
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,11 @@ def emit_with_mutating_variants(key, **kwargs):
)
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")

# Constraint ops
emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()")
emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()")
emit("aten::_assert_scalar : (Scalar, str) -> ()")

# ==========================================================================
# `prim::` namespace.
# ==========================================================================
Expand Down
59 changes: 59 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6480,3 +6480,62 @@ def forward(self, x):
@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule())
def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))


# ==============================================================================


class AtenSymConstrainRange(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
torch.ops.aten.sym_constrain_range(a, max=5)
return a


@register_test_case(module_factory=lambda: AtenSymConstrainRange())
def AtenSymConstrainRange_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))


# ==============================================================================


class AtenSymConstrainRangeForSize(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10)
return a


@register_test_case(module_factory=lambda: AtenSymConstrainRangeForSize())
def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))


# ==============================================================================
class Aten_AssertScalar(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
assert_msg = "Assertion failed for condition x.item() > 3"
torch.ops.aten._assert_scalar(a > 3, assert_msg)
return a


@register_test_case(module_factory=lambda: Aten_AssertScalar())
def Aten_AssertScalar_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))
Loading

0 comments on commit fd65a66

Please sign in to comment.