Skip to content

Commit

Permalink
added quantized type support in the tablegen specification of some ops
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jun 9, 2023
1 parent e57f8ce commit 4200771
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 27 deletions.
10 changes: 8 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,18 @@ def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
// Any floating-point or complex tensor types
def HLO_FpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>;

// Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex]>;
// Any floating-point, complex or quantized tensor types
def HLO_FpComplexOrQuantizedIntTensor : TensorOf<[HLO_Float, HLO_Complex, HLO_QuantizedInt]>;

// Any int, floating-point, complex or quantized tensor types
def HLO_IntFpOrComplexOrQuantizedIntTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex, HLO_QuantizedInt]>;

// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>;

// Any pred, int, floating-point or quantized tensor types
def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO_QuantizedInt]>;

//===----------------------------------------------------------------------===//
// HLO static shape type definitions.
//===----------------------------------------------------------------------===//
Expand Down
50 changes: 25 additions & 25 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
// Abs supports complex to real, so element type is not guaranteed to match.
def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>],
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex] /* abs_i1 */>,
TensorOf<[HLO_SInt, HLO_Float]>> {
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt] /* abs_i1 */>,
TensorOf<[HLO_SInt, HLO_Float, HLO_QuantizedInt]>> {
let summary = "Abs operation";
let description = [{
Performs element-wise abs operation on `operand` tensor and produces a
Expand All @@ -219,7 +219,7 @@ def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",

def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt",
[Pure, HLO_CompatibleOperandsAndResultType /*cbrt_c1*/],
HLO_FpOrComplexTensor /*cbrt_i1*/> { /*cbrt_c1*/
HLO_FpComplexOrQuantizedIntTensor /*cbrt_i1*/> { /*cbrt_c1*/
let summary = "Cbrt operation";
let description = [{
Performs element-wise cubic root operation on `operand` tensor and produces
Expand Down Expand Up @@ -289,7 +289,7 @@ def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros",
}

def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Cosine operation";
let description = [{
Performs element-wise cosine operation on `operand` tensor and produces a
Expand All @@ -307,7 +307,7 @@ def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",

def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
[Pure, HLO_CompatibleOperandsAndResultType /*exponential_c1*/],
HLO_FpOrComplexTensor /*exponential_i1*/> {
HLO_FpComplexOrQuantizedIntTensor /*exponential_i1*/> {
let summary = "Exp operation";
let description = [{
Performs element-wise exponential operation on `operand` tensor and produces
Expand All @@ -325,7 +325,7 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",

def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
[Pure, HLO_CompatibleOperandsAndResultType], /*exponential_minus_one_c1*/
HLO_FpOrComplexTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/
HLO_FpComplexOrQuantizedIntTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/
let summary = "Expm1 operation";
let description = [{
Performs element-wise exponential minus one operation on `operand` tensor
Expand Down Expand Up @@ -402,7 +402,7 @@ def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", [Pure,

def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",
[Pure, HLO_CompatibleOperandsAndResultType /*log_c1*/],
HLO_FpOrComplexTensor /*log_i1*/> {
HLO_FpComplexOrQuantizedIntTensor /*log_i1*/> {
let summary = "Log operation";
let description = [{
Performs element-wise logarithm operation on `operand` tensor and produces a
Expand All @@ -420,7 +420,7 @@ def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",

def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",
[Pure, HLO_CompatibleOperandsAndResultType /*log_plus_one_c1*/],
HLO_FpOrComplexTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/
HLO_FpComplexOrQuantizedIntTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/
let summary = "Log1p operation";
let description = [{
Performs element-wise logarithm plus one operation on `operand` tensor and
Expand All @@ -438,7 +438,7 @@ def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",

def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic",
[Pure, HLO_CompatibleOperandsAndResultType /*logistic_c1*/],
HLO_FpOrComplexTensor /*logistic_i1*/> { /*logistic_c1*/
HLO_FpComplexOrQuantizedIntTensor /*logistic_i1*/> { /*logistic_c1*/
let summary = "Logistic operation";
let description = [{
Performs element-wise logistic operation on `operand` tensor and produces a
Expand Down Expand Up @@ -472,7 +472,7 @@ def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not",
}

def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Neg operation";
let description = [{
Performs element-wise negation of `operand` tensor and produces a `result`
Expand Down Expand Up @@ -563,7 +563,7 @@ def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_ev

def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure,
HLO_CompatibleOperandsAndResultType /* rsqrt_c1 */],
HLO_FpOrComplexTensor /* rsqrt_i1 */> {
HLO_FpComplexOrQuantizedIntTensor /* rsqrt_i1 */> {
let summary = "Rsqrt operation";
let description = [{
Performs element-wise reciprocal square root operation on `operand` tensor
Expand All @@ -582,7 +582,7 @@ def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure,

def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
[Pure, HLO_CompatibleOperandsAndResultType /*sign_c1*/],
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]> /*sign_i1*/> { /*sign_c1*/
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt]> /*sign_i1*/> { /*sign_c1*/
let summary = "Sign operation";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
Expand All @@ -599,7 +599,7 @@ def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
}

def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Sine operation";
let description = [{
Performs element-wise sine operation on `operand` tensor and produces a
Expand All @@ -617,7 +617,7 @@ def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",

def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure,
HLO_CompatibleOperandsAndResultType /* sqrt_c1 */],
HLO_FpOrComplexTensor /* sqrt_i1 */> {
HLO_FpComplexOrQuantizedIntTensor /* sqrt_i1 */> {
let summary = "Sqrt operation";
let description = [{
Performs element-wise square root operation on `operand` tensor and produces
Expand All @@ -635,7 +635,7 @@ def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure,

def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh",
[Pure, HLO_CompatibleOperandsAndResultType],
HLO_FpOrComplexTensor> {
HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Tanh operation";
let description = [{
Performs element-wise hyperbolic tangent operation on `operand` tensor and
Expand Down Expand Up @@ -705,7 +705,7 @@ def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add",

def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2",
[Pure, HLO_CompatibleOperandsAndResultType /*atan2_c1*/],
HLO_FpOrComplexTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/
HLO_FpComplexOrQuantizedIntTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/
let summary = "Atan2 operation";
let description = [{
Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces
Expand Down Expand Up @@ -752,7 +752,7 @@ def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOp<"complex", [Pure,

def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide", [Pure,
HLO_CompatibleOperandsAndResultType /* div_c1 */],
HLO_IntFpOrComplexTensor /* div_i1, div_i2 */> {
HLO_IntFpOrComplexOrQuantizedIntTensor /* div_i1, div_i2 */> {
let summary = "Div operation";
let description = [{
Performs element-wise division of dividend `lhs` and divisor `rhs` tensors
Expand Down Expand Up @@ -821,7 +821,7 @@ def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply",

def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",
[Pure, HLO_CompatibleOperandsAndResultType /* pow_c1 */],
HLO_IntFpOrComplexTensor /* pow_i1, pow_i2 */> {
HLO_IntFpOrComplexOrQuantizedIntTensor /* pow_i1, pow_i2 */> {
let summary = "Power operation";
let description = [{
Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and
Expand All @@ -839,7 +839,7 @@ def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",

def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder",
[Pure, HLO_CompatibleOperandsAndResultType /*remainder_c1*/],
HLO_IntFpOrComplexTensor /*remainder_i1, remainder_i2*/> {
HLO_IntFpOrComplexOrQuantizedIntTensor /*remainder_i1, remainder_i2*/> {
let summary = "Rem operation";
let description = [{
Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors
Expand Down Expand Up @@ -910,7 +910,7 @@ def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_l
}

def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Subtract operation";
let description = [{
Performs element-wise subtraction of two tensors `lhs` and `rhs` and
Expand Down Expand Up @@ -1945,11 +1945,11 @@ def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
```
}];
let arguments = (ins
HLO_FpOrComplexTensor:$a,
HLO_FpComplexOrQuantizedIntTensor:$a,
DefaultValuedOptionalAttr<BoolAttr, "false">:$lower
);

let results = (outs HLO_FpOrComplexTensor:$result);
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);

let assemblyFormat = [{
$a (`,` `lower` `=` $lower^)? attr-dict `:` custom<SameOperandsAndResultType>(type($a), type($result))
Expand Down Expand Up @@ -2800,14 +2800,14 @@ def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve",
```
}];
let arguments = (ins
HLO_FpOrComplexTensor:$a,
HLO_FpOrComplexTensor:$b,
HLO_FpComplexOrQuantizedIntTensor:$a,
HLO_FpComplexOrQuantizedIntTensor:$b,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
StableHLO_TransposeAttr:$transpose_a
);
let results = (outs HLO_FpOrComplexTensor);
let results = (outs HLO_FpComplexOrQuantizedIntTensor);
}

def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [
Expand Down
34 changes: 34 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5431,6 +5431,40 @@ func.func @is_compatible_quant_signedness_mismatch(%arg0: tensor<1x!quant.unifor
func.return
}

// -----

// The following is the not the exhaustive list of ops supporting quantized
// types. The list will be updated as part of adding verification support for
// quantized ops.
func.func @quantization_supported_ops(%arg0: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%1 = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.power"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%3 = "stablehlo.remainder"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%4 = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>

%5 = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%6 = "stablehlo.cbrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%7 = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%8 = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%9 = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%10 = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%11 = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%12 = "stablehlo.logistic"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%13 = "stablehlo.negate"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%14 = "stablehlo.rsqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%15 = "stablehlo.sign"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%16 = "stablehlo.sine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%17 = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%18 = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>

%19 = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%20 = "stablehlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = true} : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>

func.return
}


// -----

// CHECK-LABEL: is_compatible_dynamism_bounds
Expand Down

0 comments on commit 4200771

Please sign in to comment.