diff --git a/runtime/include/brt/core/framework/execution_plan.h b/runtime/include/brt/core/framework/execution_plan.h index b4412debd..a5cd9d7c7 100644 --- a/runtime/include/brt/core/framework/execution_plan.h +++ b/runtime/include/brt/core/framework/execution_plan.h @@ -105,7 +105,7 @@ class ExecutionPlan { * It selects ExecutionProviders based ranking. */ -class StaticBRTExecutionPlan final : public ExecutionPlan { +class StaticBRTExecutionPlan : public ExecutionPlan { public: StaticBRTExecutionPlan(brt::ir::ByREHandle &); @@ -162,4 +162,46 @@ class StaticBRTExecutionPlan final : public ExecutionPlan { std::vector compute_op_kernels_; }; +class MultiStreamExecutionPlan : public StaticBRTExecutionPlan { + +public: + MultiStreamExecutionPlan(brt::ir::ByREHandle &); + + common::Status ProloguePerSession( + const std::unordered_map> + &allocators, + const std::vector> &providers, + const Device dev, const DeviceAPI *device_api) override; + + common::Status EpiloguePerSession() override; + + void CreateWorkQueue(std::unique_ptr *wq, int rank = 0) override; + + void CreateExecutinFrame(std::unique_ptr *frame) override; + + common::Status ProloguePerFrame(const ExecutionContext &) override; + common::Status EpiloguePerFrame(const ExecutionContext &) override; + + common::Status Run(const ExecutionContext &) override; + + using PartitionGraphMethod = std::function; + + void SetPartitionGraphMethod(PartitionGraphMethod method) { + partition_graph_method_ = method; + } + +private: + std::unordered_map kernel_stream_map_; + std::vector> logical_streams_; + std::vector cuda_streams_; + + void PartitionOpKernels(PartitionGraphMethod method); + std::unordered_map kernel_to_event_index; + std::unordered_map> kernel_to_wait_events; + std::vector event_list_; + + void AnalyzeStreamDependency(); + + int max_stream_num_; + } // namespace brt diff --git a/runtime/lib/core/framework/execution_plan.cc b/runtime/lib/core/framework/execution_plan.cc index 1126706e2..ba702f917 100644 --- a/runtime/lib/core/framework/execution_plan.cc +++ b/runtime/lib/core/framework/execution_plan.cc @@ -699,4 +699,91 @@ void StaticBRTExecutionPlan::IterateOpKernels( } } +MultiStreamExecutionPlan::MultiStreamExecutionPlan(ByREHandle &graph) + : StaticBRTExecutionPlan(graph) {} + +void MultiStreamExecutionPlan::PartitionOpKernels(PartitionGraphMethod method) { + int64_t max_stream_id_ = 0; + for (auto kernel : compute_op_kernels_) { + int stream_id = method(kernel); + max_stream_id_ = std::max(max_stream_id_, stream_id); + kernel_stream_map_[kernel] = stream_id; + } + + num_streams_ = max_stream_id_ + 1; + logical_streams_.resize(num_streams_); + cuda_streams_.resize(num_streams_); + + for (auto kernel : compute_op_kernels_) { + int stream_id = kernel_stream_map_[kernel]; + logical_streams_[stream_id].push_back(kernel); + } + + for (int i = 0; i < num_streams_; ++i) { + cudaStreamCreate(&cuda_streams_[i]); + } +} + +void MultiStreamExecutionPlan::AnalyzeStreamDependency() { + for (int i = 0; i < num_streams_; ++i) { + for (auto kernel : logical_streams_[i]) { + for (auto dep_index : kernel->GetDependencyList()) { + auto dep_kernel = op_kernels_[dep_index]; + int dep_stream_id = kernel_stream_map_[dep_kernel]; + if (dep_stream_id != i) { + if (kernel_to_event_index.find(dep_kernel) == + kernel_to_event_index.end()) { + kernel_to_event_index[dep_kernel] = event_list.size(); + event_list.push_back(cudaEvent_t()); + } + if (kernel_to_wait_events.find(kernel) == + kernel_to_wait_events.end()) { + kernel_to_wait_events[kernel] = {}; + } + kernel_to_wait_events[kernel].push_back( + kernel_to_event_index[dep_kernel]); + } + } + } + } +} + +void MultiStreamExecutionPlan::Run(const ExecutionContext &context) { + context.event_listener_manager->SignalEvent( + {}); + + std::vector stream_contexts(num_streams_, context); + for (int i = 0; i < num_streams_; ++i) { + stream_contexts[i].stream = cuda_streams_[i]; + } + + for (auto op : shape_op_kernels_) { + common::Status status = op->Run(context); + if (!status.IsOK()) { + return status; + } + } + + for (auto op : compute_op_kernels_) { + int stream_id = kernel_stream_map_[op]; + if (kernel_to_wait_events.find(op) != kernel_to_wait_events.end()) { + for (auto event_index : kernel_to_wait_events[op]) { + cudaStreamWaitEvent(cuda_streams_[stream_id], event_list[event_index], + 0); + } + } + common::Status status = op->Run(stream_contexts[stream_id]); + if (!status.IsOK()) { + return status; + } + if (kernel_to_event_index.find(op) != kernel_to_event_index.end()) { + cudaEventCreate(&event_list[kernel_to_event_index[op]]); + cudaEventRecord(event_list[kernel_to_event_index[op]], + cuda_streams_[stream_id]); + } + } + + context.event_listener_manager->SignalEvent( + {}); + } // namespace brt