Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add multi stream execute plan #340

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion runtime/include/brt/core/framework/execution_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ExecutionPlan {
* It selects ExecutionProviders based ranking.
*/

class StaticBRTExecutionPlan final : public ExecutionPlan {
class StaticBRTExecutionPlan : public ExecutionPlan {
public:
StaticBRTExecutionPlan(brt::ir::ByREHandle &);

Expand Down Expand Up @@ -162,4 +162,46 @@ class StaticBRTExecutionPlan final : public ExecutionPlan {
std::vector<OpKernel *> compute_op_kernels_;
};

class MultiStreamExecutionPlan : public StaticBRTExecutionPlan {

public:
MultiStreamExecutionPlan(brt::ir::ByREHandle &);

common::Status ProloguePerSession(
const std::unordered_map<std::string, std::unique_ptr<IAllocator>>
&allocators,
const std::vector<std::unique_ptr<ExecutionProvider>> &providers,
const Device dev, const DeviceAPI *device_api) override;

common::Status EpiloguePerSession() override;

void CreateWorkQueue(std::unique_ptr<WorkQueue> *wq, int rank = 0) override;

void CreateExecutinFrame(std::unique_ptr<ExecutionFrame> *frame) override;

common::Status ProloguePerFrame(const ExecutionContext &) override;
common::Status EpiloguePerFrame(const ExecutionContext &) override;

common::Status Run(const ExecutionContext &) override;

using PartitionGraphMethod = std::function<int(OpKernel *kernel)>;

void SetPartitionGraphMethod(PartitionGraphMethod method) {
partition_graph_method_ = method;
}

private:
std::unordered_map<OpKernel *, int> kernel_stream_map_;
std::vector<std::vector<OpKernel *>> logical_streams_;
std::vector<cudaStream_t> cuda_streams_;

void PartitionOpKernels(PartitionGraphMethod method);
std::unordered_map<OpKernel *, int64_t> kernel_to_event_index;
std::unordered_map<OpKernel *, std::vector<int64_t>> kernel_to_wait_events;
std::vector<cudaEvent_t> event_list_;

void AnalyzeStreamDependency();

int max_stream_num_;

} // namespace brt
87 changes: 87 additions & 0 deletions runtime/lib/core/framework/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,4 +699,91 @@ void StaticBRTExecutionPlan::IterateOpKernels(
}
}

MultiStreamExecutionPlan::MultiStreamExecutionPlan(ByREHandle &graph)
: StaticBRTExecutionPlan(graph) {}

void MultiStreamExecutionPlan::PartitionOpKernels(PartitionGraphMethod method) {
int64_t max_stream_id_ = 0;
for (auto kernel : compute_op_kernels_) {
int stream_id = method(kernel);
max_stream_id_ = std::max(max_stream_id_, stream_id);
kernel_stream_map_[kernel] = stream_id;
}

num_streams_ = max_stream_id_ + 1;
logical_streams_.resize(num_streams_);
cuda_streams_.resize(num_streams_);

for (auto kernel : compute_op_kernels_) {
int stream_id = kernel_stream_map_[kernel];
logical_streams_[stream_id].push_back(kernel);
}

for (int i = 0; i < num_streams_; ++i) {
cudaStreamCreate(&cuda_streams_[i]);
}
}

void MultiStreamExecutionPlan::AnalyzeStreamDependency() {
for (int i = 0; i < num_streams_; ++i) {
for (auto kernel : logical_streams_[i]) {
for (auto dep_index : kernel->GetDependencyList()) {
auto dep_kernel = op_kernels_[dep_index];
int dep_stream_id = kernel_stream_map_[dep_kernel];
if (dep_stream_id != i) {
if (kernel_to_event_index.find(dep_kernel) ==
kernel_to_event_index.end()) {
kernel_to_event_index[dep_kernel] = event_list.size();
event_list.push_back(cudaEvent_t());
}
if (kernel_to_wait_events.find(kernel) ==
kernel_to_wait_events.end()) {
kernel_to_wait_events[kernel] = {};
}
kernel_to_wait_events[kernel].push_back(
kernel_to_event_index[dep_kernel]);
}
}
}
}
}

void MultiStreamExecutionPlan::Run(const ExecutionContext &context) {
context.event_listener_manager->SignalEvent<Events::BeforeExecutionPlanRun>(
{});

std::vector<ExecutionContext> stream_contexts(num_streams_, context);
for (int i = 0; i < num_streams_; ++i) {
stream_contexts[i].stream = cuda_streams_[i];
}

for (auto op : shape_op_kernels_) {
common::Status status = op->Run(context);
if (!status.IsOK()) {
return status;
}
}

for (auto op : compute_op_kernels_) {
int stream_id = kernel_stream_map_[op];
if (kernel_to_wait_events.find(op) != kernel_to_wait_events.end()) {
for (auto event_index : kernel_to_wait_events[op]) {
cudaStreamWaitEvent(cuda_streams_[stream_id], event_list[event_index],
0);
}
}
common::Status status = op->Run(stream_contexts[stream_id]);
if (!status.IsOK()) {
return status;
}
if (kernel_to_event_index.find(op) != kernel_to_event_index.end()) {
cudaEventCreate(&event_list[kernel_to_event_index[op]]);
cudaEventRecord(event_list[kernel_to_event_index[op]],
cuda_streams_[stream_id]);
}
}

context.event_listener_manager->SignalEvent<Events::AfterExecutionPlanRun>(
{});

} // namespace brt
Loading