From c5cff395e74cd38089d418dc4c5a9c44a4894230 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Sat, 18 Jan 2025 00:03:17 -0800 Subject: [PATCH] Address comments --- include/circt/Dialect/HW/Passes.td | 4 +-- .../HW/Transforms/HWAggregateToComb.cpp | 31 ++++++++++--------- test/Dialect/HW/hw-aggregate-to-comb.mlir | 12 +++---- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/include/circt/Dialect/HW/Passes.td b/include/circt/Dialect/HW/Passes.td index adab3b271f2d..223f864564e1 100644 --- a/include/circt/Dialect/HW/Passes.td +++ b/include/circt/Dialect/HW/Passes.td @@ -93,8 +93,8 @@ def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> { let description = [{ This pass lowers aggregate *operations* to comb operations within modules. - This pass does not lower ports, as ports are handled by FlattenIO. This pass - will also change the behavior of out-of-bounds access of arrays. + Note that this pass does not lower ports. Ports lowering is handled + by FlattenIO. }]; let dependentDialects = ["comb::CombDialect"]; } diff --git a/lib/Dialect/HW/Transforms/HWAggregateToComb.cpp b/lib/Dialect/HW/Transforms/HWAggregateToComb.cpp index 8a2583d9f16a..b579536d564d 100644 --- a/lib/Dialect/HW/Transforms/HWAggregateToComb.cpp +++ b/lib/Dialect/HW/Transforms/HWAggregateToComb.cpp @@ -11,6 +11,7 @@ #include "circt/Dialect/HW/HWPasses.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" namespace circt { namespace hw { @@ -23,6 +24,8 @@ using namespace mlir; using namespace circt; namespace { + +// Lower hw.array_create and hw.array_concat to comb.concat. template struct HWArrayCreateLikeOpConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -30,12 +33,7 @@ struct HWArrayCreateLikeOpConversion : OpConversionPattern { LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Lower to concat. - auto inputs = adaptor.getInputs(); - SmallVector results; - for (auto input : inputs) - results.push_back(rewriter.getRemappedValue(input)); - rewriter.replaceOpWithNewOp(op, results); + rewriter.replaceOpWithNewOp(op, adaptor.getInputs()); return success(); } }; @@ -46,9 +44,10 @@ struct HWAggregateConstantOpConversion static LogicalResult peelAttribute(Location loc, Attribute attr, ConversionPatternRewriter &rewriter, - SmallVector &results) { + APInt &intVal) { SmallVector worklist; worklist.push_back(attr); + unsigned nextInsertion = intVal.getBitWidth(); while (!worklist.empty()) { auto current = worklist.pop_back_val(); @@ -59,7 +58,9 @@ struct HWAggregateConstantOpConversion } if (auto intAttr = dyn_cast(current)) { - results.push_back(rewriter.create(loc, intAttr)); + auto chunk = intAttr.getValue(); + nextInsertion -= chunk.getBitWidth(); + intVal.insertBits(chunk, nextInsertion); continue; } @@ -74,10 +75,13 @@ struct HWAggregateConstantOpConversion ConversionPatternRewriter &rewriter) const override { // Lower to concat. SmallVector results; + auto bitWidth = hw::getBitWidth(op.getType()); + assert(bitWidth >= 0 && "bit width must be known for constant"); + APInt intVal(bitWidth, 0); if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter, - results))) + intVal))) return failure(); - rewriter.replaceOpWithNewOp(op, results); + rewriter.replaceOpWithNewOp(op, intVal); return success(); } }; @@ -96,10 +100,7 @@ struct HWArrayGetOpConversion : OpConversionPattern { if (elemWidth < 0) return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width"); - auto lowered = rewriter.getRemappedValue(op.getInput()); - if (!lowered) - return failure(); - + auto lowered = adaptor.getInput(); for (size_t i = 0; i < numElements; ++i) results.push_back(rewriter.createOrFold( op.getLoc(), lowered, i * elemWidth, elemWidth)); @@ -166,7 +167,7 @@ struct HWAggregateToCombPass void HWAggregateToCombPass::runOnOperation() { ConversionTarget target(getContext()); - // TODO: Add structs. + // TODO: Add ArraySliceOp and struct operatons as well. target.addIllegalOp(); diff --git a/test/Dialect/HW/hw-aggregate-to-comb.mlir b/test/Dialect/HW/hw-aggregate-to-comb.mlir index 98c09c151da4..ff89ff7cf694 100644 --- a/test/Dialect/HW/hw-aggregate-to-comb.mlir +++ b/test/Dialect/HW/hw-aggregate-to-comb.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @agg_const hw.module @agg_const(out out: !hw.array<4xi4>) { - // CHECK: %[[CONST:.+]] = comb.concat %c0_i4, %c1_i4, %c-2_i4, %c-1_i4 : i4, i4, i4, i4 + // CHECK: %[[CONST:.+]] = hw.constant 495 : i16 // CHECK-NEXT: %[[BITCAST:.+]] = hw.bitcast %[[CONST]] : (i16) -> !hw.array<4xi4> // CHECK-NEXT: hw.output %[[BITCAST]] : !hw.array<4xi4> %0 = hw.aggregate_constant [0 : i4, 1 : i4, -2 : i4, -1 : i4] : !hw.array<4xi4> @@ -21,16 +21,14 @@ hw.module @array_get_for_port(in %in: !hw.array<5xi4>, out out: i4) { } // CHECK-LABEL: @array_concat -hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: i4) { +hw.module @array_concat(in %lhs: !hw.array<2xi4>, in %rhs: !hw.array<3xi4>, out out: !hw.array<5xi4>) { %0 = hw.array_concat %lhs, %rhs : !hw.array<2xi4>, !hw.array<3xi4> - %c_i2 = hw.constant 3 : i3 // CHECK-NEXT: %[[BITCAST_RHS:.+]] = hw.bitcast %rhs : (!hw.array<3xi4>) -> i12 // CHECK-NEXT: %[[BITCAST_LHS:.+]] = hw.bitcast %lhs : (!hw.array<2xi4>) -> i8 // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[BITCAST_LHS]], %[[BITCAST_RHS]] : i8, i12 - // CHECK: %[[EXTRACT:.+]] = comb.extract %[[CONCAT]] from 12 : (i20) -> i4 - // CHECK: hw.output %[[EXTRACT]] : i4 - %1 = hw.array_get %0[%c_i2] : !hw.array<5xi4>, i3 - hw.output %1 : i4 + // CHECK-NEXT: %[[BITCAST_OUT:.+]] = hw.bitcast %[[CONCAT]] : (i20) -> !hw.array<5xi4> + // CHECK: hw.output %[[BITCAST_OUT]] + hw.output %0 : !hw.array<5xi4> } hw.module.extern @foo(in %in: !hw.array<4xi2>, out out: !hw.array<4xi2>)