diff --git a/tests/dialects/test_vector.py b/tests/dialects/test_vector.py index faea6e9bd6..097cee4da2 100644 --- a/tests/dialects/test_vector.py +++ b/tests/dialects/test_vector.py @@ -1,8 +1,11 @@ import pytest from xdsl.dialects.builtin import ( + AffineMapAttr, + ArrayAttr, IndexType, IntAttr, + IntegerAttr, MemRefType, VectorType, i1, @@ -20,8 +23,11 @@ MaskedstoreOp, PrintOp, StoreOp, + TransferReadOp, + TransferWriteOp, ) from xdsl.ir import Attribute, OpResult +from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue @@ -653,3 +659,64 @@ def test_vector_insert_element_0d_verify_empty_position(): match="Expected position to be empty with 0-D vector.", ): insert_element.verify() + + +def test_vector_transfer_write_construction(): + x = AffineExpr.dimension(0) + vector_type = VectorType(IndexType(), [3]) + memref_type = MemRefType(IndexType(), [3, 3]) + # (x, y) -> x + permutation_map = AffineMapAttr(AffineMap(2, 0, (x,))) + in_bounds = ArrayAttr( + [IntegerAttr.from_bool(False) for _ in range(vector_type.get_num_dims())] + ) + + vector = TestSSAValue(vector_type) + source = TestSSAValue(memref_type) + index = TestSSAValue(IndexType()) + + transfer_write = TransferWriteOp( + vector, + source, + [index, index], + in_bounds, + permutation_map=permutation_map, + ) + + transfer_write.verify() + + assert transfer_write.vector is vector + assert transfer_write.source is source + assert len(transfer_write.indices) == 2 + assert transfer_write.indices[0] is index + assert transfer_write.permutation_map is permutation_map + + +def test_vector_transfer_read_construction(): + x = AffineExpr.dimension(0) + vector_type = VectorType(IndexType(), [3]) + memref_type = MemRefType(IndexType(), [3, 3]) + permutation_map = AffineMapAttr(AffineMap(2, 0, (x,))) + in_bounds = ArrayAttr( + [IntegerAttr.from_bool(False) for _ in range(vector_type.get_num_dims())] + ) + + source = TestSSAValue(memref_type) + index = TestSSAValue(IndexType()) + padding = TestSSAValue(IndexType()) + + transfer_read = TransferReadOp( + source, + [index, index], + padding, + vector_type, + in_bounds, + permutation_map=permutation_map, + ) + + transfer_read.verify() + + assert transfer_read.source is source + assert len(transfer_read.indices) == 2 + assert transfer_read.indices[0] is index + assert transfer_read.permutation_map is permutation_map diff --git a/tests/filecheck/dialects/vector/vector_ops.mlir b/tests/filecheck/dialects/vector/vector_ops.mlir index 7d6d069aca..7d8fbf0cad 100644 --- a/tests/filecheck/dialects/vector/vector_ops.mlir +++ b/tests/filecheck/dialects/vector/vector_ops.mlir @@ -1,5 +1,5 @@ // RUN: XDSL_ROUNDTRIP - +#map = affine_map<(d0, d1) -> (d0)> builtin.module { func.func private @vector_test(%0 : memref<4x4xindex>, %1 : vector<1xi1>, %2 : index) { %3 = "vector.load"(%0, %2, %2) : (memref<4x4xindex>, index, index) -> vector<2xindex> @@ -10,6 +10,9 @@ builtin.module { "vector.maskedstore"(%0, %2, %2, %1, %6) : (memref<4x4xindex>, index, index, vector<1xi1>, vector<1xindex>) -> () "vector.print"(%6) : (vector<1xindex>) -> () %7 = "vector.create_mask"(%2) : (index) -> vector<2xi1> + %8 = "vector.transfer_read"(%0, %2, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array, "permutation_map" = #map}> : (memref<4x4xindex>, index, index, index) -> vector<4xindex> + "vector.transfer_write"(%8, %0, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array, "permutation_map" = #map}> : (vector<4xindex>, memref<4x4xindex>, index, index) -> () + func.return } } @@ -25,6 +28,8 @@ builtin.module { // CHECK-NEXT: "vector.maskedstore"(%0, %2, %2, %1, %6) : (memref<4x4xindex>, index, index, vector<1xi1>, vector<1xindex>) -> () // CHECK-NEXT: "vector.print"(%6) : (vector<1xindex>) -> () // CHECK-NEXT: %7 = "vector.create_mask"(%2) : (index) -> vector<2xi1> +// CHECK-NEXT: %8 = "vector.transfer_read"(%0, %2, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref<4x4xindex>, index, index, index) -> vector<4xindex> +// CHECK-NEXT: "vector.transfer_write"(%8, %0, %2, %2) <{"in_bounds" = [true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<4xindex>, memref<4x4xindex>, index, index) -> () // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/vector/vector_transfer_read_verify.mlir b/tests/filecheck/dialects/vector/vector_transfer_read_verify.mlir new file mode 100644 index 0000000000..2a91859943 --- /dev/null +++ b/tests/filecheck/dialects/vector/vector_transfer_read_verify.mlir @@ -0,0 +1,43 @@ +// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s + +%source, %index, %padding = "test.op"() : () -> (vector<4x3xf32>, index, f32) +"vector.transfer_read"(%source, %index, %index, %padding) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<() -> (0)>}> : (vector<4x3xf32>, index, index, f32) -> vector<1x1x2x3xf32> +// CHECK: Expected tensor or memref type, got vector<4x3xf32> + +// ----- + +%source, %index, %padding = "test.op"() : () -> (memref, index, f32) +"vector.transfer_read"(%source, %index, %index, %index, %padding) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<() -> (0)>}> : (memref, index, index, index, f32) -> vector<128xf32> +// CHECK: Expected an index for each memref/tensor dimension + +// ----- + +%source, %index, %padding = "test.op"() : () -> (memref, index, f32) +"vector.transfer_read"(%source, %index, %index, %padding) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<(d0) -> (d0)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK: requires a permutation_map with input dims of the same rank as the source type + +// ----- + +%source, %index, %padding = "test.op"() : () -> (memref, index, f32) +"vector.transfer_read"(%source, %index, %index, %padding) <{in_bounds=[true, true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK: requires a permutation_map with result dims of the same rank as the vector type + +// ----- + +%source, %index, %padding = "test.op"() : () -> (memref, index, f32) +"vector.transfer_read"(%source, %index, %index, %padding) <{in_bounds=[true, true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> (d0 + d1)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK: requires a projected permutation_map (at most one dim or the zero constant can appear in each result + +// ----- + +%source, %index, %padding = "test.op"() : () -> (memref, index, f32) +"vector.transfer_read"(%source, %index, %index, %padding) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> (d0 + 1)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK: requires a projected permutation_map (at most one dim or the zero constant can appear in each result) + +// ----- + +%source, %index, %padding = "test.op"() : () -> (memref, index, f32) +"vector.transfer_read"(%source, %index, %index, %index, %padding) <{in_bounds=[true, true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1, d2) -> (d0, d0)>}> : (memref, index, index, index, f32) -> vector<3x7xf32> +// CHECK: requires a permutation_map that is a permutation (found one dim used more than once) + +// TODO transfer other tests from mlir/test/Dialect/Vector/invalid.mlir once verification for vector element types is implemented diff --git a/tests/filecheck/dialects/vector/vector_transfer_write_verify.mlir b/tests/filecheck/dialects/vector/vector_transfer_write_verify.mlir new file mode 100644 index 0000000000..11465b0ad8 --- /dev/null +++ b/tests/filecheck/dialects/vector/vector_transfer_write_verify.mlir @@ -0,0 +1,49 @@ +// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s + +%source, %index = "test.op"() : () -> (vector<4x3xf32>, index) +"vector.transfer_write"(%source, %source, %index, %index) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<() -> (0)>}> : (vector<4x3xf32>, vector<4x3xf32>, index, index) -> () +// CHECK: Expected tensor or memref type + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<128xf32>, index) +"vector.transfer_write"(%vector, %source, %index, %index, %index) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<() -> (0)>}> : (vector<128xf32>, memref, index, index, index) -> () +// CHECK: Expected an index for each memref/tensor dimension + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<128xf32>, index) +"vector.transfer_write"(%vector, %source, %index, %index) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<(d0) -> (d0)>}> : (vector<128xf32>, memref, index, index) -> () +// CHECK: requires a permutation_map with input dims of the same rank as the source type + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<128xf32>, index) +"vector.transfer_write"(%vector, %source, %index, %index) <{in_bounds=[true, true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<128xf32>, memref, index, index) -> () +// CHECK: requires a permutation_map with result dims of the same rank as the vector type + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<128xf32>, index) +"vector.transfer_write"(%vector, %source, %index, %index) <{in_bounds=[true, true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> (d0 + d1)>}> : (vector<128xf32>, memref, index, index) -> () +// CHECK: requires a projected permutation_map (at most one dim or the zero constant can appear in each result) + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<128xf32>, index) +"vector.transfer_write"(%vector, %source, %index, %index) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> (d0 + 1)>}> : (vector<128xf32>, memref, index, index) -> () +// CHECK: requires a projected permutation_map (at most one dim or the zero constant can appear in each result) + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<3x7xf32>, index) +"vector.transfer_write"(%vector, %source, %index, %index, %index) <{in_bounds=[true, true], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1, d2) -> (d0, d0)>}> : (vector<3x7xf32>, memref, index, index, index) -> () +// CHECK: requires a permutation_map that is a permutation (found one dim used more than once) + +// ----- + +%source, %vector, %index = "test.op"() : () -> (memref, vector<7xf32>, index) +"vector.transfer_write"(%vector, %source, %index) <{in_bounds=[true], operandSegmentSizes = array, permutation_map = affine_map<(d0) -> (0)>}> : (vector<7xf32>, memref, index) -> () +// CHECK: requires a permutation_map that is a permutation (found one dim used more than once) + +// TODO transfer other tests from mlir/test/Dialect/Vector/invalid.mlir once verification for vector element types is implemented diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir index e4c2d5649c..04365eb993 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir @@ -1,8 +1,6 @@ -// RUN: xdsl-opt --print-op-generic %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s +// RUN: xdsl-opt --split-input-file --print-op-generic %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s -builtin.module{ - -%vector0, %vector1, %i0 = "test.op"() : () -> (vector, vector<3xindex>, index) +%vector0, %vector1, %i0= "test.op"() : () -> (vector, vector<3xindex>, index) // CHECK: %0, %1, %2 = "test.op"() : () -> (vector, vector<3xindex>, index) %0 = "vector.insertelement"(%i0, %vector0) : (index, vector) -> vector @@ -16,4 +14,161 @@ builtin.module{ %3 = "vector.extractelement"(%vector0) : (vector) -> index // CHECK-NEXT: %6 = "vector.extractelement"(%0) : (vector) -> index -} + + +// ----- +// Vector transfer ops 0d +// See func vector_transfer_ops_0d in mlir/test/Dialect/Vector/ops.mlir + +%tensor, %vector, %memref, %f= "test.op"() : () -> (tensor, vector, memref, f32) +// CHECK: %0, %1, %2, %3 = "test.op"() : () -> (tensor, vector, memref, f32) + +%0 = "vector.transfer_read"(%tensor, %f) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<() -> ()>}> : (tensor, f32) -> vector +// CHECK-NEXT: %4 = "vector.transfer_read"(%0, %3) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<() -> ()>}> : (tensor, f32) -> vector + +%1 = "vector.transfer_write"(%vector, %tensor) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<() -> ()>}> : (vector, tensor) -> tensor +// CHECK-NEXT: %5 = "vector.transfer_write"(%1, %0) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<() -> ()>}> : (vector, tensor) -> tensor + +%2 = "vector.transfer_read"(%memref, %f) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<() -> ()>}> : (memref, f32) -> vector +// CHECK-NEXT: %6 = "vector.transfer_read"(%2, %3) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<() -> ()>}> : (memref, f32) -> vector + +"vector.transfer_write"(%vector, %memref) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<() -> ()>}> : (vector, memref) -> () +// CHECK-NEXT: "vector.transfer_write"(%1, %2) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<() -> ()>}> : (vector, memref) -> () + +// ----- +// Vector transfer ops 0d from higher d +// func vector_transfer_ops_0d_from_higher_d in mlir/test/Dialect/Vector/ops.mlir + +%tensor, %memref, %index, %f= "test.op"() : () -> (tensor, memref, index, f32) +// CHECK: %0, %1, %2, %3 = "test.op"() : () -> (tensor, memref, index, f32) + +%0 = "vector.transfer_read"(%tensor, %index, %f) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<(d0) -> ()>}> : (tensor, index, f32) -> vector +// CHECK-NEXT: %4 = "vector.transfer_read"(%0, %2, %3) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0) -> ()>}> : (tensor, index, f32) -> vector + +%1 = "vector.transfer_write"(%0, %tensor, %index) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<(d0) -> ()>}> : (vector, tensor, index) -> tensor +// CHECK-NEXT: %5 = "vector.transfer_write"(%4, %0, %2) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0) -> ()>}> : (vector, tensor, index) -> tensor + +%2 = "vector.transfer_read"(%memref, %index, %index, %f) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> ()>}> : (memref, index, index, f32) -> vector +// CHECK-NEXT: %6 = "vector.transfer_read"(%1, %2, %2, %3) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (memref, index, index, f32) -> vector + +"vector.transfer_write"(%2, %memref, %index, %index) <{in_bounds = [], operandSegmentSizes = array, permutation_map = affine_map<(d0, d1) -> ()>}> : (vector, memref, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%6, %1, %2, %2) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector, memref, index, index) -> () + +// ----- +// Vector transfer ops +// func vector_transfer_ops in mlir/test/Dialect/Vector/ops.mlir + +%0, %1, %2, %3, %4 = "test.op"() : () -> (memref, memref>, memref>, memref>, memref) +// CHECK: %0, %1, %2, %3, %4 = "test.op"() : () -> (memref, memref>, memref>, memref>, memref) + +%5, %6, %7, %8, %9, %10 = "test.op"() : () -> (index, f32, f32, i32, index, i1) +// CHECK-NEXT: %5, %6, %7, %8, %9, %10 = "test.op"() : () -> (index, f32, f32, i32, index, i1) + +%11, %12, %13, %14, %15 = "test.op"() : () -> (vector<4x3xf32>, vector<4x3xi32>, vector<4x3xindex>, vector<5xi1>, vector<4x5xi1>) +// CHECK-NEXT: %11, %12, %13, %14, %15 = "test.op"() : () -> (vector<4x3xf32>, vector<4x3xi32>, vector<4x3xindex>, vector<5xi1>, vector<4x5xi1>) + +%16 = "vector.transfer_read"(%0, %5, %5, %7) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK-NEXT: %16 = "vector.transfer_read"(%0, %5, %5, %7) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref, index, index, f32) -> vector<128xf32> + +%17 = "vector.transfer_read"(%0, %5, %5, %7) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (memref, index, index, f32) -> vector<3x7xf32> +// CHECK-NEXT: %17 = "vector.transfer_read"(%0, %5, %5, %7) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (memref, index, index, f32) -> vector<3x7xf32> + +%18 = "vector.transfer_read"(%0, %5, %5, %6) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK-NEXT: %18 = "vector.transfer_read"(%0, %5, %5, %6) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (memref, index, index, f32) -> vector<128xf32> + +%19 = "vector.transfer_read"(%0, %5, %5, %6) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (memref, index, index, f32) -> vector<128xf32> +// CHECK-NEXT: %19 = "vector.transfer_read"(%0, %5, %5, %6) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (memref, index, index, f32) -> vector<128xf32> + +%20 = "vector.transfer_read"(%1, %5, %5, %11) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> +// CHECK-NEXT: %20 = "vector.transfer_read"(%1, %5, %5, %11) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> + +%21 = "vector.transfer_read"(%1, %5, %5, %11) <{"in_bounds" = [false, true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> +// CHECK-NEXT: %21 = "vector.transfer_read"(%1, %5, %5, %11) <{"in_bounds" = [false, true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> + +%22 = "vector.transfer_read"(%2, %5, %5, %12) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (memref>, index, index, vector<4x3xi32>) -> vector<5x24xi8> +// CHECK-NEXT: %22 = "vector.transfer_read"(%2, %5, %5, %12) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (memref>, index, index, vector<4x3xi32>) -> vector<5x24xi8> + +%23 = "vector.transfer_read"(%3, %5, %5, %13) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (memref>, index, index, vector<4x3xindex>) -> vector<5x48xi8> +// CHECK-NEXT: %23 = "vector.transfer_read"(%3, %5, %5, %13) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (memref>, index, index, vector<4x3xindex>) -> vector<5x48xi8> + +%24 = "vector.transfer_read"(%0, %5, %5, %7, %14) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (memref, index, index, f32, vector<5xi1>) -> vector<5xf32> +// CHECK-NEXT: %24 = "vector.transfer_read"(%0, %5, %5, %7, %14) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (memref, index, index, f32, vector<5xi1>) -> vector<5xf32> + +%25 = "vector.transfer_read"(%4, %5, %5, %5, %7, %15) <{"in_bounds" = [false, false, true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1, d2) -> (d1, d0, 0)>}> : (memref, index, index, index, f32, vector<4x5xi1>) -> vector<5x4x8xf32> +// CHECK-NEXT: %25 = "vector.transfer_read"(%4, %5, %5, %5, %7, %15) <{"in_bounds" = [false, false, true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1, d2) -> (d1, d0, 0)>}> : (memref, index, index, index, f32, vector<4x5xi1>) -> vector<5x4x8xf32> + +"vector.transfer_write"(%16, %0, %5, %5) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<128xf32>, memref, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%16, %0, %5, %5) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<128xf32>, memref, index, index) -> () + +"vector.transfer_write"(%17, %0, %5, %5) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (vector<3x7xf32>, memref, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%17, %0, %5, %5) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (vector<3x7xf32>, memref, index, index) -> () + +"vector.transfer_write"(%20, %1, %5, %5) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, memref>, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%20, %1, %5, %5) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, memref>, index, index) -> () + +"vector.transfer_write"(%21, %1, %5, %5) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, memref>, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%21, %1, %5, %5) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, memref>, index, index) -> () + +"vector.transfer_write"(%22, %2, %5, %5) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x24xi8>, memref>, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%22, %2, %5, %5) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x24xi8>, memref>, index, index) -> () + +"vector.transfer_write"(%23, %3, %5, %5) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x48xi8>, memref>, index, index) -> () +// CHECK-NEXT: "vector.transfer_write"(%23, %3, %5, %5) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x48xi8>, memref>, index, index) -> () + +"vector.transfer_write"(%24, %0, %5, %5, %14) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (vector<5xf32>, memref, index, index, vector<5xi1>) -> () +// CHECK-NEXT: "vector.transfer_write"(%24, %0, %5, %5, %14) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (vector<5xf32>, memref, index, index, vector<5xi1>) -> () + +// ----- +// Vector transfer ops tensor +// func vector_transfer_ops_tensor in mlir/test/Dialect/Vector/ops.mlir + +%0, %1, %2, %3 = "test.op"() : () -> (tensor, tensor>, tensor>, tensor>) +// CHECK: %0, %1, %2, %3 = "test.op"() : () -> (tensor, tensor>, tensor>, tensor>) + +%4, %5, %6, %7, %8 = "test.op"() : () -> (index, f32, f32, i32, index) +// CHECK-NEXT: %4, %5, %6, %7, %8 = "test.op"() : () -> (index, f32, f32, i32, index) + +%9, %10, %11 = "test.op"() : () -> (vector<4x3xf32>, vector<4x3xi32>, vector<4x3xindex>) +// CHECK-NEXT: %9, %10, %11 = "test.op"() : () -> (vector<4x3xf32>, vector<4x3xi32>, vector<4x3xindex>) + +%12 = "vector.transfer_read"(%0, %4, %4, %6) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (tensor, index, index, f32) -> vector<128xf32> +// CHECK-NEXT: %12 = "vector.transfer_read"(%0, %4, %4, %6) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (tensor, index, index, f32) -> vector<128xf32> + +%13 = "vector.transfer_read"(%0, %4, %4, %6) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (tensor, index, index, f32) -> vector<3x7xf32> +// CHECK-NEXT: %13 = "vector.transfer_read"(%0, %4, %4, %6) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (tensor, index, index, f32) -> vector<3x7xf32> + +%14 = "vector.transfer_read"(%0, %4, %4, %5) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (tensor, index, index, f32) -> vector<128xf32> +// CHECK-NEXT: %14 = "vector.transfer_read"(%0, %4, %4, %5) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (tensor, index, index, f32) -> vector<128xf32> + +%15 = "vector.transfer_read"(%0, %4, %4, %5) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (tensor, index, index, f32) -> vector<128xf32> +// CHECK-NEXT: %15 = "vector.transfer_read"(%0, %4, %4, %5) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1)>}> : (tensor, index, index, f32) -> vector<128xf32> + +%16 = "vector.transfer_read"(%1, %4, %4, %9) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (tensor>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> +// CHECK-NEXT: %16 = "vector.transfer_read"(%1, %4, %4, %9) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (tensor>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> + +%17 = "vector.transfer_read"(%1, %4, %4, %9) <{"in_bounds" = [false, true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (tensor>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> +// CHECK-NEXT: %17 = "vector.transfer_read"(%1, %4, %4, %9) <{"in_bounds" = [false, true], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (tensor>, index, index, vector<4x3xf32>) -> vector<1x1x4x3xf32> + +%18 = "vector.transfer_read"(%2, %4, %4, %10) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (tensor>, index, index, vector<4x3xi32>) -> vector<5x24xi8> +// CHECK-NEXT: %18 = "vector.transfer_read"(%2, %4, %4, %10) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (tensor>, index, index, vector<4x3xi32>) -> vector<5x24xi8> + +%19 = "vector.transfer_read"(%3, %4, %4, %11) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (tensor>, index, index, vector<4x3xindex>) -> vector<5x48xi8> +// CHECK-NEXT: %19 = "vector.transfer_read"(%3, %4, %4, %11) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (tensor>, index, index, vector<4x3xindex>) -> vector<5x48xi8> + +%20 = "vector.transfer_write"(%12, %0, %4, %4) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<128xf32>, tensor, index, index) -> tensor +// CHECK-NEXT: %20 = "vector.transfer_write"(%12, %0, %4, %4) <{"in_bounds" = [false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0)>}> : (vector<128xf32>, tensor, index, index) -> tensor + +%21 = "vector.transfer_write"(%13, %0, %4, %4) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (vector<3x7xf32>, tensor, index, index) -> tensor +// CHECK-NEXT: %21 = "vector.transfer_write"(%13, %0, %4, %4) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d1, d0)>}> : (vector<3x7xf32>, tensor, index, index) -> tensor + +%22 = "vector.transfer_write"(%16, %1, %4, %4) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, tensor>, index, index) -> tensor> +// CHECK-NEXT: %22 = "vector.transfer_write"(%16, %1, %4, %4) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, tensor>, index, index) -> tensor> + +%23 = "vector.transfer_write"(%17, %1, %4, %4) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, tensor>, index, index) -> tensor> +// CHECK-NEXT: %23 = "vector.transfer_write"(%17, %1, %4, %4) <{"in_bounds" = [false, false], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x1x4x3xf32>, tensor>, index, index) -> tensor> + +%24 = "vector.transfer_write"(%18, %2, %4, %4) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x24xi8>, tensor>, index, index) -> tensor> +// CHECK-NEXT: %24 = "vector.transfer_write"(%18, %2, %4, %4) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x24xi8>, tensor>, index, index) -> tensor> + +%25 = "vector.transfer_write"(%19, %3, %4, %4) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x48xi8>, tensor>, index, index) -> tensor> +// CHECK-NEXT: %25 = "vector.transfer_write"(%19, %3, %4, %4) <{"in_bounds" = [], "operandSegmentSizes" = array, "permutation_map" = affine_map<(d0, d1) -> ()>}> : (vector<5x48xi8>, tensor>, index, index) -> tensor> diff --git a/tests/test_affine_builtins.py b/tests/test_affine_builtins.py index bf50e0c041..03b0318223 100644 --- a/tests/test_affine_builtins.py +++ b/tests/test_affine_builtins.py @@ -212,8 +212,58 @@ def test_compress_dims(): ) == AffineMap.from_callable(lambda d0, d1: (d1, d1)) -def test_used_dims(): +def test_affine_expr_used_dims(): assert AffineExpr.dimension(1).used_dims() == {1} assert (AffineExpr.dimension(2) + AffineExpr.dimension(3)).used_dims() == {2, 3} assert AffineExpr.symbol(4).used_dims() == set() assert AffineExpr.constant(5).used_dims() == set() + + +def test_affine_expr_is_function_of_dim(): + assert AffineExpr.dimension(0).is_function_of_dim(0) + assert not AffineExpr.dimension(1).is_function_of_dim(0) + assert not AffineExpr.constant(0).is_function_of_dim(0) + assert not AffineExpr.symbol(0).is_function_of_dim(0) + assert AffineMap(2, 0, (AffineExpr.dimension(0),)).results[0].is_function_of_dim(0) + assert not ( + AffineMap(2, 0, (AffineExpr.dimension(0),)).results[0].is_function_of_dim(1) + ) + assert ( + AffineMap.from_callable(lambda i, j: (i + j,)).results[0].is_function_of_dim(0) + ) + assert ( + AffineMap.from_callable(lambda i, j: (i + j,)).results[0].is_function_of_dim(1) + ) + + +def test_affine_map_is_function_of_dim(): + assert AffineMap.from_callable(lambda i, j: (i, j)).is_function_of_dim(0) + assert AffineMap.from_callable(lambda i, j: (i, j)).is_function_of_dim(1) + assert not AffineMap.from_callable(lambda i, j, _: (i, j)).is_function_of_dim(2) + + +def test_affine_map_used_dims(): + assert AffineMap.from_callable(lambda i, j: (i, j)).used_dims() == {0, 1} + assert AffineMap.from_callable(lambda i, j, _: (i + j,)).used_dims() == {0, 1} + assert AffineMap.from_callable(lambda i, _, k: (i, k)).used_dims() == {0, 2} + + +def test_affine_map_unused_dims(): + assert AffineMap.from_callable(lambda i, j: (i, j)).unused_dims() == set() + assert AffineMap.from_callable(lambda i, j, _: (i + j,)).unused_dims() == {2} + assert AffineMap.from_callable(lambda i, _, k: (i, k)).unused_dims() == {1} + + +def test_unused_dims_bit_vector(): + assert AffineMap.from_callable(lambda i, j: (i, j)).unused_dims_bit_vector() == ( + False, + False, + ) + assert AffineMap.from_callable( + lambda i, j, _: (i + j,) + ).unused_dims_bit_vector() == (False, False, True) + assert AffineMap.from_callable(lambda i, _, k: (i, k)).unused_dims_bit_vector() == ( + False, + True, + False, + ) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index e03b9be0d4..77830a2a35 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1898,8 +1898,8 @@ def get_resolvers( } def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: - if isinstance(attr, VectorType) or isinstance(attr, TensorType): - attr = cast(VectorType[Attribute] | TensorType[Attribute], attr) + if isinstance(attr, MemRefType) or isinstance(attr, TensorType): + attr = cast(MemRefType[Attribute] | TensorType[Attribute], attr) self.elem_constr.verify(attr.element_type, constraint_context) else: raise VerifyException(f"Expected tensor or memref type, got {attr}") diff --git a/xdsl/dialects/vector.py b/xdsl/dialects/vector.py index fd99afeb70..b9fb0f1b8c 100644 --- a/xdsl/dialects/vector.py +++ b/xdsl/dialects/vector.py @@ -1,12 +1,21 @@ from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Sequence from xdsl.dialects.builtin import ( + I1, + AffineMapAttr, + AnyFloat, + ArrayAttr, + BoolAttr, IndexType, IndexTypeConstr, + IntegerType, MemRefType, SignlessIntegerConstraint, + TensorOrMemrefOf, + TensorType, VectorBaseTypeAndRankConstraint, VectorBaseTypeConstraint, VectorRankConstraint, @@ -14,11 +23,16 @@ i1, ) from xdsl.ir import Attribute, Dialect, Operation, SSAValue +from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap from xdsl.irdl import ( + AttrSizedOperandSegments, IRDLOperation, + ParsePropInAttrDict, irdl_op_definition, operand_def, opt_operand_def, + opt_result_def, + prop_def, result_def, traits_def, var_operand_def, @@ -383,6 +397,257 @@ def __init__( ) +def verify_permutation_map( + op: TransferReadOp | TransferWriteOp, + permutation_map: AffineMap, +): + """ + This mirrors VectorOps.cpp -> verifyPermutationMap + """ + + seen: list[bool] = [False for _ in range(permutation_map.num_dims)] + + for expr in permutation_map.results: + if isa(expr, AffineConstantExpr): + if expr.value != 0: + raise VerifyException( + f'"{op.name}" requires a projected permutation_map (at most one dim or the zero constant can appear in each result)' + ) + continue + if not isa(expr, AffineDimExpr): + raise VerifyException( + f'"{op.name}" requires a projected permutation_map (at most one dim or the zero constant can appear in each result)' + ) + if seen[expr.position]: + raise VerifyException( + f'"{op.name}" requires a permutation_map that is a permutation (found one dim used more than once)' + ) + seen[expr.position] = True + + +def verify_transfer_op( + op: TransferReadOp | TransferWriteOp, + shaped_type: MemRefType[Attribute] | TensorType[Attribute], + vector_type: VectorType[Attribute], + permutation_map: AffineMap, + in_bounds: ArrayAttr[BoolAttr], +): + """ + This mirrors VectorOps.cpp -> verifyTransferOp from MLIR + """ + + element_type = shaped_type.element_type + vector_element_type = vector_type.element_type + + if isa(element_type, VectorType[Attribute]): + # Memref or tensor has vector element type + # TODO verify vector element type + pass + else: + # Memref of tensor has scalar element type + if isa(vector_element_type, IndexType): + if not isa(element_type, IndexType): + raise VerifyException( + "Element type of source is index, expected element type of vector also to be index" + ) + else: + assert isa(vector_element_type, IntegerType | AnyFloat) + assert isa(element_type, IntegerType | AnyFloat) + + minor_size = ( + 1 if vector_type.get_num_dims() == 0 else vector_type.get_shape()[-1] + ) + result_vec_size = vector_element_type.bitwidth * minor_size + if result_vec_size % element_type.bitwidth != 0: + raise VerifyException( + f'"{op.name}" requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the source element type' + ) + + # Check that permutation map results match rank of vector type. + if len(permutation_map.results) != vector_type.get_num_dims(): + raise VerifyException( + f'"{op.name}" requires a permutation_map with result dims of the same rank as the vector type' + ) + + if permutation_map.num_symbols != 0: + raise VerifyException(f'"{op.name}" requires permutation_map without symbols') + + if permutation_map.num_dims != shaped_type.get_num_dims(): + raise VerifyException( + f'"{op.name}" requires a permutation_map with input dims of the same rank as the source type' + ) + + if len(in_bounds) != len(permutation_map.results): + raise VerifyException( + f'"{op.name}" expects the optional in_bounds attr of same rank as permutation_map results: {str(permutation_map)} vs in_bounds of of size {len(in_bounds)}' + ) + + for i in range(len(permutation_map.results)): + if ( + isa(permutation_map.results[i], AffineConstantExpr) + and not in_bounds.data[i].value.data + ): + raise VerifyException( + f'"{op.name}" requires broadcast dimensions to be in-bounds' + ) + + +class VectorTransferOp(ABC): + """ + Mirrors VectorTransferOpInterface from VectorInterfaces.h.inc + """ + + @abstractmethod + def get_permutation_map(self) -> AffineMap: + raise NotImplementedError() + + def is_broadcast_dim(self, dim: int) -> bool: + expr = self.get_permutation_map().results[dim] + if not isa(expr, AffineConstantExpr): + return False + return expr.value == 0 + + def has_broadcast_dim(self): + for dim in range(self.get_transfer_rank()): + if self.is_broadcast_dim(dim): + return True + + return False + + def get_transfer_rank(self) -> int: + return len(self.get_permutation_map().results) + + +@irdl_op_definition +class TransferReadOp(IRDLOperation, VectorTransferOp): + name = "vector.transfer_read" + + source = operand_def(TensorOrMemrefOf(Attribute)) + indices = var_operand_def(IndexType) + padding = operand_def(Attribute) + mask = opt_operand_def(VectorType[I1]) + + permutation_map = prop_def(AffineMapAttr) + in_bounds = prop_def(ArrayAttr[BoolAttr]) + + result = result_def(VectorType) + + irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] + + def verify_(self): + assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute]) + assert isa(self.result.type, VectorType[Attribute]) + + if len(self.indices) != self.source.type.get_num_dims(): + raise VerifyException("Expected an index for each memref/tensor dimension.") + + verify_transfer_op( + self, + self.source.type, + self.result.type, + self.permutation_map.data, + self.in_bounds, + ) + + if isa(self.source.type.element_type, VectorType[Attribute]): + # TODO verify vector element type + pass + else: + # source memref/tensor has scalar element type + # TODO verify that padding type is a valid element_type for a vector + if self.source.type.element_type != self.padding.type: + raise VerifyException( + f'"{self.name}" requires formal padding and source of the same elemental type' + ) + + verify_permutation_map( + self, + self.permutation_map.data, + ) + + def __init__( + self, + source: SSAValue | Operation, + indices: Sequence[SSAValue | Operation], + padding: SSAValue | Operation, + result_type: Attribute, + in_bounds: ArrayAttr[BoolAttr], + mask: Sequence[SSAValue | Operation] | None = None, + permutation_map: AffineMapAttr | None = None, + ): + super().__init__( + operands=[source, indices, padding, mask], + result_types=[result_type], + properties={"permutation_map": permutation_map, "in_bounds": in_bounds}, + ) + + # override VectorTransferOp.get_permutation_map + def get_permutation_map(self): + return self.permutation_map.data + + +@irdl_op_definition +class TransferWriteOp(IRDLOperation, VectorTransferOp): + name = "vector.transfer_write" + + vector = operand_def(VectorType[Attribute]) + source = operand_def(TensorOrMemrefOf(Attribute)) + indices = var_operand_def(IndexType) + mask = opt_operand_def(VectorType[I1]) + + in_bounds = prop_def(ArrayAttr[BoolAttr]) + permutation_map = prop_def(AffineMapAttr) + + result = opt_result_def(TensorType[Attribute]) + + irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] + + def verify_(self): + assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute]) + assert isa(self.vector.type, VectorType[Attribute]) + + if len(self.indices) != self.source.type.get_num_dims(): + raise VerifyException("Expected an index for each memref/tensor dimension.") + + if self.has_broadcast_dim(): + raise VerifyException( + f'"{self.name}" should not have broadcast dimensions.' + ) + + verify_transfer_op( + self, + self.source.type, + self.vector.type, + self.permutation_map.data, + self.in_bounds, + ) + + verify_permutation_map( + self, + self.permutation_map.data, + ) + + def __init__( + self, + vector: SSAValue | Operation, + source: SSAValue | Operation, + indices: Sequence[SSAValue | Operation], + in_bounds: ArrayAttr[BoolAttr], + mask: Sequence[SSAValue | Operation] | None = None, + permutation_map: AffineMapAttr | None = None, + result_type: TensorType[Attribute] | None = None, + ): + super().__init__( + operands=[vector, source, indices, mask], + properties={"permutation_map": permutation_map, "in_bounds": in_bounds}, + result_types=[result_type], + ) + + # override VectorTransferOp.get_permutation_map + def get_permutation_map(self): + return self.permutation_map.data + + Vector = Dialect( "vector", [ @@ -396,6 +661,8 @@ def __init__( CreatemaskOp, ExtractElementOp, InsertElementOp, + TransferReadOp, + TransferWriteOp, ], [], ) diff --git a/xdsl/ir/affine/affine_expr.py b/xdsl/ir/affine/affine_expr.py index eaa3c8178e..ef9ef8f716 100644 --- a/xdsl/ir/affine/affine_expr.py +++ b/xdsl/ir/affine/affine_expr.py @@ -271,6 +271,9 @@ def dfs(self) -> Iterator[AffineExpr]: def used_dims(self) -> set[int]: return {expr.position for expr in self.dfs() if isinstance(expr, AffineDimExpr)} + def is_function_of_dim(self, position: int) -> bool: + return position in self.used_dims() + class AffineBinaryOpKind(Enum): """Enum for the kind of storage node used in AffineExpr.""" diff --git a/xdsl/ir/affine/affine_map.py b/xdsl/ir/affine/affine_map.py index 6baf554168..2bb18a9f29 100644 --- a/xdsl/ir/affine/affine_map.py +++ b/xdsl/ir/affine/affine_map.py @@ -4,8 +4,9 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass from inspect import getfullargspec +from typing import cast -from xdsl.ir.affine import AffineDimExpr, AffineExpr +from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr AffineExprBuilderT = AffineExpr | int @@ -169,6 +170,17 @@ def compose(self, other: AffineMap) -> AffineMap: results=results, ) + def compose_with_values(self, values: Sequence[int]) -> tuple[int, ...]: + """ + Same as SmallVector AffineMap::compose(ArrayRef values) const from AffineMap.cpp + """ + assert self.num_symbols == 0 + expressions: list[AffineExpr] = [] + for value in values: + expressions.append(AffineExpr.constant(value)) + result_map = self.compose(AffineMap(0, 0, tuple(expressions))) + return tuple(cast(AffineConstantExpr, res).value for res in result_map.results) + def inverse_permutation(self) -> AffineMap | None: """ Returns a map of codomain to domain dimensions such that the first @@ -245,6 +257,29 @@ def compress_dims(self, selectors: Sequence[bool]) -> AffineMap: new_dims, new_symbols, result_num_dims, self.num_symbols ) + def is_function_of_dim(self, position: int) -> bool: + return position in self.used_dims() + + def used_dims(self) -> set[int]: + result: set[int] = set() + + for expr in self.results: + result = result.union(expr.used_dims()) + + return result + + def unused_dims(self) -> set[int]: + dims = {i for i in range(self.num_dims)} + + return dims.difference(self.used_dims()) + + def unused_dims_bit_vector(self) -> tuple[bool, ...]: + unused_dims = self.unused_dims() + return tuple( + True if position in unused_dims else False + for position in range(self.num_dims) + ) + def __str__(self) -> str: # Create comma seperated list of dims. dims = ["d" + str(i) for i in range(self.num_dims)]