Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HW][circt-synth] Implement AggregateToComb pass and add to circt-synth pipeline #8068

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/circt/Dialect/Comb/CombOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ Value createOrFoldNot(Location loc, Value value, OpBuilder &builder,
Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
bool twoState = false);

/// Extract bits from a value.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl<Value> &bits);

/// Construct a mux tree for given leaf nodes. `selectors` is the selector for
/// each level of the tree. Currently the selector is tested from MSB to LSB.
Value constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors, ArrayRef<Value> leafNodes,
Value outOfBoundsValue);

} // namespace comb
} // namespace circt

Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/HW/HWPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
std::unique_ptr<mlir::Pass> createFooWiresPass();
std::unique_ptr<mlir::Pass> createHWAggregateToCombPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
12 changes: 12 additions & 0 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ def FooWires : Pass<"hw-foo-wires", "hw::HWModuleOp"> {
let constructor = "circt::hw::createFooWiresPass()";
}

def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> {
let summary = "Lower aggregate operations to comb operations";
let constructor = "circt::hw::createHWAggregateToCombPass()";

let description = [{
This pass lowers aggregate *operations* to comb operations within modules.
This pass does not lower ports, as ports are handled by FlattenIO. This pass
will also change the behavior of out-of-bounds access of arrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mention here how it changes the behavior? If I understand correctly it's just a refinement, so we're all good and wouldn't even need to mention it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mention here how it changes the behavior?

Yes. I'll do.

If I understand correctly it's just a refinement

Yes, it's refinement so it should be fine from the semantics perspective but i would at least mention in the doc. Maybe in the future we might want to run post-synthesis simulation on arcilator but only pre-synthesis IR can catch out-of-bounds access in this case.

}];
let dependentDialects = ["comb::CombDialect"];
}

#endif // CIRCT_DIALECT_HW_PASSES_TD
18 changes: 17 additions & 1 deletion integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-opt %s --hw-aggregate-to-comb --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL
// COMB_BIT_LOGICAL: c1 == c2
hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32,
Expand Down Expand Up @@ -78,3 +78,19 @@ hw.module @shift5(in %lhs: i5, in %rhs: i5, out out_shl: i5, out out_shr: i5, ou
%2 = comb.shrs %lhs, %rhs : i5
hw.output %0, %1, %2 : i5, i5, i5
}

// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY
// COMB_ARRAY: c1 == c2
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2
%2 = hw.array_create %arg0, %arg1, %arg2 : i2
%c3_i2 = hw.constant 3 : i2
// NOTE: If the index is out of bounds, the result value is undefined.
// In LEC such value is lowered into unbounded SMT variable and cause
// the LEC to fail. So just asssume that the index is in bounds.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for these kind of things, a translation validation tool would be quite nice 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah without circt-lec it has been almost impossible to implement lowering pattern with confidence for the correctness. Thank you @maerhart @frog-in-the-well @TaoBi22 for outstanding work!

%inbound = comb.icmp ult %sel2, %c3_i2 : i2
verif.assume %inbound : i1
%3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2
hw.output %1, %3 : i2, i2
}
71 changes: 16 additions & 55 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,64 +29,13 @@ using namespace comb;
// Utility Functions
//===----------------------------------------------------------------------===//

// Extract individual bits from a value
static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
Value val) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
// A wrapper for comb::extractBits that returns a SmallVector<Value>.
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
SmallVector<Value> bits;
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return bits;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));

comb::extractBits(builder, val, bits);
return bits;
}

// Construct a mux tree for given leaf nodes. `selectors` is the selector for
// each level of the tree. Currently the selector is tested from MSB to LSB.
static Value constructMuxTree(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> selectors,
ArrayRef<Value> leafNodes,
Value outOfBoundsValue) {
// Recursive helper function to construct the mux tree
std::function<Value(size_t, size_t)> constructTreeHelper =
[&](size_t id, size_t level) -> Value {
// Base case: at the lowest level, return the result
if (level == 0) {
// Return the result for the given index. If the index is out of bounds,
// return the out-of-bound value.
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
}

auto selector = selectors[level - 1];

// Recursive case: create muxes for true and false branches
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
auto falseVal = constructTreeHelper(2 * id, level - 1);

// Combine the results with a mux
return rewriter.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
};

return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

// Construct a mux tree for shift operations. `isLeftShift` controls the
// direction of the shift operation and is used to determine order of the
// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
Expand Down Expand Up @@ -128,7 +77,8 @@ static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
assert(outOfBoundsValue && "outOfBoundsValue must be valid");

// Construct mux tree for shift operation
auto result = constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
auto result =
comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);

// Add bounds checking
auto inBound = rewriter.createOrFold<comb::ICmpOp>(
Expand Down Expand Up @@ -667,10 +617,21 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());

// Comb is source dialect.
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp, hw::ConstantOp>();

// Treat array operations as illegal. Strictly speaking, other than array get
// operation with non-const index are legal in AIG but array types prevent a
// bunch of optimizations so just lower them to integer operations. It's
// required to run HWAggregateToComb pass before this pass.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();

// AIG is target dialect.
target.addLegalDialect<aig::AIGDialect>();

// This is a test only option to add logical ops.
Expand Down
55 changes: 55 additions & 0 deletions lib/Dialect/Comb/CombOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,61 @@ Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
return createOrFoldNot(builder.getLoc(), value, builder, twoState);
}

// Extract individual bits from a value
void comb::extractBits(OpBuilder &builder, Value val,
SmallVectorImpl<Value> &bits) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
}

// Construct a mux tree for given leaf nodes. `selectors` is the selector for
// each level of the tree. Currently the selector is tested from MSB to LSB.
Value comb::constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors,
ArrayRef<Value> leafNodes,
Value outOfBoundsValue) {
// Recursive helper function to construct the mux tree
std::function<Value(size_t, size_t)> constructTreeHelper =
[&](size_t id, size_t level) -> Value {
// Base case: at the lowest level, return the result
if (level == 0) {
// Return the result for the given index. If the index is out of bounds,
// return the out-of-bound value.
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
}

auto selector = selectors[level - 1];

// Recursive case: create muxes for true and false branches
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
auto falseVal = constructTreeHelper(2 * id, level - 1);

// Combine the results with a mux
return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
};

return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/HW/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTHWTransforms
HWAggregateToComb.cpp
HWPrintInstanceGraph.cpp
HWSpecialize.cpp
PrintHWModuleGraph.cpp
Expand Down
Loading
Loading