Skip to content

Commit

Permalink
[HW][circt-synth] Implement AggregateToComb pass and add to circt-syinth
Browse files Browse the repository at this point in the history
pipeline
  • Loading branch information
uenoku committed Jan 15, 2025
1 parent 2d87d21 commit b43e6f8
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 56 deletions.
9 changes: 9 additions & 0 deletions include/circt/Dialect/Comb/CombOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ Value createOrFoldNot(Location loc, Value value, OpBuilder &builder,
Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
bool twoState = false);

/// Extract bits from a value.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl<Value> &bits);

/// Construct a mux tree for given leaf nodes. `selectors` is the selector for
/// each level of the tree. Currently the selector is tested from MSB to LSB.
Value constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors, ArrayRef<Value> leafNodes,
Value outOfBoundsValue);

} // namespace comb
} // namespace circt

Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/HW/HWPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
std::unique_ptr<mlir::Pass> createFooWiresPass();
std::unique_ptr<mlir::Pass> createHWAggregateToCombPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
12 changes: 12 additions & 0 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ def FooWires : Pass<"hw-foo-wires", "hw::HWModuleOp"> {
let constructor = "circt::hw::createFooWiresPass()";
}

def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> {
let summary = "Lower aggregate operations to comb operations";
let constructor = "circt::hw::createHWAggregateToCombPass()";

let description = [{
This pass lowers aggregate *operations* to comb operations within modules.
This pass does not lower ports, as ports are handled by FlattenIO. This pass
will also change the behavior of out-of-bounds access of arrays.
}];
let dependentDialects = ["comb::CombDialect"];
}

#endif // CIRCT_DIALECT_HW_PASSES_TD
18 changes: 17 additions & 1 deletion integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-opt %s --hw-aggregate-to-comb --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL
// COMB_BIT_LOGICAL: c1 == c2
hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32,
Expand Down Expand Up @@ -78,3 +78,19 @@ hw.module @shift5(in %lhs: i5, in %rhs: i5, out out_shl: i5, out out_shr: i5, ou
%2 = comb.shrs %lhs, %rhs : i5
hw.output %0, %1, %2 : i5, i5, i5
}

// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY
// COMB_ARRAY: c1 == c2
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2
%2 = hw.array_create %arg0, %arg1, %arg2 : i2
%c3_i2 = hw.constant 3 : i2
// NOTE: If the index is out of bounds, the result value is undefined.
// In LEC such value is lowered into unbounded SMT variable and cause
// the LEC to fail. So just asssume that the index is in bounds.
%inbound = comb.icmp ult %sel2, %c3_i2 : i2
verif.assume %inbound : i1
%3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2
hw.output %1, %3 : i2, i2
}
71 changes: 16 additions & 55 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,64 +29,13 @@ using namespace comb;
// Utility Functions
//===----------------------------------------------------------------------===//

// Extract individual bits from a value
static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
Value val) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
// A wrapper for comb::extractBits that returns a SmallVector<Value>.
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
SmallVector<Value> bits;
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return bits;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));

comb::extractBits(builder, val, bits);
return bits;
}

// Construct a mux tree for given leaf nodes. `selectors` is the selector for
// each level of the tree. Currently the selector is tested from MSB to LSB.
static Value constructMuxTree(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> selectors,
ArrayRef<Value> leafNodes,
Value outOfBoundsValue) {
// Recursive helper function to construct the mux tree
std::function<Value(size_t, size_t)> constructTreeHelper =
[&](size_t id, size_t level) -> Value {
// Base case: at the lowest level, return the result
if (level == 0) {
// Return the result for the given index. If the index is out of bounds,
// return the out-of-bound value.
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
}

auto selector = selectors[level - 1];

// Recursive case: create muxes for true and false branches
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
auto falseVal = constructTreeHelper(2 * id, level - 1);

// Combine the results with a mux
return rewriter.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
};

return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

// Construct a mux tree for shift operations. `isLeftShift` controls the
// direction of the shift operation and is used to determine order of the
// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
Expand Down Expand Up @@ -128,7 +77,8 @@ static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
assert(outOfBoundsValue && "outOfBoundsValue must be valid");

// Construct mux tree for shift operation
auto result = constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
auto result =
comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);

// Add bounds checking
auto inBound = rewriter.createOrFold<comb::ICmpOp>(
Expand Down Expand Up @@ -667,10 +617,21 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());

// Comb is source dialect.
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp, hw::ConstantOp>();

// Treat array operations as illegal. Strictly speaking, other than array get
// operation with non-const index are legal in AIG but array types prevent a
// bunch of optimizations so just lower them to integer operations. It's
// required to run HWAggregateToComb pass before this pass.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();

// AIG is target dialect.
target.addLegalDialect<aig::AIGDialect>();

// This is a test only option to add logical ops.
Expand Down
55 changes: 55 additions & 0 deletions lib/Dialect/Comb/CombOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,61 @@ Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
return createOrFoldNot(builder.getLoc(), value, builder, twoState);
}

// Extract individual bits from a value
void comb::extractBits(OpBuilder &builder, Value val,
SmallVectorImpl<Value> &bits) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
}

// Construct a mux tree for given leaf nodes. `selectors` is the selector for
// each level of the tree. Currently the selector is tested from MSB to LSB.
Value comb::constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors,
ArrayRef<Value> leafNodes,
Value outOfBoundsValue) {
// Recursive helper function to construct the mux tree
std::function<Value(size_t, size_t)> constructTreeHelper =
[&](size_t id, size_t level) -> Value {
// Base case: at the lowest level, return the result
if (level == 0) {
// Return the result for the given index. If the index is out of bounds,
// return the out-of-bound value.
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
}

auto selector = selectors[level - 1];

// Recursive case: create muxes for true and false branches
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
auto falseVal = constructTreeHelper(2 * id, level - 1);

// Combine the results with a mux
return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
};

return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/HW/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTHWTransforms
HWAggregateToComb.cpp
HWPrintInstanceGraph.cpp
HWSpecialize.cpp
PrintHWModuleGraph.cpp
Expand Down
1 change: 1 addition & 0 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,4 @@ hw.module @shift2(in %lhs: i2, in %rhs: i2, out out_shl: i2, out out_shr: i2, ou
// ALLOW_ICMP-NEXT: hw.output %[[L_SHIFT_WITH_BOUND_CHECK]], %[[R_SHIFT_WITH_BOUND_CHECK]], %[[R_SIGNED_SHIFT]]
hw.output %0, %1, %2 : i2, i2, i2
}

61 changes: 61 additions & 0 deletions test/Dialect/HW/hw-aggregate-to-comb.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: circt-opt %s -hw-aggregate-to-comb | FileCheck %s


// CHECK-LABEL: @agg_const
hw.module @agg_const(out out: !hw.array<4xi4>) {
// CHECK: %[[CONST:.+]] = comb.concat %c0_i4, %c1_i4, %c-2_i4, %c-1_i4 : i4, i4, i4, i4
// CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONST]] : (i16) -> !hw.array<4xi4>
// CHECK-NEXT: hw.output %[[BITCAST]] : !hw.array<4xi4>
%0 = hw.aggregate_constant [0 : i4, 1 : i4, -2 : i4, -1 : i4] : !hw.array<4xi4>
hw.output %0 : !hw.array<4xi4>
}

// CHECK-LABEL: @array_get_for_port
hw.module @array_get_for_port(in %in: !hw.array<5xi4>, out out: i4) {
%c_i2 = hw.constant 3 : i3
// CHECK-NEXT: %[[BITCAST_IN:.+]] = hw.bitcast %in : (!hw.array<5xi4>) -> i20
// CHECK: %[[EXTRACT:.+]] = comb.extract %[[BITCAST_IN]] from 12 : (i20) -> i4
// CHECK: hw.output %[[EXTRACT]] : i4
%1 = hw.array_get %in[%c_i2] : !hw.array<5xi4>, i3
hw.output %1 : i4
}

// CHECK-LABEL: @array_concat
hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: i4) {
%0 = hw.array_concat %lhs, %rhs : !hw.array<2xi4>, !hw.array<3xi4>
%c_i2 = hw.constant 3 : i3
// CHECK-NEXT: %[[BITCAST_RHS:.+]] = hw.bitcast %rhs : (!hw.array<3xi4>) -> i12
// CHECK-NEXT: %[[BITCAST_LHS:.+]] = hw.bitcast %lhs : (!hw.array<2xi4>) -> i8
// CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[BITCAST_LHS]], %[[BITCAST_RHS]] : i8, i12
// CHECK: %[[EXTRACT:.+]] = comb.extract %[[CONCAT]] from 12 : (i20) -> i4
// CHECK: hw.output %[[EXTRACT]] : i4
%1 = hw.array_get %0[%c_i2] : !hw.array<5xi4>, i3
hw.output %1 : i4
}

hw.module.extern @foo(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>)
// CHECK-LABEL: @array_instance(
hw.module @array_instance(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>) {
// CHECK-NEXT: hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>)
%0 = hw.instance "foo" @foo(in: %in: !hw.array<4xi2>) -> (out: !hw.array<4xi2>)
hw.output %0 : !hw.array<4xi2>
}

// CHECK-LABEL: @array(
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, out out: !hw.array<4xi2>, in %sel: i2, out out_get: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel] : !hw.array<4xi2>, i2
// CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %arg0, %arg1, %arg2, %arg3 : i2, i2, i2, i2
// CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONCAT]] : (i8) -> !hw.array<4xi2>
// CHECK-NEXT: %[[EXTRACT_0:.+]] = comb.extract %[[CONCAT]] from 0 : (i8) -> i2
// CHECK-NEXT: %[[EXTRACT_2:.+]] = comb.extract %[[CONCAT]] from 2 : (i8) -> i2
// CHECK-NEXT: %[[EXTRACT_4:.+]] = comb.extract %[[CONCAT]] from 4 : (i8) -> i2
// CHECK-NEXT: %[[EXTRACT_6:.+]] = comb.extract %[[CONCAT]] from 6 : (i8) -> i2
// CHECK-NEXT: %[[EXTRACT_SEL:.+]] = comb.extract %sel from 0
// CHECK-NEXT: %[[EXTRACT_SEL_1:.+]] = comb.extract %sel from 1
// CHECK-NEXT: %[[MUX_0:.+]] = comb.mux %[[EXTRACT_SEL]], %[[EXTRACT_6]], %[[EXTRACT_4]]
// CHECK-NEXT: %[[MUX_1:.+]] = comb.mux %[[EXTRACT_SEL]], %[[EXTRACT_2]], %[[EXTRACT_0]]
// CHECK-NEXT: %[[MUX_2:.+]] = comb.mux %[[EXTRACT_SEL_1]], %[[MUX_0]], %[[MUX_1]]
// CHECK-NEXT: hw.output %[[BITCAST]], %[[MUX_2]]
hw.output %0, %1 : !hw.array<4xi2>, i2
}
1 change: 1 addition & 0 deletions tools/circt-synth/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ target_link_libraries(circt-synth
CIRCTComb
CIRCTCombToAIG
CIRCTHW
CIRCTHWTransforms
CIRCTSupport
MLIRIR
MLIRParser
Expand Down
2 changes: 2 additions & 0 deletions tools/circt-synth/circt-synth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "circt/Dialect/Comb/CombDialect.h"
#include "circt/Dialect/HW/HWDialect.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWPasses.h"
#include "circt/Support/Passes.h"
#include "circt/Support/Version.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ static void populateSynthesisPipeline(PassManager &pm) {
});

auto &mpm = pm.nest<hw::HWModuleOp>();
mpm.addPass(circt::hw::createHWAggregateToCombPass());
mpm.addPass(circt::createConvertCombToAIG());
mpm.addPass(createCSEPass());
if (untilReached(UntilAIGLowering))
Expand Down

0 comments on commit b43e6f8

Please sign in to comment.