From 6fd479294a6b924b298f6b871ec00264c1bb2545 Mon Sep 17 00:00:00 2001 From: Chris Vasiladiotis Date: Sat, 25 Jan 2025 14:25:07 +0100 Subject: [PATCH] dialects: (math) Add support for custom format and container types to `math` dialect ops (#3786) This PR: - Adds support for custom format - Adds support for container types in the operands and results of the relevant ops (e.g., `tensor`) - Tests (including interoperability) of the above Resolves: #3739, #2739 --- .../dialects/math/math_ops_custom.mlir | 289 ++++++++ .../dialects/math/math_ops_custom.mlir | 289 ++++++++ xdsl/dialects/math.py | 657 +++++------------- 3 files changed, 769 insertions(+), 466 deletions(-) create mode 100644 tests/filecheck/dialects/math/math_ops_custom.mlir create mode 100644 tests/filecheck/mlir-conversion/with-mlir/dialects/math/math_ops_custom.mlir diff --git a/tests/filecheck/dialects/math/math_ops_custom.mlir b/tests/filecheck/dialects/math/math_ops_custom.mlir new file mode 100644 index 0000000000..0956fcd519 --- /dev/null +++ b/tests/filecheck/dialects/math/math_ops_custom.mlir @@ -0,0 +1,289 @@ +// RUN: XDSL_ROUNDTRIP + +%vali32 = "test.op"() : () -> i32 +%vali64 = "test.op"() : () -> i64 +%valf32 = "test.op"() : () -> f32 +%valf64 = "test.op"() : () -> f64 +%vec_vali64 = "test.op"() : () -> vector<4xi64> +%vec_valf64 = "test.op"() : () -> vector<4xf64> + +// CHECK: [[VALI32:%.*]] = "test.op"() : () -> i32 +// CHECK-NEXT: [[VALI64:%.*]] = "test.op"() : () -> i64 +// CHECK-NEXT: [[VALF32:%.*]] = "test.op"() : () -> f32 +// CHECK-NEXT: [[VALF64:%.*]] = "test.op"() : () -> f64 +// CHECK-NEXT: [[VEC_VALI64:%.*]] = "test.op"() : () -> vector<4xi64> +// CHECK-NEXT: [[VEC_VALF64:%.*]] = "test.op"() : () -> vector<4xf64> + +%rhsi32 = "test.op"() : () -> i32 +%rhsi64 = "test.op"() : () -> i64 +%rhsf32 = "test.op"() : () -> f32 +%rhsf64 = "test.op"() : () -> f64 +%vec_rhsi64 = "test.op"() : () -> vector<4xi64> +%vec_rhsf64 = "test.op"() : () -> vector<4xf64> + +// CHECK-NEXT: [[RHSI32:%.*]] = "test.op"() : () -> i32 +// CHECK-NEXT: [[RHSI64:%.*]] = "test.op"() : () -> i64 +// CHECK-NEXT: [[RHSF32:%.*]] = "test.op"() : () -> f32 +// CHECK-NEXT: [[RHSF64:%.*]] = "test.op"() : () -> f64 +// CHECK-NEXT: [[VEC_RHSI64:%.*]] = "test.op"() : () -> vector<4xi64> +// CHECK-NEXT: [[VEC_RHSF64:%.*]] = "test.op"() : () -> vector<4xf64> + +%absf0 = math.absf %valf32 : f32 +%absf1 = math.absf %valf64 : f64 +%vabsf1 = math.absf %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.absf [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.absf [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.absf [[VEC_VALF64]] : vector<4xf64> + +%absi0 = math.absi %vali32: i32 +%absi1 = math.absi %vali64: i64 +%vabsi1 = math.absi %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.absi [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.absi [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.absi [[VEC_VALI64]] : vector<4xi64> + +%atan2f0 = math.atan2 %valf32, %rhsf32 : f32 +%atan2f1 = math.atan2 %valf64, %rhsf64 : f64 +%vatan2f1 = math.atan2 %vec_valf64, %vec_valf64: vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.atan2 [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.atan2 [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.atan2 [[VEC_VALF64]], [[VEC_VALF64]] : vector<4xf64> + +%atanf0 = math.atan %valf32 : f32 +%atanf1 = math.atan %valf64 : f64 +%vatanf1 = math.atan %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.atan [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.atan [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.atan [[VEC_VALF64]] : vector<4xf64> + +%cbrtf0 = math.cbrt %valf32 : f32 +%cbrtf1 = math.cbrt %valf64 : f64 +%vcbrtf1 = math.cbrt %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.cbrt [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.cbrt [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.cbrt [[VEC_VALF64]] : vector<4xf64> + +%ceilf0 = math.ceil %valf32 : f32 +%ceilf1 = math.ceil %valf64 : f64 +%vceilf1 = math.ceil %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.ceil [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.ceil [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.ceil [[VEC_VALF64]] : vector<4xf64> + +%copysign0 = math.copysign %valf32, %rhsf32 : f32 +%copysign1 = math.copysign %valf64, %rhsf64 : f64 +%vcopysign1 = math.copysign %vec_valf64, %vec_rhsf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.copysign [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.copysign [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.copysign [[VEC_VALF64]], [[VEC_RHSF64]] : vector<4xf64> + +%cosf0 = math.cos %valf32 : f32 +%cosf1 = math.cos %valf64 : f64 +%vcosf1 = math.cos %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.cos [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.cos [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.cos [[VEC_VALF64]] : vector<4xf64> + +%ctlzi0 = math.ctlz %vali32 : i32 +%ctlzi1 = math.ctlz %vali64 : i64 +%vctlzi1 = math.ctlz %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.ctlz [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.ctlz [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.ctlz [[VEC_VALI64]] : vector<4xi64> + +%cttzi0 = math.cttz %vali32 : i32 +%cttzi1 = math.cttz %vali64 : i64 +%vcttzi1 = math.cttz %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.cttz [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.cttz [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.cttz [[VEC_VALI64]] : vector<4xi64> + +%ctpopi0 = math.ctpop %vali32 : i32 +%ctpopi1 = math.ctpop %vali64 : i64 +%vctpopi1 = math.ctpop %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.ctpop [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.ctpop [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.ctpop [[VEC_VALI64]] : vector<4xi64> + +%erff0 = math.erf %valf32 : f32 +%erff1 = math.erf %valf64 : f64 +%verff1 = math.erf %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.erf [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.erf [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.erf [[VEC_VALF64]] : vector<4xf64> + +%exp2f0 = math.exp2 %valf32 : f32 +%exp2f1 = math.exp2 %valf64 : f64 +%vexp2f1 = math.exp2 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.exp2 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.exp2 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.exp2 [[VEC_VALF64]] : vector<4xf64> + +%expm10 = math.expm1 %valf32 : f32 +%expm11 = math.expm1 %valf64 : f64 +%vexpm11 = math.expm1 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.expm1 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.expm1 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.expm1 [[VEC_VALF64]] : vector<4xf64> + +%exp0 = math.exp %valf32 : f32 +%exp1 = math.exp %valf64 : f64 +%vexp1 = math.exp %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.exp [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.exp [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.exp [[VEC_VALF64]] : vector<4xf64> + +%fpowi0 = math.fpowi %valf32, %vali32 : f32, i32 +%fpowi1 = math.fpowi %valf32, %vali64 : f32, i64 +%fpowi2 = math.fpowi %valf64, %vali32 : f64, i32 +%fpowi3 = math.fpowi %valf64, %vali64 : f64, i64 +%vfpowi3 = math.fpowi %vec_valf64, %vec_vali64 : vector<4xf64>, vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF32]], [[VALI32]] : f32, i32 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF32]], [[VALI64]] : f32, i64 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF64]], [[VALI32]] : f64, i32 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF64]], [[VALI64]] : f64, i64 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VEC_VALF64]], [[VEC_VALI64]] : vector<4xf64>, vector<4xi64> + +%floor0 = math.floor %valf32 : f32 +%floor1 = math.floor %valf64 : f64 +%vfloor1 = math.floor %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.floor [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.floor [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.floor [[VEC_VALF64]] : vector<4xf64> + +%fma0 = math.fma %valf32, %valf32, %rhsf32 : f32 +%fma1 = math.fma %valf64, %valf64, %rhsf64 : f64 +%vfma1 = math.fma %vec_valf64, %vec_valf64, %vec_rhsf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.fma [[VALF32]], [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.fma [[VALF64]], [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.fma [[VEC_VALF64]], [[VEC_VALF64]], [[VEC_RHSF64]] : vector<4xf64> + +%ipowi0 = math.ipowi %vali32, %rhsi32 : i32 +%ipowi1 = math.ipowi %vali64, %rhsi64 : i64 +%vipowi1 = math.ipowi %vec_vali64, %vec_rhsi64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.ipowi [[VALI32]], [[RHSI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.ipowi [[VALI64]], [[RHSI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.ipowi [[VEC_VALI64]], [[VEC_RHSI64]] : vector<4xi64> + +%log100 = math.log10 %valf32 : f32 +%log101 = math.log10 %valf64 : f64 +%vlog101 = math.log10 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log10 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log10 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log10 [[VEC_VALF64]] : vector<4xf64> + +%log1p0 = math.log1p %valf32 : f32 +%log1p1 = math.log1p %valf64 : f64 +%vlog1p1 = math.log1p %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log1p [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log1p [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log1p [[VEC_VALF64]] : vector<4xf64> + +%log20 = math.log2 %valf32 : f32 +%log21 = math.log2 %valf64 : f64 +%vlog21 = math.log2 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log2 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log2 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log2 [[VEC_VALF64]] : vector<4xf64> + +%log0 = math.log %valf32 : f32 +%log1 = math.log %valf64 : f64 +%vlog1 = math.log %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log [[VEC_VALF64]] : vector<4xf64> + +%powf0 = math.powf %valf32, %rhsf32 : f32 +%powf1 = math.powf %valf64, %rhsf64 : f64 +%vpowf1 = math.powf %vec_valf64, %vec_rhsf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.powf [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.powf [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.powf [[VEC_VALF64]], [[VEC_RHSF64]] : vector<4xf64> + +%roundeven0 = math.roundeven %valf32 : f32 +%roundeven1 = math.roundeven %valf64 : f64 +%vroundeven1 = math.roundeven %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.roundeven [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.roundeven [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.roundeven [[VEC_VALF64]] : vector<4xf64> + +%round0 = math.round %valf32 : f32 +%round1 = math.round %valf64 : f64 +%vround1 = math.round %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.round [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.round [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.round [[VEC_VALF64]] : vector<4xf64> + +%rsqrt0 = math.rsqrt %valf32 : f32 +%rsqrt1 = math.rsqrt %valf64 : f64 +%vrsqrt1 = math.rsqrt %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.rsqrt [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.rsqrt [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.rsqrt [[VEC_VALF64]] : vector<4xf64> + +%sin0 = math.sin %valf32 : f32 +%sin1 = math.sin %valf64 : f64 +%vsin1 = math.sin %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.sin [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.sin [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.sin [[VEC_VALF64]] : vector<4xf64> + +%sqrt0 = math.sqrt %valf32 : f32 +%sqrt1 = math.sqrt %valf64 : f64 +%vsqrt1 = math.sqrt %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.sqrt [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.sqrt [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.sqrt [[VEC_VALF64]] : vector<4xf64> + +%tan0 = math.tan %valf32 : f32 +%tan1 = math.tan %valf64 : f64 +%vtan1 = math.tan %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.tan [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.tan [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.tan [[VEC_VALF64]] : vector<4xf64> + +%tanh0 = math.tanh %valf32 : f32 +%tanh1 = math.tanh %valf64 : f64 +%vtanh1 = math.tanh %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.tanh [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.tanh [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.tanh [[VEC_VALF64]] : vector<4xf64> + +%trunc0 = math.trunc %valf32 : f32 +%trunc1 = math.trunc %valf64 : f64 +%vtrunc1 = math.trunc %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.trunc [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.trunc [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.trunc [[VEC_VALF64]] : vector<4xf64> diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/math/math_ops_custom.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/math/math_ops_custom.mlir new file mode 100644 index 0000000000..d2e00d2492 --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/math/math_ops_custom.mlir @@ -0,0 +1,289 @@ +// RUN: xdsl-opt %s | xdsl-opt | mlir-opt --allow-unregistered-dialect | filecheck %s + +%vali32 = "test.op"() : () -> i32 +%vali64 = "test.op"() : () -> i64 +%valf32 = "test.op"() : () -> f32 +%valf64 = "test.op"() : () -> f64 +%vec_vali64 = "test.op"() : () -> vector<4xi64> +%vec_valf64 = "test.op"() : () -> vector<4xf64> + +// CHECK: [[VALI32:%.*]] = "test.op"() : () -> i32 +// CHECK-NEXT: [[VALI64:%.*]] = "test.op"() : () -> i64 +// CHECK-NEXT: [[VALF32:%.*]] = "test.op"() : () -> f32 +// CHECK-NEXT: [[VALF64:%.*]] = "test.op"() : () -> f64 +// CHECK-NEXT: [[VEC_VALI64:%.*]] = "test.op"() : () -> vector<4xi64> +// CHECK-NEXT: [[VEC_VALF64:%.*]] = "test.op"() : () -> vector<4xf64> + +%rhsi32 = "test.op"() : () -> i32 +%rhsi64 = "test.op"() : () -> i64 +%rhsf32 = "test.op"() : () -> f32 +%rhsf64 = "test.op"() : () -> f64 +%vec_rhsi64 = "test.op"() : () -> vector<4xi64> +%vec_rhsf64 = "test.op"() : () -> vector<4xf64> + +// CHECK-NEXT: [[RHSI32:%.*]] = "test.op"() : () -> i32 +// CHECK-NEXT: [[RHSI64:%.*]] = "test.op"() : () -> i64 +// CHECK-NEXT: [[RHSF32:%.*]] = "test.op"() : () -> f32 +// CHECK-NEXT: [[RHSF64:%.*]] = "test.op"() : () -> f64 +// CHECK-NEXT: [[VEC_RHSI64:%.*]] = "test.op"() : () -> vector<4xi64> +// CHECK-NEXT: [[VEC_RHSF64:%.*]] = "test.op"() : () -> vector<4xf64> + +%absf0 = math.absf %valf32 : f32 +%absf1 = math.absf %valf64 : f64 +%vabsf1 = math.absf %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.absf [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.absf [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.absf [[VEC_VALF64]] : vector<4xf64> + +%absi0 = math.absi %vali32: i32 +%absi1 = math.absi %vali64: i64 +%vabsi1 = math.absi %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.absi [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.absi [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.absi [[VEC_VALI64]] : vector<4xi64> + +%atan2f0 = math.atan2 %valf32, %rhsf32 : f32 +%atan2f1 = math.atan2 %valf64, %rhsf64 : f64 +%vatan2f1 = math.atan2 %vec_valf64, %vec_valf64: vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.atan2 [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.atan2 [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.atan2 [[VEC_VALF64]], [[VEC_VALF64]] : vector<4xf64> + +%atanf0 = math.atan %valf32 : f32 +%atanf1 = math.atan %valf64 : f64 +%vatanf1 = math.atan %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.atan [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.atan [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.atan [[VEC_VALF64]] : vector<4xf64> + +%cbrtf0 = math.cbrt %valf32 : f32 +%cbrtf1 = math.cbrt %valf64 : f64 +%vcbrtf1 = math.cbrt %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.cbrt [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.cbrt [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.cbrt [[VEC_VALF64]] : vector<4xf64> + +%ceilf0 = math.ceil %valf32 : f32 +%ceilf1 = math.ceil %valf64 : f64 +%vceilf1 = math.ceil %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.ceil [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.ceil [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.ceil [[VEC_VALF64]] : vector<4xf64> + +%copysign0 = math.copysign %valf32, %rhsf32 : f32 +%copysign1 = math.copysign %valf64, %rhsf64 : f64 +%vcopysign1 = math.copysign %vec_valf64, %vec_rhsf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.copysign [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.copysign [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.copysign [[VEC_VALF64]], [[VEC_RHSF64]] : vector<4xf64> + +%cosf0 = math.cos %valf32 : f32 +%cosf1 = math.cos %valf64 : f64 +%vcosf1 = math.cos %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.cos [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.cos [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.cos [[VEC_VALF64]] : vector<4xf64> + +%ctlzi0 = math.ctlz %vali32 : i32 +%ctlzi1 = math.ctlz %vali64 : i64 +%vctlzi1 = math.ctlz %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.ctlz [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.ctlz [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.ctlz [[VEC_VALI64]] : vector<4xi64> + +%cttzi0 = math.cttz %vali32 : i32 +%cttzi1 = math.cttz %vali64 : i64 +%vcttzi1 = math.cttz %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.cttz [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.cttz [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.cttz [[VEC_VALI64]] : vector<4xi64> + +%ctpopi0 = math.ctpop %vali32 : i32 +%ctpopi1 = math.ctpop %vali64 : i64 +%vctpopi1 = math.ctpop %vec_vali64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.ctpop [[VALI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.ctpop [[VALI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.ctpop [[VEC_VALI64]] : vector<4xi64> + +%erff0 = math.erf %valf32 : f32 +%erff1 = math.erf %valf64 : f64 +%verff1 = math.erf %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.erf [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.erf [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.erf [[VEC_VALF64]] : vector<4xf64> + +%exp2f0 = math.exp2 %valf32 : f32 +%exp2f1 = math.exp2 %valf64 : f64 +%vexp2f1 = math.exp2 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.exp2 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.exp2 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.exp2 [[VEC_VALF64]] : vector<4xf64> + +%expm10 = math.expm1 %valf32 : f32 +%expm11 = math.expm1 %valf64 : f64 +%vexpm11 = math.expm1 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.expm1 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.expm1 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.expm1 [[VEC_VALF64]] : vector<4xf64> + +%exp0 = math.exp %valf32 : f32 +%exp1 = math.exp %valf64 : f64 +%vexp1 = math.exp %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.exp [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.exp [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.exp [[VEC_VALF64]] : vector<4xf64> + +%fpowi0 = math.fpowi %valf32, %vali32 : f32, i32 +%fpowi1 = math.fpowi %valf32, %vali64 : f32, i64 +%fpowi2 = math.fpowi %valf64, %vali32 : f64, i32 +%fpowi3 = math.fpowi %valf64, %vali64 : f64, i64 +%vfpowi3 = math.fpowi %vec_valf64, %vec_vali64 : vector<4xf64>, vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF32]], [[VALI32]] : f32, i32 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF32]], [[VALI64]] : f32, i64 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF64]], [[VALI32]] : f64, i32 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VALF64]], [[VALI64]] : f64, i64 +// CHECK-NEXT: {{%.*}} = math.fpowi [[VEC_VALF64]], [[VEC_VALI64]] : vector<4xf64>, vector<4xi64> + +%floor0 = math.floor %valf32 : f32 +%floor1 = math.floor %valf64 : f64 +%vfloor1 = math.floor %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.floor [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.floor [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.floor [[VEC_VALF64]] : vector<4xf64> + +%fma0 = math.fma %valf32, %valf32, %rhsf32 : f32 +%fma1 = math.fma %valf64, %valf64, %rhsf64 : f64 +%vfma1 = math.fma %vec_valf64, %vec_valf64, %vec_rhsf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.fma [[VALF32]], [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.fma [[VALF64]], [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.fma [[VEC_VALF64]], [[VEC_VALF64]], [[VEC_RHSF64]] : vector<4xf64> + +%ipowi0 = math.ipowi %vali32, %rhsi32 : i32 +%ipowi1 = math.ipowi %vali64, %rhsi64 : i64 +%vipowi1 = math.ipowi %vec_vali64, %vec_rhsi64 : vector<4xi64> + +// CHECK-NEXT: {{%.*}} = math.ipowi [[VALI32]], [[RHSI32]] : i32 +// CHECK-NEXT: {{%.*}} = math.ipowi [[VALI64]], [[RHSI64]] : i64 +// CHECK-NEXT: {{%.*}} = math.ipowi [[VEC_VALI64]], [[VEC_RHSI64]] : vector<4xi64> + +%log100 = math.log10 %valf32 : f32 +%log101 = math.log10 %valf64 : f64 +%vlog101 = math.log10 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log10 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log10 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log10 [[VEC_VALF64]] : vector<4xf64> + +%log1p0 = math.log1p %valf32 : f32 +%log1p1 = math.log1p %valf64 : f64 +%vlog1p1 = math.log1p %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log1p [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log1p [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log1p [[VEC_VALF64]] : vector<4xf64> + +%log20 = math.log2 %valf32 : f32 +%log21 = math.log2 %valf64 : f64 +%vlog21 = math.log2 %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log2 [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log2 [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log2 [[VEC_VALF64]] : vector<4xf64> + +%log0 = math.log %valf32 : f32 +%log1 = math.log %valf64 : f64 +%vlog1 = math.log %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.log [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.log [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.log [[VEC_VALF64]] : vector<4xf64> + +%powf0 = math.powf %valf32, %rhsf32 : f32 +%powf1 = math.powf %valf64, %rhsf64 : f64 +%vpowf1 = math.powf %vec_valf64, %vec_rhsf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.powf [[VALF32]], [[RHSF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.powf [[VALF64]], [[RHSF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.powf [[VEC_VALF64]], [[VEC_RHSF64]] : vector<4xf64> + +%roundeven0 = math.roundeven %valf32 : f32 +%roundeven1 = math.roundeven %valf64 : f64 +%vroundeven1 = math.roundeven %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.roundeven [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.roundeven [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.roundeven [[VEC_VALF64]] : vector<4xf64> + +%round0 = math.round %valf32 : f32 +%round1 = math.round %valf64 : f64 +%vround1 = math.round %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.round [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.round [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.round [[VEC_VALF64]] : vector<4xf64> + +%rsqrt0 = math.rsqrt %valf32 : f32 +%rsqrt1 = math.rsqrt %valf64 : f64 +%vrsqrt1 = math.rsqrt %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.rsqrt [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.rsqrt [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.rsqrt [[VEC_VALF64]] : vector<4xf64> + +%sin0 = math.sin %valf32 : f32 +%sin1 = math.sin %valf64 : f64 +%vsin1 = math.sin %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.sin [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.sin [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.sin [[VEC_VALF64]] : vector<4xf64> + +%sqrt0 = math.sqrt %valf32 : f32 +%sqrt1 = math.sqrt %valf64 : f64 +%vsqrt1 = math.sqrt %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.sqrt [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.sqrt [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.sqrt [[VEC_VALF64]] : vector<4xf64> + +%tan0 = math.tan %valf32 : f32 +%tan1 = math.tan %valf64 : f64 +%vtan1 = math.tan %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.tan [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.tan [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.tan [[VEC_VALF64]] : vector<4xf64> + +%tanh0 = math.tanh %valf32 : f32 +%tanh1 = math.tanh %valf64 : f64 +%vtanh1 = math.tanh %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.tanh [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.tanh [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.tanh [[VEC_VALF64]] : vector<4xf64> + +%trunc0 = math.trunc %valf32 : f32 +%trunc1 = math.trunc %valf64 : f64 +%vtrunc1 = math.trunc %vec_valf64 : vector<4xf64> + +// CHECK-NEXT: {{%.*}} = math.trunc [[VALF32]] : f32 +// CHECK-NEXT: {{%.*}} = math.trunc [[VALF64]] : f64 +// CHECK-NEXT: {{%.*}} = math.trunc [[VEC_VALF64]] : vector<4xf64> diff --git a/xdsl/dialects/math.py b/xdsl/dialects/math.py index a2c103da19..332f4d4996 100644 --- a/xdsl/dialects/math.py +++ b/xdsl/dialects/math.py @@ -8,11 +8,21 @@ from __future__ import annotations +import abc +from typing import ClassVar + from xdsl.dialects.arith import FastMathFlagsAttr -from xdsl.dialects.builtin import AnyFloatConstr, IntegerType +from xdsl.dialects.builtin import ( + AnyFloatConstr, + ContainerOf, + IndexType, + SignlessIntegerConstraint, +) from xdsl.ir import Dialect, Operation, SSAValue from xdsl.irdl import ( + AnyOf, IRDLOperation, + VarConstraint, irdl_op_definition, operand_def, opt_prop_def, @@ -21,9 +31,138 @@ ) from xdsl.traits import Pure, SameOperandsAndResultType +signlessIntegerLike = ContainerOf(AnyOf([SignlessIntegerConstraint, IndexType])) +floatingPointLike = ContainerOf(AnyFloatConstr) + + +class SignlessIntegerLikeUnaryMathOperation(IRDLOperation, abc.ABC): + """A generic signless integer-like unary math operation.""" + + T: ClassVar = VarConstraint("T", signlessIntegerLike) + + operand = operand_def(T) + result = result_def(T) + + assembly_format = "$operand attr-dict `:` type($result)" + + def __init__(self, operand: Operation | SSAValue): + operand = SSAValue.get(operand) + super().__init__( + operands=[operand], + result_types=[operand.type], + ) + + +class FloatingPointLikeUnaryMathOperation(IRDLOperation, abc.ABC): + """A generic floating-point-like unary math operation.""" + + T: ClassVar = VarConstraint("T", floatingPointLike) + + operand = operand_def(T) + result = result_def(T) + + assembly_format = "$operand attr-dict `:` type($result)" + + def __init__(self, operand: Operation | SSAValue): + operand = SSAValue.get(operand) + super().__init__( + operands=[operand], + result_types=[operand.type], + ) + + +class FloatingPointLikeUnaryMathOperationWithFastMath( + FloatingPointLikeUnaryMathOperation, abc.ABC +): + """A generic floating-point-like unary math operation with fastmath flags.""" + + fastmath = opt_prop_def(FastMathFlagsAttr) + + assembly_format = "$operand (`fastmath` `` $fastmath^)? attr-dict `:` type($result)" + + def __init__( + self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None + ): + operand = SSAValue.get(operand) + IRDLOperation.__init__( + self, + attributes={"fastmath": fastmath}, + operands=[operand], + result_types=[operand.type], + ) + + +class SignlessIntegerLikeBinaryMathOperation(IRDLOperation, abc.ABC): + """A generic signless integer-like binary math operation.""" + + T: ClassVar = VarConstraint("T", signlessIntegerLike) + + lhs = operand_def(T) + rhs = operand_def(T) + result = result_def(T) + + assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)" + + def __init__( + self, + lhs: Operation | SSAValue, + rhs: Operation | SSAValue, + ): + super().__init__( + operands=[lhs, rhs], + result_types=[SSAValue.get(lhs).type], + ) + + +class FloatingPointLikeBinaryMathOperation(IRDLOperation, abc.ABC): + """A generic floating-point-like binary math operation.""" + + T: ClassVar = VarConstraint("T", floatingPointLike) + + lhs = operand_def(T) + rhs = operand_def(T) + result = result_def(T) + + assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)" + + def __init__( + self, + lhs: Operation | SSAValue, + rhs: Operation | SSAValue, + ): + super().__init__( + operands=[lhs, rhs], + result_types=[SSAValue.get(lhs).type], + ) + + +class FloatingPointLikeBinaryMathOperationWithFastMath( + FloatingPointLikeBinaryMathOperation, abc.ABC +): + """A generic floating-point-like binary math operation with fastmath flags.""" + + fastmath = opt_prop_def(FastMathFlagsAttr) + + assembly_format = ( + "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result)" + ) + + def __init__( + self, + lhs: Operation | SSAValue, + rhs: Operation | SSAValue, + fastmath: FastMathFlagsAttr | None = None, + ): + IRDLOperation.__init__( + self, + attributes={"fastmath": fastmath}, + operands=[lhs, rhs], + result_types=[SSAValue.get(lhs).type], + ) + @irdl_op_definition -class AbsFOp(IRDLOperation): +class AbsFOp(FloatingPointLikeUnaryMathOperation): """ The absf operation computes the absolute value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result @@ -36,25 +175,12 @@ class AbsFOp(IRDLOperation): """ name = "math.absf" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - operand = SSAValue.get(operand) - super().__init__( - attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.type], - ) - @irdl_op_definition -class AbsIOp(IRDLOperation): +class AbsIOp(SignlessIntegerLikeUnaryMathOperation): """ The absi operation computes the absolute value. It takes one operand of integer type (i.e., scalar, tensor or vector) and returns one result of the @@ -67,22 +193,13 @@ class AbsIOp(IRDLOperation): """ name = "math.absi" - operand = operand_def(IntegerType) - result = result_def(IntegerType) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__(self, operand: Operation | SSAValue): - operand = SSAValue.get(operand) - super().__init__(operands=[operand], result_types=[operand.type]) - @irdl_op_definition -class Atan2Op(IRDLOperation): +class Atan2Op(FloatingPointLikeBinaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.atan2` ssa-use `,` ssa-use `:` type - The atan2 operation takes two operands and returns one result, all of which must be of the same type. The operands must be of floating point type (i.e., scalar, tensor or vector). @@ -101,33 +218,13 @@ class Atan2Op(IRDLOperation): """ name = "math.atan2" - fastmath = opt_prop_def(FastMathFlagsAttr) - lhs = operand_def(AnyFloatConstr) - rhs = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, - lhs: Operation | SSAValue, - rhs: Operation | SSAValue, - fastmath: FastMathFlagsAttr | None = None, - ): - attributes = {"fastmath": fastmath} - super().__init__( - attributes=attributes, - operands=[lhs, rhs], - result_types=[SSAValue.get(lhs).type], - ) - @irdl_op_definition -class AtanOp(IRDLOperation): +class AtanOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.atan` ssa-use `:` type - The atan operation computes the arcus tangent of a given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -139,25 +236,12 @@ class AtanOp(IRDLOperation): """ name = "math.atan" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - operand = SSAValue.get(operand) - super().__init__( - attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.type], - ) - @irdl_op_definition -class CbrtOp(IRDLOperation): +class CbrtOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ The cbrt operation computes the cube root. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result @@ -172,29 +256,13 @@ class CbrtOp(IRDLOperation): """ name = "math.cbrt" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class CeilOp(IRDLOperation): +class CeilOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.ceil` ssa-use `:` type - The ceil operation computes the ceiling of a given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -206,29 +274,13 @@ class CeilOp(IRDLOperation): """ name = "math.ceil" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - operand = SSAValue.get(operand) - super().__init__( - attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.type], - ) - @irdl_op_definition -class CopySignOp(IRDLOperation): +class CopySignOp(FloatingPointLikeBinaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.copysign` ssa-use `,` ssa-use `:` type - The copysign returns a value with the magnitude of the first operand and the sign of the second operand. It takes two operands and returns one result of the same type. The operands must be of floating point type (i.e., scalar, @@ -241,34 +293,13 @@ class CopySignOp(IRDLOperation): """ name = "math.copysign" - fastmath = opt_prop_def(FastMathFlagsAttr) - lhs = operand_def(AnyFloatConstr) - rhs = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, - lhs: Operation | SSAValue, - rhs: Operation | SSAValue, - fastmath: FastMathFlagsAttr | None = None, - ): - attributes = {"fastmath": fastmath} - - super().__init__( - attributes=attributes, - operands=[lhs, rhs], - result_types=[SSAValue.get(lhs).type], - ) - @irdl_op_definition -class CosOp(IRDLOperation): +class CosOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.cos` ssa-use `:` type - The `cos` operation computes the cosine of a given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -280,25 +311,12 @@ class CosOp(IRDLOperation): """ name = "math.cos" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class CountLeadingZerosOp(IRDLOperation): +class CountLeadingZerosOp(SignlessIntegerLikeUnaryMathOperation): """ The ctlz operation computes the number of leading zeros of an integer value. It operates on scalar, tensor or vector. @@ -310,18 +328,12 @@ class CountLeadingZerosOp(IRDLOperation): """ name = "math.ctlz" - operand = operand_def(IntegerType) - result = result_def(IntegerType) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__(self, operand: Operation | SSAValue): - operand = SSAValue.get(operand) - super().__init__(operands=[operand], result_types=[operand.type]) - @irdl_op_definition -class CountTrailingZerosOp(IRDLOperation): +class CountTrailingZerosOp(SignlessIntegerLikeUnaryMathOperation): """ The cttz operation computes the number of trailing zeros of an integer value. It operates on scalar, tensor or vector. @@ -333,18 +345,12 @@ class CountTrailingZerosOp(IRDLOperation): """ name = "math.cttz" - operand = operand_def(IntegerType) - result = result_def(IntegerType) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__(self, operand: Operation | SSAValue): - operand = SSAValue.get(operand) - super().__init__(operands=[operand], result_types=[operand.type]) - @irdl_op_definition -class CtPopOp(IRDLOperation): +class CtPopOp(SignlessIntegerLikeUnaryMathOperation): """ The ctpop operation computes the number of set bits of an integer value. It operates on scalar, tensor or vector. @@ -356,22 +362,13 @@ class CtPopOp(IRDLOperation): """ name = "math.ctpop" - operand = operand_def(IntegerType) - result = result_def(IntegerType) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__(self, operand: Operation | SSAValue): - operand = SSAValue.get(operand) - super().__init__(operands=[operand], result_types=[operand.type]) - @irdl_op_definition -class ErfOp(IRDLOperation): +class ErfOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.erf` ssa-use `:` type - The erf operation computes the error function. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -383,29 +380,13 @@ class ErfOp(IRDLOperation): """ name = "math.erf" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class Exp2Op(IRDLOperation): +class Exp2Op(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.exp2` ssa-use `:` type - The exp operation takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -417,29 +398,13 @@ class Exp2Op(IRDLOperation): """ name = "math.exp2" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class ExpM1Op(IRDLOperation): +class ExpM1Op(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.expm1` ssa-use `:` type - expm1(x) := exp(x) - 1 The expm1 operation takes one operand of floating point type (i.e., @@ -453,29 +418,13 @@ class ExpM1Op(IRDLOperation): """ name = "math.expm1" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class ExpOp(IRDLOperation): +class ExpOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.exp` ssa-use `:` type - The exp operation takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -487,36 +436,20 @@ class ExpOp(IRDLOperation): """ name = "math.exp" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition class FPowIOp(IRDLOperation): """ - Syntax: - operation ::= ssa-id `=` `math.fpowi` ssa-use `,` ssa-use `:` type - The fpowi operation takes a `base` operand of floating point type (i.e. scalar, tensor or vector) and a `power` operand of integer type (also scalar, tensor or vector) and returns one result of the same type as `base`. The result is `base` raised to the power of `power`. The operation is elementwise for non-scalars, e.g.: - %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32 + %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32> The result is a vector of: @@ -529,12 +462,17 @@ class FPowIOp(IRDLOperation): """ name = "math.fpowi" + + T: ClassVar = VarConstraint("T1", floatingPointLike) + fastmath = opt_prop_def(FastMathFlagsAttr) - lhs = operand_def(AnyFloatConstr) - rhs = operand_def(IntegerType) - result = result_def(AnyFloatConstr) + lhs = operand_def(T) + rhs = operand_def(signlessIntegerLike) + result = result_def(T) - traits = traits_def(Pure(), SameOperandsAndResultType()) + traits = traits_def(Pure()) + + assembly_format = "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($lhs) `,` type($rhs)" def __init__( self, @@ -543,7 +481,6 @@ def __init__( fastmath: FastMathFlagsAttr | None = None, ): attributes = {"fastmath": fastmath} - super().__init__( attributes=attributes, operands=[lhs, rhs], @@ -552,11 +489,8 @@ def __init__( @irdl_op_definition -class FloorOp(IRDLOperation): +class FloorOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.floor` ssa-use `:` type - The floor operation computes the floor of a given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -568,29 +502,13 @@ class FloorOp(IRDLOperation): """ name = "math.floor" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition class FmaOp(IRDLOperation): """ - Syntax: - operation ::= ssa-id `=` `math.fma` ssa-use `,` ssa-use `,` ssa-use `:` type - The fma operation takes three operands and returns one result, each of these is required to be the same type. Operands must be of floating point type (i.e., scalar, tensor or vector). @@ -606,15 +524,22 @@ class FmaOp(IRDLOperation): to the `llvm.fma.*` intrinsic. """ + T: ClassVar = VarConstraint("T", floatingPointLike) + name = "math.fma" + fastmath = opt_prop_def(FastMathFlagsAttr) - a = operand_def(AnyFloatConstr) - b = operand_def(AnyFloatConstr) - c = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) + a = operand_def(T) + b = operand_def(T) + c = operand_def(T) + result = result_def(T) traits = traits_def(Pure(), SameOperandsAndResultType()) + assembly_format = ( + "$a `,` $b `,` $c (`fastmath` `` $fastmath^)? attr-dict `:` type($result)" + ) + def __init__( self, a: Operation | SSAValue, @@ -632,11 +557,8 @@ def __init__( @irdl_op_definition -class IPowIOp(IRDLOperation): +class IPowIOp(SignlessIntegerLikeBinaryMathOperation): """ - Syntax: - operation ::= ssa-id `=` `math.ipowi` ssa-use `,` ssa-use `:` type - The ipowi operation takes two operands of integer type (i.e., scalar, tensor or vector) and returns one result of the same type. Operands must have the same type. @@ -647,20 +569,12 @@ class IPowIOp(IRDLOperation): """ name = "math.ipowi" - lhs = operand_def(IntegerType) - rhs = operand_def(IntegerType) - result = result_def(IntegerType) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__(self, lhs: Operation | SSAValue, rhs: Operation | SSAValue): - lhs = SSAValue.get(lhs) - rhs = SSAValue.get(rhs) - super().__init__(operands=[lhs, rhs], result_types=[lhs.type]) - @irdl_op_definition -class Log10Op(IRDLOperation): +class Log10Op(FloatingPointLikeUnaryMathOperationWithFastMath): """ Computes the base-10 logarithm of the given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of @@ -673,25 +587,12 @@ class Log10Op(IRDLOperation): """ name = "math.log10" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class Log1pOp(IRDLOperation): +class Log1pOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ Computes the base-e logarithm of one plus the given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one @@ -706,25 +607,12 @@ class Log1pOp(IRDLOperation): """ name = "math.log1p" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class Log2Op(IRDLOperation): +class Log2Op(FloatingPointLikeUnaryMathOperationWithFastMath): """ Computes the base-2 logarithm of the given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of @@ -737,25 +625,12 @@ class Log2Op(IRDLOperation): """ name = "math.log2" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class LogOp(IRDLOperation): +class LogOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ Computes the base-e logarithm of the given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of @@ -768,29 +643,13 @@ class LogOp(IRDLOperation): """ name = "math.log" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class PowFOp(IRDLOperation): +class PowFOp(FloatingPointLikeBinaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.powf` ssa-use `,` ssa-use `:` type - The powf operation takes two operands of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. Operands must have the same type. @@ -802,34 +661,13 @@ class PowFOp(IRDLOperation): """ name = "math.powf" - fastmath = opt_prop_def(FastMathFlagsAttr) - lhs = operand_def(AnyFloatConstr) - rhs = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, - lhs: Operation | SSAValue, - rhs: Operation | SSAValue, - fastmath: FastMathFlagsAttr | None = None, - ): - attributes = {"fastmath": fastmath} - - super().__init__( - attributes=attributes, - operands=[lhs, rhs], - result_types=[SSAValue.get(lhs).type], - ) - @irdl_op_definition -class RoundEvenOp(IRDLOperation): +class RoundEvenOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.roundeven` ssa-use `:` type - The roundeven operation returns the operand rounded to the nearest integer value in floating-point format. It takes one operand of floating point type (i.e., scalar, tensor or vector) and produces one result of the same type. The @@ -844,29 +682,13 @@ class RoundEvenOp(IRDLOperation): """ name = "math.roundeven" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class RoundOp(IRDLOperation): +class RoundOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.round` ssa-use `:` type - The round operation returns the operand rounded to the nearest integer value in floating-point format. It takes one operand of floating point type (i.e., scalar, tensor or vector) and produces one result of the same type. The @@ -881,25 +703,12 @@ class RoundOp(IRDLOperation): """ name = "math.round" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class RsqrtOp(IRDLOperation): +class RsqrtOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ The rsqrt operation computes the reciprocal of the square root. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns @@ -911,29 +720,13 @@ class RsqrtOp(IRDLOperation): """ name = "math.rsqrt" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class SinOp(IRDLOperation): +class SinOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.sin` ssa-use `:` type - The sin operation computes the sine of a given value. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of the same type. It has no standard attributes. @@ -945,25 +738,12 @@ class SinOp(IRDLOperation): """ name = "math.sin" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class SqrtOp(IRDLOperation): +class SqrtOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ The sqrt operation computes the square root. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one result of @@ -975,25 +755,12 @@ class SqrtOp(IRDLOperation): """ name = "math.sqrt" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class TanOp(IRDLOperation): +class TanOp(FloatingPointLikeUnaryMathOperation): """ The tan operation computes the tangent. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one @@ -1006,25 +773,12 @@ class TanOp(IRDLOperation): """ name = "math.tan" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class TanhOp(IRDLOperation): +class TanhOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ The tanh operation computes the hyperbolic tangent. It takes one operand of floating point type (i.e., scalar, tensor or vector) and returns one @@ -1037,29 +791,13 @@ class TanhOp(IRDLOperation): """ name = "math.tanh" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - @irdl_op_definition -class TruncOp(IRDLOperation): +class TruncOp(FloatingPointLikeUnaryMathOperationWithFastMath): """ - Syntax: - operation ::= ssa-id `=` `math.trunc` ssa-use `:` type - The trunc operation returns the operand rounded to the nearest integer value in floating-point format. It takes one operand of floating point type (i.e., scalar, tensor or vector) and produces one result of the same type. @@ -1073,22 +811,9 @@ class TruncOp(IRDLOperation): """ name = "math.trunc" - fastmath = opt_prop_def(FastMathFlagsAttr) - operand = operand_def(AnyFloatConstr) - result = result_def(AnyFloatConstr) traits = traits_def(Pure(), SameOperandsAndResultType()) - def __init__( - self, operand: Operation | SSAValue, fastmath: FastMathFlagsAttr | None = None - ): - attributes = {"fastmath": fastmath} - - operand = SSAValue.get(operand) - super().__init__( - attributes=attributes, operands=[operand], result_types=[operand.type] - ) - Math = Dialect( "math",