Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the missing quantized-type support in the ODS #1608

Merged
merged 2 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,18 @@ def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
// Any floating-point or complex tensor types
def HLO_FpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>;

// Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex]>;
// Any floating-point, complex or quantized tensor types
def HLO_FpComplexOrQuantizedIntTensor : TensorOf<[HLO_Float, HLO_Complex, HLO_QuantizedInt]>;

// Any int, floating-point, complex or quantized tensor types
def HLO_IntFpOrComplexOrQuantizedIntTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex, HLO_QuantizedInt]>;

// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>;

// Any pred, int, floating-point or quantized tensor types
def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO_QuantizedInt]>;

//===----------------------------------------------------------------------===//
// HLO static shape type definitions.
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 20 additions & 20 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
// Abs supports complex to real, so element type is not guaranteed to match.
def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>],
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex] /* abs_i1 */>,
TensorOf<[HLO_SInt, HLO_Float]>> {
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt] /* abs_i1 */>,
TensorOf<[HLO_SInt, HLO_Float, HLO_QuantizedInt]>> {
let summary = "Abs operation";
let description = [{
Performs element-wise abs operation on `operand` tensor and produces a
Expand All @@ -219,7 +219,7 @@ def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",

def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt",
[Pure, HLO_CompatibleOperandsAndResultType /*cbrt_c1*/],
HLO_FpOrComplexTensor /*cbrt_i1*/> { /*cbrt_c1*/
HLO_FpComplexOrQuantizedIntTensor /*cbrt_i1*/> { /*cbrt_c1*/
let summary = "Cbrt operation";
let description = [{
Performs element-wise cubic root operation on `operand` tensor and produces
Expand Down Expand Up @@ -289,7 +289,7 @@ def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros",
}

def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Cosine operation";
let description = [{
Performs element-wise cosine operation on `operand` tensor and produces a
Expand All @@ -307,7 +307,7 @@ def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",

def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
[Pure, HLO_CompatibleOperandsAndResultType /*exponential_c1*/],
HLO_FpOrComplexTensor /*exponential_i1*/> {
HLO_FpComplexOrQuantizedIntTensor /*exponential_i1*/> {
let summary = "Exp operation";
let description = [{
Performs element-wise exponential operation on `operand` tensor and produces
Expand All @@ -325,7 +325,7 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",

def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
[Pure, HLO_CompatibleOperandsAndResultType], /*exponential_minus_one_c1*/
HLO_FpOrComplexTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/
HLO_FpComplexOrQuantizedIntTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/
let summary = "Expm1 operation";
let description = [{
Performs element-wise exponential minus one operation on `operand` tensor
Expand Down Expand Up @@ -402,7 +402,7 @@ def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", [Pure,

def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",
[Pure, HLO_CompatibleOperandsAndResultType /*log_c1*/],
HLO_FpOrComplexTensor /*log_i1*/> {
HLO_FpComplexOrQuantizedIntTensor /*log_i1*/> {
let summary = "Log operation";
let description = [{
Performs element-wise logarithm operation on `operand` tensor and produces a
Expand All @@ -420,7 +420,7 @@ def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",

def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",
[Pure, HLO_CompatibleOperandsAndResultType /*log_plus_one_c1*/],
HLO_FpOrComplexTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/
HLO_FpComplexOrQuantizedIntTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/
let summary = "Log1p operation";
let description = [{
Performs element-wise logarithm plus one operation on `operand` tensor and
Expand All @@ -438,7 +438,7 @@ def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",

def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic",
[Pure, HLO_CompatibleOperandsAndResultType /*logistic_c1*/],
HLO_FpOrComplexTensor /*logistic_i1*/> { /*logistic_c1*/
HLO_FpComplexOrQuantizedIntTensor /*logistic_i1*/> { /*logistic_c1*/
let summary = "Logistic operation";
let description = [{
Performs element-wise logistic operation on `operand` tensor and produces a
Expand Down Expand Up @@ -472,7 +472,7 @@ def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not",
}

def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Neg operation";
let description = [{
Performs element-wise negation of `operand` tensor and produces a `result`
Expand Down Expand Up @@ -563,7 +563,7 @@ def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_ev

def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure,
HLO_CompatibleOperandsAndResultType /* rsqrt_c1 */],
HLO_FpOrComplexTensor /* rsqrt_i1 */> {
HLO_FpComplexOrQuantizedIntTensor /* rsqrt_i1 */> {
let summary = "Rsqrt operation";
let description = [{
Performs element-wise reciprocal square root operation on `operand` tensor
Expand All @@ -582,7 +582,7 @@ def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure,

def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
[Pure, HLO_CompatibleOperandsAndResultType /*sign_c1*/],
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]> /*sign_i1*/> { /*sign_c1*/
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt]> /*sign_i1*/> { /*sign_c1*/
let summary = "Sign operation";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
Expand All @@ -599,7 +599,7 @@ def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
}

def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Sine operation";
let description = [{
Performs element-wise sine operation on `operand` tensor and produces a
Expand All @@ -617,7 +617,7 @@ def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",

def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure,
HLO_CompatibleOperandsAndResultType /* sqrt_c1 */],
HLO_FpOrComplexTensor /* sqrt_i1 */> {
HLO_FpComplexOrQuantizedIntTensor /* sqrt_i1 */> {
let summary = "Sqrt operation";
let description = [{
Performs element-wise square root operation on `operand` tensor and produces
Expand All @@ -635,7 +635,7 @@ def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure,

def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh",
[Pure, HLO_CompatibleOperandsAndResultType],
HLO_FpOrComplexTensor> {
HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Tanh operation";
let description = [{
Performs element-wise hyperbolic tangent operation on `operand` tensor and
Expand Down Expand Up @@ -705,7 +705,7 @@ def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add",

def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2",
[Pure, HLO_CompatibleOperandsAndResultType /*atan2_c1*/],
HLO_FpOrComplexTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/
HLO_FpComplexOrQuantizedIntTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/
let summary = "Atan2 operation";
let description = [{
Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces
Expand Down Expand Up @@ -752,7 +752,7 @@ def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOp<"complex", [Pure,

def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide", [Pure,
HLO_CompatibleOperandsAndResultType /* div_c1 */],
HLO_IntFpOrComplexTensor /* div_i1, div_i2 */> {
HLO_IntFpOrComplexOrQuantizedIntTensor /* div_i1, div_i2 */> {
let summary = "Div operation";
let description = [{
Performs element-wise division of dividend `lhs` and divisor `rhs` tensors
Expand Down Expand Up @@ -821,7 +821,7 @@ def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply",

def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",
[Pure, HLO_CompatibleOperandsAndResultType /* pow_c1 */],
HLO_IntFpOrComplexTensor /* pow_i1, pow_i2 */> {
HLO_IntFpOrComplexOrQuantizedIntTensor /* pow_i1, pow_i2 */> {
let summary = "Power operation";
let description = [{
Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and
Expand All @@ -839,7 +839,7 @@ def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",

def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder",
[Pure, HLO_CompatibleOperandsAndResultType /*remainder_c1*/],
HLO_IntFpOrComplexTensor /*remainder_i1, remainder_i2*/> {
HLO_IntFpOrComplexOrQuantizedIntTensor /*remainder_i1, remainder_i2*/> {
let summary = "Rem operation";
let description = [{
Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors
Expand Down Expand Up @@ -910,7 +910,7 @@ def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_l
}

def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Subtract operation";
let description = [{
Performs element-wise subtraction of two tensors `lhs` and `rhs` and
Expand Down
30 changes: 30 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5431,6 +5431,36 @@ func.func @is_compatible_quant_signedness_mismatch(%arg0: tensor<1x!quant.unifor
func.return
}

// -----

// The following is the not the exhaustive list of ops supporting quantized
// types. The list will be updated as part of adding verification support for
// quantized ops.
func.func @quantization_supported_ops(%arg0: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%1 = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.power"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%3 = "stablehlo.remainder"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%4 = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>

%5 = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%6 = "stablehlo.cbrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%7 = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%8 = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%9 = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%10 = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%11 = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%12 = "stablehlo.logistic"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%13 = "stablehlo.negate"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%14 = "stablehlo.rsqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%15 = "stablehlo.sign"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%16 = "stablehlo.sine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%17 = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%18 = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
func.return
}


// -----

// CHECK-LABEL: is_compatible_dynamism_bounds
Expand Down