Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
uenoku committed Jan 18, 2025
1 parent 547b1c6 commit c5cff39
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
4 changes: 2 additions & 2 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> {

let description = [{
This pass lowers aggregate *operations* to comb operations within modules.
This pass does not lower ports, as ports are handled by FlattenIO. This pass
will also change the behavior of out-of-bounds access of arrays.
Note that this pass does not lower ports. Ports lowering is handled
by FlattenIO.
}];
let dependentDialects = ["comb::CombDialect"];
}
Expand Down
31 changes: 16 additions & 15 deletions lib/Dialect/HW/Transforms/HWAggregateToComb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "circt/Dialect/HW/HWPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APInt.h"

namespace circt {
namespace hw {
Expand All @@ -23,19 +24,16 @@ using namespace mlir;
using namespace circt;

namespace {

// Lower hw.array_create and hw.array_concat to comb.concat.
template <typename OpTy>
struct HWArrayCreateLikeOpConversion : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Lower to concat.
auto inputs = adaptor.getInputs();
SmallVector<Value> results;
for (auto input : inputs)
results.push_back(rewriter.getRemappedValue(input));
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, adaptor.getInputs());
return success();
}
};
Expand All @@ -46,9 +44,10 @@ struct HWAggregateConstantOpConversion

static LogicalResult peelAttribute(Location loc, Attribute attr,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &results) {
APInt &intVal) {
SmallVector<Attribute> worklist;
worklist.push_back(attr);
unsigned nextInsertion = intVal.getBitWidth();

while (!worklist.empty()) {
auto current = worklist.pop_back_val();
Expand All @@ -59,7 +58,9 @@ struct HWAggregateConstantOpConversion
}

if (auto intAttr = dyn_cast<IntegerAttr>(current)) {
results.push_back(rewriter.create<hw::ConstantOp>(loc, intAttr));
auto chunk = intAttr.getValue();
nextInsertion -= chunk.getBitWidth();
intVal.insertBits(chunk, nextInsertion);
continue;
}

Expand All @@ -74,10 +75,13 @@ struct HWAggregateConstantOpConversion
ConversionPatternRewriter &rewriter) const override {
// Lower to concat.
SmallVector<Value> results;
auto bitWidth = hw::getBitWidth(op.getType());
assert(bitWidth >= 0 && "bit width must be known for constant");
APInt intVal(bitWidth, 0);
if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter,
results)))
intVal)))
return failure();
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, intVal);
return success();
}
};
Expand All @@ -96,10 +100,7 @@ struct HWArrayGetOpConversion : OpConversionPattern<hw::ArrayGetOp> {
if (elemWidth < 0)
return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width");

auto lowered = rewriter.getRemappedValue(op.getInput());
if (!lowered)
return failure();

auto lowered = adaptor.getInput();
for (size_t i = 0; i < numElements; ++i)
results.push_back(rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), lowered, i * elemWidth, elemWidth));
Expand Down Expand Up @@ -166,7 +167,7 @@ struct HWAggregateToCombPass
void HWAggregateToCombPass::runOnOperation() {
ConversionTarget target(getContext());

// TODO: Add structs.
// TODO: Add ArraySliceOp and struct operatons as well.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();

Expand Down
12 changes: 5 additions & 7 deletions test/Dialect/HW/hw-aggregate-to-comb.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

// CHECK-LABEL: @agg_const
hw.module @agg_const(out out: !hw.array<4xi4>) {
// CHECK: %[[CONST:.+]] = comb.concat %c0_i4, %c1_i4, %c-2_i4, %c-1_i4 : i4, i4, i4, i4
// CHECK: %[[CONST:.+]] = hw.constant 495 : i16
// CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONST]] : (i16) -> !hw.array<4xi4>
// CHECK-NEXT: hw.output %[[BITCAST]] : !hw.array<4xi4>
%0 = hw.aggregate_constant [0 : i4, 1 : i4, -2 : i4, -1 : i4] : !hw.array<4xi4>
Expand All @@ -21,16 +21,14 @@ hw.module @array_get_for_port(in %in: !hw.array<5xi4>, out out: i4) {
}

// CHECK-LABEL: @array_concat
hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: i4) {
hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: !hw.array<5xi4>) {
%0 = hw.array_concat %lhs, %rhs : !hw.array<2xi4>, !hw.array<3xi4>
%c_i2 = hw.constant 3 : i3
// CHECK-NEXT: %[[BITCAST_RHS:.+]] = hw.bitcast %rhs : (!hw.array<3xi4>) -> i12
// CHECK-NEXT: %[[BITCAST_LHS:.+]] = hw.bitcast %lhs : (!hw.array<2xi4>) -> i8
// CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[BITCAST_LHS]], %[[BITCAST_RHS]] : i8, i12
// CHECK: %[[EXTRACT:.+]] = comb.extract %[[CONCAT]] from 12 : (i20) -> i4
// CHECK: hw.output %[[EXTRACT]] : i4
%1 = hw.array_get %0[%c_i2] : !hw.array<5xi4>, i3
hw.output %1 : i4
// CHECK-NEXT: %[[BITCAST_OUT:.+]] = hw.bitcast %[[CONCAT]] : (i20) -> !hw.array<5xi4>
// CHECK: hw.output %[[BITCAST_OUT]]
hw.output %0 : !hw.array<5xi4>
}

hw.module.extern @foo(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>)
Expand Down

0 comments on commit c5cff39

Please sign in to comment.