diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index b1fd1acf568e..8414c347c459 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -57,6 +57,15 @@ Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, bool twoState = false); +/// Extract bits from a value. +void extractBits(OpBuilder &builder, Value val, SmallVectorImpl &bits); + +/// Construct a mux tree for given leaf nodes. `selectors` is the selector for +/// each level of the tree. Currently the selector is tested from MSB to LSB. +Value constructMuxTree(OpBuilder &builder, Location loc, + ArrayRef selectors, ArrayRef leafNodes, + Value outOfBoundsValue); + } // namespace comb } // namespace circt diff --git a/include/circt/Dialect/HW/HWPasses.h b/include/circt/Dialect/HW/HWPasses.h index d986736e2fd1..d693a237508a 100644 --- a/include/circt/Dialect/HW/HWPasses.h +++ b/include/circt/Dialect/HW/HWPasses.h @@ -33,6 +33,7 @@ std::unique_ptr createFlattenIOPass(bool recursiveFlag = true, std::unique_ptr createVerifyInnerRefNamespacePass(); std::unique_ptr createFlattenModulesPass(); std::unique_ptr createFooWiresPass(); +std::unique_ptr createHWAggregateToCombPass(); /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION diff --git a/include/circt/Dialect/HW/Passes.td b/include/circt/Dialect/HW/Passes.td index a1445d370679..adab3b271f2d 100644 --- a/include/circt/Dialect/HW/Passes.td +++ b/include/circt/Dialect/HW/Passes.td @@ -87,4 +87,16 @@ def FooWires : Pass<"hw-foo-wires", "hw::HWModuleOp"> { let constructor = "circt::hw::createFooWiresPass()"; } +def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> { + let summary = "Lower aggregate operations to comb operations"; + let constructor = "circt::hw::createHWAggregateToCombPass()"; + + let description = [{ + This pass lowers aggregate *operations* to comb operations within modules. + This pass does not lower ports, as ports are handled by FlattenIO. This pass + will also change the behavior of out-of-bounds access of arrays. + }]; + let dependentDialects = ["comb::CombDialect"]; +} + #endif // CIRCT_DIALECT_HW_PASSES_TD diff --git a/integration_test/circt-synth/comb-lowering-lec.mlir b/integration_test/circt-synth/comb-lowering-lec.mlir index 895c1d81d4a0..4cfe98afea61 100644 --- a/integration_test/circt-synth/comb-lowering-lec.mlir +++ b/integration_test/circt-synth/comb-lowering-lec.mlir @@ -1,7 +1,7 @@ // REQUIRES: libz3 // REQUIRES: circt-lec-jit -// RUN: circt-opt %s --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir +// RUN: circt-opt %s --hw-aggregate-to-comb --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir // RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL // COMB_BIT_LOGICAL: c1 == c2 hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32, @@ -78,3 +78,19 @@ hw.module @shift5(in %lhs: i5, in %rhs: i5, out out_shl: i5, out out_shr: i5, ou %2 = comb.shrs %lhs, %rhs : i5 hw.output %0, %1, %2 : i5, i5, i5 } + +// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY +// COMB_ARRAY: c1 == c2 +hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) { + %0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2 + %1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2 + %2 = hw.array_create %arg0, %arg1, %arg2 : i2 + %c3_i2 = hw.constant 3 : i2 + // NOTE: If the index is out of bounds, the result value is undefined. + // In LEC such value is lowered into unbounded SMT variable and cause + // the LEC to fail. So just asssume that the index is in bounds. + %inbound = comb.icmp ult %sel2, %c3_i2 : i2 + verif.assume %inbound : i1 + %3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2 + hw.output %1, %3 : i2, i2 +} diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index 701692ff3626..1e0412372efb 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -29,64 +29,13 @@ using namespace comb; // Utility Functions //===----------------------------------------------------------------------===// -// Extract individual bits from a value -static SmallVector extractBits(ConversionPatternRewriter &rewriter, - Value val) { - assert(val.getType().isInteger() && "expected integer"); - auto width = val.getType().getIntOrFloatBitWidth(); +// A wrapper for comb::extractBits that returns a SmallVector. +static SmallVector extractBits(OpBuilder &builder, Value val) { SmallVector bits; - bits.reserve(width); - - // Check if we can reuse concat operands - if (auto concat = val.getDefiningOp()) { - if (concat.getNumOperands() == width && - llvm::all_of(concat.getOperandTypes(), [](Type type) { - return type.getIntOrFloatBitWidth() == 1; - })) { - // Reverse the operands to match the bit order - bits.append(std::make_reverse_iterator(concat.getOperands().end()), - std::make_reverse_iterator(concat.getOperands().begin())); - return bits; - } - } - - // Extract individual bits - for (int64_t i = 0; i < width; ++i) - bits.push_back( - rewriter.createOrFold(val.getLoc(), val, i, 1)); - + comb::extractBits(builder, val, bits); return bits; } -// Construct a mux tree for given leaf nodes. `selectors` is the selector for -// each level of the tree. Currently the selector is tested from MSB to LSB. -static Value constructMuxTree(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef selectors, - ArrayRef leafNodes, - Value outOfBoundsValue) { - // Recursive helper function to construct the mux tree - std::function constructTreeHelper = - [&](size_t id, size_t level) -> Value { - // Base case: at the lowest level, return the result - if (level == 0) { - // Return the result for the given index. If the index is out of bounds, - // return the out-of-bound value. - return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue; - } - - auto selector = selectors[level - 1]; - - // Recursive case: create muxes for true and false branches - auto trueVal = constructTreeHelper(2 * id + 1, level - 1); - auto falseVal = constructTreeHelper(2 * id, level - 1); - - // Combine the results with a mux - return rewriter.createOrFold(loc, selector, trueVal, falseVal); - }; - - return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size())); -} - // Construct a mux tree for shift operations. `isLeftShift` controls the // direction of the shift operation and is used to determine order of the // padding and extracted bits. Callbacks `getPadding` and `getExtract` are used @@ -128,7 +77,8 @@ static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, assert(outOfBoundsValue && "outOfBoundsValue must be valid"); // Construct mux tree for shift operation - auto result = constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue); + auto result = + comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue); // Add bounds checking auto inBound = rewriter.createOrFold( @@ -667,10 +617,21 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { void ConvertCombToAIGPass::runOnOperation() { ConversionTarget target(getContext()); + + // Comb is source dialect. target.addIllegalDialect(); // Keep data movement operations like Extract, Concat and Replicate. target.addLegalOp(); + + // Treat array operations as illegal. Strictly speaking, other than array get + // operation with non-const index are legal in AIG but array types prevent a + // bunch of optimizations so just lower them to integer operations. It's + // required to run HWAggregateToComb pass before this pass. + target.addIllegalOp(); + + // AIG is target dialect. target.addLegalDialect(); // This is a test only option to add logical ops. diff --git a/lib/Dialect/Comb/CombOps.cpp b/lib/Dialect/Comb/CombOps.cpp index 26312306acfe..652ec40fba4e 100644 --- a/lib/Dialect/Comb/CombOps.cpp +++ b/lib/Dialect/Comb/CombOps.cpp @@ -56,6 +56,61 @@ Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, return createOrFoldNot(builder.getLoc(), value, builder, twoState); } +// Extract individual bits from a value +void comb::extractBits(OpBuilder &builder, Value val, + SmallVectorImpl &bits) { + assert(val.getType().isInteger() && "expected integer"); + auto width = val.getType().getIntOrFloatBitWidth(); + bits.reserve(width); + + // Check if we can reuse concat operands + if (auto concat = val.getDefiningOp()) { + if (concat.getNumOperands() == width && + llvm::all_of(concat.getOperandTypes(), [](Type type) { + return type.getIntOrFloatBitWidth() == 1; + })) { + // Reverse the operands to match the bit order + bits.append(std::make_reverse_iterator(concat.getOperands().end()), + std::make_reverse_iterator(concat.getOperands().begin())); + return; + } + } + + // Extract individual bits + for (int64_t i = 0; i < width; ++i) + bits.push_back( + builder.createOrFold(val.getLoc(), val, i, 1)); +} + +// Construct a mux tree for given leaf nodes. `selectors` is the selector for +// each level of the tree. Currently the selector is tested from MSB to LSB. +Value comb::constructMuxTree(OpBuilder &builder, Location loc, + ArrayRef selectors, + ArrayRef leafNodes, + Value outOfBoundsValue) { + // Recursive helper function to construct the mux tree + std::function constructTreeHelper = + [&](size_t id, size_t level) -> Value { + // Base case: at the lowest level, return the result + if (level == 0) { + // Return the result for the given index. If the index is out of bounds, + // return the out-of-bound value. + return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue; + } + + auto selector = selectors[level - 1]; + + // Recursive case: create muxes for true and false branches + auto trueVal = constructTreeHelper(2 * id + 1, level - 1); + auto falseVal = constructTreeHelper(2 * id, level - 1); + + // Combine the results with a mux + return builder.createOrFold(loc, selector, trueVal, falseVal); + }; + + return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size())); +} + //===----------------------------------------------------------------------===// // ICmpOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/HW/Transforms/CMakeLists.txt b/lib/Dialect/HW/Transforms/CMakeLists.txt index 093632602fa7..f06b5ba5bb13 100644 --- a/lib/Dialect/HW/Transforms/CMakeLists.txt +++ b/lib/Dialect/HW/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_circt_dialect_library(CIRCTHWTransforms + HWAggregateToComb.cpp HWPrintInstanceGraph.cpp HWSpecialize.cpp PrintHWModuleGraph.cpp diff --git a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir index 618ce276e6e6..6cabd5cf32ce 100644 --- a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir @@ -173,3 +173,4 @@ hw.module @shift2(in %lhs: i2, in %rhs: i2, out out_shl: i2, out out_shr: i2, ou // ALLOW_ICMP-NEXT: hw.output %[[L_SHIFT_WITH_BOUND_CHECK]], %[[R_SHIFT_WITH_BOUND_CHECK]], %[[R_SIGNED_SHIFT]] hw.output %0, %1, %2 : i2, i2, i2 } + diff --git a/test/Dialect/HW/hw-aggregate-to-comb.mlir b/test/Dialect/HW/hw-aggregate-to-comb.mlir new file mode 100644 index 000000000000..98c09c151da4 --- /dev/null +++ b/test/Dialect/HW/hw-aggregate-to-comb.mlir @@ -0,0 +1,61 @@ +// RUN: circt-opt %s -hw-aggregate-to-comb | FileCheck %s + + +// CHECK-LABEL: @agg_const +hw.module @agg_const(out out: !hw.array<4xi4>) { + // CHECK: %[[CONST:.+]] = comb.concat %c0_i4, %c1_i4, %c-2_i4, %c-1_i4 : i4, i4, i4, i4 + // CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONST]] : (i16) -> !hw.array<4xi4> + // CHECK-NEXT: hw.output %[[BITCAST]] : !hw.array<4xi4> + %0 = hw.aggregate_constant [0 : i4, 1 : i4, -2 : i4, -1 : i4] : !hw.array<4xi4> + hw.output %0 : !hw.array<4xi4> +} + +// CHECK-LABEL: @array_get_for_port +hw.module @array_get_for_port(in %in: !hw.array<5xi4>, out out: i4) { + %c_i2 = hw.constant 3 : i3 + // CHECK-NEXT: %[[BITCAST_IN:.+]] = hw.bitcast %in : (!hw.array<5xi4>) -> i20 + // CHECK: %[[EXTRACT:.+]] = comb.extract %[[BITCAST_IN]] from 12 : (i20) -> i4 + // CHECK: hw.output %[[EXTRACT]] : i4 + %1 = hw.array_get %in[%c_i2] : !hw.array<5xi4>, i3 + hw.output %1 : i4 +} + +// CHECK-LABEL: @array_concat +hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: i4) { + %0 = hw.array_concat %lhs, %rhs : !hw.array<2xi4>, !hw.array<3xi4> + %c_i2 = hw.constant 3 : i3 + // CHECK-NEXT: %[[BITCAST_RHS:.+]] = hw.bitcast %rhs : (!hw.array<3xi4>) -> i12 + // CHECK-NEXT: %[[BITCAST_LHS:.+]] = hw.bitcast %lhs : (!hw.array<2xi4>) -> i8 + // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[BITCAST_LHS]], %[[BITCAST_RHS]] : i8, i12 + // CHECK: %[[EXTRACT:.+]] = comb.extract %[[CONCAT]] from 12 : (i20) -> i4 + // CHECK: hw.output %[[EXTRACT]] : i4 + %1 = hw.array_get %0[%c_i2] : !hw.array<5xi4>, i3 + hw.output %1 : i4 +} + +hw.module.extern @foo(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>) +// CHECK-LABEL: @array_instance( +hw.module @array_instance(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>) { + // CHECK-NEXT: hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>) + %0 = hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>) + hw.output %0 : !hw.array<4xi2> +} + +// CHECK-LABEL: @array( +hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, out out: !hw.array<4xi2>, in %sel: i2, out out_get: i2) { + %0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2 + %1 = hw.array_get %0[%sel] : !hw.array<4xi2>, i2 + // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %arg0, %arg1, %arg2, %arg3 : i2, i2, i2, i2 + // CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONCAT]] : (i8) -> !hw.array<4xi2> + // CHECK-NEXT: %[[EXTRACT_0:.+]] = comb.extract %[[CONCAT]] from 0 : (i8) -> i2 + // CHECK-NEXT: %[[EXTRACT_2:.+]] = comb.extract %[[CONCAT]] from 2 : (i8) -> i2 + // CHECK-NEXT: %[[EXTRACT_4:.+]] = comb.extract %[[CONCAT]] from 4 : (i8) -> i2 + // CHECK-NEXT: %[[EXTRACT_6:.+]] = comb.extract %[[CONCAT]] from 6 : (i8) -> i2 + // CHECK-NEXT: %[[EXTRACT_SEL:.+]] = comb.extract %sel from 0 + // CHECK-NEXT: %[[EXTRACT_SEL_1:.+]] = comb.extract %sel from 1 + // CHECK-NEXT: %[[MUX_0:.+]] = comb.mux %[[EXTRACT_SEL]], %[[EXTRACT_6]], %[[EXTRACT_4]] + // CHECK-NEXT: %[[MUX_1:.+]] = comb.mux %[[EXTRACT_SEL]], %[[EXTRACT_2]], %[[EXTRACT_0]] + // CHECK-NEXT: %[[MUX_2:.+]] = comb.mux %[[EXTRACT_SEL_1]], %[[MUX_0]], %[[MUX_1]] + // CHECK-NEXT: hw.output %[[BITCAST]], %[[MUX_2]] + hw.output %0, %1 : !hw.array<4xi2>, i2 +} diff --git a/tools/circt-synth/CMakeLists.txt b/tools/circt-synth/CMakeLists.txt index 5ad952b3c81f..59ddad086e84 100644 --- a/tools/circt-synth/CMakeLists.txt +++ b/tools/circt-synth/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries(circt-synth CIRCTComb CIRCTCombToAIG CIRCTHW + CIRCTHWTransforms CIRCTSupport MLIRIR MLIRParser diff --git a/tools/circt-synth/circt-synth.cpp b/tools/circt-synth/circt-synth.cpp index 6f556b71c93a..36c7bed564cc 100644 --- a/tools/circt-synth/circt-synth.cpp +++ b/tools/circt-synth/circt-synth.cpp @@ -18,6 +18,7 @@ #include "circt/Dialect/Comb/CombDialect.h" #include "circt/Dialect/HW/HWDialect.h" #include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWPasses.h" #include "circt/Support/Passes.h" #include "circt/Support/Version.h" #include "mlir/IR/Diagnostics.h" @@ -107,6 +108,7 @@ static void populateSynthesisPipeline(PassManager &pm) { }); auto &mpm = pm.nest(); + mpm.addPass(circt::hw::createHWAggregateToCombPass()); mpm.addPass(circt::createConvertCombToAIG()); mpm.addPass(createCSEPass()); if (untilReached(UntilAIGLowering))