Skip to content

Commit

Permalink
Add a proc-scoped channel tests for scheduling and optimization pipel…
Browse files Browse the repository at this point in the history
…ine.

These are basic checks that the tests do not crash the optimizer/scheduler.

PiperOrigin-RevId: 716802371
  • Loading branch information
meheffernan authored and copybara-github committed Jan 17, 2025
1 parent 61dc280 commit 0a21a7c
Show file tree
Hide file tree
Showing 14 changed files with 230 additions and 27 deletions.
5 changes: 5 additions & 0 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,11 @@ absl::Status Proc::RemoveInterfaceChannel(ChannelReference* channel_ref) {
return absl::OkStatus();
}

bool Proc::IsInterfaceChannel(ChannelReference* channel_ref) {
return std::find(interface_.begin(), interface_.end(), channel_ref) !=
interface_.end();
}

absl::StatusOr<ProcInstantiation*> Proc::AddProcInstantiation(
std::string_view name, absl::Span<ChannelReference* const> channel_args,
Proc* proc) {
Expand Down
4 changes: 4 additions & 0 deletions xls/ir/proc.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ class Proc : public FunctionBase {
// than `channel_ref` in the interface are shifted down.
absl::Status RemoveInterfaceChannel(ChannelReference* channel_ref);

// Returns true if the given ChannelReference refers to an element of the
// interface of the proc.
bool IsInterfaceChannel(ChannelReference* channel_ref);

// Add an input/output channel to the interface of the proc. Only can be
// called for new style procs.
absl::StatusOr<ReceiveChannelReference*> AddInputChannelReference(
Expand Down
10 changes: 8 additions & 2 deletions xls/ir/proc_elaboration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,11 @@ absl::Status ProcElaboration::BuildInstanceMaps(ProcInstance* proc_instance) {
std::make_unique<ChannelInstance>(ChannelInstance{
.channel = elaboration.interface_channels_.back().get(),
.path = std::nullopt}));
ChannelInstance* channel_instance =
elaboration.interface_channel_instances_.back().get();
elaboration.interface_channel_instance_set_.insert(channel_instance);
interface_bindings.push_back(ChannelBinding{
.instance = elaboration.interface_channel_instances_.back().get(),
.parent_reference = std::nullopt});
.instance = channel_instance, .parent_reference = std::nullopt});
}
XLS_ASSIGN_OR_RETURN(
elaboration.top_,
Expand Down Expand Up @@ -437,6 +439,10 @@ ProcElaboration::GetInstancesOfChannelReference(
return instances_of_channel_reference_.at(channel_reference);
}

bool ProcElaboration::IsTopInterfaceChannel(ChannelInstance* channel) const {
return interface_channel_instance_set_.contains(channel);
}

absl::StatusOr<ProcInstance*> ProcElaboration::GetUniqueInstance(
Proc* proc) const {
absl::Span<ProcInstance* const> instances = GetInstances(proc);
Expand Down
6 changes: 6 additions & 0 deletions xls/ir/proc_elaboration.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -253,6 +254,10 @@ class ProcElaboration {
absl::Span<ChannelInstance* const> GetInstancesOfChannelReference(
ChannelReference* channel_reference) const;

// Returns whether the given channel reference binds to a channel on the top
// interface.
bool IsTopInterfaceChannel(ChannelInstance* channel) const;

// Return the unique instance of the given proc/channel. Returns an error if
// there is not exactly one instance associated with the IR object.
absl::StatusOr<ProcInstance*> GetUniqueInstance(Proc* proc) const;
Expand Down Expand Up @@ -304,6 +309,7 @@ class ProcElaboration {

// Channel instances for the interface channels.
std::vector<std::unique_ptr<ChannelInstance>> interface_channel_instances_;
absl::flat_hash_set<ChannelInstance*> interface_channel_instance_set_;

// All proc instances in the elaboration indexed by instantiation path.
absl::flat_hash_map<ProcInstantiationPath, ProcInstance*>
Expand Down
8 changes: 8 additions & 0 deletions xls/ir/proc_elaboration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,24 @@ TEST_F(ElaborationTest, ProcInstantiatingProc) {
EXPECT_THAT(elab.GetChannelInstance("the_ch", "top_proc"),
IsOkAndHolds(elab.top()->channels().front().get()));

EXPECT_TRUE(elab.IsTopInterfaceChannel(
elab.GetChannelInstance("in_ch", "top_proc").value()));
EXPECT_FALSE(elab.IsTopInterfaceChannel(
elab.GetChannelInstance("the_ch", "top_proc").value()));

ProcInstance* leaf_instance = elab.top()->instantiated_procs().front().get();
EXPECT_THAT(elab.GetProcInstance("top_proc::leaf_inst->leaf"),
IsOkAndHolds(leaf_instance));

XLS_ASSERT_OK_AND_ASSIGN(ChannelInstance * leaf_ch0_instance,
leaf_instance->GetChannelInstance("leaf_ch0"));
EXPECT_EQ(leaf_ch0_instance->channel->name(), "the_ch");
EXPECT_FALSE(elab.IsTopInterfaceChannel(leaf_ch0_instance));

XLS_ASSERT_OK_AND_ASSIGN(ChannelInstance * leaf_ch1_instance,
leaf_instance->GetChannelInstance("leaf_ch1"));
EXPECT_EQ(leaf_ch1_instance->channel->name(), "in_ch");
EXPECT_TRUE(elab.IsTopInterfaceChannel(leaf_ch1_instance));

EXPECT_THAT(
elab.GetChannelInstance("leaf_ch0", leaf_instance->path().value()),
Expand Down
2 changes: 2 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,12 @@ cc_test(
"//xls/examples:sample_packages",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:channel",
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/ir:op",
"//xls/ir:value",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest",
Expand Down
49 changes: 49 additions & 0 deletions xls/passes/optimization_pass_pipeline_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
#include "xls/common/status/matchers.h"
#include "xls/examples/sample_packages.h"
#include "xls/ir/bits.h"
#include "xls/ir/channel.h"
#include "xls/ir/function.h"
#include "xls/ir/function_builder.h"
#include "xls/ir/ir_matcher.h"
#include "xls/ir/ir_test_base.h"
#include "xls/ir/nodes.h"
#include "xls/ir/op.h"
#include "xls/ir/package.h"
#include "xls/ir/value.h"
#include "xls/passes/optimization_pass.h"

namespace m = ::xls::op_matchers;
Expand Down Expand Up @@ -276,5 +278,52 @@ TEST_F(OptimizationPipelineTest, LogicCombining) {
EXPECT_THAT(f->return_value(), m::Param("x"));
}

TEST_F(OptimizationPipelineTest, ProcScopedChannels) {
auto p = CreatePackage();

// Create leaf proc which adds one to its input.
Proc* leaf;
{
TokenlessProcBuilder pb(NewStyleProc(), "myleaf", "tkn", p.get());
XLS_ASSERT_OK_AND_ASSIGN(ReceiveChannelReference * in,
pb.AddInputChannel("in", p->GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(SendChannelReference * out,
pb.AddOutputChannel("out", p->GetBitsType(32)));

// Create an optimization opportunity (constant folding).
BValue one = pb.Add(pb.Literal(UBits(0, 32)), pb.Literal(UBits(1, 32)));

pb.Send(out, pb.Add(pb.Receive(in), one));
XLS_ASSERT_OK_AND_ASSIGN(leaf, pb.Build());
}

// Create a top proc which instantiates two leaf procs and sends an input
// value through the chain, accumulates it and then sends to the output.
Proc* top;
{
TokenlessProcBuilder pb(NewStyleProc(), "myproc", "tkn", p.get());
XLS_ASSERT_OK_AND_ASSIGN(ReceiveChannelReference * in,
pb.AddInputChannel("in", p->GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(SendChannelReference * out,
pb.AddOutputChannel("out", p->GetBitsType(32)));

XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences tmp0_ch,
pb.AddChannel("tmp0", p->GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences tmp1_ch,
pb.AddChannel("tmp1", p->GetBitsType(32)));

BValue accum = pb.StateElement("accum", Value(UBits(0, 32)));
XLS_ASSERT_OK(pb.InstantiateProc("inst0", leaf, {in, tmp0_ch.send_ref}));
XLS_ASSERT_OK(pb.InstantiateProc("inst1", leaf,
{tmp0_ch.receive_ref, tmp1_ch.send_ref}));
BValue next_accum = pb.Add(pb.Receive(tmp1_ch.receive_ref), accum);
pb.Send(out, next_accum);
XLS_ASSERT_OK(pb.SetAsTop());
XLS_ASSERT_OK_AND_ASSIGN(top, pb.Build({next_accum}));
}

ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
}

} // namespace
} // namespace xls
3 changes: 3 additions & 0 deletions xls/scheduling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ cc_library(
"//xls/ir:channel_ops",
"//xls/ir:node_util",
"//xls/ir:op",
"//xls/ir:proc_elaboration",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -602,6 +603,7 @@ cc_library(
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:proc_elaboration",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
],
Expand All @@ -626,6 +628,7 @@ cc_test(
"//xls/fdo:synthesizer",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:channel",
"//xls/ir:channel_ops",
"//xls/ir:foreign_function",
"//xls/ir:foreign_function_data_cc_proto",
Expand Down
3 changes: 1 addition & 2 deletions xls/scheduling/pipeline_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ absl::Status PipelineSchedule::VerifyConstraints(
int64_t last_cycle = 0;
for (Node* node : TopoSort(function_base_)) {
if (node->Is<Receive>() || node->Is<Send>()) {
XLS_ASSIGN_OR_RETURN(Channel * channel, GetChannelUsedByNode(node));
channel_to_nodes[channel->name()].push_back(node);
channel_to_nodes[node->As<ChannelNode>()->channel_name()].push_back(node);
}
last_cycle = std::max(last_cycle, cycle_map_.at(node));

Expand Down
21 changes: 19 additions & 2 deletions xls/scheduling/pipeline_scheduling_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "xls/scheduling/pipeline_scheduling_pass.h"

#include <cstdint>
#include <optional>
#include <utility>
#include <vector>

Expand All @@ -23,6 +24,8 @@
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/node.h"
#include "xls/ir/proc.h"
#include "xls/ir/proc_elaboration.h"
#include "xls/scheduling/pipeline_schedule.h"
#include "xls/scheduling/run_pipeline_schedule.h"
#include "xls/scheduling/scheduling_options.h"
Expand Down Expand Up @@ -51,6 +54,16 @@ absl::StatusOr<bool> PipelineSchedulingPass::RunInternal(

XLS_ASSIGN_OR_RETURN(std::vector<FunctionBase*> schedulable_functions,
unit->GetSchedulableFunctions());

// Scheduling of procs with proc-scoped channels requires an elaboration.
std::optional<ProcElaboration> elab;
if (unit->GetPackage()->ChannelsAreProcScoped() &&
!schedulable_functions.empty() &&
schedulable_functions.front()->IsProc()) {
XLS_ASSIGN_OR_RETURN(Proc * top, unit->GetPackage()->GetTopAsProc());
XLS_ASSIGN_OR_RETURN(elab, ProcElaboration::Elaborate(top));
}

for (FunctionBase* f : schedulable_functions) {
if (f->ForeignFunctionData().has_value()) {
continue;
Expand All @@ -68,8 +81,12 @@ absl::StatusOr<bool> PipelineSchedulingPass::RunInternal(

XLS_ASSIGN_OR_RETURN(
PipelineSchedule schedule,
RunPipelineSchedule(f, *options.delay_estimator, scheduling_options,
options.synthesizer));
options.synthesizer == nullptr
? RunPipelineSchedule(f, *options.delay_estimator,
scheduling_options, elab)
: RunPipelineScheduleWithFdo(f, *options.delay_estimator,
scheduling_options,
*options.synthesizer, elab));

// Compute `changed` before moving schedule into unit->schedules.
changed = changed || (schedule_cycle_map_before != schedule.GetCycleMap());
Expand Down
51 changes: 51 additions & 0 deletions xls/scheduling/pipeline_scheduling_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "xls/common/status/status_macros.h"
#include "xls/fdo/synthesizer.h"
#include "xls/ir/bits.h"
#include "xls/ir/channel.h"
#include "xls/ir/channel_ops.h"
#include "xls/ir/foreign_function.h"
#include "xls/ir/foreign_function_data.pb.h"
Expand Down Expand Up @@ -370,5 +371,55 @@ TEST_F(PipelineSchedulingPassTest, MultiProcWithFFI) {
IsOkAndHolds(Pair(true, SchedulingUnitWithElements(UnorderedElementsAre(
Pair(caller, VerifiedPipelineSchedule()))))));
}

TEST_F(PipelineSchedulingPassTest, MultiProcScopedChannels) {
auto p = CreatePackage();

// Create leaf proc which adds one to its input.
Proc* leaf;
{
TokenlessProcBuilder pb(NewStyleProc(), "myleaf", "tkn", p.get());
XLS_ASSERT_OK_AND_ASSIGN(ReceiveChannelReference * in,
pb.AddInputChannel("in", p->GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(SendChannelReference * out,
pb.AddOutputChannel("out", p->GetBitsType(32)));

// Create an optimization opportunity (constant folding).
BValue one = pb.Add(pb.Literal(UBits(0, 32)), pb.Literal(UBits(1, 32)));

pb.Send(out, pb.Add(pb.Receive(in), one));
XLS_ASSERT_OK_AND_ASSIGN(leaf, pb.Build());
}

// Create a top proc which instantiates two leaf procs and sends an input
// value through the chain, accumulates it and then sends to the output.
Proc* top;
{
TokenlessProcBuilder pb(NewStyleProc(), "myproc", "tkn", p.get());
XLS_ASSERT_OK_AND_ASSIGN(ReceiveChannelReference * in,
pb.AddInputChannel("in", p->GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(SendChannelReference * out,
pb.AddOutputChannel("out", p->GetBitsType(32)));

XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences tmp0_ch,
pb.AddChannel("tmp0", p->GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences tmp1_ch,
pb.AddChannel("tmp1", p->GetBitsType(32)));

BValue accum = pb.StateElement("accum", Value(UBits(0, 32)));
XLS_ASSERT_OK(pb.InstantiateProc("inst0", leaf, {in, tmp0_ch.send_ref}));
XLS_ASSERT_OK(pb.InstantiateProc("inst1", leaf,
{tmp0_ch.receive_ref, tmp1_ch.send_ref}));
BValue next_accum = pb.Add(pb.Receive(tmp1_ch.receive_ref), accum);
pb.Send(out, next_accum);
XLS_ASSERT_OK(pb.SetAsTop());
XLS_ASSERT_OK_AND_ASSIGN(top, pb.Build({next_accum}));
}

XLS_ASSERT_OK(RunPipelineSchedulingPass(
p.get(),
SchedulingOptions().pipeline_stages(2).schedule_all_procs(true)));
}

} // namespace
} // namespace xls
Loading

0 comments on commit 0a21a7c

Please sign in to comment.