From cc7868780355053c528ebb7ec9c1350d3f179d32 Mon Sep 17 00:00:00 2001 From: Andrew Young Date: Fri, 31 Jan 2025 16:23:32 -0800 Subject: [PATCH] [FIRRTL] LowerXMR: process all modules This changes LowerXMR to process all modules instead of just those that are reachable from the top level module. --- lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp | 114 +++++++++++---------- llvm | 2 +- test/Dialect/FIRRTL/lowerXMR.mlir | 23 +++++ 3 files changed, 83 insertions(+), 56 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp index ca89f47bce49..d71e7456c0cf 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp @@ -300,69 +300,73 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase { SmallVector publicModules; // Traverse the modules in post order. - for (auto node : llvm::post_order(&instanceGraph)) { - auto module = dyn_cast(*node->getModule()); - if (!module) - continue; - LLVM_DEBUG(llvm::dbgs() - << "Traversing module:" << module.getModuleNameAttr() << "\n"); - moduleStates.insert({module, ModuleState(module)}); + DenseSet visited; + for (auto *root : instanceGraph) { + for (auto *node : llvm::post_order_ext(root, visited)) { + auto module = dyn_cast(*node->getModule()); + if (!module) + continue; + LLVM_DEBUG(llvm::dbgs() << "Traversing module:" + << module.getModuleNameAttr() << "\n"); - if (module.isPublic()) - publicModules.push_back(module); + moduleStates.insert({module, ModuleState(module)}); - auto result = module.walk([&](Operation *op) { - if (transferFunc(op).failed()) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); + if (module.isPublic()) + publicModules.push_back(module); - if (result.wasInterrupted()) - return signalPassFailure(); + auto result = module.walk([&](Operation *op) { + if (transferFunc(op).failed()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); - // Clear any enabled layers. - module.setLayersAttr(ArrayAttr::get(module.getContext(), {})); - - // Since we walk operations pre-order and not along dataflow edges, - // ref.sub may not be resolvable when we encounter them (they're not just - // unification). This can happen when refs go through an output port or - // input instance result and back into the design. Handle these by walking - // them, resolving what we can, until all are handled or nothing can be - // resolved. - while (!indexingOps.empty()) { - // Grab the set of unresolved ref.sub's. - decltype(indexingOps) worklist; - worklist.swap(indexingOps); - - for (auto op : worklist) { - auto inputEntry = - getRemoteRefSend(op.getInput(), /*errorIfNotFound=*/false); - // If we can't resolve, add back and move on. - if (!inputEntry) - indexingOps.push_back(op); - else - addReachingSendsEntry(op.getResult(), op.getOperation(), - inputEntry); - } - // If nothing was resolved, give up. - if (worklist.size() == indexingOps.size()) { - auto op = worklist.front(); - getRemoteRefSend(op.getInput()); - op.emitError( - "indexing through probe of unknown origin (input probe?)") - .attachNote(op.getInput().getLoc()) - .append("indexing through this reference"); + if (result.wasInterrupted()) return signalPassFailure(); - } - } - // Record all the RefType ports to be removed later. - size_t numPorts = module.getNumPorts(); - for (size_t portNum = 0; portNum < numPorts; ++portNum) - if (isa(module.getPortType(portNum))) { - setPortToRemove(module, portNum, numPorts); + // Clear any enabled layers. + module.setLayersAttr(ArrayAttr::get(module.getContext(), {})); + + // Since we walk operations pre-order and not along dataflow edges, + // ref.sub may not be resolvable when we encounter them (they're not + // just unification). This can happen when refs go through an output + // port or input instance result and back into the design. Handle these + // by walking them, resolving what we can, until all are handled or + // nothing can be resolved. + while (!indexingOps.empty()) { + // Grab the set of unresolved ref.sub's. + decltype(indexingOps) worklist; + worklist.swap(indexingOps); + + for (auto op : worklist) { + auto inputEntry = + getRemoteRefSend(op.getInput(), /*errorIfNotFound=*/false); + // If we can't resolve, add back and move on. + if (!inputEntry) + indexingOps.push_back(op); + else + addReachingSendsEntry(op.getResult(), op.getOperation(), + inputEntry); + } + // If nothing was resolved, give up. + if (worklist.size() == indexingOps.size()) { + auto op = worklist.front(); + getRemoteRefSend(op.getInput()); + op.emitError( + "indexing through probe of unknown origin (input probe?)") + .attachNote(op.getInput().getLoc()) + .append("indexing through this reference"); + return signalPassFailure(); + } } + + // Record all the RefType ports to be removed later. + size_t numPorts = module.getNumPorts(); + for (size_t portNum = 0; portNum < numPorts; ++portNum) + if (isa(module.getPortType(portNum))) { + setPortToRemove(module, portNum, numPorts); + } + } } LLVM_DEBUG({ diff --git a/llvm b/llvm index aa580c2ec5eb..ebc7efbab5c5 160000 --- a/llvm +++ b/llvm @@ -1 +1 @@ -Subproject commit aa580c2ec5eb4217c945a47a561181be7e7b1032 +Subproject commit ebc7efbab5c58b46f7215d63be6d0208cb588192 diff --git a/test/Dialect/FIRRTL/lowerXMR.mlir b/test/Dialect/FIRRTL/lowerXMR.mlir index 1e55aeb1c69e..b32ec08a3627 100644 --- a/test/Dialect/FIRRTL/lowerXMR.mlir +++ b/test/Dialect/FIRRTL/lowerXMR.mlir @@ -796,3 +796,26 @@ firrtl.circuit "Foo" { } } } + +// ----- +// Test that all modules are reached and updated. + +firrtl.circuit "PF" { + // CHECK: @Child() + firrtl.module @Child(out %p: !firrtl.probe>) { + %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> + %0 = firrtl.ref.send %c1_ui1 : !firrtl.uint<1> + firrtl.ref.define %p, %0 : !firrtl.probe> + } + // CHECK: @PF() + firrtl.module @PF(out %p: !firrtl.probe>) { + %c_p = firrtl.instance c @Child(out p: !firrtl.probe>) + %c_p = firrtl.instance c @Child(out p: !firrtl.probe>) + firrtl.ref.define %p, %c_p : !firrtl.probe> + } + // CHECK: @Other() + firrtl.module @Other(out %p: !firrtl.probe>) { + %c_p = firrtl.instance c @Child(out p: !firrtl.probe>) + firrtl.ref.define %p, %c_p : !firrtl.probe> + } +}