From 25b59fde9a7e85b06f840a68df5fe89a39a1a92f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 9 Jun 2023 03:29:16 +0000 Subject: [PATCH 01/66] add tensor_interface class --- cinn/hlir/framework/tensor_interface.h | 40 +++++++++++++++++++++ cinn/hlir/framework/tensor_interface_list.h | 30 ++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 cinn/hlir/framework/tensor_interface.h create mode 100644 cinn/hlir/framework/tensor_interface_list.h diff --git a/cinn/hlir/framework/tensor_interface.h b/cinn/hlir/framework/tensor_interface.h new file mode 100644 index 0000000000..843cd0cd3e --- /dev/null +++ b/cinn/hlir/framework/tensor_interface.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace cinn { +namespace hlir { +namespace framework { + +class ShapeInterface; + +class TensorInterface { + public: + // Get the shape of tensor. + virtual const ShapeInterface& shape() const = 0; + + protected: + TensorInterface() = default; + TensorInterface(const TensorInterface&) = delete; + TensorInterface(TensorInterface&&) = delete; +}; + +using TensorInterfacePtr = std::shared_ptr; + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/tensor_interface_list.h b/cinn/hlir/framework/tensor_interface_list.h new file mode 100644 index 0000000000..ddd0c221e4 --- /dev/null +++ b/cinn/hlir/framework/tensor_interface_list.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/hlir/framework/tensor_interface.h" +#include "cinn/utils/small_vector.h" + +namespace cinn { +namespace hlir { +namespace framework { + +using TensorInterfaceList = cinn::utils::SmallVector; + +} // namespace framework +} // namespace hlir +} // namespace cinn From dee0847792f4d8b5c0b0b71331854d77449e3108 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 9 Jun 2023 04:00:23 +0000 Subject: [PATCH 02/66] add '+=' for TensorInterfaceList --- cinn/hlir/framework/tensor_interface_list.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cinn/hlir/framework/tensor_interface_list.h b/cinn/hlir/framework/tensor_interface_list.h index ddd0c221e4..244461c8cd 100644 --- a/cinn/hlir/framework/tensor_interface_list.h +++ b/cinn/hlir/framework/tensor_interface_list.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "cinn/hlir/framework/tensor_interface.h" #include "cinn/utils/small_vector.h" @@ -25,6 +26,21 @@ namespace framework { using TensorInterfaceList = cinn::utils::SmallVector; +class TensorInterfaceList : public cinn::utils::SmallVector { + public: + using cinn::utils::SmallVector::SmallVector; + + TensorInterfaceList& operator+=(const TensorInterfaceList& other) { + std::unordered_set tensor_set(this->begin(), this->end()); + for (const auto& tensor_if : other) { + if (tensor_set.find(tensor_if) == tensor_set.end()) { + this->push_back(tensor_if); + tensor_set.insert(tensor_if); + } + } + } +}; + } // namespace framework } // namespace hlir } // namespace cinn From 6c58e0d37dfa6a100960cd3bf77ba587a64887eb Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 9 Jun 2023 04:07:25 +0000 Subject: [PATCH 03/66] polish code --- cinn/hlir/framework/tensor_interface_list.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/cinn/hlir/framework/tensor_interface_list.h b/cinn/hlir/framework/tensor_interface_list.h index 244461c8cd..a378f7a732 100644 --- a/cinn/hlir/framework/tensor_interface_list.h +++ b/cinn/hlir/framework/tensor_interface_list.h @@ -24,8 +24,6 @@ namespace cinn { namespace hlir { namespace framework { -using TensorInterfaceList = cinn::utils::SmallVector; - class TensorInterfaceList : public cinn::utils::SmallVector { public: using cinn::utils::SmallVector::SmallVector; From 37ce34b1ad84e044b153d46c257cb342ce153e08 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 9 Jun 2023 04:13:44 +0000 Subject: [PATCH 04/66] add OpGroupInterface class --- cinn/hlir/framework/tensor_interface_list.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cinn/hlir/framework/tensor_interface_list.h b/cinn/hlir/framework/tensor_interface_list.h index a378f7a732..beb3f32877 100644 --- a/cinn/hlir/framework/tensor_interface_list.h +++ b/cinn/hlir/framework/tensor_interface_list.h @@ -36,6 +36,7 @@ class TensorInterfaceList : public cinn::utils::SmallVector Date: Fri, 9 Jun 2023 04:14:14 +0000 Subject: [PATCH 05/66] add OpGroupInterface class --- cinn/hlir/framework/op_group_interface.h | 30 ++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 cinn/hlir/framework/op_group_interface.h diff --git a/cinn/hlir/framework/op_group_interface.h b/cinn/hlir/framework/op_group_interface.h new file mode 100644 index 0000000000..52ba5f1005 --- /dev/null +++ b/cinn/hlir/framework/op_group_interface.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/hlir/framework/tensor_interface.h" + +namespace cinn { +namespace hlir { +namespace framework { + +class OpGroupInterface {}; + +} // namespace framework +} // namespace hlir +} // namespace cinn From a1c2b701708fbe4216b0d63e3334d4b9f62e5905 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 9 Jun 2023 05:27:40 +0000 Subject: [PATCH 06/66] Leave function parameter unchanged --- cinn/hlir/framework/graph.h | 17 +-- cinn/hlir/framework/op_lowering.cc | 2 - cinn/hlir/pass/fusion_merge_pass.cc | 165 +++++++++++++++------------- cinn/hlir/pass/op_fusion_pass.cc | 18 +-- 4 files changed, 101 insertions(+), 101 deletions(-) diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 484d551508..f041930742 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -24,6 +24,7 @@ #include "cinn/common/graph_utils.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/tensor_interface_list.h" namespace cinn { namespace hlir { @@ -79,24 +80,14 @@ class Graph : public cinn::common::Graph { // master node for schedule std::unordered_set master_nodes; - struct SharedGroupHasher { - size_t operator()(const std::shared_ptr& group) const noexcept { - return std::hash()(reinterpret_cast(group.get())); - } - }; - struct SharedGroupComparator { - bool operator()(const std::shared_ptr& first, const std::shared_ptr& second) const noexcept { - return first.get() == second.get(); - } - }; // input groups - std::unordered_set, SharedGroupHasher, SharedGroupComparator> producer_groups; + std::unordered_map, std::shared_ptr> producer_groups; // output grous - std::unordered_set, SharedGroupHasher, SharedGroupComparator> consumer_groups; + std::unordered_map, std::shared_ptr> consumer_groups; // fused sub-groups, used for fusion merge pass std::vector> fused_sub_groups; // if as sub-group, used for belong groups. - std::unordered_set, SharedGroupHasher, SharedGroupComparator> belong_groups; + std::unordered_set> belong_groups; // for op lowering. std::vector input_names; diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 9a213242a1..ee47c4a7ca 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -41,8 +41,6 @@ using common::GraphNode; using common::Type; using namespace lang; -using Comparator = Graph::Group::SharedGroupComparator; -using Hasher = Graph::Group::SharedGroupHasher; using cinn::hlir::op::ExternalApiRegistry; OpLowerer::OpLowerer(const absl::flat_hash_map& type_dict, diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 7a63e12c4b..4907fe4515 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -29,11 +29,9 @@ using framework::shape_t; using common::GraphEdge; using common::GraphNode; -using Comparator = Graph::Group::SharedGroupComparator; -using Hasher = Graph::Group::SharedGroupHasher; - using GroupPtr = std::shared_ptr; using GroupList = std::vector; +using GroupIter = std::unordered_map, std::shared_ptr>::iterator; using ConditionFunction = std::function; @@ -61,10 +59,12 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& sub_group : group->fused_sub_groups) { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } - for (auto& producer : group->producer_groups) { + for (const auto& pair : group->producer_groups) { + const auto& producer = pair.first; VLOG(3) << " Producer -> " << producer->group_id; } - for (auto& consumer : group->consumer_groups) { + for (const auto& pair : group->consumer_groups) { + const auto& consumer = pair.first; VLOG(3) << " Consumer -> " << consumer->group_id; } } @@ -130,7 +130,7 @@ class FusionMergePassHelper : public FusionHelperBase { void UpdateFusionGroup() { VLOG(3) << "UpdateFusionGroup..."; GroupList fusion_groups; - std::unordered_set fusion_groups_set; + std::unordered_set fusion_groups_set; // update fusion_groups_ for (auto& group : fusion_groups_) { if (!group->belong_groups.size()) { @@ -150,7 +150,8 @@ class FusionMergePassHelper : public FusionHelperBase { } bool exist = false; - for (auto& producer : group->producer_groups) { + for (const auto& pair : group->producer_groups) { + const auto& producer = pair.first; if (fusion_groups_set.count(producer)) { VLOG(4) << group->group_id << " " << producer->group_id; exist = true; @@ -173,13 +174,13 @@ class FusionMergePassHelper : public FusionHelperBase { } } - bool HorizontalFusion(GroupPtr producer, std::unordered_set& consumers) { + bool HorizontalFusion(GroupPtr producer, std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { return false; } - std::unordered_set candidates; + std::unordered_set candidates; for (auto& consumer : consumers) { // relation auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; @@ -249,7 +250,7 @@ class FusionMergePassHelper : public FusionHelperBase { auto fused_group = std::make_shared(); // As recompute exist which may case sub-group used by more than one time. std::vector repeat_sub_groups; - std::unordered_set sub_group_set; + std::unordered_set sub_group_set; // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. @@ -315,18 +316,20 @@ class FusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (auto& producer : consumer->producer_groups) { - fused_group->producer_groups.insert(producer); + for (const auto& producer_and_list : consumer->producer_groups) { + *(fused_group->producer_groups[producer_and_list.first]) += *producer_and_list.second; // update producer's consumer - producer->consumer_groups.erase(consumer); - producer->consumer_groups.insert(fused_group); + producer_and_list.first->consumer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(producer_and_list.first->consumer_groups[fused_group]) += {}; } // consumer group - for (auto& gconsumer : consumer->consumer_groups) { - fused_group->consumer_groups.insert(gconsumer); + for (const auto& gconsumer_and_list : consumer->consumer_groups) { + *(fused_group->consumer_groups[gconsumer_and_list.first]) += *gconsumer_and_list.second; // update consumer's producer - gconsumer->producer_groups.erase(consumer); - gconsumer->producer_groups.insert(fused_group); + gconsumer_and_list.first->producer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(gconsumer_and_list.first->producer_groups[fused_group]) += {}; } // belongs group consumer->belong_groups.insert(fused_group); @@ -384,7 +387,7 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, std::unordered_set& consumers, bool recompute) { + bool VerticalFusion(GroupPtr& producer, std::unordered_set& consumers, bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; // if producer can't fuse others @@ -392,8 +395,8 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - std::unordered_set fuse_consumers_unsafe; - std::unordered_set fuse_consumers; + std::unordered_set fuse_consumers_unsafe; + std::unordered_set fuse_consumers; for (auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse @@ -456,7 +459,7 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); @@ -499,11 +502,12 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer groups - for (auto& group : producer->producer_groups) { - fused_group->producer_groups.insert(group); + for (const auto& group_and_list : producer->producer_groups) { + *(fused_group->producer_groups[group_and_list.first]) += *group_and_list.second; // update producer's producer's consumer - group->consumer_groups.erase(producer); - group->consumer_groups.insert(fused_group); + group_and_list.first->consumer_groups.erase(producer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(group_and_list.first->consumer_groups[fused_group]) += {}; } // sub groups @@ -549,20 +553,23 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (auto& group : consumer->producer_groups) { - if (group.get() != producer.get()) { - fused_group->producer_groups.insert(group); + for (const auto& group_and_list : consumer->producer_groups) { + if (group_and_list.first.get() != producer.get()) { + *(fused_group->producer_groups[group_and_list.first]) += *group_and_list.second; // update consumer's producer's consumer - group->consumer_groups.erase(consumer); - group->consumer_groups.insert(fused_group); + group_and_list.first->consumer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(group_and_list.first->consumer_groups[fused_group]) += {}; } } + // consumer nodes - for (auto& group : consumer->consumer_groups) { - fused_group->consumer_groups.insert(group); + for (const auto& group_and_list : consumer->consumer_groups) { + *(fused_group->consumer_groups[group_and_list.first]) += *group_and_list.second; // update consumer's consumer's producer - group->producer_groups.erase(consumer); - group->producer_groups.insert(fused_group); + group_and_list.first->producer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(group_and_list.first->producer_groups[fused_group]) += {}; } // sub group @@ -596,16 +603,16 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (auto& consumer : producer->consumer_groups) { + for (const auto& consumer_and_list : producer->consumer_groups) { // if consumer is in fusionable. - if (fusionable_consumers.count(consumer)) { - if (consumer->input_nodes.count(node)) { + if (fusionable_consumers.count(consumer_and_list.first)) { + if (consumer_and_list.first->input_nodes.count(node)) { be_output = false; } continue; } // if consumer is not in fusionable. - if (consumer->input_nodes.count(node)) { + if (consumer_and_list.first->input_nodes.count(node)) { be_output = true; break; } @@ -622,29 +629,28 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (auto& consumer : producer->consumer_groups) { - if (fusionable_consumers.count(consumer)) { + for (const auto& consumer_and_list : producer->consumer_groups) { + if (fusionable_consumers.count(consumer_and_list.first)) { continue; } - master_fuesd_group->consumer_groups.insert(consumer); + *(master_fuesd_group->consumer_groups[consumer_and_list.first]) += *consumer_and_list.second; // update consumer's producer - consumer->producer_groups.erase(producer); - consumer->producer_groups.insert(master_fuesd_group); + consumer_and_list.first->producer_groups.erase(producer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(consumer_and_list.first->producer_groups[master_fuesd_group]) += {}; } } - void RecomputeEleGraph(const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } } - void SelectConsumerToFuse(const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { - std::unordered_set candidates; + std::unordered_set candidates; for (auto& consumer : fusionable_consumers) { // if can be output node. if (is_same_shape(this, producer, consumer)) { @@ -714,7 +720,7 @@ class FusionMergePassHelper : public FusionHelperBase { fusionable_consumers.insert(*candidates.begin()); } } else { - std::unordered_set candidates; + std::unordered_set candidates; for (auto& consumer : fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.insert(consumer); @@ -739,24 +745,24 @@ class FusionMergePassHelper : public FusionHelperBase { bool IsDependency(const GroupPtr& producer_g, const GroupPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer : candidate->producer_groups) { - if (producer.get() == producer_g.get()) { + for (const auto& producer_and_list : candidate->producer_groups) { + if (producer_and_list.first.get() == producer_g.get()) { continue; } - if (consumers.count(producer)) { + if (consumers.count(producer_and_list.first)) { return true; } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); + if (!visited_set.count(producer_and_list.first)) { + visited_set.insert(producer_and_list.first); + candidates.push(producer_and_list.first); } } } @@ -765,28 +771,28 @@ class FusionMergePassHelper : public FusionHelperBase { bool IsDependencySimplify(const GroupPtr& producer_g, const GroupPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); // check upper. int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer : candidate->producer_groups) { - if (producer.get() == producer_g.get()) { + for (auto& producer_and_list : candidate->producer_groups) { + if (producer_and_list.first.get() == producer_g.get()) { continue; } - if (producer->min_depth > check_upper_depth) { + if (producer_and_list.first->min_depth > check_upper_depth) { continue; } - if (consumers.count(producer)) { + if (consumers.count(producer_and_list.first)) { return true; } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); + if (!visited_set.count(producer_and_list.first)) { + visited_set.insert(producer_and_list.first); + candidates.push(producer_and_list.first); } } } @@ -818,7 +824,7 @@ class FusionMergePassHelper : public FusionHelperBase { void UpdateInputToConsumers() { for (auto& input_consumers : input_to_consumers_) { auto& consumers = input_consumers.second; - std::unordered_set updated_consumers; + std::unordered_set updated_consumers; for (auto& consumer : consumers) { std::queue fused_groups; fused_groups.push(consumer); @@ -887,16 +893,19 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_set producers; - std::unordered_set consumers; + std::unordered_map> producers; + std::unordered_map> consumers; - for (auto& producer : group->producer_groups) { - CHECK(producer->belong_groups.size()); - producers.insert(*producer->belong_groups.begin()); + for (auto& producer_and_list : group->producer_groups) { + CHECK(producer_and_list.first->belong_groups.size()); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(producers[*producer_and_list.first->belong_groups.begin()]) += {}; } - for (auto& consumer : group->consumer_groups) { - CHECK(consumer->belong_groups.size()); - consumers.insert(*consumer->belong_groups.begin()); + + for (auto& consumer_and_list : group->consumer_groups) { + CHECK(consumer_and_list.first->belong_groups.size()); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + *(consumers[*consumer_and_list.first->belong_groups.begin()]) += {}; } CHECK_EQ(group->producer_groups.size(), producers.size()); CHECK_EQ(group->consumer_groups.size(), consumers.size()); @@ -996,8 +1005,8 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList fusion_groups_; - std::unordered_map fusion_groups_index_; - std::unordered_map> input_to_consumers_; + std::unordered_map fusion_groups_index_; + std::unordered_map> input_to_consumers_; struct Relation { std::unordered_map vertical_relation; diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 23fa6bcb14..749da7d2c6 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -100,16 +100,18 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& consumer : fusion_groups) { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; - consumer->producer_groups.insert(producer); - producer->consumer_groups.insert(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + consumer->producer_groups[producer] += {}; + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + producer->consumer_groups[consumer] += {}; } } // init group depth. for (auto& group : fusion_groups) { - for (auto& consumer : group->consumer_groups) { + for (const auto& consumer_and_list : group->consumer_groups) { // update depth. - group->depth = std::max(group->depth, consumer->depth + 1); + group->depth = std::max(group->depth, consumer_and_list.first->depth + 1); } } @@ -348,11 +350,11 @@ void OpFusionPassInternal(Graph* graph) { for (auto& group : graph->fusion_groups) { VLOG(3) << "Group Id : " << group->group_id; - for (auto& producer : group->producer_groups) { - VLOG(3) << " producer group -> " << producer->group_id; + for (const auto& producer_and_list : group->producer_groups) { + VLOG(3) << " producer group -> " << producer_and_list.first->group_id; } - for (auto& consumer : group->consumer_groups) { - VLOG(3) << " consumer group -> " << consumer->group_id; + for (const auto& consumer_and_list : group->consumer_groups) { + VLOG(3) << " consumer group -> " << consumer_and_list.first->group_id; } } VLOG(3) << "OpFusionPass Finish...!"; From 530b98082c5be6fbd1d29fc0cb273c1265b55966 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 9 Jun 2023 05:45:47 +0000 Subject: [PATCH 07/66] Fix vertical/horizontal fuse funcion parameter --- cinn/hlir/framework/graph.h | 8 ++++++++ cinn/hlir/pass/fusion_merge_pass.cc | 15 +++++++-------- cinn/hlir/pass/op_fusion_pass.cc | 4 ++-- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index f041930742..27289bd352 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -93,6 +93,14 @@ class Graph : public cinn::common::Graph { std::vector input_names; std::vector output_names; + std::unordered_set> CollectConsumerGroups() { + std::unordered_set> groups; + for (const auto& consumer_and_list : consumer_groups) { + groups.insert(consumer_and_list.first); + } + return groups; + } + std::vector CollectNodes() { if (fused_sub_groups.size()) { std::vector tmp_nodes; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 4907fe4515..0ae5608967 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -31,7 +31,6 @@ using common::GraphNode; using GroupPtr = std::shared_ptr; using GroupList = std::vector; -using GroupIter = std::unordered_map, std::shared_ptr>::iterator; using ConditionFunction = std::function; @@ -93,7 +92,7 @@ class FusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - updated |= HorizontalFusion(producer, producer->consumer_groups); + updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); } if (updated) { @@ -114,9 +113,9 @@ class FusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. if (!recompute) { - updated |= HorizontalFusion(producer, producer->consumer_groups); + updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); } - updated |= VerticalFusion(producer, producer->consumer_groups, recompute); + updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); } // fuse input consumers updated |= FuseInputToConsumers(); @@ -174,14 +173,14 @@ class FusionMergePassHelper : public FusionHelperBase { } } - bool HorizontalFusion(GroupPtr producer, std::unordered_set& consumers) { + bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { return false; } std::unordered_set candidates; - for (auto& consumer : consumers) { + for (const auto& consumer : consumers) { // relation auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; // check horizontal relation exist @@ -387,7 +386,7 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, std::unordered_set& consumers, bool recompute) { + bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; // if producer can't fuse others @@ -397,7 +396,7 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set fuse_consumers_unsafe; std::unordered_set fuse_consumers; - for (auto& consumer : consumers) { + for (const auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 749da7d2c6..d56385ebcb 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -101,9 +101,9 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumer->producer_groups[producer] += {}; + *(consumer->producer_groups[producer]) += {}; // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producer->consumer_groups[consumer] += {}; + *(producer->consumer_groups[consumer]) += {}; } } From 8cd38d72053d9b2ed2fe86acdd5979232928d257 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 9 Jun 2023 06:27:15 +0000 Subject: [PATCH 08/66] fix shared_ptr bug --- cinn/hlir/framework/graph.h | 4 ++-- cinn/hlir/pass/fusion_merge_pass.cc | 32 ++++++++++++++--------------- cinn/hlir/pass/op_fusion_pass.cc | 4 ++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 27289bd352..4e7a9f7d1b 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -81,9 +81,9 @@ class Graph : public cinn::common::Graph { std::unordered_set master_nodes; // input groups - std::unordered_map, std::shared_ptr> producer_groups; + std::unordered_map, TensorInterfaceList> producer_groups; // output grous - std::unordered_map, std::shared_ptr> consumer_groups; + std::unordered_map, TensorInterfaceList> consumer_groups; // fused sub-groups, used for fusion merge pass std::vector> fused_sub_groups; // if as sub-group, used for belong groups. diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 0ae5608967..dd7508d299 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -316,19 +316,19 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer group for (const auto& producer_and_list : consumer->producer_groups) { - *(fused_group->producer_groups[producer_and_list.first]) += *producer_and_list.second; + fused_group->producer_groups[producer_and_list.first] += producer_and_list.second; // update producer's consumer producer_and_list.first->consumer_groups.erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(producer_and_list.first->consumer_groups[fused_group]) += {}; + producer_and_list.first->consumer_groups[fused_group] += {}; } // consumer group for (const auto& gconsumer_and_list : consumer->consumer_groups) { - *(fused_group->consumer_groups[gconsumer_and_list.first]) += *gconsumer_and_list.second; + fused_group->consumer_groups[gconsumer_and_list.first] += gconsumer_and_list.second; // update consumer's producer gconsumer_and_list.first->producer_groups.erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(gconsumer_and_list.first->producer_groups[fused_group]) += {}; + gconsumer_and_list.first->producer_groups[fused_group] += {}; } // belongs group consumer->belong_groups.insert(fused_group); @@ -502,11 +502,11 @@ class FusionMergePassHelper : public FusionHelperBase { // producer groups for (const auto& group_and_list : producer->producer_groups) { - *(fused_group->producer_groups[group_and_list.first]) += *group_and_list.second; + fused_group->producer_groups[group_and_list.first] += group_and_list.second; // update producer's producer's consumer group_and_list.first->consumer_groups.erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(group_and_list.first->consumer_groups[fused_group]) += {}; + group_and_list.first->consumer_groups[fused_group] += {}; } // sub groups @@ -554,21 +554,21 @@ class FusionMergePassHelper : public FusionHelperBase { // producer nodes for (const auto& group_and_list : consumer->producer_groups) { if (group_and_list.first.get() != producer.get()) { - *(fused_group->producer_groups[group_and_list.first]) += *group_and_list.second; + fused_group->producer_groups[group_and_list.first] += group_and_list.second; // update consumer's producer's consumer group_and_list.first->consumer_groups.erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(group_and_list.first->consumer_groups[fused_group]) += {}; + group_and_list.first->consumer_groups[fused_group] += {}; } } // consumer nodes for (const auto& group_and_list : consumer->consumer_groups) { - *(fused_group->consumer_groups[group_and_list.first]) += *group_and_list.second; + fused_group->consumer_groups[group_and_list.first] += group_and_list.second; // update consumer's consumer's producer group_and_list.first->producer_groups.erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(group_and_list.first->producer_groups[fused_group]) += {}; + group_and_list.first->producer_groups[fused_group] += {}; } // sub group @@ -632,11 +632,11 @@ class FusionMergePassHelper : public FusionHelperBase { if (fusionable_consumers.count(consumer_and_list.first)) { continue; } - *(master_fuesd_group->consumer_groups[consumer_and_list.first]) += *consumer_and_list.second; + master_fuesd_group->consumer_groups[consumer_and_list.first] += consumer_and_list.second; // update consumer's producer consumer_and_list.first->producer_groups.erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(consumer_and_list.first->producer_groups[master_fuesd_group]) += {}; + consumer_and_list.first->producer_groups[master_fuesd_group] += {}; } } @@ -892,19 +892,19 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map> producers; - std::unordered_map> consumers; + std::unordered_map producers; + std::unordered_map consumers; for (auto& producer_and_list : group->producer_groups) { CHECK(producer_and_list.first->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(producers[*producer_and_list.first->belong_groups.begin()]) += {}; + producers[*producer_and_list.first->belong_groups.begin()] += {}; } for (auto& consumer_and_list : group->consumer_groups) { CHECK(consumer_and_list.first->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(consumers[*consumer_and_list.first->belong_groups.begin()]) += {}; + consumers[*consumer_and_list.first->belong_groups.begin()] += {}; } CHECK_EQ(group->producer_groups.size(), producers.size()); CHECK_EQ(group->consumer_groups.size(), consumers.size()); diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index d56385ebcb..749da7d2c6 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -101,9 +101,9 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(consumer->producer_groups[producer]) += {}; + consumer->producer_groups[producer] += {}; // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - *(producer->consumer_groups[consumer]) += {}; + producer->consumer_groups[consumer] += {}; } } From 73fce39bec480df8a8bfee6d96f651d8144da9d9 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 9 Jun 2023 08:09:49 +0000 Subject: [PATCH 09/66] add FusePassContext --- cinn/hlir/framework/fuse_pass_context.h | 40 ++++++++++++++++++++++++ cinn/hlir/framework/op_group_interface.h | 16 +++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 cinn/hlir/framework/fuse_pass_context.h diff --git a/cinn/hlir/framework/fuse_pass_context.h b/cinn/hlir/framework/fuse_pass_context.h new file mode 100644 index 0000000000..2be1b4e907 --- /dev/null +++ b/cinn/hlir/framework/fuse_pass_context.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/hlir/framework/op_group_interface.h" + +namespace cinn { +namespace hlir { +namespace framework { + + +class FusePassContext { + public: + FusePassContext() = default; + + std::shared_ptr PickGroup(); + + void EnableRecompute(const OpGroupInterface& op_group); + + void EnableVerticalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group); + + void EnableHorizontalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group); + +}; + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/framework/op_group_interface.h b/cinn/hlir/framework/op_group_interface.h index 52ba5f1005..a7ef09d46d 100644 --- a/cinn/hlir/framework/op_group_interface.h +++ b/cinn/hlir/framework/op_group_interface.h @@ -18,12 +18,26 @@ #include #include "cinn/hlir/framework/tensor_interface.h" +#include "cinn/hlir/framework/tensor_interface_list.h" namespace cinn { namespace hlir { namespace framework { -class OpGroupInterface {}; + +class OpGroupInterface { + public: + virtual const TensorInterfaceList& input_tensors() const = 0; + + virtual const TensorInterfaceList& output_tensors() const = 0; + + virtual const std::unordered_set> producers() const = 0; + + virtual const std::unordered_set> consumers() const = 0; + + protect: + OpGroupInterface() = default; +}; } // namespace framework } // namespace hlir From a6d741392949cb1fa08cf1a98f2a2458c7e31800 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 12 Jun 2023 06:38:01 +0000 Subject: [PATCH 10/66] Add topology relevant algorithms: dfs,bfs,scc,topo,is_reachable --- cinn/common/CMakeLists.txt | 4 + cinn/common/bfs_visitor.h | 69 ++++++++++++ cinn/common/dfs_visitor.h | 90 ++++++++++++++++ cinn/common/dfs_visitor_test.cc | 70 ++++++++++++ cinn/common/is_reachable_predicator.h | 73 +++++++++++++ cinn/common/is_reachable_predicator_test.cc | 36 +++++++ cinn/common/scc_visitor.h | 89 ++++++++++++++++ cinn/common/scc_visitor_test.cc | 111 ++++++++++++++++++++ cinn/common/topo_visitor.h | 78 ++++++++++++++ cinn/common/topo_visitor_test.cc | 48 +++++++++ 10 files changed, 668 insertions(+) create mode 100644 cinn/common/bfs_visitor.h create mode 100644 cinn/common/dfs_visitor.h create mode 100644 cinn/common/dfs_visitor_test.cc create mode 100644 cinn/common/is_reachable_predicator.h create mode 100644 cinn/common/is_reachable_predicator_test.cc create mode 100644 cinn/common/scc_visitor.h create mode 100644 cinn/common/scc_visitor_test.cc create mode 100644 cinn/common/topo_visitor.h create mode 100644 cinn/common/topo_visitor_test.cc diff --git a/cinn/common/CMakeLists.txt b/cinn/common/CMakeLists.txt index f45e281296..568ff2a3e9 100644 --- a/cinn/common/CMakeLists.txt +++ b/cinn/common/CMakeLists.txt @@ -22,6 +22,10 @@ gather_srcs(cinnapi_src SRCS message(STATUS "srcs: ${cinnapi_src}") +cc_test(test_dfs_visitor SRCS dfs_visitor_test.cc DEPS gtest glog) +cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc DEPS gtest glog) +cc_test(test_scc_visitor SRCS scc_visitor_test.cc DEPS gtest glog) +cc_test(test_topo_visitor SRCS topo_visitor_test.cc DEPS gtest glog) cc_test(test_cinn_value SRCS cinn_value_test.cc DEPS cinncore) cc_test(test_shared SRCS shared_test.cc DEPS cinncore) cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS cinncore) diff --git a/cinn/common/bfs_visitor.h b/cinn/common/bfs_visitor.h new file mode 100644 index 0000000000..032191dff0 --- /dev/null +++ b/cinn/common/bfs_visitor.h @@ -0,0 +1,69 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" + +namespace cinn { +namespace common { + +// breadth-first search visitor +template +class BfsVisitor final { + public: + BfsVisitor(const BfsVisitor&) = delete; + BfsVisitor(BfsVisitor&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = std::function; + + BfsVisitor(const NodesVisitorType& VisitNextNodes) : VisitNextNodes_(VisitNextNodes) {} + + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + llvm::SmallVector nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + template + void operator()(NodeIt begin, NodeIt end, const NodeHandlerType& NodeHandler) const { + std::queue node_queue; + std::unordered_set queued_nodes; + const auto& TryEnqueueNode = [&](NodeType node) { + if (queued_nodes.count(node) == 0) { + node_queue.push(node); + queued_nodes.insert(node); + } + }; + for (NodeIt iter = begin; iter != end; ++iter) { + TryEnqueueNode(*iter); + } + while (!node_queue.empty()) { + NodeType node = node_queue.front(); + node_queue.pop(); + NodeHandler(node); + VisitNextNodes_(node, TryEnqueueNode); + } + } + + private: + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/cinn/common/dfs_visitor.h b/cinn/common/dfs_visitor.h new file mode 100644 index 0000000000..74df11089e --- /dev/null +++ b/cinn/common/dfs_visitor.h @@ -0,0 +1,90 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" + +namespace cinn { +namespace common { + +// depth-first search visitor +template +class DfsVisitor final { + public: + DfsVisitor(const DfsVisitor&) = delete; + DfsVisitor(DfsVisitor&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = std::function; + + DfsVisitor(const NodesVisitorType& VisitNextNodes) : VisitNextNodes_(VisitNextNodes) {} + + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + llvm::SmallVector nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler, [&](NodeType) {}); + } + + template + void operator()(NodeIt begin, NodeIt end, const NodeHandlerType& NodeHandler) const { + (*this)(begin, end, NodeHandler, [&](NodeType) {}); + } + + // https://en.wikipedia.org/wiki/Depth-first_search + template + void operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandlerOnPush, + const NodeHandlerType& NodeHandlerOnPop) const { + std::unordered_set discovered; + struct Neighbours { + NodeType producer; + std::queue consumers; + }; + std::stack stack; + const auto& TryPush = [&](NodeType node) { + if (discovered.count(node) == 0) { + discovered.insert(node); + NodeHandlerOnPush(node); + stack.push(Neighbours{.producer = node}); + VisitNextNodes_(node, [&](NodeType next_node) { stack.top().consumers.push(next_node); }); + } + }; + for (NodeIt node_iter = begin; node_iter != end; ++node_iter) { + TryPush(*node_iter); + while (!stack.empty()) { + auto* neighbours = &stack.top(); + if (neighbours->consumers.empty()) { + NodeHandlerOnPop(neighbours->producer); + stack.pop(); + } else { + TryPush(neighbours->consumers.front()); + neighbours->consumers.pop(); + } + } + } + } + + private: + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/cinn/common/dfs_visitor_test.cc b/cinn/common/dfs_visitor_test.cc new file mode 100644 index 0000000000..4fed2c03cd --- /dev/null +++ b/cinn/common/dfs_visitor_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/common/dfs_visitor.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(BfsVisitor, simple_on_push) { + DfsVisitor visitor([](int node, const std::function& NodeHandler) { + if (node == 0) { + NodeHandler(3); + } else if (node == 1) { + NodeHandler(2); + NodeHandler(3); + } else if (node == 2 || node == 3) { + NodeHandler(4); + } + }); + std::vector sources{0, 1}; + std::vector outputs; + visitor(sources.begin(), sources.end(), [&](int node) { + LOG(ERROR) << node; + outputs.push_back(node); + }); + std::vector expected{0, 3, 4, 1, 2}; + EXPECT_TRUE((outputs == expected)); +} + +TEST(BfsVisitor, simple_on_pop) { + DfsVisitor visitor([](int node, const std::function& NodeHandler) { + if (node == 0) { + NodeHandler(3); + } else if (node == 1) { + NodeHandler(2); + NodeHandler(3); + } else if (node == 2 || node == 3) { + NodeHandler(4); + } + }); + std::vector sources{0, 1}; + std::vector outputs; + visitor( + sources.begin(), + sources.end(), + [](int) {}, + [&](int node) { + LOG(ERROR) << node; + outputs.push_back(node); + }); + std::vector expected{4, 3, 0, 2, 1}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn \ No newline at end of file diff --git a/cinn/common/is_reachable_predicator.h b/cinn/common/is_reachable_predicator.h new file mode 100644 index 0000000000..38211daec1 --- /dev/null +++ b/cinn/common/is_reachable_predicator.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/common/bfs_visitor.h" + +namespace cinn { +namespace common { + +template +class IsReachablePredicator final { + public: + IsReachablePredicator(const IsReachablePredicator&) = delete; + IsReachablePredicator(IsReachablePredicator&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = std::function; + using NodeDepthGetterType = std::function; + + IsReachablePredicator(const NodeDepthGetterType& MinDepth4Node, + const NodeDepthGetterType& MaxDepth4Node, + const NodesVisitorType& VisitNextNodes) + : MinDepth4Node_(MinDepth4Node), MaxDepth4Node_(MaxDepth4Node), VisitNextNodes_(VisitNextNodes) {} + + bool operator()(NodeType src, NodeType dst, const NodeHandlerType& HandleVisited) const { + const size_t dst_max_depth = MaxDepth4Node_(dst); + bool detect_reachable = false; + BfsVisitor bfs_visitor([&](NodeType node, const NodeHandlerType& Handler) { + VisitNextNodes_(node, [&](NodeType out_node) { + if (dst_max_depth < MinDepth4Node_(out_node)) { + // Pruned. + // Do nothing. + } else if (detect_reachable) { + // Pruned. + // Reachability is detected. + } else { + Handler(out_node); + } + }); + }); + std::array starts{src}; + bfs_visitor(starts.begin(), starts.end(), [&](NodeType node) { + HandleVisited(node); + if (node == dst) { + detect_reachable = true; + } + }); + return detect_reachable; + } + + private: + NodeDepthGetterType MinDepth4Node_; + NodeDepthGetterType MaxDepth4Node_; + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/cinn/common/is_reachable_predicator_test.cc b/cinn/common/is_reachable_predicator_test.cc new file mode 100644 index 0000000000..81da3b202d --- /dev/null +++ b/cinn/common/is_reachable_predicator_test.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/common/is_reachable_predicator.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(IsReachablePredicator, simple) { + IsReachablePredicator IsReachable( + // Get min depth + [](int x) { return std::abs(x); }, + // Get max depth + [](int x) { return std::abs(x); }, + // visit next node + [](int x, const std::function& Handler) { Handler(x + (x / std::abs(x))); }); + EXPECT_TRUE(IsReachable(33, 99, [](int) {})); + EXPECT_FALSE(IsReachable(33, -99, [](int) {})); +} + +} // namespace common +} // namespace cinn \ No newline at end of file diff --git a/cinn/common/scc_visitor.h b/cinn/common/scc_visitor.h new file mode 100644 index 0000000000..f1949fad8d --- /dev/null +++ b/cinn/common/scc_visitor.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "cinn/common/dfs_visitor.h" + +namespace cinn { +namespace common { + +// strong connnected components visitor +template +class SccVisitor final { + public: + SccVisitor(const SccVisitor&) = delete; + SccVisitor(SccVisitor&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = std::function; + + SccVisitor(const NodesVisitorType& VisitPrevNodes, const NodesVisitorType& VisitNextNodes) + : VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {} + + using SccHandlerType = std::function&)>; + + // https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm + template + void operator()(NodeIt begin, NodeIt end, const SccHandlerType& SccHandler) const { + const std::list& dfs_ordered_nodes = [&]() { + std::list dfs_ordered_nodes; + DfsVisitor visitor(VisitNextNodes_); + visitor( + begin, + end, + /*on push*/ [](NodeType) {}, + /*on pop*/ [&](NodeType node) { dfs_ordered_nodes.push_front(node); }); + return dfs_ordered_nodes; + }(); + std::unordered_map node2root; + const auto& VisitPrevNode = [&](NodeType node, const NodeHandlerType& NodeHandler) { + VisitPrevNodes_(node, [&](NodeType prev_node) { + if (node2root.count(prev_node) == 0) { + NodeHandler(prev_node); + } + }); + }; + for (NodeType root : dfs_ordered_nodes) { + if (node2root.count(root) > 0) { + continue; + } + std::vector scc; + // Use node2root immutablely inside dfs visitor. + DfsVisitor visitor(VisitPrevNode); + visitor(root, [&](NodeType node) { scc.push_back(node); }); + SccHandler(scc); + // Update node2root outside dfs visitor. + for (NodeType node : scc) { + CHECK(node2root.emplace(node, root).second); + } + } + } + + private: + NodesVisitorType VisitPrevNodes_; + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/cinn/common/scc_visitor_test.cc b/cinn/common/scc_visitor_test.cc new file mode 100644 index 0000000000..fe062dce88 --- /dev/null +++ b/cinn/common/scc_visitor_test.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/common/scc_visitor.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(SccVisitor, trivial) { + std::list> edges{{0, 3}, {1, 2}, {1, 3}, {2, 4}, {3, 4}}; + + SccVisitor visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0, 1}; + std::vector> outputs; + visitor(sources.begin(), sources.end(), [&](const auto& nodes) { outputs.push_back(nodes); }); + std::vector> expected{{1}, {2}, {0}, {3}, {4}}; + EXPECT_TRUE((outputs == expected)); +} + +TEST(SccVisitor, circle) { + std::list> edges{ + {0, 1}, + {1, 2}, + {2, 3}, + {3, 4}, + {4, 0}, + }; + + SccVisitor visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0}; + std::vector> outputs; + visitor(sources.begin(), sources.end(), [&](const auto& nodes) { outputs.push_back(nodes); }); + std::vector> expected{{0, 4, 3, 2, 1}}; + EXPECT_TRUE((outputs == expected)); +} + +TEST(SccVisitor, double_circle) { + std::list> edges{ + {0, 1}, + {1, 0}, + {1, 2}, + {2, 3}, + {3, 2}, + }; + + SccVisitor visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0}; + std::vector> outputs; + visitor(sources.begin(), sources.end(), [&](const auto& nodes) { outputs.push_back(nodes); }); + std::vector> expected{{0, 1}, {2, 3}}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn \ No newline at end of file diff --git a/cinn/common/topo_visitor.h b/cinn/common/topo_visitor.h new file mode 100644 index 0000000000..3eeb417cc6 --- /dev/null +++ b/cinn/common/topo_visitor.h @@ -0,0 +1,78 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" + +namespace cinn { +namespace common { + +// Topological order visitor +template +class TopoVisitor final { + public: + TopoVisitor(const TopoVisitor&) = delete; + TopoVisitor(TopoVisitor&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = std::function; + + TopoVisitor(const NodesVisitorType& VisitPrevNodes, const NodesVisitorType& VisitNextNodes) + : VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {} + + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + llvm::SmallVector nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + template + void operator()(NodeIt begin, NodeIt end, const NodeHandlerType& NodeHandler) const { + std::queue node_queue; + std::unordered_set queued_nodes; + const auto& TryEnqueueNode = [&](NodeType node) { + if (queued_nodes.count(node) == 0) { + node_queue.push(node); + queued_nodes.insert(node); + } + }; + for (NodeIt iter = begin; iter != end; ++iter) { + TryEnqueueNode(*iter); + } + while (!node_queue.empty()) { + NodeType node = node_queue.front(); + node_queue.pop(); + NodeHandler(node); + VisitNextNodes_(node, [&](NodeType node) { + size_t num_unfinished_inputs = 0; + VisitPrevNodes_(node, + [&](NodeType in_node) { num_unfinished_inputs += (queued_nodes.count(in_node) > 0 ? 0 : 1); }); + if (num_unfinished_inputs == 0) { + TryEnqueueNode(node); + } + }); + } + } + + private: + NodesVisitorType VisitPrevNodes_; + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/cinn/common/topo_visitor_test.cc b/cinn/common/topo_visitor_test.cc new file mode 100644 index 0000000000..1c8cb91deb --- /dev/null +++ b/cinn/common/topo_visitor_test.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/common/topo_visitor.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(TopoVisitor, simple) { + std::vector> edges{{0, 3}, {1, 2}, {1, 3}, {2, 3}, {3, 4}}; + TopoVisitor visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0, 1}; + std::vector outputs; + visitor(sources.begin(), sources.end(), [&](int node) { outputs.push_back(node); }); + std::vector expected{0, 1, 2, 3, 4}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn \ No newline at end of file From f9832670252d6dd7cbdde5a27ffb9ccbee553bd9 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 12 Jun 2023 11:24:32 +0000 Subject: [PATCH 11/66] horizontal fuse --- cinn/frontend/optimize.cc | 7 +- cinn/hlir/framework/graph.h | 30 +- cinn/hlir/pass/general_fusion_merge_pass.cc | 1237 +++++++++++++++++++ cinn/runtime/flags.cc | 4 + 4 files changed, 1272 insertions(+), 6 deletions(-) create mode 100644 cinn/hlir/pass/general_fusion_merge_pass.cc diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index b93326b806..dc4ce886d3 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -37,6 +37,7 @@ DECLARE_bool(cinn_use_custom_call); DECLARE_bool(use_reduce_split_pass); DECLARE_bool(cinn_use_dense_merge_pass); DECLARE_string(cinn_custom_call_deny_ops); +DECLARE_bool(general_fusion_merge_pass); namespace cinn { namespace frontend { @@ -95,7 +96,11 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { if (FLAGS_cinn_use_op_fusion) { options.graph_passes.emplace_back("OpFusionPass"); - options.graph_passes.emplace_back("FusionMergePass"); + if (FLAGS_general_fusion_merge_pass) { + options.graph_passes.emplace_back("GeneralFusionMergePass"); + } else { + options.graph_passes.emplace_back("FusionMergePass"); + } } else { options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); } diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 4e7a9f7d1b..935383aa2c 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -21,6 +21,7 @@ #include #include +#include "cinn/api/cpp/op_group_interface.h" #include "cinn/common/graph_utils.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/node.h" @@ -56,7 +57,7 @@ class Graph : public cinn::common::Graph { absl::flat_hash_map> attrs; std::vector> groups; - struct Group { + struct Group final : public OpGroupInterface { // distance to last group. int depth{0}; int max_depth{0}; @@ -80,10 +81,6 @@ class Graph : public cinn::common::Graph { // master node for schedule std::unordered_set master_nodes; - // input groups - std::unordered_map, TensorInterfaceList> producer_groups; - // output grous - std::unordered_map, TensorInterfaceList> consumer_groups; // fused sub-groups, used for fusion merge pass std::vector> fused_sub_groups; // if as sub-group, used for belong groups. @@ -125,6 +122,29 @@ class Graph : public cinn::common::Graph { std::unordered_set GetOutputNodeDatas(); std::string GetFuncName() { return "fn_" + group_id + unique_id; } + + public: + const std::unordered_map, TensorInterfaceList>& producer_groups() const override { + return producer_groups_; + } + + const std::unordered_map, TensorInterfaceList>& consumer_groups() const override { + return consumer_groups_; + } + + std::unordered_map, TensorInterfaceList>* mut_producer_groups() { + return &producer_groups_; + } + + std::unordered_map, TensorInterfaceList>* mut_consumer_groups() { + return &consumer_groups_; + } + + private: + // input groups + std::unordered_map, TensorInterfaceList> producer_groups_; + // output grous + std::unordered_map, TensorInterfaceList> consumer_groups_; }; std::vector> fusion_groups; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc new file mode 100644 index 0000000000..cbf33de10d --- /dev/null +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -0,0 +1,1237 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/pass/fusion_merge_pass_util.h" + +DECLARE_bool(enhance_vertical_fusion_with_recompute); + +namespace cinn { +namespace hlir { +namespace pass { +namespace { + +using framework::Graph; +using framework::Node; +using framework::NodeData; +using framework::OpPatternKind; +using framework::shape_t; + +using common::GraphEdge; +using common::GraphNode; + +using GroupPtr = std::shared_ptr; +using GroupList = std::vector; + +using OpGroupPtr = std::shared_ptr; +using OpGroupList = std::vector; + +using ConditionFunction = std::function; + +// Op Fusion Pass which performs Ops fusion, Ops are fused +// "vertically", meaning producing Ops are fused into their consumers +// with the intent that the loops which compute their values will be fused in +// code generation. +class FusionMergePassHelper : public FusionHelperBase { + public: + FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { + fusion_groups_ = graph->fusion_groups; + // init fusion relation. + InitFusionRelation(); + // init input to consumers. + InitInputToConsumers(); + // init fusion group index. + InitFusionGroupsAndIndex(); + } + + GroupList operator()() { + // run fusion merge untill no update. + DoFusionMerge(); + for (auto& group : fusion_groups_) { + VLOG(3) << "Fusion Group -> " << group->group_id; + for (auto& sub_group : group->fused_sub_groups) { + VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; + } + for (const auto& pair : group->producer_groups) { + const auto& producer = pair.first; + VLOG(3) << " Producer -> " << producer->group_id; + } + for (const auto& pair : group->consumer_groups) { + const auto& consumer = pair.first; + VLOG(3) << " Consumer -> " << consumer->group_id; + } + } + return fusion_groups_; + } + + private: + void DoFusionMerge() { + VLOG(3) << "DoFusionMerge...!"; + while (DoGeneralHorizontalFusion()) { + } + while (DoVerticalFusion(/* recompute=*/false)) { + } + while (DoVerticalFusion(/* recompute=*/true)) { + } + } + + bool DoGeneralHorizontalFusion() { + VLOG(3) << "DoGeneralHorizontalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + updated |= GeneralHorizontalFuse(producer); + } + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + bool DoVerticalFusion(bool recompute) { + VLOG(3) << "DoVerticalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + if (!recompute) { + updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + } + updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); + } + // fuse input consumers + updated |= FuseInputToConsumers(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + void UpdateFusionGroup() { + VLOG(3) << "UpdateFusionGroup..."; + GroupList fusion_groups; + std::unordered_set fusion_groups_set; + // update fusion_groups_ + for (auto& group : fusion_groups_) { + if (!group->belong_groups.size()) { + fusion_groups.push_back(group); + fusion_groups_set.insert(group); + } + } + // keep group in order + fusion_groups_.clear(); + fusion_groups_index_.clear(); + while (!fusion_groups_set.empty()) { + bool is_ring = true; + for (int idx = 0; idx < fusion_groups.size(); ++idx) { + auto& group = fusion_groups[idx]; + if (!group.get()) { + continue; + } + + bool exist = false; + for (const auto& pair : group->producer_groups) { + const auto& producer = pair.first; + if (fusion_groups_set.count(producer)) { + VLOG(4) << group->group_id << " " << producer->group_id; + exist = true; + break; + } + } + + if (!exist) { + fusion_groups_index_[group] = fusion_groups_.size(); + fusion_groups_.push_back(group); + fusion_groups_set.erase(group); + group.reset(); + is_ring = false; + continue; + } + } + if (is_ring) { + LOG(FATAL) << "Exists Ring, Please Check!"; + } + } + } + + struct LightwareFusePassCtx; + + struct FuseHelper final { + public: + explicit FuseHelper(LightwareFusePassCtx* ctx) : ctx_(ctx) {} + + bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { TODO(); } + + bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const { TODO(); } + + private: + LightwareFusePassCtx* ctx_; + }; + + struct LightwareFusePassCtx final { + public: + LightwareFusePassCtx(const OpGroupPtr& group, + const std::function& EnableFuse) + : group_(group), EnableFuse_(EnableFuse), fuse_helper_(this) {} + + const OpGroupPtr& PickOpGroup() const { return group_; } + + const FuseHelper& fuse_helper() const { return fuse_helper_; } + + void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) { EnableFuse_(first, second); } + + private: + const OpGroupPtr group_; + const std::function EnableFuse_; + const FuseHelper fuse_helper_; + }; + + class FusePass { + public: + virtual ~FusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + protected: + FusePass(); + }; + + class DefautlHorizontalFusePass final : public FusePass { + public: + DefautlHorizontalFusePass() : FusePass() {} + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& pair : producer->consumer2outputs()) { + consumers.insert(pair.first); + } + return consumers; + }(); + if (consumers.size() <= 1) { + return; + } + for (int i = 0; i < consumers.size(); ++i) { + const auto& src = consumers.at(i); + for (int j = i + 1; j < consumers.size(); ++j) { + const auto& dst = consumers.at(j); + if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { + continue; + } + if (!DetectFusabilityByKind(ctx, src, dst)) { + continue; + } + ctx->EnableFuse(src, dst); + return; + } + } + } + + using KindKeyT = std::pair; + + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src->kind(), dst->kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + + typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); + + const std::unordered_map& GetConditionMap() const { + thread_local static std::unordered_map map(RawConditionMap()); + return map; + } + + std::unordered_map RawConditionMap() const { + return std::unordered_map{ + {{OpPatternKind::kElementWise, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kReduction}, + &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, + + {{OpPatternKind::kInjective, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, + + {{OpPatternKind::kReduction, framework::kElementWise}, + &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, + {{OpPatternKind::kReduction, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kReduction, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kReduction, framework::kReduction}, &DefautlHorizontalFusePass::ReduceFuseReduce}, + }; + } + + static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().AllOutputsSameSize(src, dst); + } + + static bool HorizontalElementwiseFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); + } + + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } + }; + + std::vector> RawHorizontalFusePasses() const { + std::vector>{ + std::shared_ptr(new DefautlHorizontalFusePass{}), + }; + return ret; + } + + const std::vector>& GetHorizontalFusePasses() const { + thread_local static std::vector> fuse_passes = RawHorizontalFusePasses(); + return fuse_passes; + } + + void TagHorizontalGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer->consumer2outputs().size() <= 1) { + return; + } + const auto& fuse_passes = GetHorizontalFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralHorizontalFuse(const GroupPtr& producer) const { + VLOG(3) << "GeneralHorizontalFuse...!"; + using GroupSets = std::set>; + const auto& GetFusableConsumerGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& EnableFuse = [&](const GroupPtr& first, const GroupPtr& second) { + tagged_sets.insert(std::set{first, second}); + }; + LightwareFusePassCtx fuse_ctx(producer, EnableFuse); + TagHorizontalGroups(&fuse_ctx); + return tagged_sets; + }; + const auto& GetFusableConsumerGroupList = [&]() -> GroupList { + const auto& group_sets = GetFusableConsumerGroupSets(); + if (group_sets.empty()) { + return GroupList{}; + } + return GroupList{group_sets.begin()->begin(), group_sets.begin()->end()}; + }; + size_t fuse_count = 0; + while (true) { + const auto& groups = GetFusableConsumerGroupList(); + if (groups.size() <= 1) { + break; + } + fuse_count += HorizontalFuse(groups); + } + return fuse_count > 0; + } + + bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { + VLOG(3) << "HorizontalFusion...!"; + if (consumers.size() <= 1) { + return false; + } + + std::unordered_set candidates; + for (const auto& consumer : consumers) { + // relation + auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; + // check horizontal relation exist + if (!relation.horizontal_relation.size()) { + continue; + } + candidates.insert(consumer); + } + + std::vector fusionable_consumers; + for (auto& candidate : candidates) { + // check dependency + if (IsDependencySimplify(producer, candidate, candidates)) { + VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id << ", As it depency others!"; + continue; + } + + if (IsDependency(producer, candidate, candidates)) { + VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id << ", As it depency others!"; + continue; + } + + if (!fusionable_consumers.size()) { + fusionable_consumers.push_back({candidate}); + continue; + } + + // check each fusionable groups + bool fusionable = false; + auto& relation = fusion_relation_map_[candidate->op_pattern_kind]; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!relation.horizontal_relation.count(last->op_pattern_kind)) { + continue; + } + + if (!relation.horizontal_relation[last->op_pattern_kind](this, candidate, last)) { + continue; + } + + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + bool updated = false; + for (auto& groups : fusionable_consumers) { + if (groups.size() > 1) { + updated = true; + HorizontalFuse(groups); + } + } + + return updated; + } + + void HorizontalFuse(const GroupList& consumers) { + VLOG(3) << "HorizontalFuse Groups..."; + // create fusion group + auto fused_group = std::make_shared(); + // As recompute exist which may case sub-group used by more than one time. + std::vector repeat_sub_groups; + std::unordered_set sub_group_set; + // find the first consumer. + GroupPtr first_consumer(nullptr); + // fuse all group into fusion group. + for (const auto& consumer : consumers) { + VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; + // update depth + fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); + fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); + // update group id + if (fused_group->group_id.size()) { + fused_group->group_id += "_" + consumer->group_id; + } else { + fused_group->group_id = consumer->group_id; + } + // set op pattern kind + fused_group->op_pattern_kind = + static_cast(fused_group->op_pattern_kind) >= static_cast(consumer->op_pattern_kind) + ? fused_group->op_pattern_kind + : consumer->op_pattern_kind; + // input nodes + for (auto& node : consumer->input_nodes) { + if (fused_group->input_nodes.count(node.first)) { + fused_group->input_nodes[node.first] += node.second; + } else { + fused_group->input_nodes.insert(node); + } + } + // output node + for (auto& node : consumer->output_nodes) { + fused_group->output_nodes.insert(node); + } + // internal node + if (consumer->fused_sub_groups.size()) { + for (auto& node : consumer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + // master node + for (auto& node : consumer->master_nodes) { + if (GetOpKind(node) == framework::kReduction) { + fused_group->master_nodes.insert(node); + } + } + // insert sub group + if (consumer->fused_sub_groups.size()) { + for (auto& sub_group : consumer->fused_sub_groups) { + // check sub group is repeat. + if (sub_group_set.count(sub_group)) { + VLOG(3) << sub_group->group_id << " is repeated!"; + repeat_sub_groups.push_back(sub_group); + continue; + } + // record sub group + sub_group_set.insert(sub_group); + + // insert to fused sub group. + fused_group->fused_sub_groups.push_back(sub_group); + // update belongs group + sub_group->belong_groups.erase(consumer); + sub_group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(consumer); + } + // producer group + for (const auto& producer_and_list : consumer->producer_groups) { + fused_group->producer_groups[producer_and_list.first] += producer_and_list.second; + // update producer's consumer + producer_and_list.first->consumer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + producer_and_list.first->consumer_groups[fused_group] += {}; + } + // consumer group + for (const auto& gconsumer_and_list : consumer->consumer_groups) { + fused_group->consumer_groups[gconsumer_and_list.first] += gconsumer_and_list.second; + // update consumer's producer + gconsumer_and_list.first->producer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + gconsumer_and_list.first->producer_groups[fused_group] += {}; + } + // belongs group + consumer->belong_groups.insert(fused_group); + + // find the first consumer. + CHECK(fusion_groups_index_.count(consumer)) + << "Can't find consumer " << consumer->group_id << " index in fusion_groups_index_!"; + if (first_consumer.get()) { + if (fusion_groups_index_[consumer] < fusion_groups_index_[first_consumer]) { + first_consumer = consumer; + } + } else { + first_consumer = consumer; + } + } + + // if node is output nodes of sub_group, check it can't be internal node. + for (auto& sub_group : repeat_sub_groups) { + // check each output node in sub_group. + for (auto& node : sub_group->output_nodes) { + // if node is not output node of fused_group. + if (!fused_group->output_nodes.count(node)) { + fused_group->internal_nodes.insert(node); + } + } + } + + if (static_cast(framework::kReduction) > static_cast((consumers.back())->op_pattern_kind)) { + auto consumer = consumers.back(); + + for (auto& node : consumer->master_nodes) { + fused_group->master_nodes.insert(node); + } + } else { + for (auto consumer = consumers.rbegin(); consumer != consumers.rend(); ++consumer) { + Node* master_node = nullptr; + for (auto& node : (*consumer)->master_nodes) { + if (GetOpKind(node) != framework::kReduction) { + master_node = node; + break; + } + } + if (master_node) { + VLOG(3) << "Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id; + fused_group->master_nodes.insert(master_node); + break; + } + } + } + + auto postion = fusion_groups_index_[first_consumer]; + fusion_groups_[postion] = fused_group; + fusion_groups_index_[fused_group] = postion; + + CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; + } + + bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { + VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); + auto& relation = fusion_relation_map_[producer->op_pattern_kind]; + // if producer can't fuse others + if (!relation.vertical_relation.size()) { + return false; + } + + std::unordered_set fuse_consumers_unsafe; + std::unordered_set fuse_consumers; + for (const auto& consumer : consumers) { + VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; + // if can't fuse + if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { + VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; + continue; + } + + // if condition function is false + if (!relation.vertical_relation[consumer->op_pattern_kind](this, producer, consumer)) { + VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; + continue; + } + + fuse_consumers_unsafe.insert(consumer); + + if (IsDependencySimplify(producer, consumer, consumers)) { + VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; + continue; + } + + if (IsDependency(producer, consumer, consumers)) { + VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; + continue; + } + + fuse_consumers.insert(consumer); + } + + VLOG(3) << "VerticalFusion, Number of fuse Consumers : " << fuse_consumers.size(); + VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); + + if (fuse_consumers.size() == 0) { + return false; + } + // if can_fuse_consumers == consumers + // if producer op kind == kElementwise + // if use recompute + if (fuse_consumers_unsafe.size() == producer->consumer_groups.size() && + producer->op_pattern_kind == framework::kElementWise) { + if (!recompute) { + return false; + } else { + RecomputeEleGraph(producer, fuse_consumers_unsafe); + VerticalFuse(producer, fuse_consumers_unsafe); + return true; + } + } + + if (fuse_consumers.size()) { + SelectConsumerToFuse(producer, fuse_consumers); + } + + // if fusionable consumers exist + if (fuse_consumers.size()) { + VerticalFuse(producer, fuse_consumers); + return true; + } + + return false; + } + + void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + VLOG(3) << "VerticalFuse...!"; + GroupList fused_groups; + GroupPtr master_fuesd_group(nullptr); + for (auto& consumer : fusionable_consumers) { + auto fused_group = std::make_shared(); + // update depth using consumer depth. + fused_group->max_depth = std::max(producer->max_depth, consumer->max_depth); + fused_group->min_depth = std::min(producer->min_depth, consumer->min_depth); + // update group id + fused_group->group_id = producer->group_id + "_" + consumer->group_id; + VLOG(3) << "fuse producer " << producer->group_id << " into consumer " << consumer->group_id; + // fuse producer into fusion group + fused_group->op_pattern_kind = + static_cast(producer->op_pattern_kind) >= static_cast(consumer->op_pattern_kind) + ? producer->op_pattern_kind + : consumer->op_pattern_kind; + // input nodes + fused_group->input_nodes = producer->input_nodes; + + // internal nodes + if (producer->fused_sub_groups.size()) { + for (auto& node : producer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + // convert producer's output node to internal. + for (auto node : producer->output_nodes) { + // if node is used more than 1 time. + if (consumer->input_nodes.count(node)) { + if (consumer->input_nodes[node] > 1 && node->inlinks().size() > 0) { + fused_group->internal_nodes.insert(node); + } + } + } + // master nodes + for (auto& node : producer->master_nodes) { + if (GetOpKind(node) == framework::kReduction) { + fused_group->master_nodes.insert(node); + } + } + + // producer groups + for (const auto& group_and_list : producer->producer_groups) { + fused_group->producer_groups[group_and_list.first] += group_and_list.second; + // update producer's producer's consumer + group_and_list.first->consumer_groups.erase(producer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + group_and_list.first->consumer_groups[fused_group] += {}; + } + + // sub groups + if (producer->fused_sub_groups.size()) { + for (auto& group : producer->fused_sub_groups) { + fused_group->fused_sub_groups.push_back(group); + // update belong group + group->belong_groups.erase(producer); + group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(producer); + } + producer->belong_groups.insert(fused_group); + + // input nodes + for (auto& input_node : consumer->input_nodes) { + // if input node not in producer output. + if (!producer->output_nodes.count(input_node.first)) { + if (fused_group->input_nodes.count(input_node.first)) { + fused_group->input_nodes[input_node.first] += input_node.second; + } else { + fused_group->input_nodes.insert(input_node); + } + } + } + + // output nodes + for (auto& node : consumer->output_nodes) { + fused_group->output_nodes.insert(node); + } + + // internal nodes + if (consumer->fused_sub_groups.size()) { + for (auto& node : consumer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + + // master nodes + for (auto& node : consumer->master_nodes) { + fused_group->master_nodes.insert(node); + } + + // producer nodes + for (const auto& group_and_list : consumer->producer_groups) { + if (group_and_list.first.get() != producer.get()) { + fused_group->producer_groups[group_and_list.first] += group_and_list.second; + // update consumer's producer's consumer + group_and_list.first->consumer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + group_and_list.first->consumer_groups[fused_group] += {}; + } + } + + // consumer nodes + for (const auto& group_and_list : consumer->consumer_groups) { + fused_group->consumer_groups[group_and_list.first] += group_and_list.second; + // update consumer's consumer's producer + group_and_list.first->producer_groups.erase(consumer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + group_and_list.first->producer_groups[fused_group] += {}; + } + + // sub group + if (consumer->fused_sub_groups.size()) { + for (auto& sub_group : consumer->fused_sub_groups) { + if (std::find(fused_group->fused_sub_groups.begin(), fused_group->fused_sub_groups.end(), sub_group) == + fused_group->fused_sub_groups.end()) { + fused_group->fused_sub_groups.push_back(sub_group); + } + // update belong group + sub_group->belong_groups.erase(consumer); + sub_group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(consumer); + } + consumer->belong_groups.insert(fused_group); + + fused_groups.push_back(fused_group); + CHECK(fusion_groups_index_.count(consumer)) + << "Can't find consumer " << consumer->group_id << " index in fusion_groups_index_!"; + auto postion = fusion_groups_index_[consumer]; + fusion_groups_[postion] = fused_group; + fusion_groups_index_[fused_group] = postion; + + if (!master_fuesd_group.get()) { + master_fuesd_group = fused_group; + } + CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; + } + + for (auto& node : producer->output_nodes) { + bool be_output = true; + for (const auto& consumer_and_list : producer->consumer_groups) { + // if consumer is in fusionable. + if (fusionable_consumers.count(consumer_and_list.first)) { + if (consumer_and_list.first->input_nodes.count(node)) { + be_output = false; + } + continue; + } + // if consumer is not in fusionable. + if (consumer_and_list.first->input_nodes.count(node)) { + be_output = true; + break; + } + // others node is as graph output. + } + + if (output_nodes_set_.count(node)) { + be_output = true; + } + + if (be_output) { + VLOG(4) << "Insert Id " << node->id() << " Into Group " << master_fuesd_group->group_id; + master_fuesd_group->output_nodes.insert(node); + } + } + // insert unfusionable consumer groups + for (const auto& consumer_and_list : producer->consumer_groups) { + if (fusionable_consumers.count(consumer_and_list.first)) { + continue; + } + master_fuesd_group->consumer_groups[consumer_and_list.first] += consumer_and_list.second; + // update consumer's producer + consumer_and_list.first->producer_groups.erase(producer); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + consumer_and_list.first->producer_groups[master_fuesd_group] += {}; + } + } + + void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + if (producer->op_pattern_kind != framework::kElementWise) { + SelectConsumerToFuse(producer, fusionable_consumers); + } + } + + void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + // if is const op + if (is_const_group(this, producer)) { + std::unordered_set candidates; + for (auto& consumer : fusionable_consumers) { + // if can be output node. + if (is_same_shape(this, producer, consumer)) { + candidates.insert(consumer); + } else { + VLOG(4) << "Fuse Producer : " << producer->group_id << " into Consumer : " << consumer->group_id; + consumer->group_id = producer->group_id + "_" + consumer->group_id; + // just merge the node into group. + auto& sub_group = consumer->fused_sub_groups.front(); + sub_group->group_id = producer->group_id + "_" + sub_group->group_id; + sub_group->nodes.insert(sub_group->nodes.begin(), producer->CollectNodes()[0]); + sub_group->nodes_set.insert(producer->CollectNodes()[0]); + // remove depency. + consumer->input_nodes.erase(producer->CollectNodes()[0]); + consumer->producer_groups.erase(producer); + producer->consumer_groups.erase(consumer); + } + } + + CHECK_GE(producer->consumer_groups.size(), candidates.size()); + if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && + output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { + producer->belong_groups.insert(*fusionable_consumers.begin()); + } + + fusionable_consumers = candidates; + return; + } + // 1 to 1 fusion. + if (producer->consumer_groups.size() == 1) { + return; + } + + if (FLAGS_enhance_vertical_fusion_with_recompute) { + std::vector candidates; + for (auto& consumer : fusionable_consumers) { + if (consumer->op_pattern_kind == framework::kElementWise) { + candidates.push_back(consumer); + continue; + } + + auto producer_output_shape = this->GetNodeDataShape(*producer->output_nodes.begin()); + auto consumer_output_shape = this->GetNodeDataShape(*consumer->output_nodes.begin()); + auto consumer_master_input_shape = this->GetNodeInputShape(*(consumer->master_nodes.begin())); + int producer_output_numel = + std::accumulate(producer_output_shape.begin(), producer_output_shape.end(), 1, std::multiplies()); + int consumer_output_numel = + std::accumulate(consumer_output_shape.begin(), consumer_output_shape.end(), 1, std::multiplies()); + int consumer_master_input_numel = std::accumulate( + consumer_master_input_shape.begin(), consumer_master_input_shape.end(), 1, std::multiplies()); + if (producer_output_numel == consumer_output_numel) { + candidates.push_back(consumer); + continue; + } + + if (producer->op_pattern_kind != framework::kInjective && consumer->op_pattern_kind == framework::kReduction && + producer_output_numel == consumer_master_input_numel) { + candidates.push_back(consumer); + } + } + sort(candidates.begin(), candidates.end(), [](const auto& lhs, const auto& rhs) { + return lhs->op_pattern_kind < rhs->op_pattern_kind; + }); + + fusionable_consumers.clear(); + if (candidates.size()) { + fusionable_consumers.insert(*candidates.begin()); + } + } else { + std::unordered_set candidates; + for (auto& consumer : fusionable_consumers) { + if (consumer->op_pattern_kind == framework::kElementWise) { + candidates.insert(consumer); + continue; + } + + auto shape0 = this->GetNodeDataShape(*producer->output_nodes.begin()); + auto shape1 = this->GetNodeDataShape(*consumer->output_nodes.begin()); + + if (std::accumulate(shape0.begin(), shape0.end(), 1, std::multiplies()) == + std::accumulate(shape1.begin(), shape1.end(), 1, std::multiplies())) { + candidates.insert(consumer); + } + } + + fusionable_consumers.clear(); + if (candidates.size()) { + fusionable_consumers.insert(*candidates.begin()); + } + } + } + + bool IsDependency(const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (const auto& producer_and_list : candidate->producer_groups) { + if (producer_and_list.first.get() == producer_g.get()) { + continue; + } + if (consumers.count(producer_and_list.first)) { + return true; + } + if (!visited_set.count(producer_and_list.first)) { + visited_set.insert(producer_and_list.first); + candidates.push(producer_and_list.first); + } + } + } + return false; + } + + bool IsDependencySimplify(const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + // check upper. + int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (auto& producer_and_list : candidate->producer_groups) { + if (producer_and_list.first.get() == producer_g.get()) { + continue; + } + if (producer_and_list.first->min_depth > check_upper_depth) { + continue; + } + if (consumers.count(producer_and_list.first)) { + return true; + } + if (!visited_set.count(producer_and_list.first)) { + visited_set.insert(producer_and_list.first); + candidates.push(producer_and_list.first); + } + } + } + return false; + } + + bool FuseInputToConsumers() { + VLOG(3) << "FuseInputToConsumers...!"; + auto updated = false; + UpdateInputToConsumers(); + GroupPtr producer(nullptr); + for (auto& input_consumers : input_to_consumers_) { + // if group set size == 1. + if (input_consumers.second.size() == 1) { + continue; + } + // do horizontal fusion. + auto st = HorizontalFusion(producer, input_consumers.second); + if (st) { + // fused consumers, update + UpdateInputToConsumers(); + } + updated |= st; + } + + return updated; + } + + void UpdateInputToConsumers() { + for (auto& input_consumers : input_to_consumers_) { + auto& consumers = input_consumers.second; + std::unordered_set updated_consumers; + for (auto& consumer : consumers) { + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(belong_group); + } else { + fused_groups.push(belong_group); + } + } + } + } + } + consumers = updated_consumers; + } + } + + void InitInputToConsumers() { + VLOG(3) << "InitInputToConsumers...!"; + // init input data node -> fusion group map. + for (auto& group : fusion_groups_) { + for (auto& node : group->nodes_set) { + // collect producer node data. + auto producer_node_datas = GetProducerNodeData(node); + for (auto& node_data : producer_node_datas) { + // node data's source node is null. + if (!node_data->source_node.get()) { + // insert group to set. + input_to_consumers_[node_data].insert(group); + } + } + } + } + } + + void InitFusionGroupsAndIndex() { + VLOG(3) << "InitFusionGroupsAndIndex...!"; + // init the postion of groups in fusion groups. + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto group = fusion_groups_[idx]; + auto belong_group = std::make_shared(); + // copy from group. + belong_group->max_depth = group->depth; + belong_group->min_depth = group->depth; + belong_group->group_id = group->group_id; + belong_group->input_nodes = group->input_nodes; + belong_group->output_nodes = group->output_nodes; + belong_group->op_pattern_kind = group->op_pattern_kind; + belong_group->master_nodes = group->master_nodes; + belong_group->producer_groups = group->producer_groups; + belong_group->consumer_groups = group->consumer_groups; + belong_group->fused_sub_groups.push_back(group); + group->belong_groups.insert(belong_group); + // replace group to fused_group + fusion_groups_[idx] = belong_group; + // record idx + fusion_groups_index_[belong_group] = idx; + } + + // update producer and consumer. + for (auto& group : fusion_groups_) { + std::unordered_map producers; + std::unordered_map consumers; + + for (auto& producer_and_list : group->producer_groups) { + CHECK(producer_and_list.first->belong_groups.size()); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + producers[*producer_and_list.first->belong_groups.begin()] += {}; + } + + for (auto& consumer_and_list : group->consumer_groups) { + CHECK(consumer_and_list.first->belong_groups.size()); + // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. + consumers[*consumer_and_list.first->belong_groups.begin()] += {}; + } + CHECK_EQ(group->producer_groups.size(), producers.size()); + CHECK_EQ(group->consumer_groups.size(), consumers.size()); + group->producer_groups = producers; + group->consumer_groups = consumers; + } + } + + void InitFusionRelation() { + VLOG(3) << "InitFusionRelation...!"; + // kElementWise + { + auto& relation = fusion_relation_map_[OpPatternKind::kElementWise]; + // horizontal + relation.horizontal_relation = {{framework::kElementWise, is_same_size}, + // element-wise and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // element-wise and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // element-wise and reduce op must be horizontal relation. + {OpPatternKind::kReduction, honrizontal_elementwise_fuse_reduce}}; + // vertical + relation.vertical_relation = {{OpPatternKind::kElementWise, is_same_size}, + // element-wise and broadcast can be vertical/horizontal relation. + {OpPatternKind::kBroadcast, elementwise_fuse_broadcast}, + // element-wise and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // element-wise and reduce can be vertical/horizontal relation. + {OpPatternKind::kReduction, elementwise_fuse_reduce}}; + } + // kBroadcast + { + auto& relation = fusion_relation_map_[OpPatternKind::kBroadcast]; + // horizontal + relation.horizontal_relation = {// broadcast and element-wise op must be horizontal relation. + {framework::kElementWise, is_same_size}, + // broadcast and broadcast op must be horizontal relation. + {framework::kBroadcast, is_same_size}, + // broadcast and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // broadcast and reduce op must be horizontal relation. + {OpPatternKind::kReduction, is_same_size}}; + // vertical + relation.vertical_relation = {// broadcast and element-wise op must be vertical relation. + {OpPatternKind::kElementWise, is_same_size}, + // broadcast and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // broadcast and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // broadcast and reduce must be vertical relation. + {OpPatternKind::kReduction, broadcast_fuse_reduce}}; + } + // kInjective + { + auto& relation = fusion_relation_map_[OpPatternKind::kInjective]; + // horizontal + relation.horizontal_relation = {// injective and element-wise op must be horizontal relation. + {OpPatternKind::kElementWise, is_same_size}, + // injective and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // injective and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // injective and reduce must be horizontal relation. + {OpPatternKind::kReduction, is_same_size}}; + // vertical + relation.vertical_relation = {// injective and element-wise op must be horizontal relation. + {OpPatternKind::kElementWise, is_same_size}, + // injective and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // injective and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // injective and reduce can be horizontal/vertical relation. + {OpPatternKind::kReduction, injective_horizontal_with_reduce}}; + } + // kReduction + { + auto& relation = fusion_relation_map_[OpPatternKind::kReduction]; + // horizontal + relation.horizontal_relation = {// reduce and element-wise op must be horizontal relation. + {OpPatternKind::kElementWise, honrizontal_elementwise_fuse_reduce}, + // reduce and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // reduce and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // reduce and reduce must be horizontal relation. + {OpPatternKind::kReduction, reduce_fuse_reduce}}; + // vertical + relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. + {OpPatternKind::kElementWise, reduce_fuse_elementwise}, + // reduce and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, + // reduce and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // reduce and reduce must be horizontal relation. + {OpPatternKind::kReduction, reduce_fuse_reduce}}; + } + } + + GroupList fusion_groups_; + std::unordered_map fusion_groups_index_; + std::unordered_map> input_to_consumers_; + + struct Relation { + std::unordered_map vertical_relation; + std::unordered_map horizontal_relation; + }; + std::unordered_map fusion_relation_map_; +}; + +} // namespace + +void GeneralFusionMergePassInternal(Graph* graph) { + if (graph->fusion_groups.size() <= 1) { + VLOG(3) << "Don't do Fusoin Merge Pass...!"; + return; + } + + FusionMergePassHelper fusion_merge_pass_helper(graph); + graph->fusion_groups = fusion_merge_pass_helper(); +} + +} // namespace pass +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(GeneralFusionMergePass) { + CINN_REGISTER_PASS(GeneralFusionMergePass) + .describe( + "Fusion Merge Pass which performs Fusion-Ops fusion, Producer Fusion-Ops are fused into Consumer Fusion-Ops " + "with certain conditions.") + .set_change_structure(false) + .set_body(cinn::hlir::pass::GeneralFusionMergePassInternal); + + return true; +} diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index 649baced8d..03227caf7c 100644 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -52,6 +52,10 @@ DEFINE_int32(cinn_parallel_compile_thread, DEFINE_bool(cinn_use_op_fusion, BoolFromEnv("FLAGS_cinn_use_op_fusion", true), "Whether to use op fusion pass."); +DEFINE_bool(general_fusion_merge_pass, + BoolFromEnv("FLAGS_general_fusion_merge_pass", true), + "Whether to use general fusion_merge pass."); + DEFINE_bool(cinn_use_common_subexpression_elimination, BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), "Whether to use common subexpression elimination pass."); From 8b0f9487e91ab69a93ccb57569d489ced1f4df22 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 12 Jun 2023 13:15:05 +0000 Subject: [PATCH 12/66] horizontal fuse, code complete, wait for debug --- .../framework => api}/fuse_pass_context.h | 10 +- cinn/api/op_group_interface.h | 46 ++ .../framework => api}/tensor_interface.h | 6 +- .../framework => api}/tensor_interface_list.h | 8 +- cinn/common/CMakeLists.txt | 2 +- cinn/hlir/framework/graph.h | 11 +- cinn/hlir/framework/op_group_interface.h | 44 -- cinn/hlir/pass/fusion_merge_pass.cc | 160 +++--- cinn/hlir/pass/general_fusion_merge_pass.cc | 499 ++++++++++-------- cinn/hlir/pass/op_fusion_pass.cc | 17 +- 10 files changed, 443 insertions(+), 360 deletions(-) rename cinn/{hlir/framework => api}/fuse_pass_context.h (88%) create mode 100644 cinn/api/op_group_interface.h rename cinn/{hlir/framework => api}/tensor_interface.h (92%) rename cinn/{hlir/framework => api}/tensor_interface_list.h (90%) delete mode 100644 cinn/hlir/framework/op_group_interface.h diff --git a/cinn/hlir/framework/fuse_pass_context.h b/cinn/api/fuse_pass_context.h similarity index 88% rename from cinn/hlir/framework/fuse_pass_context.h rename to cinn/api/fuse_pass_context.h index 2be1b4e907..ed702fabec 100644 --- a/cinn/hlir/framework/fuse_pass_context.h +++ b/cinn/api/fuse_pass_context.h @@ -14,12 +14,10 @@ #pragma once -#include "cinn/hlir/framework/op_group_interface.h" +#include "cinn/api/op_group_interface.h" namespace cinn { -namespace hlir { -namespace framework { - +namespace api { class FusePassContext { public: @@ -32,9 +30,7 @@ class FusePassContext { void EnableVerticalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group); void EnableHorizontalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group); - }; -} // namespace framework -} // namespace hlir +} // namespace api } // namespace cinn diff --git a/cinn/api/op_group_interface.h b/cinn/api/op_group_interface.h new file mode 100644 index 0000000000..30d6e8db1a --- /dev/null +++ b/cinn/api/op_group_interface.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/api/tensor_interface.h" +#include "cinn/api/tensor_interface_list.h" + +namespace cinn { +namespace api { + +class OpGroupInterface { + public: + // virtual const TensorInterfaceList& input_tensors() const = 0; + + // virtual const TensorInterfaceList& output_tensors() const = 0; + + // virtual const std::unordered_set> producers() const = 0; + + // virtual const std::unordered_set> consumers() const = 0; + + virtual const std::unordered_map, TensorInterfaceList>& producer_groups() const = 0; + + virtual const std::unordered_map, TensorInterfaceList>& consumer_groups() const = 0; + + protected: + OpGroupInterface() = default; +}; + +} // namespace api +} // namespace cinn diff --git a/cinn/hlir/framework/tensor_interface.h b/cinn/api/tensor_interface.h similarity index 92% rename from cinn/hlir/framework/tensor_interface.h rename to cinn/api/tensor_interface.h index 843cd0cd3e..2bca4a62df 100644 --- a/cinn/hlir/framework/tensor_interface.h +++ b/cinn/api/tensor_interface.h @@ -17,8 +17,7 @@ #include namespace cinn { -namespace hlir { -namespace framework { +namespace api { class ShapeInterface; @@ -35,6 +34,5 @@ class TensorInterface { using TensorInterfacePtr = std::shared_ptr; -} // namespace framework -} // namespace hlir +} // namespace api } // namespace cinn diff --git a/cinn/hlir/framework/tensor_interface_list.h b/cinn/api/tensor_interface_list.h similarity index 90% rename from cinn/hlir/framework/tensor_interface_list.h rename to cinn/api/tensor_interface_list.h index beb3f32877..0a0a2121f3 100644 --- a/cinn/hlir/framework/tensor_interface_list.h +++ b/cinn/api/tensor_interface_list.h @@ -17,12 +17,11 @@ #include #include -#include "cinn/hlir/framework/tensor_interface.h" +#include "cinn/api/tensor_interface.h" #include "cinn/utils/small_vector.h" namespace cinn { -namespace hlir { -namespace framework { +namespace api { class TensorInterfaceList : public cinn::utils::SmallVector { public: @@ -40,6 +39,5 @@ class TensorInterfaceList : public cinn::utils::SmallVector #include -#include "cinn/api/cpp/op_group_interface.h" +#include "cinn/api/op_group_interface.h" +#include "cinn/api/tensor_interface_list.h" #include "cinn/common/graph_utils.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/node.h" -#include "cinn/hlir/framework/tensor_interface_list.h" namespace cinn { namespace hlir { namespace framework { +using OpGroupInterface = cinn::api::OpGroupInterface; +using TensorInterfaceList = cinn::api::TensorInterfaceList; + /** * \brief Symbolic computation graph. * This is the intermediate representation for optimization pass. @@ -92,8 +95,8 @@ class Graph : public cinn::common::Graph { std::unordered_set> CollectConsumerGroups() { std::unordered_set> groups; - for (const auto& consumer_and_list : consumer_groups) { - groups.insert(consumer_and_list.first); + for (const auto& consumer_and_list : consumer_groups_) { + groups.insert(std::dynamic_pointer_cast(consumer_and_list.first)); } return groups; } diff --git a/cinn/hlir/framework/op_group_interface.h b/cinn/hlir/framework/op_group_interface.h deleted file mode 100644 index a7ef09d46d..0000000000 --- a/cinn/hlir/framework/op_group_interface.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "cinn/hlir/framework/tensor_interface.h" -#include "cinn/hlir/framework/tensor_interface_list.h" - -namespace cinn { -namespace hlir { -namespace framework { - - -class OpGroupInterface { - public: - virtual const TensorInterfaceList& input_tensors() const = 0; - - virtual const TensorInterfaceList& output_tensors() const = 0; - - virtual const std::unordered_set> producers() const = 0; - - virtual const std::unordered_set> consumers() const = 0; - - protect: - OpGroupInterface() = default; -}; - -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index dd7508d299..957447942d 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -32,6 +32,9 @@ using common::GraphNode; using GroupPtr = std::shared_ptr; using GroupList = std::vector; +using OpGroupPtr = std::shared_ptr; +using OpGroupList = std::vector; + using ConditionFunction = std::function; // Op Fusion Pass which performs Ops fusion, Ops are fused @@ -58,13 +61,13 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& sub_group : group->fused_sub_groups) { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } - for (const auto& pair : group->producer_groups) { + for (const auto& pair : group->producer_groups()) { const auto& producer = pair.first; - VLOG(3) << " Producer -> " << producer->group_id; + VLOG(3) << " Producer -> " << std::dynamic_pointer_cast(producer)->group_id; } - for (const auto& pair : group->consumer_groups) { + for (const auto& pair : group->consumer_groups()) { const auto& consumer = pair.first; - VLOG(3) << " Consumer -> " << consumer->group_id; + VLOG(3) << " Consumer -> " << std::dynamic_pointer_cast(consumer)->group_id; } } return fusion_groups_; @@ -149,8 +152,8 @@ class FusionMergePassHelper : public FusionHelperBase { } bool exist = false; - for (const auto& pair : group->producer_groups) { - const auto& producer = pair.first; + for (const auto& pair : group->producer_groups()) { + const auto& producer = std::dynamic_pointer_cast(pair.first); if (fusion_groups_set.count(producer)) { VLOG(4) << group->group_id << " " << producer->group_id; exist = true; @@ -315,20 +318,22 @@ class FusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (const auto& producer_and_list : consumer->producer_groups) { - fused_group->producer_groups[producer_and_list.first] += producer_and_list.second; + for (const auto& producer_and_list : consumer->producer_groups()) { + GroupPtr producer = std::dynamic_pointer_cast(producer_and_list.first); + (*fused_group->mut_producer_groups())[producer] += producer_and_list.second; // update producer's consumer - producer_and_list.first->consumer_groups.erase(consumer); + producer->mut_consumer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producer_and_list.first->consumer_groups[fused_group] += {}; + (*producer->mut_consumer_groups())[fused_group] += {}; } // consumer group - for (const auto& gconsumer_and_list : consumer->consumer_groups) { - fused_group->consumer_groups[gconsumer_and_list.first] += gconsumer_and_list.second; + for (const auto& gconsumer_and_list : consumer->consumer_groups()) { + GroupPtr gconsumer = std::dynamic_pointer_cast(gconsumer_and_list.first); + (*fused_group->mut_consumer_groups())[gconsumer] += gconsumer_and_list.second; // update consumer's producer - gconsumer_and_list.first->producer_groups.erase(consumer); + gconsumer->mut_producer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - gconsumer_and_list.first->producer_groups[fused_group] += {}; + (*gconsumer->mut_producer_groups())[fused_group] += {}; } // belongs group consumer->belong_groups.insert(fused_group); @@ -434,7 +439,7 @@ class FusionMergePassHelper : public FusionHelperBase { // if can_fuse_consumers == consumers // if producer op kind == kElementwise // if use recompute - if (fuse_consumers_unsafe.size() == producer->consumer_groups.size() && + if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() && producer->op_pattern_kind == framework::kElementWise) { if (!recompute) { return false; @@ -501,12 +506,13 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer groups - for (const auto& group_and_list : producer->producer_groups) { - fused_group->producer_groups[group_and_list.first] += group_and_list.second; + for (const auto& group_and_list : producer->producer_groups()) { + (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; + const auto& group = std::dynamic_pointer_cast(group_and_list.first); // update producer's producer's consumer - group_and_list.first->consumer_groups.erase(producer); + group->mut_consumer_groups()->erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->consumer_groups[fused_group] += {}; + (*group->mut_consumer_groups())[fused_group] += {}; } // sub groups @@ -552,23 +558,25 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group_and_list : consumer->producer_groups) { + for (const auto& group_and_list : consumer->producer_groups()) { if (group_and_list.first.get() != producer.get()) { - fused_group->producer_groups[group_and_list.first] += group_and_list.second; + (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; + const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); // update consumer's producer's consumer - group_and_list.first->consumer_groups.erase(consumer); + group->mut_consumer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->consumer_groups[fused_group] += {}; + (*group->mut_consumer_groups())[fused_group] += {}; } } // consumer nodes - for (const auto& group_and_list : consumer->consumer_groups) { - fused_group->consumer_groups[group_and_list.first] += group_and_list.second; + for (const auto& group_and_list : consumer->consumer_groups()) { + (*fused_group->mut_consumer_groups())[group_and_list.first] += group_and_list.second; + const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); // update consumer's consumer's producer - group_and_list.first->producer_groups.erase(consumer); + group->mut_producer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->producer_groups[fused_group] += {}; + (*group->mut_producer_groups())[fused_group] += {}; } // sub group @@ -602,16 +610,17 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (const auto& consumer_and_list : producer->consumer_groups) { + for (const auto& consumer_and_list : producer->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); // if consumer is in fusionable. - if (fusionable_consumers.count(consumer_and_list.first)) { - if (consumer_and_list.first->input_nodes.count(node)) { + if (fusionable_consumers.count(consumer)) { + if (consumer->input_nodes.count(node)) { be_output = false; } continue; } // if consumer is not in fusionable. - if (consumer_and_list.first->input_nodes.count(node)) { + if (consumer->input_nodes.count(node)) { be_output = true; break; } @@ -628,15 +637,16 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer_and_list : producer->consumer_groups) { - if (fusionable_consumers.count(consumer_and_list.first)) { + for (const auto& consumer_and_list : producer->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + if (fusionable_consumers.count(consumer)) { continue; } - master_fuesd_group->consumer_groups[consumer_and_list.first] += consumer_and_list.second; + (*master_fuesd_group->mut_consumer_groups())[consumer_and_list.first] += consumer_and_list.second; // update consumer's producer - consumer_and_list.first->producer_groups.erase(producer); + consumer->mut_producer_groups()->erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumer_and_list.first->producer_groups[master_fuesd_group] += {}; + (*consumer->mut_producer_groups())[master_fuesd_group] += {}; } } @@ -664,13 +674,13 @@ class FusionMergePassHelper : public FusionHelperBase { sub_group->nodes_set.insert(producer->CollectNodes()[0]); // remove depency. consumer->input_nodes.erase(producer->CollectNodes()[0]); - consumer->producer_groups.erase(producer); - producer->consumer_groups.erase(consumer); + consumer->mut_producer_groups()->erase(producer); + producer->mut_consumer_groups()->erase(consumer); } } - CHECK_GE(producer->consumer_groups.size(), candidates.size()); - if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && + CHECK_GE(producer->consumer_groups().size(), candidates.size()); + if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { producer->belong_groups.insert(*fusionable_consumers.begin()); } @@ -679,7 +689,7 @@ class FusionMergePassHelper : public FusionHelperBase { return; } // 1 to 1 fusion. - if (producer->consumer_groups.size() == 1) { + if (producer->consumer_groups().size() == 1) { return; } @@ -752,16 +762,17 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups) { + for (const auto& producer_and_list : candidate->producer_groups()) { if (producer_and_list.first.get() == producer_g.get()) { continue; } - if (consumers.count(producer_and_list.first)) { + const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + if (consumers.count(producer)) { return true; } - if (!visited_set.count(producer_and_list.first)) { - visited_set.insert(producer_and_list.first); - candidates.push(producer_and_list.first); + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); } } } @@ -779,19 +790,20 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer_and_list : candidate->producer_groups) { + for (auto& producer_and_list : candidate->producer_groups()) { if (producer_and_list.first.get() == producer_g.get()) { continue; } - if (producer_and_list.first->min_depth > check_upper_depth) { + const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + if (producer->min_depth > check_upper_depth) { continue; } - if (consumers.count(producer_and_list.first)) { + if (consumers.count(producer)) { return true; } - if (!visited_set.count(producer_and_list.first)) { - visited_set.insert(producer_and_list.first); - candidates.push(producer_and_list.first); + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); } } } @@ -873,15 +885,15 @@ class FusionMergePassHelper : public FusionHelperBase { auto group = fusion_groups_[idx]; auto belong_group = std::make_shared(); // copy from group. - belong_group->max_depth = group->depth; - belong_group->min_depth = group->depth; - belong_group->group_id = group->group_id; - belong_group->input_nodes = group->input_nodes; - belong_group->output_nodes = group->output_nodes; - belong_group->op_pattern_kind = group->op_pattern_kind; - belong_group->master_nodes = group->master_nodes; - belong_group->producer_groups = group->producer_groups; - belong_group->consumer_groups = group->consumer_groups; + belong_group->max_depth = group->depth; + belong_group->min_depth = group->depth; + belong_group->group_id = group->group_id; + belong_group->input_nodes = group->input_nodes; + belong_group->output_nodes = group->output_nodes; + belong_group->op_pattern_kind = group->op_pattern_kind; + belong_group->master_nodes = group->master_nodes; + (*belong_group->mut_producer_groups()) = group->producer_groups(); + (*belong_group->mut_consumer_groups()) = group->consumer_groups(); belong_group->fused_sub_groups.push_back(group); group->belong_groups.insert(belong_group); // replace group to fused_group @@ -892,24 +904,26 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_map producers; + std::unordered_map consumers; - for (auto& producer_and_list : group->producer_groups) { - CHECK(producer_and_list.first->belong_groups.size()); + for (auto& producer_and_list : group->producer_groups()) { + const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + CHECK(producer->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producers[*producer_and_list.first->belong_groups.begin()] += {}; + producers[*producer->belong_groups.begin()] += {}; } - for (auto& consumer_and_list : group->consumer_groups) { - CHECK(consumer_and_list.first->belong_groups.size()); + for (auto& consumer_and_list : group->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + CHECK(consumer->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumers[*consumer_and_list.first->belong_groups.begin()] += {}; + consumers[*consumer->belong_groups.begin()] += {}; } - CHECK_EQ(group->producer_groups.size(), producers.size()); - CHECK_EQ(group->consumer_groups.size(), consumers.size()); - group->producer_groups = producers; - group->consumer_groups = consumers; + CHECK_EQ(group->producer_groups().size(), producers.size()); + CHECK_EQ(group->consumer_groups().size(), consumers.size()); + (*group->mut_producer_groups()) = producers; + (*group->mut_consumer_groups()) = consumers; } } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index cbf33de10d..d920436869 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "cinn/api/op_group_interface.h" +#include "cinn/common/is_reachable_predicator.h" #include "cinn/hlir/pass/fusion_merge_pass_util.h" DECLARE_bool(enhance_vertical_fusion_with_recompute); @@ -38,6 +40,202 @@ using OpGroupList = std::vector; using ConditionFunction = std::function; +class GraphGroupLightwareFusePassCtx; +class FuseHelper { + public: + virtual ~FuseHelper() = default; + + virtual bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const = 0; + + virtual bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + + virtual bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + + protected: + FuseHelper() = default; +}; + +class GraphGroupFuseHelper final : public FuseHelper { + public: + explicit GraphGroupFuseHelper(const GraphGroupLightwareFusePassCtx* ctx) : ctx_(ctx) {} + + bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const override { + return is_same_size(&ctx_->graph_group_fusion_helper(), first, second); + } + + bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override { + return horizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), src, dst); + } + + bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override { + return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), src, dst); + } + + bool DetectCycleIfFuse(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const override { + return ReachableIfDirectEdgeIgnored(lhs, rhs) || ReachableIfDirectEdgeIgnored(rhs, lhs); + } + + private: + bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& src, const OpGroupPtr& dst) const override { + const auto& MinDepth4Node = [&](OpGroupPtr node) { return std::dynamic_pointer_cast(node)->min_depth; }; + const auto& MaxDepth4Node = [&](OpGroupPtr node) { return std::dynamic_pointer_cast(node)->max_depth; }; + const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { + for (const auto& pair : node->consumer2outputs()) { + if (node == src && pair.first == dst) { + continue; + } + Visit(pair.first); + } + }; + common::IsReachablePredicator is_reachable(MinDepth4Node, MaxDepth4Node, VisitNextNodes); + return is_reachable(src, dst, [](OpGroupPtr) {}); + } + + const GraphGroupLightwareFusePassCtx* ctx_; +}; + +class LightwareFusePassCtx { + public: + virtual ~LightwareFusePassCtx() {} + + virtual const OpGroupPtr& PickOpGroup() const = 0; + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + + protected: + LightwareFusePassCtx() = default; +}; + +class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { + public: + GraphGroupLightwareFusePassCtx( + const FusionHelperBase* graph_group_fusion_helper, + const OpGroupPtr& group, + const std::function& EnableFuse) + : graph_group_fusion_helper_(graph_group_fusion_helper), + group_(group), + EnableFuse_(EnableFuse), + fuse_helper_(this) {} + + const OpGroupPtr& PickOpGroup() const override { return group_; } + + const FuseHelper& fuse_helper() const override { return fuse_helper_; } + + void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } + + const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } + + private: + const FusionHelperBase* graph_group_fusion_helper_; + const OpGroupPtr group_; + const std::function EnableFuse_; + const GraphGroupFuseHelper fuse_helper_; +}; + +class FusePass { + public: + virtual ~FusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + protected: + FusePass(); +}; + +class DefautlHorizontalFusePass final : public FusePass { + public: + DefautlHorizontalFusePass() : FusePass() {} + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& pair : producer->consumer2outputs()) { + consumers.insert(pair.first); + } + return consumers; + }(); + if (consumers.size() <= 1) { + return; + } + for (int i = 0; i < consumers.size(); ++i) { + const auto& src = consumers.at(i); + for (int j = i + 1; j < consumers.size(); ++j) { + const auto& dst = consumers.at(j); + if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { + continue; + } + if (!DetectFusabilityByKind(ctx, src, dst)) { + continue; + } + ctx->EnableFuse(src, dst); + return; + } + } + } + + using KindKeyT = std::pair; + + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src->kind(), dst->kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + + typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); + + const std::unordered_map& GetConditionMap() const { + thread_local static std::unordered_map map(RawConditionMap()); + return map; + } + + std::unordered_map RawConditionMap() const { + return std::unordered_map{ + {{OpPatternKind::kElementWise, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kReduction}, + &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, + + {{OpPatternKind::kInjective, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, + + {{OpPatternKind::kReduction, framework::kElementWise}, + &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, + {{OpPatternKind::kReduction, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kReduction, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kReduction, framework::kReduction}, &DefautlHorizontalFusePass::ReduceFuseReduce}, + }; + } + + static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().AllOutputsSameSize(src, dst); + } + + static bool HorizontalElementwiseFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); + } + + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } +}; + // Op Fusion Pass which performs Ops fusion, Ops are fused // "vertically", meaning producing Ops are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -62,12 +260,12 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& sub_group : group->fused_sub_groups) { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } - for (const auto& pair : group->producer_groups) { - const auto& producer = pair.first; + for (const auto& pair : group->producer_groups()) { + const auto& producer = std::dynamic_pointer_cast(pair.first); VLOG(3) << " Producer -> " << producer->group_id; } - for (const auto& pair : group->consumer_groups) { - const auto& consumer = pair.first; + for (const auto& pair : group->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(pair.first); VLOG(3) << " Consumer -> " << consumer->group_id; } } @@ -153,8 +351,8 @@ class FusionMergePassHelper : public FusionHelperBase { } bool exist = false; - for (const auto& pair : group->producer_groups) { - const auto& producer = pair.first; + for (const auto& pair : group->producer_groups()) { + const auto& producer = std::dynamic_pointer_cast(pair.first); if (fusion_groups_set.count(producer)) { VLOG(4) << group->group_id << " " << producer->group_id; exist = true; @@ -177,144 +375,6 @@ class FusionMergePassHelper : public FusionHelperBase { } } - struct LightwareFusePassCtx; - - struct FuseHelper final { - public: - explicit FuseHelper(LightwareFusePassCtx* ctx) : ctx_(ctx) {} - - bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { TODO(); } - - bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } - - bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } - - bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const { TODO(); } - - private: - LightwareFusePassCtx* ctx_; - }; - - struct LightwareFusePassCtx final { - public: - LightwareFusePassCtx(const OpGroupPtr& group, - const std::function& EnableFuse) - : group_(group), EnableFuse_(EnableFuse), fuse_helper_(this) {} - - const OpGroupPtr& PickOpGroup() const { return group_; } - - const FuseHelper& fuse_helper() const { return fuse_helper_; } - - void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) { EnableFuse_(first, second); } - - private: - const OpGroupPtr group_; - const std::function EnableFuse_; - const FuseHelper fuse_helper_; - }; - - class FusePass { - public: - virtual ~FusePass() = default; - - virtual void operator()(LightwareFusePassCtx* ctx) const = 0; - - protected: - FusePass(); - }; - - class DefautlHorizontalFusePass final : public FusePass { - public: - DefautlHorizontalFusePass() : FusePass() {} - - void operator()(LightwareFusePassCtx* ctx) const override { - const auto& producer = ctx->PickOpGroup(); - const OpGroupList consumers = [&]() { - OpGroupList consumers; - for (const auto& pair : producer->consumer2outputs()) { - consumers.insert(pair.first); - } - return consumers; - }(); - if (consumers.size() <= 1) { - return; - } - for (int i = 0; i < consumers.size(); ++i) { - const auto& src = consumers.at(i); - for (int j = i + 1; j < consumers.size(); ++j) { - const auto& dst = consumers.at(j); - if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { - continue; - } - if (!DetectFusabilityByKind(ctx, src, dst)) { - continue; - } - ctx->EnableFuse(src, dst); - return; - } - } - } - - using KindKeyT = std::pair; - - bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src->kind(), dst->kind()); - const auto& map = GetConditionMap(); - const auto& iter = map.find(kind_pair); - if (iter == map.end()) { - return false; - } - return iter->second(ctx, src, dst); - } - - typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); - - const std::unordered_map& GetConditionMap() const { - thread_local static std::unordered_map map(RawConditionMap()); - return map; - } - - std::unordered_map RawConditionMap() const { - return std::unordered_map{ - {{OpPatternKind::kElementWise, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kReduction}, - &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, - - {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, - - {{OpPatternKind::kInjective, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, - - {{OpPatternKind::kReduction, framework::kElementWise}, - &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, - {{OpPatternKind::kReduction, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kReduction, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kReduction, framework::kReduction}, &DefautlHorizontalFusePass::ReduceFuseReduce}, - }; - } - - static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().AllOutputsSameSize(src, dst); - } - - static bool HorizontalElementwiseFuseReduce(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); - } - - static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseReduce(src, dst); - } - }; - std::vector> RawHorizontalFusePasses() const { std::vector>{ std::shared_ptr(new DefautlHorizontalFusePass{}), @@ -327,7 +387,7 @@ class FusionMergePassHelper : public FusionHelperBase { return fuse_passes; } - void TagHorizontalGroups(LightwareFusePassCtx* ctx) const { + void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); if (producer->consumer2outputs().size() <= 1) { return; @@ -346,8 +406,8 @@ class FusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const GroupPtr& first, const GroupPtr& second) { tagged_sets.insert(std::set{first, second}); }; - LightwareFusePassCtx fuse_ctx(producer, EnableFuse); - TagHorizontalGroups(&fuse_ctx); + GraphGroupLightwareFusePassCtx fuse_ctx(producer, EnableFuse); + EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; const auto& GetFusableConsumerGroupList = [&]() -> GroupList { @@ -510,20 +570,22 @@ class FusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (const auto& producer_and_list : consumer->producer_groups) { - fused_group->producer_groups[producer_and_list.first] += producer_and_list.second; + for (const auto& producer_and_list : consumer->producer_groups()) { + GroupPtr producer = std::dynamic_pointer_cast(producer_and_list.first); + (*fused_group->mut_producer_groups())[producer] += producer_and_list.second; // update producer's consumer - producer_and_list.first->consumer_groups.erase(consumer); + producer->mut_consumer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producer_and_list.first->consumer_groups[fused_group] += {}; + (*producer->mut_consumer_groups())[fused_group] += {}; } // consumer group - for (const auto& gconsumer_and_list : consumer->consumer_groups) { - fused_group->consumer_groups[gconsumer_and_list.first] += gconsumer_and_list.second; + for (const auto& gconsumer_and_list : consumer->consumer_groups()) { + GroupPtr gconsumer = std::dynamic_pointer_cast(gconsumer_and_list.first); + (*fused_group->mut_consumer_groups())[gconsumer] += gconsumer_and_list.second; // update consumer's producer - gconsumer_and_list.first->producer_groups.erase(consumer); + gconsumer->mut_producer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - gconsumer_and_list.first->producer_groups[fused_group] += {}; + (*gconsumer->mut_producer_groups())[fused_group] += {}; } // belongs group consumer->belong_groups.insert(fused_group); @@ -629,7 +691,7 @@ class FusionMergePassHelper : public FusionHelperBase { // if can_fuse_consumers == consumers // if producer op kind == kElementwise // if use recompute - if (fuse_consumers_unsafe.size() == producer->consumer_groups.size() && + if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() && producer->op_pattern_kind == framework::kElementWise) { if (!recompute) { return false; @@ -696,12 +758,13 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer groups - for (const auto& group_and_list : producer->producer_groups) { - fused_group->producer_groups[group_and_list.first] += group_and_list.second; + for (const auto& group_and_list : producer->producer_groups()) { + (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; + const auto& group = std::dynamic_pointer_cast(group_and_list.first); // update producer's producer's consumer - group_and_list.first->consumer_groups.erase(producer); + group->mut_consumer_groups()->erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->consumer_groups[fused_group] += {}; + (*group->mut_consumer_groups())[fused_group] += {}; } // sub groups @@ -747,23 +810,25 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group_and_list : consumer->producer_groups) { + for (const auto& group_and_list : consumer->producer_groups()) { if (group_and_list.first.get() != producer.get()) { - fused_group->producer_groups[group_and_list.first] += group_and_list.second; + (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; + const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); // update consumer's producer's consumer - group_and_list.first->consumer_groups.erase(consumer); + group->mut_consumer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->consumer_groups[fused_group] += {}; + (*group->mut_consumer_groups())[fused_group] += {}; } } // consumer nodes - for (const auto& group_and_list : consumer->consumer_groups) { - fused_group->consumer_groups[group_and_list.first] += group_and_list.second; + for (const auto& group_and_list : consumer->consumer_groups()) { + (*fused_group->mut_consumer_groups())[group_and_list.first] += group_and_list.second; + const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); // update consumer's consumer's producer - group_and_list.first->producer_groups.erase(consumer); + group->mut_producer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->producer_groups[fused_group] += {}; + (*group->mut_producer_groups())[fused_group] += {}; } // sub group @@ -797,16 +862,17 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (const auto& consumer_and_list : producer->consumer_groups) { + for (const auto& consumer_and_list : producer->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); // if consumer is in fusionable. - if (fusionable_consumers.count(consumer_and_list.first)) { - if (consumer_and_list.first->input_nodes.count(node)) { + if (fusionable_consumers.count(consumer)) { + if (consumer->input_nodes.count(node)) { be_output = false; } continue; } // if consumer is not in fusionable. - if (consumer_and_list.first->input_nodes.count(node)) { + if (consumer->input_nodes.count(node)) { be_output = true; break; } @@ -823,15 +889,16 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer_and_list : producer->consumer_groups) { - if (fusionable_consumers.count(consumer_and_list.first)) { + for (const auto& consumer_and_list : producer->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + if (fusionable_consumers.count(consumer)) { continue; } - master_fuesd_group->consumer_groups[consumer_and_list.first] += consumer_and_list.second; + (*master_fuesd_group->mut_consumer_groups())[consumer_and_list.first] += consumer_and_list.second; // update consumer's producer - consumer_and_list.first->producer_groups.erase(producer); + consumer->mut_producer_groups()->erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumer_and_list.first->producer_groups[master_fuesd_group] += {}; + (*consumer->mut_producer_groups())[master_fuesd_group] += {}; } } @@ -859,13 +926,13 @@ class FusionMergePassHelper : public FusionHelperBase { sub_group->nodes_set.insert(producer->CollectNodes()[0]); // remove depency. consumer->input_nodes.erase(producer->CollectNodes()[0]); - consumer->producer_groups.erase(producer); - producer->consumer_groups.erase(consumer); + consumer->mut_producer_groups()->erase(producer); + producer->mut_consumer_groups()->erase(consumer); } } - CHECK_GE(producer->consumer_groups.size(), candidates.size()); - if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && + CHECK_GE(producer->consumer_groups().size(), candidates.size()); + if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { producer->belong_groups.insert(*fusionable_consumers.begin()); } @@ -874,7 +941,7 @@ class FusionMergePassHelper : public FusionHelperBase { return; } // 1 to 1 fusion. - if (producer->consumer_groups.size() == 1) { + if (producer->consumer_groups().size() == 1) { return; } @@ -947,16 +1014,17 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups) { + for (const auto& producer_and_list : candidate->producer_groups()) { if (producer_and_list.first.get() == producer_g.get()) { continue; } - if (consumers.count(producer_and_list.first)) { + const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + if (consumers.count(producer)) { return true; } - if (!visited_set.count(producer_and_list.first)) { - visited_set.insert(producer_and_list.first); - candidates.push(producer_and_list.first); + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); } } } @@ -974,19 +1042,20 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer_and_list : candidate->producer_groups) { + for (auto& producer_and_list : candidate->producer_groups()) { if (producer_and_list.first.get() == producer_g.get()) { continue; } - if (producer_and_list.first->min_depth > check_upper_depth) { + const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + if (producer->min_depth > check_upper_depth) { continue; } - if (consumers.count(producer_and_list.first)) { + if (consumers.count(producer)) { return true; } - if (!visited_set.count(producer_and_list.first)) { - visited_set.insert(producer_and_list.first); - candidates.push(producer_and_list.first); + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); } } } @@ -1068,15 +1137,15 @@ class FusionMergePassHelper : public FusionHelperBase { auto group = fusion_groups_[idx]; auto belong_group = std::make_shared(); // copy from group. - belong_group->max_depth = group->depth; - belong_group->min_depth = group->depth; - belong_group->group_id = group->group_id; - belong_group->input_nodes = group->input_nodes; - belong_group->output_nodes = group->output_nodes; - belong_group->op_pattern_kind = group->op_pattern_kind; - belong_group->master_nodes = group->master_nodes; - belong_group->producer_groups = group->producer_groups; - belong_group->consumer_groups = group->consumer_groups; + belong_group->max_depth = group->depth; + belong_group->min_depth = group->depth; + belong_group->group_id = group->group_id; + belong_group->input_nodes = group->input_nodes; + belong_group->output_nodes = group->output_nodes; + belong_group->op_pattern_kind = group->op_pattern_kind; + belong_group->master_nodes = group->master_nodes; + (*belong_group->mut_producer_groups()) = group->producer_groups(); + (*belong_group->mut_consumer_groups()) = group->consumer_groups(); belong_group->fused_sub_groups.push_back(group); group->belong_groups.insert(belong_group); // replace group to fused_group @@ -1087,24 +1156,26 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_map producers; + std::unordered_map consumers; - for (auto& producer_and_list : group->producer_groups) { - CHECK(producer_and_list.first->belong_groups.size()); + for (auto& producer_and_list : group->producer_groups()) { + const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + CHECK(producer->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producers[*producer_and_list.first->belong_groups.begin()] += {}; + producers[*producer->belong_groups.begin()] += {}; } - for (auto& consumer_and_list : group->consumer_groups) { - CHECK(consumer_and_list.first->belong_groups.size()); + for (auto& consumer_and_list : group->consumer_groups()) { + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + CHECK(consumer->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumers[*consumer_and_list.first->belong_groups.begin()] += {}; + consumers[*consumer->belong_groups.begin()] += {}; } - CHECK_EQ(group->producer_groups.size(), producers.size()); - CHECK_EQ(group->consumer_groups.size(), consumers.size()); - group->producer_groups = producers; - group->consumer_groups = consumers; + CHECK_EQ(group->producer_groups().size(), producers.size()); + CHECK_EQ(group->consumer_groups().size(), consumers.size()); + (*group->mut_producer_groups()) = producers; + (*group->mut_consumer_groups()) = consumers; } } diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 749da7d2c6..759ca24d25 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -101,17 +101,18 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumer->producer_groups[producer] += {}; + (*consumer->mut_producer_groups())[producer] += {}; // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producer->consumer_groups[consumer] += {}; + (*producer->mut_consumer_groups())[consumer] += {}; } } // init group depth. for (auto& group : fusion_groups) { - for (const auto& consumer_and_list : group->consumer_groups) { + for (const auto& consumer_and_list : group->consumer_groups()) { // update depth. - group->depth = std::max(group->depth, consumer_and_list.first->depth + 1); + const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + group->depth = std::max(group->depth, consumer->depth + 1); } } @@ -350,11 +351,11 @@ void OpFusionPassInternal(Graph* graph) { for (auto& group : graph->fusion_groups) { VLOG(3) << "Group Id : " << group->group_id; - for (const auto& producer_and_list : group->producer_groups) { - VLOG(3) << " producer group -> " << producer_and_list.first->group_id; + for (const auto& producer_and_list : group->producer_groups()) { + VLOG(3) << " producer group -> " << std::dynamic_pointer_cast(producer_and_list.first)->group_id; } - for (const auto& consumer_and_list : group->consumer_groups) { - VLOG(3) << " consumer group -> " << consumer_and_list.first->group_id; + for (const auto& consumer_and_list : group->consumer_groups()) { + VLOG(3) << " consumer group -> " << std::dynamic_pointer_cast(consumer_and_list.first)->group_id; } } VLOG(3) << "OpFusionPass Finish...!"; From 59e3e4d2fe19a3dc0e84a0391d3fbfebca0a6a61 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 12 Jun 2023 14:15:14 +0000 Subject: [PATCH 13/66] Pass Horizontal Fuse Test --- cinn/api/op_group_interface.h | 7 + cinn/hlir/framework/graph.h | 2 + cinn/hlir/pass/CMakeLists.txt | 1 + .../pass/check_fusion_accuracy_pass_test.cc | 528 ++++++++++++++++++ cinn/hlir/pass/fusion_merge_pass_util.h | 2 +- cinn/hlir/pass/general_fusion_merge_pass.cc | 86 +-- cinn/hlir/pass/use_pass.h | 1 + 7 files changed, 594 insertions(+), 33 deletions(-) diff --git a/cinn/api/op_group_interface.h b/cinn/api/op_group_interface.h index 30d6e8db1a..506b562f99 100644 --- a/cinn/api/op_group_interface.h +++ b/cinn/api/op_group_interface.h @@ -20,12 +20,15 @@ #include "cinn/api/tensor_interface.h" #include "cinn/api/tensor_interface_list.h" +#include "cinn/hlir/framework/op.h" namespace cinn { namespace api { class OpGroupInterface { public: + virtual hlir::framework::OpPatternKind kind() const = 0; + // virtual const TensorInterfaceList& input_tensors() const = 0; // virtual const TensorInterfaceList& output_tensors() const = 0; @@ -38,6 +41,10 @@ class OpGroupInterface { virtual const std::unordered_map, TensorInterfaceList>& consumer_groups() const = 0; + const std::unordered_map, TensorInterfaceList>& consumer2outputs() const { + return consumer_groups(); + } + protected: OpGroupInterface() = default; }; diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 9ec872813b..6e15af3cea 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -143,6 +143,8 @@ class Graph : public cinn::common::Graph { return &consumer_groups_; } + hlir::framework::OpPatternKind kind() const override { return op_pattern_kind; } + private: // input groups std::unordered_map, TensorInterfaceList> producer_groups_; diff --git a/cinn/hlir/pass/CMakeLists.txt b/cinn/hlir/pass/CMakeLists.txt index ac48b9a153..20b5a1f071 100644 --- a/cinn/hlir/pass/CMakeLists.txt +++ b/cinn/hlir/pass/CMakeLists.txt @@ -8,6 +8,7 @@ gather_srcs(cinnapi_src SRCS const_propagate.cc op_fusion_pass.cc fusion_merge_pass.cc + general_fusion_merge_pass.cc dot_merger.cc check_fusion_accuracy_pass.cc custom_call_pass.cc diff --git a/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc b/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc index 3db0a1ff21..3283964ea9 100644 --- a/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc +++ b/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc @@ -92,6 +92,40 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion) { RunTest(target, graph, {"A", "B", "C", "D"}); } +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_0"); + std::unordered_set fetch_ids; + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(E, C); + auto G = net_builder.Add(E, D); + + fetch_ids = {F->id, G->id}; + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { int h = 32, w = 32; NetBuilder net_builder("ElementWise_Fusion_1"); @@ -125,6 +159,39 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { RunTest(target, graph, {"A", "B", "C", "D"}); } +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_1) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, D); + auto G = net_builder.Add(E, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { int h = 32, w = 32; NetBuilder net_builder("ElementWise_Fusion_2"); @@ -161,6 +228,42 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); } +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_2) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(C, D); + auto I = net_builder.Add(E, G); + auto J = net_builder.Add(G, H); + auto K = net_builder.Add(H, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { int h = 32, w = 32; NetBuilder net_builder("ElementWise_Fusion_3"); @@ -197,6 +300,42 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); } +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_3) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(G, C); + auto I = net_builder.Add(G, D); + auto J = net_builder.Add(G, E); + auto K = net_builder.Add(G, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { int h = 32, w = 32; NetBuilder net_builder("ElementWise_Fusion_4"); @@ -233,6 +372,42 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); } +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_4) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(G, C); + auto I = net_builder.Add(G, D); + auto J = net_builder.Add(I, E); + auto K = net_builder.Add(I, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { int h = 32, w = 32; NetBuilder net_builder("ElementWise_Fusion_5"); @@ -262,6 +437,35 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { RunTest(target, graph, {"A", "B"}); } +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_5) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Add(A, B); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B"}); +} + TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { int h = 32, w = 32; NetBuilder net_builder("Broadcast_Test_0"); @@ -294,6 +498,38 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { RunTest(target, graph, {"A", "B", "C", "D"}); } +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_0) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, D); + auto G = net_builder.Add(F, E); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { int h = 32, w = 32; NetBuilder net_builder("Broadcast_Test_2"); @@ -326,6 +562,38 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { RunTest(target, graph, {"A", "B", "C", "D"}); } +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, E); + auto G = net_builder.Add(D, E); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { int h = 32, w = 32; NetBuilder net_builder("Broadcast_Test_4"); @@ -360,6 +628,40 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { RunTest(target, graph, {"A", "B", "C", "D", "E"}); } +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_4) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.Add(A, B); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E"}); +} + TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { int h = 32, w = 32; NetBuilder net_builder("Broadcast_Test_5"); @@ -394,6 +696,40 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { RunTest(target, graph, {"A", "B", "C", "D", "E"}); } +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_5) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); + auto F = net_builder.Add(A, B); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E"}); +} + TEST(CheckFusionAccuracyPass, Reduce_Test_0) { int h = 32, w = 32; NetBuilder net_builder("Reduce_Test_0"); @@ -425,6 +761,37 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_0) { RunTest(target, graph, {"A", "B"}); } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_0) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.ReduceSum(C, {0}); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B"}); +} + TEST(CheckFusionAccuracyPass, Reduce_Test_1) { int h = 32, w = 32; NetBuilder net_builder("Reduce_Test_1"); @@ -455,6 +822,36 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_1) { RunTest(target, graph, {"A", "B"}); } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_1) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + auto E = net_builder.ReduceSum(C, {1}); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B"}); +} + TEST(CheckFusionAccuracyPass, Reduce_Test_2) { int h = 32, w = 32; NetBuilder net_builder("Reduce_Test_2"); @@ -488,6 +885,39 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_2) { RunTest(target, graph, {"A", "B", "C"}); } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0}); + auto F = net_builder.ReduceSum(D, {1}); + auto G = net_builder.Add(C, E); + auto H = net_builder.Add(C, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C"}); +} + TEST(CheckFusionAccuracyPass, Reduce_Test_3) { int h = 32, w = 32; NetBuilder net_builder("Reduce_Test_3"); @@ -521,6 +951,39 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_3) { RunTest(target, graph, {"A", "B", "C", "D"}); } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_3) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.ReduceSum(E, {0}); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + TEST(CheckFusionAccuracyPass, Reduce_Test_4) { int h = 32, w = 32; NetBuilder net_builder("Reduce_Test_4"); @@ -555,6 +1018,40 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_4) { RunTest(target, graph, {"A", "B", "C", "D"}); } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_4) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.ReduceSum(E, {0}); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(D, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + TEST(CheckFusionAccuracyPass, Reduce_Test_5) { int h = 128, w = 128; NetBuilder net_builder("Reduce_Test_5"); @@ -586,4 +1083,35 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_5) { RunTest(target, graph, {"A", "B"}); } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_5) { + int h = 128, w = 128; + NetBuilder net_builder("Reduce_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(A, {1}); + auto E = net_builder.ReduceSum(B, {1}); + auto F = net_builder.ReduceSum(C, {1}); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B"}); +} + } // namespace cinn::frontend diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index 3857db0564..834fc31ad7 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -71,7 +71,7 @@ CONDITION_FUNC(is_same_size) { return size_0 == size_1; } -bool is_const_group(const FusionHelperBase* helper, const std::shared_ptr& group) { +inline bool is_const_group(const FusionHelperBase* helper, const std::shared_ptr& group) { return group->CollectNodes().size() == 1 && helper->IsConstOp(group->CollectNodes()[0]); }; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index d920436869..8cc2814057 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "cinn/api/op_group_interface.h" #include "cinn/common/is_reachable_predicator.h" #include "cinn/hlir/pass/fusion_merge_pass_util.h" @@ -61,26 +63,24 @@ class GraphGroupFuseHelper final : public FuseHelper { public: explicit GraphGroupFuseHelper(const GraphGroupLightwareFusePassCtx* ctx) : ctx_(ctx) {} - bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const override { - return is_same_size(&ctx_->graph_group_fusion_helper(), first, second); - } + bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const override; - bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override { - return horizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), src, dst); - } + bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override; - bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override { - return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), src, dst); - } + bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override; bool DetectCycleIfFuse(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const override { return ReachableIfDirectEdgeIgnored(lhs, rhs) || ReachableIfDirectEdgeIgnored(rhs, lhs); } private: - bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& src, const OpGroupPtr& dst) const override { - const auto& MinDepth4Node = [&](OpGroupPtr node) { return std::dynamic_pointer_cast(node)->min_depth; }; - const auto& MaxDepth4Node = [&](OpGroupPtr node) { return std::dynamic_pointer_cast(node)->max_depth; }; + bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& src, const OpGroupPtr& dst) const { + const auto& MinDepth4Node = [&](OpGroupPtr node) { + return std::dynamic_pointer_cast(node)->min_depth; + }; + const auto& MaxDepth4Node = [&](OpGroupPtr node) { + return std::dynamic_pointer_cast(node)->max_depth; + }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { for (const auto& pair : node->consumer2outputs()) { if (node == src && pair.first == dst) { @@ -132,10 +132,28 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { private: const FusionHelperBase* graph_group_fusion_helper_; const OpGroupPtr group_; - const std::function EnableFuse_; + const std::function EnableFuse_; const GraphGroupFuseHelper fuse_helper_; }; +bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { + return is_same_size(&ctx_->graph_group_fusion_helper(), + std::dynamic_pointer_cast(first), + std::dynamic_pointer_cast(second)); +} + +bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { + return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); +} + +bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); +} + class FusePass { public: virtual ~FusePass() = default; @@ -143,7 +161,7 @@ class FusePass { virtual void operator()(LightwareFusePassCtx* ctx) const = 0; protected: - FusePass(); + FusePass() = default; }; class DefautlHorizontalFusePass final : public FusePass { @@ -155,7 +173,7 @@ class DefautlHorizontalFusePass final : public FusePass { const OpGroupList consumers = [&]() { OpGroupList consumers; for (const auto& pair : producer->consumer2outputs()) { - consumers.insert(pair.first); + consumers.push_back(pair.first); } return consumers; }(); @@ -192,13 +210,13 @@ class DefautlHorizontalFusePass final : public FusePass { typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); - const std::unordered_map& GetConditionMap() const { - thread_local static std::unordered_map map(RawConditionMap()); + const std::map& GetConditionMap() const { + thread_local static std::map map(RawConditionMap()); return map; } - std::unordered_map RawConditionMap() const { - return std::unordered_map{ + std::map RawConditionMap() const { + return std::map{ {{OpPatternKind::kElementWise, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, {{OpPatternKind::kElementWise, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, @@ -376,10 +394,9 @@ class FusionMergePassHelper : public FusionHelperBase { } std::vector> RawHorizontalFusePasses() const { - std::vector>{ + return std::vector>{ std::shared_ptr(new DefautlHorizontalFusePass{}), }; - return ret; } const std::vector>& GetHorizontalFusePasses() const { @@ -398,15 +415,15 @@ class FusionMergePassHelper : public FusionHelperBase { } } - bool GeneralHorizontalFuse(const GroupPtr& producer) const { + bool GeneralHorizontalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralHorizontalFuse...!"; - using GroupSets = std::set>; - const auto& GetFusableConsumerGroupSets = [&]() -> GroupSets { - GroupSets tagged_sets; - const auto& EnableFuse = [&](const GroupPtr& first, const GroupPtr& second) { - tagged_sets.insert(std::set{first, second}); + using OpGroupSets = std::set>; + const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { + OpGroupSets tagged_sets; + const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { + tagged_sets.insert(std::set{first, second}); }; - GraphGroupLightwareFusePassCtx fuse_ctx(producer, EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -415,17 +432,22 @@ class FusionMergePassHelper : public FusionHelperBase { if (group_sets.empty()) { return GroupList{}; } - return GroupList{group_sets.begin()->begin(), group_sets.begin()->end()}; + GroupList ret; + for (const auto& group : *group_sets.begin()) { + ret.push_back(std::dynamic_pointer_cast(group)); + } + return ret; }; - size_t fuse_count = 0; + bool update = false; while (true) { const auto& groups = GetFusableConsumerGroupList(); if (groups.size() <= 1) { break; } - fuse_count += HorizontalFuse(groups); + HorizontalFuse(groups); + update = true; } - return fuse_count > 0; + return update; } bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { diff --git a/cinn/hlir/pass/use_pass.h b/cinn/hlir/pass/use_pass.h index dc44fb0869..048c50faff 100644 --- a/cinn/hlir/pass/use_pass.h +++ b/cinn/hlir/pass/use_pass.h @@ -25,6 +25,7 @@ CINN_USE_REGISTER(DCE) CINN_USE_REGISTER(DotMerger) CINN_USE_REGISTER(OpFusionPass) CINN_USE_REGISTER(FusionMergePass) +CINN_USE_REGISTER(GeneralFusionMergePass) CINN_USE_REGISTER(CheckFusionAccuracyPass) CINN_USE_REGISTER(CommonSubexpressionEliminationPass) From 8dfd12c519d481875e6560f509e9ec8794d66bda Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 12 Jun 2023 14:16:34 +0000 Subject: [PATCH 14/66] temp save --- cinn/hlir/pass/general_fusion_merge_pass.cc | 367 +++++++++++--------- 1 file changed, 195 insertions(+), 172 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index cbf33de10d..186d8a470a 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -79,9 +79,9 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "DoFusionMerge...!"; while (DoGeneralHorizontalFusion()) { } - while (DoVerticalFusion(/* recompute=*/false)) { + while (DoGeneralVerticalFusion(/* recompute=*/false)) { } - while (DoVerticalFusion(/* recompute=*/true)) { + while (DoGeneralVerticalFusion(/* recompute=*/true)) { } } @@ -105,7 +105,7 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } - bool DoVerticalFusion(bool recompute) { + bool DoGeneralVerticalFusion(bool recompute) { VLOG(3) << "DoVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { @@ -117,12 +117,10 @@ class FusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. if (!recompute) { - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + updated |= GeneralHorizontalFuse(producer); } - updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); + updated |= GeneralVerticalFuse(producer); } - // fuse input consumers - updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -187,6 +185,20 @@ class FusionMergePassHelper : public FusionHelperBase { bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + bool ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + + bool ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } + bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { TODO(); } bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const { TODO(); } @@ -581,83 +593,196 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { - VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); - auto& relation = fusion_relation_map_[producer->op_pattern_kind]; - // if producer can't fuse others - if (!relation.vertical_relation.size()) { + class DefautlVerticalFusePass final : public FusePass { + public: + DefautlVerticalFusePass() : FusePass() {} + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& pair : producer->consumer2outputs()) { + consumers.insert(pair.first); + } + return consumers; + }(); + if (consumers.size() == 0) { + return; + } + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + continue; + } + if (ctx->IsConstGroup(producer)) { + if (!ctx->IsSameShape(producer, consumer)) { + continue; + } + } + ctx->EnableFuse(src, dst); + } + } + + using KindKeyT = std::pair; + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src->kind(), dst->kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { return false; } + return iter->second(ctx, src, dst); + } - std::unordered_set fuse_consumers_unsafe; - std::unordered_set fuse_consumers; - for (const auto& consumer : consumers) { - VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; - // if can't fuse - if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; - continue; - } + typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); - // if condition function is false - if (!relation.vertical_relation[consumer->op_pattern_kind](this, producer, consumer)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; - continue; - } + const std::unordered_map& GetConditionMap() const { + thread_local static std::unordered_map map(RawConditionMap()); + return map; + } - fuse_consumers_unsafe.insert(consumer); + std::unordered_map RawConditionMap() const { + return std::unordered_map{ + {{OpPatternKind::kElementWise, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefautlHorizontalFusePass::ElementwiseFuseBroadcast}, + {{OpPatternKind::kElementWise, framework::kInjective}, &DefautlHorizontalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kElementWise, framework::kReduction}, + &DefautlHorizontalFusePass::ElementwiseFuseReduce}, - if (IsDependencySimplify(producer, consumer, consumers)) { - VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; - continue; - } + {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, &DefautlHorizontalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kBroadcast, framework::kReduction}, &DefautlHorizontalFusePass::BroadcastFuseReduce}, - if (IsDependency(producer, consumer, consumers)) { - VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; - continue; - } + {{OpPatternKind::kInjective, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, &DefautlHorizontalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kInjective, framework::kReduction}, &DefautlHorizontalFusePass::InjectiveHorizontalWithReduce}, + + {{OpPatternKind::kReduction, framework::kElementWise}, + &DefautlHorizontalFusePass::ReduceFuseElementwise}, + {{OpPatternKind::kReduction, framework::kBroadcast}, &DefautlHorizontalFusePass::ReduceFuseBroadcast}, + {{OpPatternKind::kReduction, framework::kInjective}, &DefautlHorizontalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kReduction, framework::kReduction}, &DefautlHorizontalFusePass::ReduceFuseReduce}, + }; + } - fuse_consumers.insert(consumer); + static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().AllOutputsSameSize(src, dst); } - VLOG(3) << "VerticalFusion, Number of fuse Consumers : " << fuse_consumers.size(); - VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); + static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst); + } - if (fuse_consumers.size() == 0) { - return false; + static bool HorizontalWithInjective(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalWithInjective(src, dst); } - // if can_fuse_consumers == consumers - // if producer op kind == kElementwise - // if use recompute - if (fuse_consumers_unsafe.size() == producer->consumer_groups.size() && - producer->op_pattern_kind == framework::kElementWise) { - if (!recompute) { - return false; - } else { - RecomputeEleGraph(producer, fuse_consumers_unsafe); - VerticalFuse(producer, fuse_consumers_unsafe); - return true; - } + + static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseReduce(src, dst); + } + + static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().BroadcastFuseReduce(src, dst); } - if (fuse_consumers.size()) { - SelectConsumerToFuse(producer, fuse_consumers); + static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst); } - // if fusionable consumers exist - if (fuse_consumers.size()) { - VerticalFuse(producer, fuse_consumers); - return true; + static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseElementwise(src, dst); } - return false; + static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseBroadcast(src, dst); + } + + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } + }; + + std::vector> RawVerticalFusePasses() const { + std::vector>{ + std::shared_ptr(new DefautlVerticalFusePass{}), + }; + return ret; + } + + const std::vector>& GetVerticalFusePasses() const { + thread_local static std::vector> fuse_passes = RawVerticalFusePasses(); + return fuse_passes; + } + + void TagVerticalGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer->consumer2outputs().empty()) { + return; + } + const auto& fuse_passes = GetVerticalFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + bool GeneralVerticalFuse(const GroupPtr& producer) const { + VLOG(3) << "GeneralVerticalFuse...!"; + using GroupSets = std::set>; + const auto& GetFusableConsumerGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& EnableFuse = [&](const GroupPtr& first, const GroupPtr& second) { + tagged_sets.insert(std::set{first, second}); + }; + LightwareFusePassCtx fuse_ctx(producer, EnableFuse); + TagVerticalGroups(&fuse_ctx); + return tagged_sets; + }; + + const auto& GetFusableConsumerGroupList = [&]() -> GroupList { + const auto& group_sets = GetFusableConsumerGroupSets(); + if (group_sets.empty()) { + return GroupList{}; + } + return GroupList{group_sets.begin()->begin(), group_sets.begin()->end()}; + }; + + size_t fuse_count = 0; + while (true) { + const auto& consumer_groups = GetFusableConsumerGroupList(); + if (groups.empty()) { + break; + } + fuse_count += VerticalFuse(producer, consumer_groups); + } + return fuse_count > 0; + } + + void VerticalFuse(GroupPtr& producer_group, std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; + auto producer = std::shared_dynamic_cast(producer_group); GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); - for (auto& consumer : fusionable_consumers) { + for (auto& consumer_group : fusionable_consumers) { + auto consumer = std::shared_dynamic_cast(consumer_group); auto fused_group = std::make_shared(); // update depth using consumer depth. fused_group->max_depth = std::max(producer->max_depth, consumer->max_depth); @@ -747,9 +872,9 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group_and_list : consumer->producer_groups) { + for (const auto& group_and_list : consumer->producer_groups()) { if (group_and_list.first.get() != producer.get()) { - fused_group->producer_groups[group_and_list.first] += group_and_list.second; + *(fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; // update consumer's producer's consumer group_and_list.first->consumer_groups.erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. @@ -761,9 +886,9 @@ class FusionMergePassHelper : public FusionHelperBase { for (const auto& group_and_list : consumer->consumer_groups) { fused_group->consumer_groups[group_and_list.first] += group_and_list.second; // update consumer's consumer's producer - group_and_list.first->producer_groups.erase(consumer); + group_and_list.first->mut_producer_groups()->erase(consumer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - group_and_list.first->producer_groups[fused_group] += {}; + *(group_and_list.first->mut_producer_groups())[fused_group] += {}; } // sub group @@ -797,7 +922,7 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (const auto& consumer_and_list : producer->consumer_groups) { + for (const auto& consumer_and_list : producer->consumer_groups()) { // if consumer is in fusionable. if (fusionable_consumers.count(consumer_and_list.first)) { if (consumer_and_list.first->input_nodes.count(node)) { @@ -823,117 +948,15 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer_and_list : producer->consumer_groups) { + for (const auto& consumer_and_list : producer->consumer_groups()) { if (fusionable_consumers.count(consumer_and_list.first)) { continue; } - master_fuesd_group->consumer_groups[consumer_and_list.first] += consumer_and_list.second; + *(master_fuesd_group-mut_consumer_groups())[consumer_and_list.first] += consumer_and_list.second; // update consumer's producer - consumer_and_list.first->producer_groups.erase(producer); + consumer_and_list.first->mut_producer_groups()->erase(producer); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumer_and_list.first->producer_groups[master_fuesd_group] += {}; - } - } - - void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind != framework::kElementWise) { - SelectConsumerToFuse(producer, fusionable_consumers); - } - } - - void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - // if is const op - if (is_const_group(this, producer)) { - std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { - // if can be output node. - if (is_same_shape(this, producer, consumer)) { - candidates.insert(consumer); - } else { - VLOG(4) << "Fuse Producer : " << producer->group_id << " into Consumer : " << consumer->group_id; - consumer->group_id = producer->group_id + "_" + consumer->group_id; - // just merge the node into group. - auto& sub_group = consumer->fused_sub_groups.front(); - sub_group->group_id = producer->group_id + "_" + sub_group->group_id; - sub_group->nodes.insert(sub_group->nodes.begin(), producer->CollectNodes()[0]); - sub_group->nodes_set.insert(producer->CollectNodes()[0]); - // remove depency. - consumer->input_nodes.erase(producer->CollectNodes()[0]); - consumer->producer_groups.erase(producer); - producer->consumer_groups.erase(consumer); - } - } - - CHECK_GE(producer->consumer_groups.size(), candidates.size()); - if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && - output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { - producer->belong_groups.insert(*fusionable_consumers.begin()); - } - - fusionable_consumers = candidates; - return; - } - // 1 to 1 fusion. - if (producer->consumer_groups.size() == 1) { - return; - } - - if (FLAGS_enhance_vertical_fusion_with_recompute) { - std::vector candidates; - for (auto& consumer : fusionable_consumers) { - if (consumer->op_pattern_kind == framework::kElementWise) { - candidates.push_back(consumer); - continue; - } - - auto producer_output_shape = this->GetNodeDataShape(*producer->output_nodes.begin()); - auto consumer_output_shape = this->GetNodeDataShape(*consumer->output_nodes.begin()); - auto consumer_master_input_shape = this->GetNodeInputShape(*(consumer->master_nodes.begin())); - int producer_output_numel = - std::accumulate(producer_output_shape.begin(), producer_output_shape.end(), 1, std::multiplies()); - int consumer_output_numel = - std::accumulate(consumer_output_shape.begin(), consumer_output_shape.end(), 1, std::multiplies()); - int consumer_master_input_numel = std::accumulate( - consumer_master_input_shape.begin(), consumer_master_input_shape.end(), 1, std::multiplies()); - if (producer_output_numel == consumer_output_numel) { - candidates.push_back(consumer); - continue; - } - - if (producer->op_pattern_kind != framework::kInjective && consumer->op_pattern_kind == framework::kReduction && - producer_output_numel == consumer_master_input_numel) { - candidates.push_back(consumer); - } - } - sort(candidates.begin(), candidates.end(), [](const auto& lhs, const auto& rhs) { - return lhs->op_pattern_kind < rhs->op_pattern_kind; - }); - - fusionable_consumers.clear(); - if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); - } - } else { - std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { - if (consumer->op_pattern_kind == framework::kElementWise) { - candidates.insert(consumer); - continue; - } - - auto shape0 = this->GetNodeDataShape(*producer->output_nodes.begin()); - auto shape1 = this->GetNodeDataShape(*consumer->output_nodes.begin()); - - if (std::accumulate(shape0.begin(), shape0.end(), 1, std::multiplies()) == - std::accumulate(shape1.begin(), shape1.end(), 1, std::multiplies())) { - candidates.insert(consumer); - } - } - - fusionable_consumers.clear(); - if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); - } + *(consumer_and_list.first->mut_producer_groups())[master_fuesd_group] += {}; } } From 835889cc1a28e76e748938c092ab2a6a4b40732f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 13 Jun 2023 07:50:44 +0000 Subject: [PATCH 15/66] revert for debug --- cinn/hlir/pass/general_fusion_merge_pass.cc | 80 +++++++++++++++++++-- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 7f05725e89..e06319f361 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -857,6 +857,78 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } + bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { + VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); + auto& relation = fusion_relation_map_[producer->op_pattern_kind]; + // if producer can't fuse others + if (!relation.vertical_relation.size()) { + return false; + } + + std::unordered_set fuse_consumers_unsafe; + std::unordered_set fuse_consumers; + for (const auto& consumer : consumers) { + VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; + // if can't fuse + if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { + VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; + continue; + } + + // if condition function is false + if (!relation.vertical_relation[consumer->op_pattern_kind](this, producer, consumer)) { + VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; + continue; + } + + fuse_consumers_unsafe.insert(consumer); + + if (IsDependencySimplify(producer, consumer, consumers)) { + VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; + continue; + } + + if (IsDependency(producer, consumer, consumers)) { + VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; + continue; + } + + fuse_consumers.insert(consumer); + } + + VLOG(3) << "VerticalFusion, Number of fuse Consumers : " << fuse_consumers.size(); + VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); + + if (fuse_consumers.size() == 0) { + return false; + } + // if can_fuse_consumers == consumers + // if producer op kind == kElementwise + // if use recompute + if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() && + producer->op_pattern_kind == framework::kElementWise) { + if (!recompute) { + return false; + } else { + RecomputeEleGraph(producer, fuse_consumers_unsafe); + VerticalFuse(producer, fuse_consumers_unsafe); + return true; + } + } + + if (fuse_consumers.size()) { + SelectConsumerToFuse(producer, fuse_consumers); + } + + // if fusionable consumers exist + if (fuse_consumers.size()) { + VerticalFuse(producer, fuse_consumers); + return true; + } + + return false; + } + std::vector> RawVerticalFusePasses() const { return std::vector>{ std::shared_ptr(new DefaultVerticalFusePass()), @@ -905,12 +977,8 @@ class FusionMergePassHelper : public FusionHelperBase { }; bool update = false; - // Maybe this loop is no need. - while (true) { - auto consumer_groups = GetFusableConsumerGroupSet(); - if (consumer_groups.empty()) { - break; - } + auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size() > 0) { VerticalFuse(producer, consumer_groups); update = true; } From 740819a0154cf8d2436d3302ddfd57065819596d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 13 Jun 2023 07:56:53 +0000 Subject: [PATCH 16/66] debug change --- cinn/hlir/pass/general_fusion_merge_pass.cc | 29 +++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index e06319f361..bff92d42b9 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -489,9 +489,9 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "DoFusionMerge...!"; while (DoGeneralHorizontalFusion()) { } - while (DoGeneralVerticalFusion(/* recompute=*/false)) { + while (DoVerticalFusion(/* recompute=*/false)) { } - while (DoGeneralVerticalFusion(/* recompute=*/true)) { + while (DoVerticalFusion(/* recompute=*/true)) { } } @@ -515,6 +515,31 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } + bool DoVerticalFusion(bool recompute) { + VLOG(3) << "DoVerticalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + if (!recompute) { + updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + } + updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); + } + // fuse input consumers + updated |= FuseInputToConsumers(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + bool DoGeneralVerticalFusion(bool recompute) { VLOG(3) << "DoVerticalFusion...!"; bool updated = false; From 23f6ac4e91f64aa0b157b6d28459ea9daefcfa7c Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 13 Jun 2023 08:37:35 +0000 Subject: [PATCH 17/66] fix vertical fuse bug --- cinn/hlir/pass/general_fusion_merge_pass.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index bff92d42b9..f728b9bc41 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -489,7 +489,7 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "DoFusionMerge...!"; while (DoGeneralHorizontalFusion()) { } - while (DoVerticalFusion(/* recompute=*/false)) { + while (DoGeneralVerticalFusion()) { } while (DoVerticalFusion(/* recompute=*/true)) { } @@ -540,7 +540,7 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } - bool DoGeneralVerticalFusion(bool recompute) { + bool DoGeneralVerticalFusion() { VLOG(3) << "DoVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { @@ -551,12 +551,15 @@ class FusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - if (!recompute) { - updated |= GeneralHorizontalFuse(producer); - } - updated |= GeneralVerticalFuse(producer); + updated |= GeneralHorizontalFuse(producer); + updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), false); + + // updated |= GeneralVerticalFuse(producer); } + // fuse input consumers + updated |= FuseInputToConsumers(); + if (updated) { UpdateFusionGroup(); } @@ -1003,6 +1006,9 @@ class FusionMergePassHelper : public FusionHelperBase { bool update = false; auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size()) { + SelectConsumerToFuse(producer, consumer_groups); + } if (consumer_groups.size() > 0) { VerticalFuse(producer, consumer_groups); update = true; From d15e3f7906ec2245cc8559523c7a35bf4b996cb2 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 13 Jun 2023 08:41:27 +0000 Subject: [PATCH 18/66] polish code --- cinn/hlir/pass/general_fusion_merge_pass.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index f728b9bc41..fa425325f9 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -552,9 +552,7 @@ class FusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. updated |= GeneralHorizontalFuse(producer); - updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), false); - - // updated |= GeneralVerticalFuse(producer); + updated |= GeneralVerticalFuse(producer); } // fuse input consumers From d921da6db1dd094e7e732404bdb9399bf773dd3b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 13 Jun 2023 10:04:16 +0000 Subject: [PATCH 19/66] Refactor Recompute Fuse --- cinn/hlir/pass/general_fusion_merge_pass.cc | 136 +++++++++++++++++++- 1 file changed, 132 insertions(+), 4 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index fa425325f9..68d64b8ca8 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -366,12 +366,12 @@ class DefaultVerticalFusePass final : public FusePass { typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); - const std::map& GetConditionMap() const { + static const std::map& GetConditionMap() { thread_local static std::map map(RawConditionMap()); return map; } - std::map RawConditionMap() const { + static std::map RawConditionMap() { return std::map{ {{OpPatternKind::kElementWise, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, @@ -448,6 +448,50 @@ class DefaultVerticalFusePass final : public FusePass { } }; +class DefaultRecomputeFusePass final : public FusePass { + public: + DefaultRecomputeFusePass() : FusePass() {} + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& pair : producer->consumer2outputs()) { + consumers.push_back(pair.first); + } + return consumers; + }(); + if (consumers.size() <= 1) { + return; + } + std::unordered_set candidates; + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + candidates.insert(consumer); + } + if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + for (const auto& consumer : consumers) { + ctx->EnableFuse(producer, consumer); + } + } + } + + using KindKeyT = std::pair; + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src->kind(), dst->kind()); + const auto& map = DefaultVerticalFusePass::GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + +}; + // Op Fusion Pass which performs Ops fusion, Ops are fused // "vertically", meaning producing Ops are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -491,7 +535,7 @@ class FusionMergePassHelper : public FusionHelperBase { } while (DoGeneralVerticalFusion()) { } - while (DoVerticalFusion(/* recompute=*/true)) { + while (DoGeneralRecomputeAndVerticalFusion()) { } } @@ -541,7 +585,7 @@ class FusionMergePassHelper : public FusionHelperBase { } bool DoGeneralVerticalFusion() { - VLOG(3) << "DoVerticalFusion...!"; + VLOG(3) << "DoGeneralVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; @@ -564,6 +608,30 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } + bool DoGeneralRecomputeAndVerticalFusion() { + VLOG(3) << "DoGeneralRecomputeAndVerticalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + updated |= GeneralRecomputeFuse(producer); + updated |= GeneralVerticalFuse(producer); + } + + // fuse input consumers + updated |= FuseInputToConsumers(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + void UpdateFusionGroup() { VLOG(3) << "UpdateFusionGroup..."; GroupList fusion_groups; @@ -1204,6 +1272,66 @@ class FusionMergePassHelper : public FusionHelperBase { } } + std::vector> RawRecomputeFusePasses() const { + return std::vector>{ + std::shared_ptr(new DefaultRecomputeFusePass()), + }; + } + + const std::vector>& GetRecomputeFusePasses() const { + thread_local static std::vector> fuse_passes = RawRecomputeFusePasses(); + return fuse_passes; + } + + void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer->consumer2outputs().size() <= 1) { + return; + } + const auto& fuse_passes = GetRecomputeFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralRecomputeFuse(GroupPtr& producer) { + VLOG(3) << "GeneralRecomputeFuse...!"; + using GroupSets = std::set>; + const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { + tagged_sets.insert(std::make_pair(first, second)); + }; + GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); + TagRecomputeGroups(&fuse_ctx); + return tagged_sets; + }; + + auto GetFusableConsumerGroupSet = [&]() -> std::unordered_set { + const auto& group_sets = GetFusableConsumerOpGroupSets(); + if (group_sets.empty()) { + return {}; + } + std::unordered_set ret; + for (const auto& group_pair : group_sets) { + ret.insert(std::dynamic_pointer_cast(group_pair.second)); + } + return ret; + }; + + bool update = false; + auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size() > 0) { + RecomputeFuse(producer, consumer_groups); + update = true; + } + return update; + } + + void RecomputeFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + VerticalFuse(producer, fusionable_consumers); + } + void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); From 1969a79d5c989700503277979aa409cc4f8e43aa Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 13 Jun 2023 17:03:07 +0000 Subject: [PATCH 20/66] Support InputFuse --- cinn/hlir/pass/general_fusion_merge_pass.cc | 549 +++++++++++++------- 1 file changed, 361 insertions(+), 188 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index fa425325f9..a44e6288ba 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -73,9 +73,10 @@ class FuseHelper { FuseHelper() = default; }; +template class GraphGroupFuseHelper final : public FuseHelper { public: - explicit GraphGroupFuseHelper(const GraphGroupLightwareFusePassCtx* ctx) : ctx_(ctx) {} + explicit GraphGroupFuseHelper(const FusePassCtxT* ctx) : ctx_(ctx) {} bool AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const override; @@ -121,10 +122,22 @@ class GraphGroupFuseHelper final : public FuseHelper { return is_reachable(src, dst, [](OpGroupPtr) {}); } - const GraphGroupLightwareFusePassCtx* ctx_; + const FusePassCtxT* ctx_; }; -class LightwareFusePassCtx { +class FusePassCtx { + public: + virtual ~FusePassCtx() {} + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + + protected: + FusePassCtx() = default; +}; + +class LightwareFusePassCtx : public FusePassCtx { public: virtual ~LightwareFusePassCtx() {} @@ -147,11 +160,11 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { : graph_group_fusion_helper_(graph_group_fusion_helper), group_(group), EnableFuse_(EnableFuse), - fuse_helper_(this) {} + fuse_helper_(new GraphGroupFuseHelper(this)) {} const OpGroupPtr& PickOpGroup() const override { return group_; } - const FuseHelper& fuse_helper() const override { return fuse_helper_; } + const FuseHelper& fuse_helper() const override { return *fuse_helper_; } void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } @@ -161,83 +174,242 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { const FusionHelperBase* graph_group_fusion_helper_; const OpGroupPtr group_; const std::function EnableFuse_; - const GraphGroupFuseHelper fuse_helper_; + const std::unique_ptr fuse_helper_; +}; + +class InputFusePassCtx : public FusePassCtx { + public: + virtual ~InputFusePassCtx() {} + + virtual const std::unordered_set& PickConsumersWithSameInputs() const = 0; + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + + protected: + InputFusePassCtx() = default; }; -bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { +class GraphGroupInputFusePassCtx final : public InputFusePassCtx { + public: + GraphGroupInputFusePassCtx(const FusionHelperBase* graph_group_fusion_helper, + const std::unordered_set& groups, + const std::function& EnableFuse) + : graph_group_fusion_helper_(graph_group_fusion_helper), + groups_(groups), + EnableFuse_(EnableFuse), + fuse_helper_(new GraphGroupFuseHelper(this)) {} + + const std::unordered_set& PickConsumersWithSameInputs() const override { return groups_; } + + const FuseHelper& fuse_helper() const override { return *fuse_helper_; } + + void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } + + const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } + + private: + const FusionHelperBase* graph_group_fusion_helper_; + const std::unordered_set& groups_; + const std::function EnableFuse_; + const std::unique_ptr fuse_helper_; +}; + +template +bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { return is_same_size(&ctx_->graph_group_fusion_helper(), std::dynamic_pointer_cast(first), std::dynamic_pointer_cast(second)); } -bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const { return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), std::dynamic_pointer_cast(src), std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const { return horizontal_with_injective(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return broadcast_fuse_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const { return injective_horizontal_with_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_elementwise(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } -bool GraphGroupFuseHelper::ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + std::dynamic_pointer_cast(src), + std::dynamic_pointer_cast(dst)); } - -bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { +template +bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), std::dynamic_pointer_cast(src), std::dynamic_pointer_cast(dst)); } +template +struct HorizontalFuseUtil { + using KindKeyT = std::pair; + + static bool DetectFusabilityByKind(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + const KindKeyT kind_pair(src->kind(), dst->kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + + typedef bool (*ConditionT)(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); + + static const std::map& GetConditionMap() { + thread_local static std::map map(RawConditionMap()); + return map; + } + + static std::map RawConditionMap() { + return std::map{ + {{OpPatternKind::kElementWise, framework::kElementWise}, &IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kElementWise, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kElementWise, framework::kReduction}, &HorizontalElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, framework::kElementWise}, &IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kReduction}, &IsSameSize}, + + {{OpPatternKind::kInjective, framework::kElementWise}, &IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kInjective, framework::kReduction}, &IsSameSize}, + + {{OpPatternKind::kReduction, framework::kElementWise}, &HorizontalElementwiseFuseReduce}, + {{OpPatternKind::kReduction, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kReduction, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kReduction, framework::kReduction}, &ReduceFuseReduce}, + }; + } + + static bool IsSameSize(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().AllOutputsSameSize(src, dst); + } + + static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); + } + + static bool ReduceFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } +}; + class FusePass { public: virtual ~FusePass() = default; + protected: + FusePass() = default; +}; + +class InputFusePass : public FusePass { + public: + virtual ~InputFusePass() = default; + + virtual void operator()(InputFusePassCtx* ctx) const = 0; + + protected: + InputFusePass() = default; +}; + +class DefautlInputFusePass final : public InputFusePass { + public: + DefautlInputFusePass() : InputFusePass() {} + + void operator()(InputFusePassCtx* ctx) const override { + const auto& consumer_set = ctx->PickConsumersWithSameInputs(); + if (consumer_set.size() <= 1) { + return; + } + const OpGroupList consumers = [&]() { + OpGroupList ret; + for (const auto& consumer : consumer_set) { + ret.push_back(consumer); + } + return ret; + }(); + for (int i = 0; i < consumers.size(); ++i) { + const auto& src = consumers.at(i); + for (int j = i + 1; j < consumers.size(); ++j) { + const auto& dst = consumers.at(j); + if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { + continue; + } + if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, src, dst)) { + continue; + } + ctx->EnableFuse(src, dst); + return; + } + } + } +}; + +class LightwareFusePass : public FusePass { + public: + virtual ~LightwareFusePass() = default; + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; protected: - FusePass() = default; + LightwareFusePass() = default; }; -class DefautlHorizontalFusePass final : public FusePass { +class DefautlHorizontalFusePass final : public LightwareFusePass { public: - DefautlHorizontalFusePass() : FusePass() {} + DefautlHorizontalFusePass() : LightwareFusePass() {} void operator()(LightwareFusePassCtx* ctx) const override { const auto& producer = ctx->PickOpGroup(); @@ -258,7 +430,7 @@ class DefautlHorizontalFusePass final : public FusePass { if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { continue; } - if (!DetectFusabilityByKind(ctx, src, dst)) { + if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, src, dst)) { continue; } ctx->EnableFuse(src, dst); @@ -266,9 +438,37 @@ class DefautlHorizontalFusePass final : public FusePass { } } } +}; - using KindKeyT = std::pair; +class DefaultVerticalFusePass final : public LightwareFusePass { + public: + DefaultVerticalFusePass() : LightwareFusePass() {} + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& pair : producer->consumer2outputs()) { + consumers.push_back(pair.first); + } + return consumers; + }(); + if (consumers.size() == 0) { + return; + } + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + continue; + } + ctx->EnableFuse(producer, consumer); + } + } + using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { const KindKeyT kind_pair(src->kind(), dst->kind()); const auto& map = GetConditionMap(); @@ -288,27 +488,25 @@ class DefautlHorizontalFusePass final : public FusePass { std::map RawConditionMap() const { return std::map{ - {{OpPatternKind::kElementWise, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kReduction}, - &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, - - {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, - - {{OpPatternKind::kInjective, framework::kElementWise}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kReduction}, &DefautlHorizontalFusePass::IsSameSize}, - - {{OpPatternKind::kReduction, framework::kElementWise}, - &DefautlHorizontalFusePass::HorizontalElementwiseFuseReduce}, - {{OpPatternKind::kReduction, framework::kBroadcast}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kReduction, framework::kInjective}, &DefautlHorizontalFusePass::IsSameSize}, - {{OpPatternKind::kReduction, framework::kReduction}, &DefautlHorizontalFusePass::ReduceFuseReduce}, + {{OpPatternKind::kElementWise, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, + {{OpPatternKind::kElementWise, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kElementWise, framework::kReduction}, &DefaultVerticalFusePass::ElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kBroadcast, framework::kReduction}, &DefaultVerticalFusePass::BroadcastFuseReduce}, + + {{OpPatternKind::kInjective, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kInjective, framework::kReduction}, &DefaultVerticalFusePass::InjectiveHorizontalWithReduce}, + + {{OpPatternKind::kReduction, framework::kElementWise}, &DefaultVerticalFusePass::ReduceFuseElementwise}, + {{OpPatternKind::kReduction, framework::kBroadcast}, &DefaultVerticalFusePass::ReduceFuseBroadcast}, + {{OpPatternKind::kReduction, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kReduction, framework::kReduction}, &DefaultVerticalFusePass::ReduceFuseReduce}, }; } @@ -316,136 +514,37 @@ class DefautlHorizontalFusePass final : public FusePass { return ctx->fuse_helper().AllOutputsSameSize(src, dst); } - static bool HorizontalElementwiseFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); + static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst); } - static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseReduce(src, dst); + static bool HorizontalWithInjective(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalWithInjective(src, dst); } -}; - -class DefaultVerticalFusePass final : public FusePass { - public: - DefaultVerticalFusePass() : FusePass() {} - void operator()(LightwareFusePassCtx* ctx) const override { - const auto& producer = ctx->PickOpGroup(); - const OpGroupList consumers = [&]() { - OpGroupList consumers; - for (const auto& pair : producer->consumer2outputs()) { - consumers.push_back(pair.first); - } - return consumers; - }(); - if (consumers.size() == 0) { - return; - } - for (int i = 0; i < consumers.size(); ++i) { - const auto& consumer = consumers.at(i); - if (!DetectFusabilityByKind(ctx, producer, consumer)) { - continue; - } - if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { - continue; - } - ctx->EnableFuse(producer, consumer); - } + static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseReduce(src, dst); } - using KindKeyT = std::pair; - bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src->kind(), dst->kind()); - const auto& map = GetConditionMap(); - const auto& iter = map.find(kind_pair); - if (iter == map.end()) { - return false; - } - return iter->second(ctx, src, dst); + static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().BroadcastFuseReduce(src, dst); } - typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); - - const std::map& GetConditionMap() const { - thread_local static std::map map(RawConditionMap()); - return map; + static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst); } - std::map RawConditionMap() const { - return std::map{ - {{OpPatternKind::kElementWise, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, - {{OpPatternKind::kElementWise, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kElementWise, framework::kReduction}, - &DefaultVerticalFusePass::ElementwiseFuseReduce}, - - {{OpPatternKind::kBroadcast, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kBroadcast, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kBroadcast, framework::kReduction}, &DefaultVerticalFusePass::BroadcastFuseReduce}, - - {{OpPatternKind::kInjective, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kBroadcast}, &DefaultVerticalFusePass::IsSameSize}, - {{OpPatternKind::kInjective, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kInjective, framework::kReduction}, &DefaultVerticalFusePass::InjectiveHorizontalWithReduce}, - - {{OpPatternKind::kReduction, framework::kElementWise}, - &DefaultVerticalFusePass::ReduceFuseElementwise}, - {{OpPatternKind::kReduction, framework::kBroadcast}, &DefaultVerticalFusePass::ReduceFuseBroadcast}, - {{OpPatternKind::kReduction, framework::kInjective}, &DefaultVerticalFusePass::HorizontalWithInjective}, - {{OpPatternKind::kReduction, framework::kReduction}, &DefaultVerticalFusePass::ReduceFuseReduce}, - }; + static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseElementwise(src, dst); } - static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().AllOutputsSameSize(src, dst); - } - - static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst); - } - - static bool HorizontalWithInjective(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().HorizontalWithInjective(src, dst); - } - - static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().ElementwiseFuseReduce(src, dst); - } - - static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().BroadcastFuseReduce(src, dst); - } - - static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst); - } - - static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseElementwise(src, dst); - } - - static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx, - const OpGroupPtr& src, - const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseBroadcast(src, dst); - } + static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseBroadcast(src, dst); + } - static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseReduce(src, dst); - } + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } }; // Op Fusion Pass which performs Ops fusion, Ops are fused @@ -556,7 +655,7 @@ class FusionMergePassHelper : public FusionHelperBase { } // fuse input consumers - updated |= FuseInputToConsumers(); + updated |= GeneralInputFuse(); if (updated) { UpdateFusionGroup(); @@ -611,14 +710,14 @@ class FusionMergePassHelper : public FusionHelperBase { } } - std::vector> RawHorizontalFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefautlHorizontalFusePass{}), + std::vector> RawHorizontalFusePasses() const { + return std::vector>{ + std::shared_ptr(new DefautlHorizontalFusePass{}), }; } - const std::vector>& GetHorizontalFusePasses() const { - thread_local static std::vector> fuse_passes = RawHorizontalFusePasses(); + const std::vector>& GetHorizontalFusePasses() const { + thread_local static std::vector> fuse_passes = RawHorizontalFusePasses(); return fuse_passes; } @@ -668,6 +767,59 @@ class FusionMergePassHelper : public FusionHelperBase { return update; } + std::vector> RawInputFusePasses() const { + return std::vector>{ + std::shared_ptr(new DefautlInputFusePass{}), + }; + } + + const std::vector>& GetInputFusePasses() const { + thread_local static std::vector> fuse_passes = RawInputFusePasses(); + return fuse_passes; + } + + void EnableFusedInputGroups(InputFusePassCtx* ctx) const { + const auto& fuse_passes = GetInputFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool CallGeneralInputFusePass(const std::unordered_set& consumers) { + VLOG(3) << "CallGeneralInputFusePass...!"; + using OpGroupSets = std::set>; + const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { + OpGroupSets tagged_sets; + const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { + tagged_sets.insert(std::set{first, second}); + }; + GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse); + EnableFusedInputGroups(&fuse_ctx); + return tagged_sets; + }; + const auto& GetFusableConsumerGroupList = [&]() -> GroupList { + const auto& group_sets = GetFusableConsumerGroupSets(); + if (group_sets.empty()) { + return GroupList{}; + } + GroupList ret; + for (const auto& group : *group_sets.begin()) { + ret.push_back(std::dynamic_pointer_cast(group)); + } + return ret; + }; + bool update = false; + while (true) { + const auto& groups = GetFusableConsumerGroupList(); + if (groups.size() <= 1) { + break; + } + HorizontalFuse(groups); + update = true; + } + return update; + } + bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { @@ -955,18 +1107,18 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - std::vector> RawVerticalFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefaultVerticalFusePass()), + std::vector> RawVerticalFusePasses() const { + return std::vector>{ + std::shared_ptr(new DefaultVerticalFusePass()), }; } - const std::vector>& GetVerticalFusePasses() const { - thread_local static std::vector> fuse_passes = RawVerticalFusePasses(); + const std::vector>& GetVerticalFusePasses() const { + thread_local static std::vector> fuse_passes = RawVerticalFusePasses(); return fuse_passes; } - void TagVerticalGroups(LightwareFusePassCtx* ctx) const { + void TagVerticalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); if (producer->consumer2outputs().empty()) { return; @@ -979,7 +1131,7 @@ class FusionMergePassHelper : public FusionHelperBase { bool GeneralVerticalFuse(GroupPtr& producer) { VLOG(3) << "GeneralVerticalFuse...!"; - using GroupSets = std::set>; + using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { @@ -1002,7 +1154,7 @@ class FusionMergePassHelper : public FusionHelperBase { return ret; }; - bool update = false; + bool update = false; auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size()) { SelectConsumerToFuse(producer, consumer_groups); @@ -1386,6 +1538,27 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } + bool GeneralInputFuse() { + VLOG(3) << "GeneralInputFuse...!"; + auto updated = false; + UpdateInputToConsumers(); + for (auto& input_consumers : input_to_consumers_) { + // if group set size == 1. + if (input_consumers.second.size() == 1) { + continue; + } + // do input fusion. + auto st = CallGeneralInputFusePass(input_consumers.second); + if (st) { + // fused consumers, update + UpdateInputToConsumers(); + } + updated |= st; + } + + return updated; + } + void UpdateInputToConsumers() { for (auto& input_consumers : input_to_consumers_) { auto& consumers = input_consumers.second; From 5fa2c7da976e32a24666f2d4669d2eb145101024 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 13 Jun 2023 17:24:02 +0000 Subject: [PATCH 21/66] update recompute codes --- cinn/hlir/pass/general_fusion_merge_pass.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 9874fd97ad..0986a8f0a5 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -481,12 +481,12 @@ class DefaultVerticalFusePass final : public LightwareFusePass { typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst); - const std::map& GetConditionMap() const { + static const std::map& GetConditionMap() { thread_local static std::map map(RawConditionMap()); return map; } - std::map RawConditionMap() const { + static std::map RawConditionMap() { return std::map{ {{OpPatternKind::kElementWise, framework::kElementWise}, &DefaultVerticalFusePass::IsSameSize}, {{OpPatternKind::kElementWise, framework::kBroadcast}, &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, @@ -547,9 +547,9 @@ class DefaultVerticalFusePass final : public LightwareFusePass { } }; -class DefaultRecomputeFusePass final : public FusePass { +class DefaultRecomputeFusePass final : public LightwareFusePass { public: - DefaultRecomputeFusePass() : FusePass() {} + DefaultRecomputeFusePass() : LightwareFusePass() {} void operator()(LightwareFusePassCtx* ctx) const override { const auto& producer = ctx->PickOpGroup(); @@ -1423,14 +1423,14 @@ class FusionMergePassHelper : public FusionHelperBase { } } - std::vector> RawRecomputeFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefaultRecomputeFusePass()), + std::vector> RawRecomputeFusePasses() const { + return std::vector>{ + std::shared_ptr(new DefaultRecomputeFusePass()), }; } - const std::vector>& GetRecomputeFusePasses() const { - thread_local static std::vector> fuse_passes = RawRecomputeFusePasses(); + const std::vector>& GetRecomputeFusePasses() const { + thread_local static std::vector> fuse_passes = RawRecomputeFusePasses(); return fuse_passes; } From bc7d3ad9b7091d00bd0668a4bf42081f0696b4ba Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 13 Jun 2023 17:41:05 +0000 Subject: [PATCH 22/66] Fix loop error --- cinn/hlir/pass/general_fusion_merge_pass.cc | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 0986a8f0a5..7247b6213d 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -875,16 +875,12 @@ class FusionMergePassHelper : public FusionHelperBase { } return ret; }; - bool update = false; - while (true) { - const auto& groups = GetFusableConsumerGroupList(); - if (groups.size() <= 1) { - break; - } - HorizontalFuse(groups); - update = true; + const auto& groups = GetFusableConsumerGroupList(); + if (groups.size() <= 1) { + return false; } - return update; + HorizontalFuse(groups); + return true; } bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { @@ -1675,12 +1671,11 @@ class FusionMergePassHelper : public FusionHelperBase { continue; } // do input fusion. - auto st = CallGeneralInputFusePass(input_consumers.second); - if (st) { + while (CallGeneralInputFusePass(input_consumers.second)) { // fused consumers, update UpdateInputToConsumers(); + updated = true; } - updated |= st; } return updated; From 31caa331e8fcc9aeebc06557711ad53d7b07f500 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 14 Jun 2023 03:09:15 +0000 Subject: [PATCH 23/66] Change to GeneralInputFuse in DoGeneralRecomputeAndVerticalFusion --- cinn/hlir/pass/general_fusion_merge_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 7247b6213d..84583da9de 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -722,7 +722,7 @@ class FusionMergePassHelper : public FusionHelperBase { } // fuse input consumers - updated |= FuseInputToConsumers(); + updated |= GeneralInputFuse(); if (updated) { UpdateFusionGroup(); From 314f01c2622c8c5c21080070efeee3567ac08486 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 14 Jun 2023 05:05:02 +0000 Subject: [PATCH 24/66] fix DetectCycle --- cinn/api/op_group_interface.h | 4 ++++ cinn/hlir/pass/general_fusion_merge_pass.cc | 22 ++++++++++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/cinn/api/op_group_interface.h b/cinn/api/op_group_interface.h index 506b562f99..49c3cd5272 100644 --- a/cinn/api/op_group_interface.h +++ b/cinn/api/op_group_interface.h @@ -41,6 +41,10 @@ class OpGroupInterface { virtual const std::unordered_map, TensorInterfaceList>& consumer_groups() const = 0; + const std::unordered_map, TensorInterfaceList>& producer2inputs() const { + return producer_groups(); + } + const std::unordered_map, TensorInterfaceList>& consumer2outputs() const { return consumer_groups(); } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 68d64b8ca8..9e46031b72 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -102,7 +102,7 @@ class GraphGroupFuseHelper final : public FuseHelper { } private: - bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& src, const OpGroupPtr& dst) const { + bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { return std::dynamic_pointer_cast(node)->min_depth; }; @@ -110,15 +110,15 @@ class GraphGroupFuseHelper final : public FuseHelper { return std::dynamic_pointer_cast(node)->max_depth; }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - for (const auto& pair : node->consumer2outputs()) { - if (node == src && pair.first == dst) { + for (const auto& pair : node->producer2inputs()) { + if (node == consumer && pair.first == producer) { continue; } Visit(pair.first); } }; common::IsReachablePredicator is_reachable(MinDepth4Node, MaxDepth4Node, VisitNextNodes); - return is_reachable(src, dst, [](OpGroupPtr) {}); + return is_reachable(consumer, producer, [](OpGroupPtr) {}); } const GraphGroupLightwareFusePassCtx* ctx_; @@ -341,15 +341,27 @@ class DefaultVerticalFusePass final : public FusePass { if (consumers.size() == 0) { return; } + + std::unordered_set candidates; for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); if (!DetectFusabilityByKind(ctx, producer, consumer)) { continue; } if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + VLOG(4) << "Can't fuse because detect cycle"; continue; } - ctx->EnableFuse(producer, consumer); + candidates.insert(consumer); + } + + // Jump for Recompute + if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + return; + } + // Add Tag + for (const auto& candidate : candidates) { + ctx->EnableFuse(producer, candidate); } } From 5507b5ad8e111629755d8a109da153909981351e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 14 Jun 2023 09:08:33 +0000 Subject: [PATCH 25/66] fix some random bug --- cinn/hlir/pass/general_fusion_merge_pass.cc | 34 ++++++++------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 9e46031b72..cefd13c356 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -327,9 +327,9 @@ class DefautlHorizontalFusePass final : public FusePass { class DefaultVerticalFusePass final : public FusePass { public: - DefaultVerticalFusePass() : FusePass() {} + DefaultVerticalFusePass() : FusePass() {} - void operator()(LightwareFusePassCtx* ctx) const override { + void operator()(LightwareFusePassCtx* ctx) const override { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; @@ -342,7 +342,6 @@ class DefaultVerticalFusePass final : public FusePass { return; } - std::unordered_set candidates; for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); if (!DetectFusabilityByKind(ctx, producer, consumer)) { @@ -352,18 +351,9 @@ class DefaultVerticalFusePass final : public FusePass { VLOG(4) << "Can't fuse because detect cycle"; continue; } - candidates.insert(consumer); - } - - // Jump for Recompute - if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { - return; - } - // Add Tag - for (const auto& candidate : candidates) { - ctx->EnableFuse(producer, candidate); + ctx->EnableFuse(producer, consumer); } - } + } using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { @@ -476,13 +466,13 @@ class DefaultRecomputeFusePass final : public FusePass { if (consumers.size() <= 1) { return; } - std::unordered_set candidates; + std::vector candidates; for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); if (!DetectFusabilityByKind(ctx, producer, consumer)) { continue; } - candidates.insert(consumer); + candidates.push_back(consumer); } if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { @@ -632,7 +622,9 @@ class FusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. updated |= GeneralRecomputeFuse(producer); - updated |= GeneralVerticalFuse(producer); + if (!updated) { + updated |= GeneralVerticalFuse(producer); + } } // fuse input consumers @@ -1423,10 +1415,10 @@ class FusionMergePassHelper : public FusionHelperBase { fusionable_consumers.insert(*candidates.begin()); } } else { - std::unordered_set candidates; + std::vector candidates; for (auto& consumer : fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { - candidates.insert(consumer); + candidates.push_back(consumer); continue; } @@ -1435,13 +1427,13 @@ class FusionMergePassHelper : public FusionHelperBase { if (std::accumulate(shape0.begin(), shape0.end(), 1, std::multiplies()) == std::accumulate(shape1.begin(), shape1.end(), 1, std::multiplies())) { - candidates.insert(consumer); + candidates.push_back(consumer); } } fusionable_consumers.clear(); if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); + fusionable_consumers.insert(candidates.front()); } } } From 0888e1ae869a03f58efcf20e4c8b3e021a9e6980 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Thu, 15 Jun 2023 13:35:00 +0000 Subject: [PATCH 26/66] Support pass register --- cinn/common/macros.h | 29 +++ cinn/frontend/decomposer/test_helper.h | 1 + cinn/frontend/interpreter.cc | 1 + cinn/frontend/optimize.cc | 1 + cinn/frontend/pass/test_helper.h | 1 + cinn/hlir/framework/pass.cc | 1 + cinn/hlir/pass/general_fusion_merge_pass.cc | 269 +++++++++++++++----- cmake/core.cmake | 22 ++ 8 files changed, 263 insertions(+), 62 deletions(-) mode change 100755 => 100644 cinn/frontend/interpreter.cc diff --git a/cinn/common/macros.h b/cinn/common/macros.h index fce0d19292..4d48c5e096 100644 --- a/cinn/common/macros.h +++ b/cinn/common/macros.h @@ -49,3 +49,32 @@ #else #define CINN_NODISCARD #endif + +#define DISABLE_COPY_AND_ASSIGN(classname) \ + private: \ + classname(const classname&) = delete; \ + classname(classname&&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname& operator=(classname&&) = delete + +/** + * check if MACRO is used in GLOBAL NAMESPACE. + */ +#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert( \ + std::is_same<::__test_global_namespace_##uniq_name##__, __test_global_namespace_##uniq_name##__>::value, msg) + +#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_pass__##pass_name, \ + "CINN_REGISTER_FUSION_PASS must be called in global namespace"); \ + static ::cinn::hlir::pass::FusionPassRegistrar __pass_registrar_##pass_name##__(#pass_name); \ + int TouchFusionPassRegistrar_##pass_name() { \ + __pass_registrar_##pass_name##__.Touch(); \ + return 0; \ + } + +#define USE_FUSION_PASS(pass_name) \ + STATIC_ASSERT_GLOBAL_NAMESPACE(__use_fusion_pass_##pass_name, "USE_OP_ITSELF must be called in global namespace"); \ + extern int TouchFusionPassRegistrar_##pass_name(); \ + [[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = TouchFusionPassRegistrar_##pass_name() diff --git a/cinn/frontend/decomposer/test_helper.h b/cinn/frontend/decomposer/test_helper.h index a00e94b82f..f9fd8d61d7 100644 --- a/cinn/frontend/decomposer/test_helper.h +++ b/cinn/frontend/decomposer/test_helper.h @@ -30,6 +30,7 @@ #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/framework/tensor.h" #include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pass/use_general_pass.h" #include "cinn/hlir/pass/use_pass.h" namespace cinn::frontend { diff --git a/cinn/frontend/interpreter.cc b/cinn/frontend/interpreter.cc old mode 100755 new mode 100644 index 5b151ed098..82448a2612 --- a/cinn/frontend/interpreter.cc +++ b/cinn/frontend/interpreter.cc @@ -21,6 +21,7 @@ #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pass/use_general_pass.h" #include "cinn/hlir/pass/use_pass.h" #include "cinn/runtime/flags.h" diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index dc4ce886d3..fe16f31c2e 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -26,6 +26,7 @@ #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/framework/visualize_helper.h" +#include "cinn/hlir/pass/use_general_pass.h" #include "cinn/hlir/pass/use_pass.h" #include "cinn/runtime/flags.h" diff --git a/cinn/frontend/pass/test_helper.h b/cinn/frontend/pass/test_helper.h index d68d876dfe..7086cf76dd 100644 --- a/cinn/frontend/pass/test_helper.h +++ b/cinn/frontend/pass/test_helper.h @@ -24,6 +24,7 @@ #include "cinn/frontend/program_pass.h" #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/pass/use_general_pass.h" #include "cinn/hlir/pass/use_pass.h" namespace cinn::frontend { diff --git a/cinn/hlir/framework/pass.cc b/cinn/hlir/framework/pass.cc index ad4d6152e5..01d9618277 100644 --- a/cinn/hlir/framework/pass.cc +++ b/cinn/hlir/framework/pass.cc @@ -15,6 +15,7 @@ #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/framework/visualize_helper.h" +#include "cinn/hlir/pass/use_general_pass.h" #include "cinn/hlir/pass/use_pass.h" namespace cinn { diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 6add2c8758..a47c37ce73 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -13,9 +13,11 @@ // limitations under the License. #include +#include #include "cinn/api/op_group_interface.h" #include "cinn/common/is_reachable_predicator.h" +#include "cinn/common/macros.h" #include "cinn/hlir/pass/fusion_merge_pass_util.h" DECLARE_bool(enhance_vertical_fusion_with_recompute); @@ -23,7 +25,6 @@ DECLARE_bool(enhance_vertical_fusion_with_recompute); namespace cinn { namespace hlir { namespace pass { -namespace { using framework::Graph; using framework::Node; @@ -42,7 +43,6 @@ using OpGroupList = std::vector; using ConditionFunction = std::function; -class GraphGroupLightwareFusePassCtx; class FuseHelper { public: virtual ~FuseHelper() = default; @@ -350,6 +350,10 @@ class FusePass { public: virtual ~FusePass() = default; + virtual const std::string FuseMode() const = 0; + + virtual int Benefit() const = 0; + protected: FusePass() = default; }; @@ -360,15 +364,22 @@ class InputFusePass : public FusePass { virtual void operator()(InputFusePassCtx* ctx) const = 0; + virtual const std::string FuseMode() const override final { return "InputFuse"; } + + virtual int Benefit() const = 0; + protected: InputFusePass() = default; }; -class DefautlInputFusePass final : public InputFusePass { +class DefaultInputFusePass final : public InputFusePass { public: - DefautlInputFusePass() : InputFusePass() {} + DefaultInputFusePass() : InputFusePass() {} + + int Benefit() const override { return 100; } void operator()(InputFusePassCtx* ctx) const override { + VLOG(1) << "DefaultInputFusePass"; const auto& consumer_set = ctx->PickConsumersWithSameInputs(); if (consumer_set.size() <= 1) { return; @@ -403,15 +414,36 @@ class LightwareFusePass : public FusePass { virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + virtual const std::string FuseMode() const = 0; + + virtual int Benefit() const = 0; + protected: LightwareFusePass() = default; }; -class DefautlHorizontalFusePass final : public LightwareFusePass { +class HorizontalFusePass : public LightwareFusePass { + public: + virtual ~HorizontalFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + virtual const std::string FuseMode() const override final { return "HorizontalFuse"; } + + virtual int Benefit() const = 0; + + protected: + HorizontalFusePass() = default; +}; + +class DefaultHorizontalFusePass final : public HorizontalFusePass { public: - DefautlHorizontalFusePass() : LightwareFusePass() {} + DefaultHorizontalFusePass() : HorizontalFusePass() {} + + int Benefit() const override { return 100; } void operator()(LightwareFusePassCtx* ctx) const override { + VLOG(1) << "DefaultHorizontalFusePass"; const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; @@ -440,35 +472,52 @@ class DefautlHorizontalFusePass final : public LightwareFusePass { } }; -class DefaultVerticalFusePass final : public LightwareFusePass { - public: - DefaultVerticalFusePass() : LightwareFusePass() {} +class VerticalFusePass : public LightwareFusePass { + public: + virtual ~VerticalFusePass() = default; - void operator()(LightwareFusePassCtx* ctx) const override { - const auto& producer = ctx->PickOpGroup(); - const OpGroupList consumers = [&]() { - OpGroupList consumers; - for (const auto& pair : producer->consumer2outputs()) { - consumers.push_back(pair.first); - } - return consumers; - }(); - if (consumers.size() == 0) { - return; + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + virtual const std::string FuseMode() const override final { return "VerticalFuse"; } + + virtual int Benefit() const = 0; + + protected: + VerticalFusePass() = default; +}; + +class DefaultVerticalFusePass final : public VerticalFusePass { + public: + DefaultVerticalFusePass() : VerticalFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + VLOG(1) << "DefaultVerticalFusePass"; + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& pair : producer->consumer2outputs()) { + consumers.push_back(pair.first); } + return consumers; + }(); + if (consumers.size() == 0) { + return; + } - for (int i = 0; i < consumers.size(); ++i) { - const auto& consumer = consumers.at(i); - if (!DetectFusabilityByKind(ctx, producer, consumer)) { - continue; - } - if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { - VLOG(4) << "Can't fuse because detect cycle"; - continue; - } - ctx->EnableFuse(producer, consumer); + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + VLOG(4) << "Can't fuse because detect cycle"; + continue; } - } + ctx->EnableFuse(producer, consumer); + } + } using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { @@ -549,11 +598,28 @@ class DefaultVerticalFusePass final : public LightwareFusePass { } }; -class DefaultRecomputeFusePass final : public LightwareFusePass { +class RecomputeFusePass : public LightwareFusePass { public: - DefaultRecomputeFusePass() : LightwareFusePass() {} + virtual ~RecomputeFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + virtual const std::string FuseMode() const override final { return "RecomputeFuse"; } + + virtual int Benefit() const = 0; + + protected: + RecomputeFusePass() = default; +}; + +class DefaultRecomputeFusePass final : public RecomputeFusePass { + public: + DefaultRecomputeFusePass() : RecomputeFusePass() {} + + int Benefit() const override { return 100; } void operator()(LightwareFusePassCtx* ctx) const override { + VLOG(1) << "DefaultRecomputeFusePass"; const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; @@ -592,13 +658,97 @@ class DefaultRecomputeFusePass final : public LightwareFusePass { } }; +struct LightwareFusePassComparator { + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs->Benefit() > rhs->Benefit(); + } +}; + +struct InputFusePassComparator { + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs->Benefit() > rhs->Benefit(); + } +}; + +class FusionPassMap { + public: + static FusionPassMap& Instance() { + static FusionPassMap global_fusion_pass_map; + return global_fusion_pass_map; + } + + bool Has(const std::string& pass_name) const { return map_.find(pass_name) != map_.end(); } + + void Insert(const std::string& pass_name, const std::shared_ptr& pass) { + CHECK(!Has(pass_name)) << "FusePass " << pass_name << " has already been registered."; + map_.insert({pass_name, pass}); + } + + std::shared_ptr Get(const std::string& pass_name) const { + auto it = map_.find(pass_name); + CHECK(it != map_.end()) << "FusePass " << pass_name << " has not been registered."; + return it->second; + } + + // fuse_mode: HorizontalFuse, VerticalFuse, RecomputeFuse + std::vector> GetLightwareFusePassesByMode(const std::string& fuse_mode) const { + CHECK(fuse_mode == "HorizontalFuse" || fuse_mode == "VerticalFuse" || fuse_mode == "RecomputeFuse") + << "fuse_mode only supports HorizontalFuse, VerticalFuse and RecomputeFuse. Please check your input modes = " + << fuse_mode; + std::set, LightwareFusePassComparator> candidate_passes; + for (const auto iter : map_) { + if (fuse_mode == iter.second->FuseMode()) { + candidate_passes.insert(std::dynamic_pointer_cast(iter.second)); + } + } + return std::vector>(candidate_passes.begin(), candidate_passes.end()); + } + + std::vector> GetInputFusePasses() const { + std::set, InputFusePassComparator> candidate_passes; + for (const auto iter : map_) { + if (iter.second->FuseMode() == "InputFuse") { + candidate_passes.insert(std::dynamic_pointer_cast(iter.second)); + } + } + return std::vector>(candidate_passes.begin(), candidate_passes.end()); + } + + private: + FusionPassMap() = default; + std::unordered_map> map_; + + DISABLE_COPY_AND_ASSIGN(FusionPassMap); +}; + +class Registrar { + public: + // In our design, various kinds of classes, e.g., operators and kernels, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_OP macros to + // call this method. So, as long as the callee code calls USE_OP, the global + // registrar variable won't be removed by the linker. + void Touch() {} +}; + +template +class FusionPassRegistrar final : public Registrar { + public: + explicit FusionPassRegistrar(const std::string& pass_name) { + FusionPassMap::Instance().Insert(pass_name, std::shared_ptr(new PassClassT())); + } +}; + // Op Fusion Pass which performs Ops fusion, Ops are fused // "vertically", meaning producing Ops are fused into their consumers // with the intent that the loops which compute their values will be fused in // code generation. -class FusionMergePassHelper : public FusionHelperBase { +class GeneralFusionMergePassHelper : public FusionHelperBase { public: - FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { + GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { fusion_groups_ = graph->fusion_groups; // init fusion relation. InitFusionRelation(); @@ -781,14 +931,12 @@ class FusionMergePassHelper : public FusionHelperBase { } } - std::vector> RawHorizontalFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefautlHorizontalFusePass{}), - }; + std::vector> RawHorizontalFusePasses() const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode("HorizontalFuse"); } - const std::vector>& GetHorizontalFusePasses() const { - thread_local static std::vector> fuse_passes = RawHorizontalFusePasses(); + const std::vector>& GetHorizontalFusePasses() const { + thread_local static std::vector> fuse_passes = RawHorizontalFusePasses(); return fuse_passes; } @@ -838,14 +986,12 @@ class FusionMergePassHelper : public FusionHelperBase { return update; } - std::vector> RawInputFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefautlInputFusePass{}), - }; + std::vector> RawInputFusePasses() const { + return FusionPassMap::Instance().GetInputFusePasses(); } - const std::vector>& GetInputFusePasses() const { - thread_local static std::vector> fuse_passes = RawInputFusePasses(); + const std::vector>& GetInputFusePasses() const { + thread_local static std::vector> fuse_passes = RawInputFusePasses(); return fuse_passes; } @@ -1174,14 +1320,12 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - std::vector> RawVerticalFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefaultVerticalFusePass()), - }; + std::vector> RawVerticalFusePasses() const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode("VerticalFuse"); } - const std::vector>& GetVerticalFusePasses() const { - thread_local static std::vector> fuse_passes = RawVerticalFusePasses(); + const std::vector>& GetVerticalFusePasses() const { + thread_local static std::vector> fuse_passes = RawVerticalFusePasses(); return fuse_passes; } @@ -1423,14 +1567,12 @@ class FusionMergePassHelper : public FusionHelperBase { } } - std::vector> RawRecomputeFusePasses() const { - return std::vector>{ - std::shared_ptr(new DefaultRecomputeFusePass()), - }; + std::vector> RawRecomputeFusePasses() const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode("RecomputeFuse"); } - const std::vector>& GetRecomputeFusePasses() const { - thread_local static std::vector> fuse_passes = RawRecomputeFusePasses(); + const std::vector>& GetRecomputeFusePasses() const { + thread_local static std::vector> fuse_passes = RawRecomputeFusePasses(); return fuse_passes; } @@ -1881,15 +2023,13 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_map fusion_relation_map_; }; -} // namespace - void GeneralFusionMergePassInternal(Graph* graph) { if (graph->fusion_groups.size() <= 1) { VLOG(3) << "Don't do Fusoin Merge Pass...!"; return; } - FusionMergePassHelper fusion_merge_pass_helper(graph); + GeneralFusionMergePassHelper fusion_merge_pass_helper(graph); graph->fusion_groups = fusion_merge_pass_helper(); } @@ -1907,3 +2047,8 @@ CINN_REGISTER_HELPER(GeneralFusionMergePass) { return true; } + +CINN_REGISTER_FUSION_PASS(DefaultHorizontalFusePass, cinn::hlir::pass::DefaultHorizontalFusePass); +CINN_REGISTER_FUSION_PASS(DefaultVerticalFusePass, cinn::hlir::pass::DefaultVerticalFusePass); +CINN_REGISTER_FUSION_PASS(DefaultRecomputeFusePass, cinn::hlir::pass::DefaultRecomputeFusePass); +CINN_REGISTER_FUSION_PASS(DefaultInputFusePass, cinn::hlir::pass::DefaultInputFusePass); diff --git a/cmake/core.cmake b/cmake/core.cmake index 72afc1a596..6da7b3e714 100644 --- a/cmake/core.cmake +++ b/cmake/core.cmake @@ -378,6 +378,27 @@ function(download_and_uncompress INSTALL_DIR URL FILENAME) ) endfunction() +set(fusion_pass_file ${CINN_BINARY_DIR}/cinn/hlir/pass/use_general_pass.h CACHE INTERNAL "use_general_pass.h file") +file(WRITE ${fusion_pass_file} "#include \"cinn/common/macros.h\" // Generated by the cinn/hlir/pass/CMakeLists.txt. DO NOT EDIT!\n\n") + +function(find_fusion_pass_register FILENAME ADD_PATH PATTERN) + # set op_name to OUTPUT + file(READ ${FILENAME} CONTENT) + string( + REGEX + MATCHALL + "${PATTERN}\\([a-zA-Z0-9_]*," + fusion_pass_patterns + "${CONTENT}") + if(NOT fusion_pass_patterns STREQUAL "") + foreach(pass_pattern ${fusion_pass_patterns}) + string(REPLACE "${PATTERN}(" "" pass_pattern "${pass_pattern}") + string(REPLACE "," "" pass_pattern "${pass_pattern}") + file(APPEND ${ADD_PATH} "USE_FUSION_PASS(${pass_pattern});\n") + endforeach() + endif() +endfunction() + function(gather_srcs SRC_GROUP) set(options) set(oneValueArgs) @@ -385,6 +406,7 @@ function(gather_srcs SRC_GROUP) cmake_parse_arguments(prefix "" "" "${multiValueArgs}" ${ARGN}) foreach(cpp ${prefix_SRCS}) set(${SRC_GROUP} "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" CACHE INTERNAL "") + find_fusion_pass_register("${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" ${fusion_pass_file} "CINN_REGISTER_FUSION_PASS") endforeach() endfunction() From ac5615179ca77791a5f437690cefb86302f6db55 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 19 Jun 2023 09:28:08 +0000 Subject: [PATCH 27/66] update node inferface --- cinn/api/op_group.h | 68 +++++++++++++++++++++++++++++++++++++++++ cinn/api/op_interface.h | 55 +++++++++++++++++++++++++++++++++ cinn/api/op_node.cc | 33 ++++++++++++++++++++ cinn/api/op_node.h | 64 ++++++++++++++++++++++++++++++++++++++ cinn/api/tensor_node.cc | 39 +++++++++++++++++++++++ cinn/api/tensor_node.h | 48 +++++++++++++++++++++++++++++ 6 files changed, 307 insertions(+) create mode 100644 cinn/api/op_group.h create mode 100644 cinn/api/op_interface.h create mode 100644 cinn/api/op_node.cc create mode 100644 cinn/api/op_node.h create mode 100644 cinn/api/tensor_node.cc create mode 100644 cinn/api/tensor_node.h diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h new file mode 100644 index 0000000000..3fd5b3742d --- /dev/null +++ b/cinn/api/op_group.h @@ -0,0 +1,68 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/api/op_node.h" + +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/pass/fusion_helper_base.h" + +namespace cinn { +namespace api { + +class OpGroup { + public: + OpNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::Graph::Group* group) : helper_(helper), group_(group) {} + + size_t OpSize() const { + return group->CollectNodes().size(); + } + + OpNode GetOp(size_t index) const { + return group->CollectNodes()[index]; + } + + size_t ProducerSize() const { + return group->producer_groups().size(); + } + OpGroup GetProducer(size_t index) const { + std::vector producer_groups; + producer_groups.reserve(ProducerSize()); + for(const auto& producer : group->producer_groups()) { + producer_groups.push_back(producer.first.get()); + } + return OpGroup(helper_, producer_groups[index]); + } + + size_t ConsumerSize() const { + return group->consumer_groups().size(); + } + + OpGroup GetConsumer(size_t index) const { + std::vector consumer_groups; + consumer_groups.reserve(ConsumerSize()); + for(const auto& consumer : group->consumer_groups()) { + consumer_groups.push_back(consumer.first.get()); + } + return OpGroup(helper_, consumer_groups[index]); + } + + private: + const hlir::pass::FusionHelperBase* helper_; + const hlir::framework::Graph::Group* group_; +}; + +} // namespace api +} // namespace cinn diff --git a/cinn/api/op_interface.h b/cinn/api/op_interface.h new file mode 100644 index 0000000000..b6c2abb69c --- /dev/null +++ b/cinn/api/op_interface.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/api/tensor_interface.h" +#include "cinn/utils/type_defs.h" +#include "cinn/hlir/framework/op.h" + +namespace cinn { +namespace api { + +using OpPatternKind = cinn::hlir::framework::OpPatternKind; +using Attribute = cinn::utils::Attribute; + +class OpInterface { + public: + virtual OpPatternKind kind () = 0; + + virtual size_t InputsSize() const = 0; + virtual TensorInterface Inputs(size_t i) const = 0; + + virtual const TensorInterfaceList& Inputs() = 0; + virtual const TensorInterfaceList& Outputs() = 0; + + template + const T& GetAttr(const std::string& attr_name) const { + return absl::get(GetAttr(attr_name)); + } + + protected: + OpInterface() = default; + OpInterface(const OpInterface&) = delete; + OpInterface(OpInterface&&) = delete; + + virtual const Attribute& GetAttr(const std::string& attr_name) = 0; +}; + +using OpInterfacePtr = std::shared_ptr; + +} // namespace api +} // namespace cinn diff --git a/cinn/api/op_node.cc b/cinn/api/op_node.cc new file mode 100644 index 0000000000..6a74f3d20e --- /dev/null +++ b/cinn/api/op_node.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/api/op_node.h" + +#include "cinn/api/tensor_node.h" + +namespace cinn { +namespace api { + +TensorNode OpNode::GetInput(size_t i) const { + auto edges = node_->inlinks_in_order(); + return TensorNode(helper_, edges[i]->safe_as()); +} + +TensorNode OpNode::GetOutput(size_t i) const { + auto edges = node_->outlinks_in_order(); + return TensorNode(helper_, edges[i]->safe_as()); +} + +} // namespace api +} // namespace cinn \ No newline at end of file diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h new file mode 100644 index 0000000000..1291644aae --- /dev/null +++ b/cinn/api/op_node.h @@ -0,0 +1,64 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/pass/fusion_helper_base.h" + +namespace cinn { +namespace api { + +using OpPatternKind = cinn::hlir::framework::OpPatternKind; +using Attribute = cinn::utils::Attribute; + +class TensorNode; + +class OpNode { + public: + OpNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::Node* node) : helper_(helper), node_(node) {} + + OpPatternKind kind () { + return helper_->GetOpKind(node_); + } + + size_t InputsSize() const { + return node_->inlinks.size(); + } + + size_t OutputsSize() const { + return node_->outlinks.size(); + } + + TensorNode GetInput(size_t i) const; + + TensorNode GetOutput(size_t i) const; + + template + const T& GetAttr(const std::string& attr_name) const { + return absl::get(GetAttr(attr_name)); + } + + private: + const Attribute& GetAttr(const std::string& attr_name) { + return node_->attrs.attr_store.at(attr_name); + } + + const hlir::pass::FusionHelperBase* helper_; + const hlir::framework::Node* node_; +}; + +} // namespace api +} // namespace cinn diff --git a/cinn/api/tensor_node.cc b/cinn/api/tensor_node.cc new file mode 100644 index 0000000000..5f717d4d37 --- /dev/null +++ b/cinn/api/tensor_node.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/api/tensor_node.h" + +#include "cinn/api/op_node.h" + +namespace cinn { +namespace api { + +OpNode TensorNode::Producer() const { + return OpNode(helper_, node_data_->source_node.get()); +} + +OpNode TensorNode::Consumer(size_t index) const { + std::vector consumer_nodes; + for (auto& link : node_data_->outlinks()) { + auto consumer = link->sink()->safe_as(); + consumer_nodes.push_back(consumer); + } + return OpNode(helper_, consumer_nodes[index]); +} + + + + +} // namespace api +} // namespace cinn \ No newline at end of file diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h new file mode 100644 index 0000000000..8de0f4ba53 --- /dev/null +++ b/cinn/api/tensor_node.h @@ -0,0 +1,48 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/pass/fusion_helper_base.h" + +namespace cinn { +namespace api { + +class OpNode; + +class TensorNode { + public: + TensorNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::NodeData* node_data) : helper_(helper), node_data_(node_data) {} + + // Get the shape of tensor. + const shpae_t& Shape() const { + return helper_->GetNodeDataShape(node_data_) + } + + OpNode Producer() const; + + size_t ConsumerSize() const { + return node_data_->outlinks().size(); + } + + OpNode Consumer(size_t index) const; + + private: + const hlir::pass::FusionHelperBase* helper_; + const hlir::framework::NodeData* node_data_; +}; + +} // namespace api +} // namespace cinn From 052749e9c51be5625189aadceb6989d1f14f25ac Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 20 Jun 2023 07:09:23 +0000 Subject: [PATCH 28/66] update --- cinn/api/op_group.h | 93 ++++++++++++---- cinn/api/op_node.h | 4 +- cinn/hlir/framework/graph.h | 16 +-- cinn/hlir/pass/fusion_merge_pass.cc | 4 +- cinn/hlir/pass/general_fusion_merge_pass.cc | 115 +++++++++++--------- 5 files changed, 145 insertions(+), 87 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 3fd5b3742d..190dd9f8da 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "cinn/api/op_node.h" #include "cinn/hlir/framework/graph.h" @@ -24,44 +26,93 @@ namespace api { class OpGroup { public: - OpNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::Graph::Group* group) : helper_(helper), group_(group) {} + OpGroup(const hlir::pass::FusionHelperBase* helper, const std::shared_ptr& group) : helper_(helper), group_(group) {} + + OpGroup(const OpGroup& other) = default; + + class iterator { + public: + iterator(std::unordered_map, TensorInterfaceList>::iterator it, const hlir::pass::FusionHelperBase* helper) : iter_(it), helper_(helper) {} + + iterator& operator++() { + ++iter_; + return *this; + } + + iterator operator++(int) { + iterator tmp = *this; + ++iter_; + return tmp; + } + + OpGroup operator*() { + return OpGroup(helper_, iter_->first); + } + + bool operator==(const iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + private: + std::unordered_map, TensorInterfaceList>::iterator iter_; + const hlir::pass::FusionHelperBase* helper_; + }; + + hlir::framework::OpPatternKind kind() const { return group_->kind(); } size_t OpSize() const { - return group->CollectNodes().size(); + return group_->CollectNodes().size(); } OpNode GetOp(size_t index) const { - return group->CollectNodes()[index]; + return OpNode(helper_, group_->CollectNodes()[index]); } size_t ProducerSize() const { - return group->producer_groups().size(); - } - OpGroup GetProducer(size_t index) const { - std::vector producer_groups; - producer_groups.reserve(ProducerSize()); - for(const auto& producer : group->producer_groups()) { - producer_groups.push_back(producer.first.get()); - } - return OpGroup(helper_, producer_groups[index]); + return group_->producer_groups().size(); } size_t ConsumerSize() const { - return group->consumer_groups().size(); + return group_->consumer_groups().size(); } - OpGroup GetConsumer(size_t index) const { - std::vector consumer_groups; - consumer_groups.reserve(ConsumerSize()); - for(const auto& consumer : group->consumer_groups()) { - consumer_groups.push_back(consumer.first.get()); - } - return OpGroup(helper_, consumer_groups[index]); + iterator ProducerBegin() const { + return iterator(group_->mut_producer_groups()->begin(), helper_); + } + + iterator ProducerEnd() const { + return iterator(group_->mut_producer_groups()->end(), helper_); + } + + iterator ConsumerBegin() const { + return iterator(group_->mut_consumer_groups()->begin(), helper_); } + iterator ConsumerEnd() const { + return iterator(group_->mut_consumer_groups()->end(), helper_); + } + + std::shared_ptr GetGroup() const { + return group_; + } + + bool operator==(const OpGroup& other) const { + return group_.get() == other.group_.get(); + } + + // struct OpGroupHash { + // std::size_t operator()(const OpGroup& obj) const { + // return std::hash{}(obj.GetGroup().get()); + // } + // }; + private: const hlir::pass::FusionHelperBase* helper_; - const hlir::framework::Graph::Group* group_; + const std::shared_ptr group_; }; } // namespace api diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index 1291644aae..a62994d7fc 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -35,11 +35,11 @@ class OpNode { } size_t InputsSize() const { - return node_->inlinks.size(); + return node_->inlinks().size(); } size_t OutputsSize() const { - return node_->outlinks.size(); + return node_->outlinks().size(); } TensorNode GetInput(size_t i) const; diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 6e15af3cea..ba18a5ac6a 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -60,7 +60,7 @@ class Graph : public cinn::common::Graph { absl::flat_hash_map> attrs; std::vector> groups; - struct Group final : public OpGroupInterface { + struct Group { // distance to last group. int depth{0}; int max_depth{0}; @@ -127,29 +127,29 @@ class Graph : public cinn::common::Graph { std::string GetFuncName() { return "fn_" + group_id + unique_id; } public: - const std::unordered_map, TensorInterfaceList>& producer_groups() const override { + const std::unordered_map, TensorInterfaceList>& producer_groups() const { return producer_groups_; } - const std::unordered_map, TensorInterfaceList>& consumer_groups() const override { + const std::unordered_map, TensorInterfaceList>& consumer_groups() const { return consumer_groups_; } - std::unordered_map, TensorInterfaceList>* mut_producer_groups() { + std::unordered_map, TensorInterfaceList>* mut_producer_groups() { return &producer_groups_; } - std::unordered_map, TensorInterfaceList>* mut_consumer_groups() { + std::unordered_map, TensorInterfaceList>* mut_consumer_groups() { return &consumer_groups_; } - hlir::framework::OpPatternKind kind() const override { return op_pattern_kind; } + hlir::framework::OpPatternKind kind() const { return op_pattern_kind; } private: // input groups - std::unordered_map, TensorInterfaceList> producer_groups_; + std::unordered_map, TensorInterfaceList> producer_groups_; // output grous - std::unordered_map, TensorInterfaceList> consumer_groups_; + std::unordered_map, TensorInterfaceList> consumer_groups_; }; std::vector> fusion_groups; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 957447942d..22e16eb5b2 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -904,8 +904,8 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_map producers; + std::unordered_map consumers; for (auto& producer_and_list : group->producer_groups()) { const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index a47c37ce73..11a9bdb946 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -16,6 +16,7 @@ #include #include "cinn/api/op_group_interface.h" +#include "cinn/api/op_group.h" #include "cinn/common/is_reachable_predicator.h" #include "cinn/common/macros.h" #include "cinn/hlir/pass/fusion_merge_pass_util.h" @@ -38,7 +39,8 @@ using common::GraphNode; using GroupPtr = std::shared_ptr; using GroupList = std::vector; -using OpGroupPtr = std::shared_ptr; +// using OpGroupPtr = std::shared_ptr; +using OpGroupPtr = api::OpGroup; using OpGroupList = std::vector; using ConditionFunction = std::function; @@ -105,17 +107,17 @@ class GraphGroupFuseHelper final : public FuseHelper { private: bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { - return std::dynamic_pointer_cast(node)->min_depth; + return node.GetGroup()->min_depth; }; const auto& MaxDepth4Node = [&](OpGroupPtr node) { - return std::dynamic_pointer_cast(node)->max_depth; + return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - for (const auto& pair : node->producer2inputs()) { - if (node == consumer && pair.first == producer) { + for(auto iter = node.ProducerBegin(); iter != node.ProducerEnd(); ++iter) { + if (node == consumer && *iter == producer) { continue; } - Visit(pair.first); + Visit(*iter); } }; common::IsReachablePredicator is_reachable(MinDepth4Node, MaxDepth4Node, VisitNextNodes); @@ -172,7 +174,7 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { private: const FusionHelperBase* graph_group_fusion_helper_; - const OpGroupPtr group_; + OpGroupPtr group_; const std::function EnableFuse_; const std::unique_ptr fuse_helper_; }; @@ -181,7 +183,7 @@ class InputFusePassCtx : public FusePassCtx { public: virtual ~InputFusePassCtx() {} - virtual const std::unordered_set& PickConsumersWithSameInputs() const = 0; + virtual const OpGroupList& PickConsumersWithSameInputs() const = 0; virtual const FuseHelper& fuse_helper() const = 0; @@ -194,14 +196,14 @@ class InputFusePassCtx : public FusePassCtx { class GraphGroupInputFusePassCtx final : public InputFusePassCtx { public: GraphGroupInputFusePassCtx(const FusionHelperBase* graph_group_fusion_helper, - const std::unordered_set& groups, + const OpGroupList& groups, const std::function& EnableFuse) : graph_group_fusion_helper_(graph_group_fusion_helper), groups_(groups), EnableFuse_(EnableFuse), fuse_helper_(new GraphGroupFuseHelper(this)) {} - const std::unordered_set& PickConsumersWithSameInputs() const override { return groups_; } + const OpGroupList& PickConsumersWithSameInputs() const override { return groups_; } const FuseHelper& fuse_helper() const override { return *fuse_helper_; } @@ -211,7 +213,7 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { private: const FusionHelperBase* graph_group_fusion_helper_; - const std::unordered_set& groups_; + const OpGroupList& groups_; const std::function EnableFuse_; const std::unique_ptr fuse_helper_; }; @@ -219,73 +221,73 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { template bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { return is_same_size(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(first), - std::dynamic_pointer_cast(second)); + first.GetGroup(), + second.GetGroup()); } template bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const { return horizontal_with_injective(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return broadcast_fuse_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return injective_horizontal_with_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_elementwise(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), - std::dynamic_pointer_cast(src), - std::dynamic_pointer_cast(dst)); + src.GetGroup(), + dst.GetGroup()); } template @@ -293,7 +295,7 @@ struct HorizontalFuseUtil { using KindKeyT = std::pair; static bool DetectFusabilityByKind(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - const KindKeyT kind_pair(src->kind(), dst->kind()); + const KindKeyT kind_pair(src.kind(), dst.kind()); const auto& map = GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -447,8 +449,8 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for (const auto& pair : producer->consumer2outputs()) { - consumers.push_back(pair.first); + for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + consumers.push_back(*iter); } return consumers; }(); @@ -497,8 +499,8 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for (const auto& pair : producer->consumer2outputs()) { - consumers.push_back(pair.first); + for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + consumers.push_back(*iter); } return consumers; }(); @@ -521,7 +523,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src->kind(), dst->kind()); + const KindKeyT kind_pair(src.kind(), dst.kind()); const auto& map = GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -623,8 +625,8 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for (const auto& pair : producer->consumer2outputs()) { - consumers.push_back(pair.first); + for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + consumers.push_back(*iter); } return consumers; }(); @@ -639,7 +641,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } candidates.push_back(consumer); } - if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + if (candidates.size() == consumers.size() && producer.kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); } @@ -648,7 +650,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src->kind(), dst->kind()); + const KindKeyT kind_pair(src.kind(), dst.kind()); const auto& map = DefaultVerticalFusePass::GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -942,7 +944,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer->consumer2outputs().size() <= 1) { + if (producer.ConsumerSize() <= 1) { return; } const auto& fuse_passes = GetHorizontalFusePasses(); @@ -959,7 +961,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::set{first, second}); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -970,7 +972,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } GroupList ret; for (const auto& group : *group_sets.begin()) { - ret.push_back(std::dynamic_pointer_cast(group)); + ret.push_back(group.GetGroup()); } return ret; }; @@ -1010,7 +1012,12 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::set{first, second}); }; - GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse); + OpGroupList consumer_groups; + consumer_groups.reserve(consumers.size()); + for(auto& consumer : consumers) { + consumer_groups.emplace_back(this, consumer); + } + GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); EnableFusedInputGroups(&fuse_ctx); return tagged_sets; }; @@ -1021,7 +1028,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } GroupList ret; for (const auto& group : *group_sets.begin()) { - ret.push_back(std::dynamic_pointer_cast(group)); + ret.push_back(group.GetGroup()); } return ret; }; @@ -1331,7 +1338,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagVerticalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer->consumer2outputs().empty()) { + if (producer.ConsumerSize() == 0) { return; } const auto& fuse_passes = GetVerticalFusePasses(); @@ -1348,7 +1355,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); TagVerticalGroups(&fuse_ctx); return tagged_sets; }; @@ -1360,7 +1367,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } std::unordered_set ret; for (const auto& group_pair : group_sets) { - ret.insert(std::dynamic_pointer_cast(group_pair.second)); + ret.insert(group_pair.second.GetGroup()); } return ret; }; @@ -1578,7 +1585,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer->consumer2outputs().size() <= 1) { + if (producer.ConsumerSize() <= 1) { return; } const auto& fuse_passes = GetRecomputeFusePasses(); @@ -1595,7 +1602,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); TagRecomputeGroups(&fuse_ctx); return tagged_sets; }; @@ -1607,7 +1614,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } std::unordered_set ret; for (const auto& group_pair : group_sets) { - ret.insert(std::dynamic_pointer_cast(group_pair.second)); + ret.insert(group_pair.second.GetGroup()); } return ret; }; @@ -1899,8 +1906,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_map producers; + std::unordered_map consumers; for (auto& producer_and_list : group->producer_groups()) { const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); From f2650ce3b2e2d760977e0f33bbae76c966fbfc19 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 20 Jun 2023 08:36:41 +0000 Subject: [PATCH 29/66] update --- cinn/api/op_group.h | 10 +- cinn/api/op_node.h | 2 +- cinn/hlir/pass/general_fusion_merge_pass.cc | 114 +++++++++----------- 3 files changed, 60 insertions(+), 66 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 190dd9f8da..69eb658c25 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -45,8 +45,8 @@ class OpGroup { return tmp; } - OpGroup operator*() { - return OpGroup(helper_, iter_->first); + std::shared_ptr operator*() { + return std::make_shared(helper_, iter_->first); } bool operator==(const iterator& other) const { @@ -100,10 +100,14 @@ class OpGroup { return group_; } - bool operator==(const OpGroup& other) const { + bool operator == (const OpGroup& other) const { return group_.get() == other.group_.get(); } + bool operator < (const OpGroup& other) const { + return group_.get() < other.group_.get(); + } + // struct OpGroupHash { // std::size_t operator()(const OpGroup& obj) const { // return std::hash{}(obj.GetGroup().get()); diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index a62994d7fc..36627bf186 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -52,7 +52,7 @@ class OpNode { } private: - const Attribute& GetAttr(const std::string& attr_name) { + const Attribute& GetAttr(const std::string& attr_name) const { return node_->attrs.attr_store.at(attr_name); } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 11a9bdb946..283b34b46a 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -39,8 +39,8 @@ using common::GraphNode; using GroupPtr = std::shared_ptr; using GroupList = std::vector; -// using OpGroupPtr = std::shared_ptr; -using OpGroupPtr = api::OpGroup; +using OpGroupPtr = std::shared_ptr; +// using OpGroupPtr = api::OpGroup; using OpGroupList = std::vector; using ConditionFunction = std::function; @@ -107,13 +107,13 @@ class GraphGroupFuseHelper final : public FuseHelper { private: bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { - return node.GetGroup()->min_depth; + return node->GetGroup()->min_depth; }; const auto& MaxDepth4Node = [&](OpGroupPtr node) { - return node.GetGroup()->max_depth; + return node->GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - for(auto iter = node.ProducerBegin(); iter != node.ProducerEnd(); ++iter) { + for(auto iter = node->ProducerEnd(); iter != node->ProducerEnd(); ++iter) { if (node == consumer && *iter == producer) { continue; } @@ -221,73 +221,73 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { template bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { return is_same_size(&ctx_->graph_group_fusion_helper(), - first.GetGroup(), - second.GetGroup()); + first->GetGroup(), + second->GetGroup()); } template bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const { return horizontal_with_injective(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return broadcast_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return injective_horizontal_with_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_elementwise(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + src->GetGroup(), + dst->GetGroup()); } template @@ -295,7 +295,7 @@ struct HorizontalFuseUtil { using KindKeyT = std::pair; static bool DetectFusabilityByKind(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - const KindKeyT kind_pair(src.kind(), dst.kind()); + const KindKeyT kind_pair(src->kind(), dst->kind()); const auto& map = GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -382,17 +382,11 @@ class DefaultInputFusePass final : public InputFusePass { void operator()(InputFusePassCtx* ctx) const override { VLOG(1) << "DefaultInputFusePass"; - const auto& consumer_set = ctx->PickConsumersWithSameInputs(); - if (consumer_set.size() <= 1) { + const auto& consumers = ctx->PickConsumersWithSameInputs(); + if (consumers.size() <= 1) { return; } - const OpGroupList consumers = [&]() { - OpGroupList ret; - for (const auto& consumer : consumer_set) { - ret.push_back(consumer); - } - return ret; - }(); + for (int i = 0; i < consumers.size(); ++i) { const auto& src = consumers.at(i); for (int j = i + 1; j < consumers.size(); ++j) { @@ -449,7 +443,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + for(auto iter = producer->ConsumerBegin(); iter!= producer->ConsumerEnd(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -499,7 +493,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + for(auto iter = producer->ConsumerBegin(); iter!= producer->ConsumerEnd(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -523,7 +517,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src.kind(), dst.kind()); + const KindKeyT kind_pair(src->kind(), dst->kind()); const auto& map = GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -625,7 +619,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + for(auto iter = producer->ConsumerBegin(); iter!= producer->ConsumerEnd(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -641,7 +635,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } candidates.push_back(consumer); } - if (candidates.size() == consumers.size() && producer.kind() == framework::kElementWise) { + if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); } @@ -650,7 +644,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src.kind(), dst.kind()); + const KindKeyT kind_pair(src->kind(), dst->kind()); const auto& map = DefaultVerticalFusePass::GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -944,7 +938,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.ConsumerSize() <= 1) { + if (producer->ConsumerSize() <= 1) { return; } const auto& fuse_passes = GetHorizontalFusePasses(); @@ -955,13 +949,13 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralHorizontalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralHorizontalFuse...!"; - using OpGroupSets = std::set>; + using OpGroupSets = std::set>; const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { OpGroupSets tagged_sets; const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { - tagged_sets.insert(std::set{first, second}); + tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -970,10 +964,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { if (group_sets.empty()) { return GroupList{}; } - GroupList ret; - for (const auto& group : *group_sets.begin()) { - ret.push_back(group.GetGroup()); - } + const auto& group_pair = *group_sets.begin(); + GroupList ret{group_pair.first->GetGroup(), group_pair.second->GetGroup()}; return ret; }; bool update = false; @@ -1006,16 +998,16 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; - using OpGroupSets = std::set>; + using OpGroupSets = std::set>; const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { OpGroupSets tagged_sets; const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { - tagged_sets.insert(std::set{first, second}); + tagged_sets.insert(std::make_pair(first, second)); }; OpGroupList consumer_groups; consumer_groups.reserve(consumers.size()); for(auto& consumer : consumers) { - consumer_groups.emplace_back(this, consumer); + consumer_groups.push_back(std::make_shared(this, consumer)); } GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); EnableFusedInputGroups(&fuse_ctx); @@ -1026,10 +1018,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { if (group_sets.empty()) { return GroupList{}; } - GroupList ret; - for (const auto& group : *group_sets.begin()) { - ret.push_back(group.GetGroup()); - } + const auto& group_pair = *group_sets.begin(); + GroupList ret{group_pair.first->GetGroup(), group_pair.second->GetGroup()}; return ret; }; const auto& groups = GetFusableConsumerGroupList(); @@ -1338,7 +1328,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagVerticalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.ConsumerSize() == 0) { + if (producer->ConsumerSize() == 0) { return; } const auto& fuse_passes = GetVerticalFusePasses(); @@ -1355,7 +1345,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); TagVerticalGroups(&fuse_ctx); return tagged_sets; }; @@ -1367,7 +1357,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } std::unordered_set ret; for (const auto& group_pair : group_sets) { - ret.insert(group_pair.second.GetGroup()); + ret.insert(group_pair.second->GetGroup()); } return ret; }; @@ -1585,7 +1575,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.ConsumerSize() <= 1) { + if (producer->ConsumerSize() <= 1) { return; } const auto& fuse_passes = GetRecomputeFusePasses(); @@ -1602,7 +1592,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); TagRecomputeGroups(&fuse_ctx); return tagged_sets; }; @@ -1614,7 +1604,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } std::unordered_set ret; for (const auto& group_pair : group_sets) { - ret.insert(group_pair.second.GetGroup()); + ret.insert(group_pair.second->GetGroup()); } return ret; }; From 110a4955203a330e1090be0eb197f31ddbe9dc37 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 20 Jun 2023 11:30:02 +0000 Subject: [PATCH 30/66] update --- cinn/api/op_group.h | 15 ++++++++------ cinn/hlir/framework/graph.h | 23 +++++++++++++++------ cinn/hlir/pass/fusion_merge_pass.cc | 11 ++++++---- cinn/hlir/pass/general_fusion_merge_pass.cc | 11 ++++++---- 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 69eb658c25..f176ada16b 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -108,12 +108,6 @@ class OpGroup { return group_.get() < other.group_.get(); } - // struct OpGroupHash { - // std::size_t operator()(const OpGroup& obj) const { - // return std::hash{}(obj.GetGroup().get()); - // } - // }; - private: const hlir::pass::FusionHelperBase* helper_; const std::shared_ptr group_; @@ -121,3 +115,12 @@ class OpGroup { } // namespace api } // namespace cinn + +namespace std { + template <> + struct hash { + size_t operator()(const cinn::api::OpGroup& obj) const { + return std::hash{}(obj.GetGroup().get()); + } + }; +} diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index ba18a5ac6a..9eb4e3abba 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -93,6 +93,17 @@ class Graph : public cinn::common::Graph { std::vector input_names; std::vector output_names; + struct SharedGroupHasher { + size_t operator()(const std::shared_ptr& group) const noexcept { + return std::hash()(reinterpret_cast(group.get())); + } + }; + struct SharedGroupComparator { + bool operator()(const std::shared_ptr& first, const std::shared_ptr& second) const noexcept { + return first.get() == second.get(); + } + }; + std::unordered_set> CollectConsumerGroups() { std::unordered_set> groups; for (const auto& consumer_and_list : consumer_groups_) { @@ -127,19 +138,19 @@ class Graph : public cinn::common::Graph { std::string GetFuncName() { return "fn_" + group_id + unique_id; } public: - const std::unordered_map, TensorInterfaceList>& producer_groups() const { + const std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>& producer_groups() const { return producer_groups_; } - const std::unordered_map, TensorInterfaceList>& consumer_groups() const { + const std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>& consumer_groups() const { return consumer_groups_; } - std::unordered_map, TensorInterfaceList>* mut_producer_groups() { + std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>* mut_producer_groups() { return &producer_groups_; } - std::unordered_map, TensorInterfaceList>* mut_consumer_groups() { + std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>* mut_consumer_groups() { return &consumer_groups_; } @@ -147,9 +158,9 @@ class Graph : public cinn::common::Graph { private: // input groups - std::unordered_map, TensorInterfaceList> producer_groups_; + std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator> producer_groups_; // output grous - std::unordered_map, TensorInterfaceList> consumer_groups_; + std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator> consumer_groups_; }; std::vector> fusion_groups; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 22e16eb5b2..5a509375d2 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -32,6 +32,9 @@ using common::GraphNode; using GroupPtr = std::shared_ptr; using GroupList = std::vector; +using Comparator = Graph::Group::SharedGroupComparator; +using Hasher = Graph::Group::SharedGroupHasher; + using OpGroupPtr = std::shared_ptr; using OpGroupList = std::vector; @@ -904,11 +907,11 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_map producers; + std::unordered_map consumers; - for (auto& producer_and_list : group->producer_groups()) { - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + for (const auto& producer_and_list : group->producer_groups()) { + const auto& producer = producer_and_list.first; CHECK(producer->belong_groups.size()); // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. producers[*producer->belong_groups.begin()] += {}; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 283b34b46a..440e96e98b 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -39,6 +39,9 @@ using common::GraphNode; using GroupPtr = std::shared_ptr; using GroupList = std::vector; +using Comparator = Graph::Group::SharedGroupComparator; +using Hasher = Graph::Group::SharedGroupHasher; + using OpGroupPtr = std::shared_ptr; // using OpGroupPtr = api::OpGroup; using OpGroupList = std::vector; @@ -1339,11 +1342,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralVerticalFuse(GroupPtr& producer) { VLOG(3) << "GeneralVerticalFuse...!"; - using GroupSets = std::set>; + using GroupSets = std::vector>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { - tagged_sets.insert(std::make_pair(first, second)); + tagged_sets.push_back(std::make_pair(first, second)); }; GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); TagVerticalGroups(&fuse_ctx); @@ -1896,8 +1899,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_map producers; + std::unordered_map consumers; for (auto& producer_and_list : group->producer_groups()) { const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); From 2c767e40fff74d30d9279559c0772834e970643a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 20 Jun 2023 12:06:21 +0000 Subject: [PATCH 31/66] change shared_ptr to OpGroup object in iterator --- cinn/api/op_group.h | 20 +- cinn/hlir/framework/graph.h | 4 +- .../pass/check_fusion_accuracy_pass_test.cc | 1804 ++++++++--------- cinn/hlir/pass/fusion_merge_pass.cc | 36 +- cinn/hlir/pass/general_fusion_merge_pass.cc | 130 +- 5 files changed, 998 insertions(+), 996 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index f176ada16b..5bdad0765e 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -45,8 +45,8 @@ class OpGroup { return tmp; } - std::shared_ptr operator*() { - return std::make_shared(helper_, iter_->first); + OpGroup operator*() { + return OpGroup(helper_, iter_->first); } bool operator==(const iterator& other) const { @@ -117,10 +117,12 @@ class OpGroup { } // namespace cinn namespace std { - template <> - struct hash { - size_t operator()(const cinn::api::OpGroup& obj) const { - return std::hash{}(obj.GetGroup().get()); - } - }; -} + +template <> +struct hash { + size_t operator()(const cinn::api::OpGroup& obj) const { + return std::hash()(reinterpret_cast(obj.GetGroup().get())); + } +}; + +} // namespace std diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 9eb4e3abba..73396ac888 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -104,8 +104,8 @@ class Graph : public cinn::common::Graph { } }; - std::unordered_set> CollectConsumerGroups() { - std::unordered_set> groups; + std::unordered_set, SharedGroupHasher, SharedGroupComparator> CollectConsumerGroups() { + std::unordered_set, SharedGroupHasher, SharedGroupComparator> groups; for (const auto& consumer_and_list : consumer_groups_) { groups.insert(std::dynamic_pointer_cast(consumer_and_list.first)); } diff --git a/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc b/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc index 3283964ea9..72dc0a72cc 100644 --- a/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc +++ b/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc @@ -92,1026 +92,1026 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion) { RunTest(target, graph, {"A", "B", "C", "D"}); } -TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_0"); - std::unordered_set fetch_ids; - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(E, C); - auto G = net_builder.Add(E, D); - - fetch_ids = {F->id, G->id}; - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_1"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(C, D); - auto G = net_builder.Add(E, F); - auto I = net_builder.Add(E, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_1) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_1"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(C, D); - auto G = net_builder.Add(E, F); - auto I = net_builder.Add(E, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_2"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); - auto G = net_builder.Add(A, B); - auto H = net_builder.Add(C, D); - auto I = net_builder.Add(E, G); - auto J = net_builder.Add(G, H); - auto K = net_builder.Add(H, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -} - -TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_2) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_2"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); - auto G = net_builder.Add(A, B); - auto H = net_builder.Add(C, D); - auto I = net_builder.Add(E, G); - auto J = net_builder.Add(G, H); - auto K = net_builder.Add(H, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -} - -TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_3"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); - auto G = net_builder.Add(A, B); - auto H = net_builder.Add(G, C); - auto I = net_builder.Add(G, D); - auto J = net_builder.Add(G, E); - auto K = net_builder.Add(G, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -} - -TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_3) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_3"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); - auto G = net_builder.Add(A, B); - auto H = net_builder.Add(G, C); - auto I = net_builder.Add(G, D); - auto J = net_builder.Add(G, E); - auto K = net_builder.Add(G, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -} - -TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_4"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); - auto G = net_builder.Add(A, B); - auto H = net_builder.Add(G, C); - auto I = net_builder.Add(G, D); - auto J = net_builder.Add(I, E); - auto K = net_builder.Add(I, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -} - -TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_4) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_4"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); - auto G = net_builder.Add(A, B); - auto H = net_builder.Add(G, C); - auto I = net_builder.Add(G, D); - auto J = net_builder.Add(I, E); - auto K = net_builder.Add(I, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -} - -TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_5"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.Add(A, B); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B"}); -} - -TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_5) { - int h = 32, w = 32; - NetBuilder net_builder("ElementWise_Fusion_5"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.Add(A, B); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B"}); -} - -TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_0"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(C, D); - auto G = net_builder.Add(F, E); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, General_Broadcast_Test_0) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_0"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(C, D); - auto G = net_builder.Add(F, E); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_2"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(C, E); - auto G = net_builder.Add(D, E); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, General_Broadcast_Test_2) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_2"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.Add(C, E); - auto G = net_builder.Add(D, E); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D"}); -} - -TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_4"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.Add(A, B); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - auto I = net_builder.Add(E, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - - auto graph = std::make_shared(program, target); - - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - - CHECK_EQ(graph->fusion_groups.size(), group_size_after); - - RunTest(target, graph, {"A", "B", "C", "D", "E"}); -} +// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_0"); +// std::unordered_set fetch_ids; +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(E, C); +// auto G = net_builder.Add(E, D); + +// fetch_ids = {F->id, G->id}; +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_1"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(C, D); +// auto G = net_builder.Add(E, F); +// auto I = net_builder.Add(E, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_1) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_1"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(C, D); +// auto G = net_builder.Add(E, F); +// auto I = net_builder.Add(E, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_2"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); +// auto G = net_builder.Add(A, B); +// auto H = net_builder.Add(C, D); +// auto I = net_builder.Add(E, G); +// auto J = net_builder.Add(G, H); +// auto K = net_builder.Add(H, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +// } + +// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_2) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_2"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); +// auto G = net_builder.Add(A, B); +// auto H = net_builder.Add(C, D); +// auto I = net_builder.Add(E, G); +// auto J = net_builder.Add(G, H); +// auto K = net_builder.Add(H, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +// } + +// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_3"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); +// auto G = net_builder.Add(A, B); +// auto H = net_builder.Add(G, C); +// auto I = net_builder.Add(G, D); +// auto J = net_builder.Add(G, E); +// auto K = net_builder.Add(G, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +// } + +// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_3) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_3"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); +// auto G = net_builder.Add(A, B); +// auto H = net_builder.Add(G, C); +// auto I = net_builder.Add(G, D); +// auto J = net_builder.Add(G, E); +// auto K = net_builder.Add(G, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +// } + +// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_4"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); +// auto G = net_builder.Add(A, B); +// auto H = net_builder.Add(G, C); +// auto I = net_builder.Add(G, D); +// auto J = net_builder.Add(I, E); +// auto K = net_builder.Add(I, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +// } + +// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_4) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_4"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); +// auto G = net_builder.Add(A, B); +// auto H = net_builder.Add(G, C); +// auto I = net_builder.Add(G, D); +// auto J = net_builder.Add(I, E); +// auto K = net_builder.Add(I, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); -TEST(CheckFusionAccuracyPass, General_Broadcast_Test_4) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_4"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); - auto F = net_builder.Add(A, B); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - auto I = net_builder.Add(E, F); - } - - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +// } + +// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_5"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.Add(A, B); +// } - auto graph = std::make_shared(program, target); +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// auto graph = std::make_shared(program, target); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - RunTest(target, graph, {"A", "B", "C", "D", "E"}); -} +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B"}); +// } + +// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_5) { +// int h = 32, w = 32; +// NetBuilder net_builder("ElementWise_Fusion_5"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.Add(A, B); +// } -TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_5"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); - auto F = net_builder.Add(A, B); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - auto I = net_builder.Add(E, F); - } +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// auto graph = std::make_shared(program, target); - auto graph = std::make_shared(program, target); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// RunTest(target, graph, {"A", "B"}); +// } - RunTest(target, graph, {"A", "B", "C", "D", "E"}); -} +// TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_0"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(C, D); +// auto G = net_builder.Add(F, E); +// } -TEST(CheckFusionAccuracyPass, General_Broadcast_Test_5) { - int h = 32, w = 32; - NetBuilder net_builder("Broadcast_Test_5"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {w}, "A"); - auto B = net_builder.CreateInput(Float(32), {w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); - auto F = net_builder.Add(A, B); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - auto I = net_builder.Add(E, F); - } +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// auto graph = std::make_shared(program, target); - auto graph = std::make_shared(program, target); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } - RunTest(target, graph, {"A", "B", "C", "D", "E"}); -} +// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_0) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_0"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(C, D); +// auto G = net_builder.Add(F, E); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); -TEST(CheckFusionAccuracyPass, Reduce_Test_0) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_0"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.ReduceSum(C, {0}); - auto E = net_builder.ReduceSum(C, {0}); - auto F = net_builder.ReduceSum(C, {0}); - } +// auto graph = std::make_shared(program, target); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - auto graph = std::make_shared(program, target); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_2"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(C, E); +// auto G = net_builder.Add(D, E); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - RunTest(target, graph, {"A", "B"}); -} +// auto graph = std::make_shared(program, target); -TEST(CheckFusionAccuracyPass, General_Reduce_Test_0) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_0"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.ReduceSum(C, {0}); - auto E = net_builder.ReduceSum(C, {0}); - auto F = net_builder.ReduceSum(C, {0}); - } +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - auto graph = std::make_shared(program, target); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_2) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_2"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.Add(C, E); +// auto G = net_builder.Add(D, E); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - RunTest(target, graph, {"A", "B"}); -} +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_4"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.Add(A, B); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// auto I = net_builder.Add(E, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E"}); +// } + +// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_4) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_4"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); +// auto F = net_builder.Add(A, B); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// auto I = net_builder.Add(E, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E"}); +// } + +// TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_5"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); +// auto F = net_builder.Add(A, B); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// auto I = net_builder.Add(E, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); -TEST(CheckFusionAccuracyPass, Reduce_Test_1) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_1"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.ReduceSum(C, {0}); - auto E = net_builder.ReduceSum(C, {1}); - } +// auto graph = std::make_shared(program, target); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E"}); +// } + +// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_5) { +// int h = 32, w = 32; +// NetBuilder net_builder("Broadcast_Test_5"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); +// auto F = net_builder.Add(A, B); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// auto I = net_builder.Add(E, F); +// } - auto graph = std::make_shared(program, target); +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// auto graph = std::make_shared(program, target); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - RunTest(target, graph, {"A", "B"}); -} +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D", "E"}); +// } + +// TEST(CheckFusionAccuracyPass, Reduce_Test_0) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_0"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.ReduceSum(C, {0}); +// auto E = net_builder.ReduceSum(C, {0}); +// auto F = net_builder.ReduceSum(C, {0}); +// } -TEST(CheckFusionAccuracyPass, General_Reduce_Test_1) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_1"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.ReduceSum(C, {0}); - auto E = net_builder.ReduceSum(C, {1}); - } +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// auto graph = std::make_shared(program, target); - auto graph = std::make_shared(program, target); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// RunTest(target, graph, {"A", "B"}); +// } + +// TEST(CheckFusionAccuracyPass, General_Reduce_Test_0) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_0"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.ReduceSum(C, {0}); +// auto E = net_builder.ReduceSum(C, {0}); +// auto F = net_builder.ReduceSum(C, {0}); +// } - RunTest(target, graph, {"A", "B"}); -} +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); -TEST(CheckFusionAccuracyPass, Reduce_Test_2) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_2"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.Add(A, B); - auto E = net_builder.ReduceSum(D, {0}); - auto F = net_builder.ReduceSum(D, {1}); - auto G = net_builder.Add(C, E); - auto H = net_builder.Add(C, F); - } +// auto graph = std::make_shared(program, target); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - auto graph = std::make_shared(program, target); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// RunTest(target, graph, {"A", "B"}); +// } - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// TEST(CheckFusionAccuracyPass, Reduce_Test_1) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_1"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.ReduceSum(C, {0}); +// auto E = net_builder.ReduceSum(C, {1}); +// } - RunTest(target, graph, {"A", "B", "C"}); -} +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); -TEST(CheckFusionAccuracyPass, General_Reduce_Test_2) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_2"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.Add(A, B); - auto E = net_builder.ReduceSum(D, {0}); - auto F = net_builder.ReduceSum(D, {1}); - auto G = net_builder.Add(C, E); - auto H = net_builder.Add(C, F); - } +// auto graph = std::make_shared(program, target); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - auto graph = std::make_shared(program, target); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// RunTest(target, graph, {"A", "B"}); +// } - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// TEST(CheckFusionAccuracyPass, General_Reduce_Test_1) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_1"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.ReduceSum(C, {0}); +// auto E = net_builder.ReduceSum(C, {1}); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - RunTest(target, graph, {"A", "B", "C"}); -} +// auto graph = std::make_shared(program, target); -TEST(CheckFusionAccuracyPass, Reduce_Test_3) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_3"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.ReduceSum(E, {0}); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - } +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - auto graph = std::make_shared(program, target); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// RunTest(target, graph, {"A", "B"}); +// } - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// TEST(CheckFusionAccuracyPass, Reduce_Test_2) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_2"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.Add(A, B); +// auto E = net_builder.ReduceSum(D, {0}); +// auto F = net_builder.ReduceSum(D, {1}); +// auto G = net_builder.Add(C, E); +// auto H = net_builder.Add(C, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// auto graph = std::make_shared(program, target); - RunTest(target, graph, {"A", "B", "C", "D"}); -} +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -TEST(CheckFusionAccuracyPass, General_Reduce_Test_3) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_3"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.ReduceSum(E, {0}); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - } +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - auto graph = std::make_shared(program, target); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// RunTest(target, graph, {"A", "B", "C"}); +// } - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// TEST(CheckFusionAccuracyPass, General_Reduce_Test_2) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_2"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.Add(A, B); +// auto E = net_builder.ReduceSum(D, {0}); +// auto F = net_builder.ReduceSum(D, {1}); +// auto G = net_builder.Add(C, E); +// auto H = net_builder.Add(C, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// auto graph = std::make_shared(program, target); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - RunTest(target, graph, {"A", "B", "C", "D"}); -} +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -TEST(CheckFusionAccuracyPass, Reduce_Test_4) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_4"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.ReduceSum(E, {0}); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - auto I = net_builder.Add(D, F); - } +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - auto graph = std::make_shared(program, target); +// RunTest(target, graph, {"A", "B", "C"}); +// } - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// TEST(CheckFusionAccuracyPass, Reduce_Test_3) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_3"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.ReduceSum(E, {0}); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, General_Reduce_Test_3) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_3"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.ReduceSum(E, {0}); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// } + +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); + +// auto graph = std::make_shared(program, target); + +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, Reduce_Test_4) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_4"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.ReduceSum(E, {0}); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// auto I = net_builder.Add(D, F); +// } - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// auto graph = std::make_shared(program, target); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - RunTest(target, graph, {"A", "B", "C", "D"}); -} +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -TEST(CheckFusionAccuracyPass, General_Reduce_Test_4) { - int h = 32, w = 32; - NetBuilder net_builder("Reduce_Test_4"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.CreateInput(Float(32), {w}, "C"); - auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); - auto E = net_builder.Add(A, B); - auto F = net_builder.ReduceSum(E, {0}); - auto G = net_builder.Add(C, F); - auto H = net_builder.Add(D, F); - auto I = net_builder.Add(D, F); - } +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - auto graph = std::make_shared(program, target); +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, General_Reduce_Test_4) { +// int h = 32, w = 32; +// NetBuilder net_builder("Reduce_Test_4"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.CreateInput(Float(32), {w}, "C"); +// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); +// auto E = net_builder.Add(A, B); +// auto F = net_builder.ReduceSum(E, {0}); +// auto G = net_builder.Add(C, F); +// auto H = net_builder.Add(D, F); +// auto I = net_builder.Add(D, F); +// } - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// auto graph = std::make_shared(program, target); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - RunTest(target, graph, {"A", "B", "C", "D"}); -} +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -TEST(CheckFusionAccuracyPass, Reduce_Test_5) { - int h = 128, w = 128; - NetBuilder net_builder("Reduce_Test_5"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.ReduceSum(A, {1}); - auto E = net_builder.ReduceSum(B, {1}); - auto F = net_builder.ReduceSum(C, {1}); - } +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + +// RunTest(target, graph, {"A", "B", "C", "D"}); +// } + +// TEST(CheckFusionAccuracyPass, Reduce_Test_5) { +// int h = 128, w = 128; +// NetBuilder net_builder("Reduce_Test_5"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.ReduceSum(A, {1}); +// auto E = net_builder.ReduceSum(B, {1}); +// auto F = net_builder.ReduceSum(C, {1}); +// } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, target); +// auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - RunTest(target, graph, {"A", "B"}); -} +// RunTest(target, graph, {"A", "B"}); +// } -TEST(CheckFusionAccuracyPass, General_Reduce_Test_5) { - int h = 128, w = 128; - NetBuilder net_builder("Reduce_Test_5"); - // create model - { - auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); - auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); - auto C = net_builder.Add(A, B); - auto D = net_builder.ReduceSum(A, {1}); - auto E = net_builder.ReduceSum(B, {1}); - auto F = net_builder.ReduceSum(C, {1}); - } +// TEST(CheckFusionAccuracyPass, General_Reduce_Test_5) { +// int h = 128, w = 128; +// NetBuilder net_builder("Reduce_Test_5"); +// // create model +// { +// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); +// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); +// auto C = net_builder.Add(A, B); +// auto D = net_builder.ReduceSum(A, {1}); +// auto E = net_builder.ReduceSum(B, {1}); +// auto F = net_builder.ReduceSum(C, {1}); +// } - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); +// auto program = net_builder.Build(); +// auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, target); +// auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); +// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - CHECK_EQ(graph->fusion_groups.size(), group_size_after); +// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - RunTest(target, graph, {"A", "B"}); -} +// RunTest(target, graph, {"A", "B"}); +// } } // namespace cinn::frontend diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 5a509375d2..ffa1eeaf30 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -135,7 +135,7 @@ class FusionMergePassHelper : public FusionHelperBase { void UpdateFusionGroup() { VLOG(3) << "UpdateFusionGroup..."; GroupList fusion_groups; - std::unordered_set fusion_groups_set; + std::unordered_set fusion_groups_set; // update fusion_groups_ for (auto& group : fusion_groups_) { if (!group->belong_groups.size()) { @@ -179,13 +179,13 @@ class FusionMergePassHelper : public FusionHelperBase { } } - bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { + bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { return false; } - std::unordered_set candidates; + std::unordered_set candidates; for (const auto& consumer : consumers) { // relation auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; @@ -255,7 +255,7 @@ class FusionMergePassHelper : public FusionHelperBase { auto fused_group = std::make_shared(); // As recompute exist which may case sub-group used by more than one time. std::vector repeat_sub_groups; - std::unordered_set sub_group_set; + std::unordered_set sub_group_set; // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. @@ -394,7 +394,7 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { + bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; // if producer can't fuse others @@ -402,8 +402,8 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - std::unordered_set fuse_consumers_unsafe; - std::unordered_set fuse_consumers; + std::unordered_set fuse_consumers_unsafe; + std::unordered_set fuse_consumers; for (const auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse @@ -466,7 +466,7 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); @@ -653,16 +653,16 @@ class FusionMergePassHelper : public FusionHelperBase { } } - void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } } - void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { - std::unordered_set candidates; + std::unordered_set candidates; for (auto& consumer : fusionable_consumers) { // if can be output node. if (is_same_shape(this, producer, consumer)) { @@ -732,7 +732,7 @@ class FusionMergePassHelper : public FusionHelperBase { fusionable_consumers.insert(*candidates.begin()); } } else { - std::unordered_set candidates; + std::unordered_set candidates; for (auto& consumer : fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.insert(consumer); @@ -757,11 +757,11 @@ class FusionMergePassHelper : public FusionHelperBase { bool IsDependency(const GroupPtr& producer_g, const GroupPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); @@ -784,12 +784,12 @@ class FusionMergePassHelper : public FusionHelperBase { bool IsDependencySimplify(const GroupPtr& producer_g, const GroupPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); // check upper. int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); @@ -838,7 +838,7 @@ class FusionMergePassHelper : public FusionHelperBase { void UpdateInputToConsumers() { for (auto& input_consumers : input_to_consumers_) { auto& consumers = input_consumers.second; - std::unordered_set updated_consumers; + std::unordered_set updated_consumers; for (auto& consumer : consumers) { std::queue fused_groups; fused_groups.push(consumer); @@ -1022,7 +1022,7 @@ class FusionMergePassHelper : public FusionHelperBase { GroupList fusion_groups_; std::unordered_map fusion_groups_index_; - std::unordered_map> input_to_consumers_; + std::unordered_map> input_to_consumers_; struct Relation { std::unordered_map vertical_relation; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 440e96e98b..0633966a3f 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -42,7 +42,7 @@ using GroupList = std::vector; using Comparator = Graph::Group::SharedGroupComparator; using Hasher = Graph::Group::SharedGroupHasher; -using OpGroupPtr = std::shared_ptr; +using OpGroupPtr = api::OpGroup; // using OpGroupPtr = api::OpGroup; using OpGroupList = std::vector; @@ -110,13 +110,13 @@ class GraphGroupFuseHelper final : public FuseHelper { private: bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { - return node->GetGroup()->min_depth; + return node.GetGroup()->min_depth; }; const auto& MaxDepth4Node = [&](OpGroupPtr node) { - return node->GetGroup()->max_depth; + return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - for(auto iter = node->ProducerEnd(); iter != node->ProducerEnd(); ++iter) { + for(auto iter = node.ProducerEnd(); iter != node.ProducerEnd(); ++iter) { if (node == consumer && *iter == producer) { continue; } @@ -224,73 +224,73 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { template bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { return is_same_size(&ctx_->graph_group_fusion_helper(), - first->GetGroup(), - second->GetGroup()); + first.GetGroup(), + second.GetGroup()); } template bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const { return horizontal_with_injective(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return broadcast_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return injective_horizontal_with_reduce(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_elementwise(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src->GetGroup(), - dst->GetGroup()); + src.GetGroup(), + dst.GetGroup()); } template @@ -298,7 +298,7 @@ struct HorizontalFuseUtil { using KindKeyT = std::pair; static bool DetectFusabilityByKind(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - const KindKeyT kind_pair(src->kind(), dst->kind()); + const KindKeyT kind_pair(src.kind(), dst.kind()); const auto& map = GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -446,7 +446,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer->ConsumerBegin(); iter!= producer->ConsumerEnd(); ++iter) { + for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -496,7 +496,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer->ConsumerBegin(); iter!= producer->ConsumerEnd(); ++iter) { + for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -520,7 +520,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src->kind(), dst->kind()); + const KindKeyT kind_pair(src.kind(), dst.kind()); const auto& map = GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -622,7 +622,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer->ConsumerBegin(); iter!= producer->ConsumerEnd(); ++iter) { + for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -638,7 +638,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } candidates.push_back(consumer); } - if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + if (candidates.size() == consumers.size() && producer.kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); } @@ -647,7 +647,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { using KindKeyT = std::pair; bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { - const KindKeyT kind_pair(src->kind(), dst->kind()); + const KindKeyT kind_pair(src.kind(), dst.kind()); const auto& map = DefaultVerticalFusePass::GetConditionMap(); const auto& iter = map.find(kind_pair); if (iter == map.end()) { @@ -886,7 +886,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void UpdateFusionGroup() { VLOG(3) << "UpdateFusionGroup..."; GroupList fusion_groups; - std::unordered_set fusion_groups_set; + std::unordered_set fusion_groups_set; // update fusion_groups_ for (auto& group : fusion_groups_) { if (!group->belong_groups.size()) { @@ -941,7 +941,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer->ConsumerSize() <= 1) { + if (producer.ConsumerSize() <= 1) { return; } const auto& fuse_passes = GetHorizontalFusePasses(); @@ -958,7 +958,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -968,7 +968,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return GroupList{}; } const auto& group_pair = *group_sets.begin(); - GroupList ret{group_pair.first->GetGroup(), group_pair.second->GetGroup()}; + GroupList ret{group_pair.first.GetGroup(), group_pair.second.GetGroup()}; return ret; }; bool update = false; @@ -999,7 +999,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - bool CallGeneralInputFusePass(const std::unordered_set& consumers) { + bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; using OpGroupSets = std::set>; const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { @@ -1010,7 +1010,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { OpGroupList consumer_groups; consumer_groups.reserve(consumers.size()); for(auto& consumer : consumers) { - consumer_groups.push_back(std::make_shared(this, consumer)); + consumer_groups.push_back(api::OpGroup(this, consumer)); } GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); EnableFusedInputGroups(&fuse_ctx); @@ -1022,7 +1022,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return GroupList{}; } const auto& group_pair = *group_sets.begin(); - GroupList ret{group_pair.first->GetGroup(), group_pair.second->GetGroup()}; + GroupList ret{group_pair.first.GetGroup(), group_pair.second.GetGroup()}; return ret; }; const auto& groups = GetFusableConsumerGroupList(); @@ -1033,13 +1033,13 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return true; } - bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { + bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { return false; } - std::unordered_set candidates; + std::unordered_set candidates; for (const auto& consumer : consumers) { // relation auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; @@ -1109,7 +1109,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { auto fused_group = std::make_shared(); // As recompute exist which may case sub-group used by more than one time. std::vector repeat_sub_groups; - std::unordered_set sub_group_set; + std::unordered_set sub_group_set; // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. @@ -1248,7 +1248,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { + bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; // if producer can't fuse others @@ -1256,8 +1256,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return false; } - std::unordered_set fuse_consumers_unsafe; - std::unordered_set fuse_consumers; + std::unordered_set fuse_consumers_unsafe; + std::unordered_set fuse_consumers; for (const auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse @@ -1331,7 +1331,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagVerticalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer->ConsumerSize() == 0) { + if (producer.ConsumerSize() == 0) { return; } const auto& fuse_passes = GetVerticalFusePasses(); @@ -1348,19 +1348,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.push_back(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); TagVerticalGroups(&fuse_ctx); return tagged_sets; }; - auto GetFusableConsumerGroupSet = [&]() -> std::unordered_set { + auto GetFusableConsumerGroupSet = [&]() -> std::unordered_set { const auto& group_sets = GetFusableConsumerOpGroupSets(); if (group_sets.empty()) { return {}; } - std::unordered_set ret; + std::unordered_set ret; for (const auto& group_pair : group_sets) { - ret.insert(group_pair.second->GetGroup()); + ret.insert(group_pair.second.GetGroup()); } return ret; }; @@ -1377,7 +1377,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return update; } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); @@ -1578,7 +1578,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer->ConsumerSize() <= 1) { + if (producer.ConsumerSize() <= 1) { return; } const auto& fuse_passes = GetRecomputeFusePasses(); @@ -1595,19 +1595,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, std::make_shared(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); TagRecomputeGroups(&fuse_ctx); return tagged_sets; }; - auto GetFusableConsumerGroupSet = [&]() -> std::unordered_set { + auto GetFusableConsumerGroupSet = [&]() -> std::unordered_set { const auto& group_sets = GetFusableConsumerOpGroupSets(); if (group_sets.empty()) { return {}; } - std::unordered_set ret; + std::unordered_set ret; for (const auto& group_pair : group_sets) { - ret.insert(group_pair.second->GetGroup()); + ret.insert(group_pair.second.GetGroup()); } return ret; }; @@ -1621,20 +1621,20 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return update; } - void RecomputeFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void RecomputeFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { VerticalFuse(producer, fusionable_consumers); } - void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } } - void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { - std::unordered_set candidates; + std::unordered_set candidates; for (auto& consumer : fusionable_consumers) { // if can be output node. if (is_same_shape(this, producer, consumer)) { @@ -1729,11 +1729,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool IsDependency(const GroupPtr& producer_g, const GroupPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); @@ -1756,12 +1756,12 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool IsDependencySimplify(const GroupPtr& producer_g, const GroupPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); // check upper. int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); @@ -1830,7 +1830,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void UpdateInputToConsumers() { for (auto& input_consumers : input_to_consumers_) { auto& consumers = input_consumers.second; - std::unordered_set updated_consumers; + std::unordered_set updated_consumers; for (auto& consumer : consumers) { std::queue fused_groups; fused_groups.push(consumer); @@ -2014,7 +2014,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { GroupList fusion_groups_; std::unordered_map fusion_groups_index_; - std::unordered_map> input_to_consumers_; + std::unordered_map> input_to_consumers_; struct Relation { std::unordered_map vertical_relation; From c5a2560aae8e1bb5a255e0b0083cb04c0e1f5c41 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 20 Jun 2023 16:34:38 +0000 Subject: [PATCH 32/66] modify is_same_size by new interface --- cinn/CMakeLists.txt | 1 + cinn/api/CMakeLists.txt | 8 + cinn/api/op_node.cc | 6 +- cinn/api/op_node.h | 4 +- cinn/api/tensor_node.h | 9 +- .../pass/check_fusion_accuracy_pass_test.cc | 1804 ++++++++--------- cinn/hlir/pass/general_fusion_merge_pass.cc | 56 +- 7 files changed, 977 insertions(+), 911 deletions(-) create mode 100644 cinn/api/CMakeLists.txt diff --git a/cinn/CMakeLists.txt b/cinn/CMakeLists.txt index 16c70714d7..36d7e4a516 100644 --- a/cinn/CMakeLists.txt +++ b/cinn/CMakeLists.txt @@ -2,6 +2,7 @@ if (WITH_TESTING) cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest) endif() +add_subdirectory(api) add_subdirectory(auto_schedule) add_subdirectory(common) add_subdirectory(utils) diff --git a/cinn/api/CMakeLists.txt b/cinn/api/CMakeLists.txt new file mode 100644 index 0000000000..0b9bd92e91 --- /dev/null +++ b/cinn/api/CMakeLists.txt @@ -0,0 +1,8 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + op_node.cc + tensor_node.cc + ) + +message(STATUS "srcs: ${cinnapi_src}") diff --git a/cinn/api/op_node.cc b/cinn/api/op_node.cc index 6a74f3d20e..6265143e04 100644 --- a/cinn/api/op_node.cc +++ b/cinn/api/op_node.cc @@ -14,19 +14,17 @@ #include "cinn/api/op_node.h" -#include "cinn/api/tensor_node.h" - namespace cinn { namespace api { TensorNode OpNode::GetInput(size_t i) const { auto edges = node_->inlinks_in_order(); - return TensorNode(helper_, edges[i]->safe_as()); + return TensorNode(helper_, edges[i]->source()->safe_as()); } TensorNode OpNode::GetOutput(size_t i) const { auto edges = node_->outlinks_in_order(); - return TensorNode(helper_, edges[i]->safe_as()); + return TensorNode(helper_, edges[i]->sink()->safe_as()); } } // namespace api diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index 36627bf186..bc471e4ba1 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -17,6 +17,8 @@ #include #include "cinn/hlir/framework/node.h" #include "cinn/hlir/pass/fusion_helper_base.h" +#include "cinn/api/tensor_node.h" + namespace cinn { namespace api { @@ -24,8 +26,6 @@ namespace api { using OpPatternKind = cinn::hlir::framework::OpPatternKind; using Attribute = cinn::utils::Attribute; -class TensorNode; - class OpNode { public: OpNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::Node* node) : helper_(helper), node_(node) {} diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index 8de0f4ba53..ae2ffa6add 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -15,20 +15,25 @@ #pragma once #include "cinn/hlir/framework/node.h" +#include "cinn/utils/type_defs.h" #include "cinn/hlir/pass/fusion_helper_base.h" + namespace cinn { namespace api { class OpNode; +using shape_t = utils::ShapeType; + class TensorNode { public: TensorNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::NodeData* node_data) : helper_(helper), node_data_(node_data) {} // Get the shape of tensor. - const shpae_t& Shape() const { - return helper_->GetNodeDataShape(node_data_) + const shape_t& Shape() const { + CHECK(helper_->shape_dict_.count(node_data_->id())) << "Can't find " << node_data_->id() << " 's shape!"; + return helper_->shape_dict_.at(node_data_->id()); } OpNode Producer() const; diff --git a/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc b/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc index 72dc0a72cc..3283964ea9 100644 --- a/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc +++ b/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc @@ -92,1026 +92,1026 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion) { RunTest(target, graph, {"A", "B", "C", "D"}); } -// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_0"); -// std::unordered_set fetch_ids; -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(E, C); -// auto G = net_builder.Add(E, D); - -// fetch_ids = {F->id, G->id}; -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_1"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(C, D); -// auto G = net_builder.Add(E, F); -// auto I = net_builder.Add(E, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_1) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_1"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(C, D); -// auto G = net_builder.Add(E, F); -// auto I = net_builder.Add(E, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_2"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); -// auto G = net_builder.Add(A, B); -// auto H = net_builder.Add(C, D); -// auto I = net_builder.Add(E, G); -// auto J = net_builder.Add(G, H); -// auto K = net_builder.Add(H, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -// } - -// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_2) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_2"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); -// auto G = net_builder.Add(A, B); -// auto H = net_builder.Add(C, D); -// auto I = net_builder.Add(E, G); -// auto J = net_builder.Add(G, H); -// auto K = net_builder.Add(H, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -// } - -// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_3"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); -// auto G = net_builder.Add(A, B); -// auto H = net_builder.Add(G, C); -// auto I = net_builder.Add(G, D); -// auto J = net_builder.Add(G, E); -// auto K = net_builder.Add(G, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -// } - -// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_3) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_3"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); -// auto G = net_builder.Add(A, B); -// auto H = net_builder.Add(G, C); -// auto I = net_builder.Add(G, D); -// auto J = net_builder.Add(G, E); -// auto K = net_builder.Add(G, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -// } - -// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_4"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); -// auto G = net_builder.Add(A, B); -// auto H = net_builder.Add(G, C); -// auto I = net_builder.Add(G, D); -// auto J = net_builder.Add(I, E); -// auto K = net_builder.Add(I, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -// } - -// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_4) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_4"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); -// auto G = net_builder.Add(A, B); -// auto H = net_builder.Add(G, C); -// auto I = net_builder.Add(G, D); -// auto J = net_builder.Add(I, E); -// auto K = net_builder.Add(I, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_0"); + std::unordered_set fetch_ids; + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(E, C); + auto G = net_builder.Add(E, D); + + fetch_ids = {F->id, G->id}; + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, D); + auto G = net_builder.Add(E, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_1) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, D); + auto G = net_builder.Add(E, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(C, D); + auto I = net_builder.Add(E, G); + auto J = net_builder.Add(G, H); + auto K = net_builder.Add(H, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_2) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(C, D); + auto I = net_builder.Add(E, G); + auto J = net_builder.Add(G, H); + auto K = net_builder.Add(H, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + +TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(G, C); + auto I = net_builder.Add(G, D); + auto J = net_builder.Add(G, E); + auto K = net_builder.Add(G, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_3) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(G, C); + auto I = net_builder.Add(G, D); + auto J = net_builder.Add(G, E); + auto K = net_builder.Add(G, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + +TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(G, C); + auto I = net_builder.Add(G, D); + auto J = net_builder.Add(I, E); + auto K = net_builder.Add(I, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_4) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.CreateInput(Float(32), {h, w}, "F"); + auto G = net_builder.Add(A, B); + auto H = net_builder.Add(G, C); + auto I = net_builder.Add(G, D); + auto J = net_builder.Add(I, E); + auto K = net_builder.Add(I, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); +} + +TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Add(A, B); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B"}); +} + +TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_5) { + int h = 32, w = 32; + NetBuilder net_builder("ElementWise_Fusion_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.Add(A, B); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B"}); +} + +TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, D); + auto G = net_builder.Add(F, E); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_0) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, D); + auto G = net_builder.Add(F, E); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, E); + auto G = net_builder.Add(D, E); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.Add(C, E); + auto G = net_builder.Add(D, E); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D"}); +} + +TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.Add(A, B); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); + + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + + CHECK_EQ(graph->fusion_groups.size(), group_size_after); + + RunTest(target, graph, {"A", "B", "C", "D", "E"}); +} -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E", "F"}); -// } - -// TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_5"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.Add(A, B); -// } +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_4) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); + auto F = net_builder.Add(A, B); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(E, F); + } + + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + auto graph = std::make_shared(program, target); -// auto graph = std::make_shared(program, target); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B"}); -// } - -// TEST(CheckFusionAccuracyPass, General_ElementWise_Fusion_5) { -// int h = 32, w = 32; -// NetBuilder net_builder("ElementWise_Fusion_5"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.Add(A, B); -// } + RunTest(target, graph, {"A", "B", "C", "D", "E"}); +} -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); +TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); + auto F = net_builder.Add(A, B); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(E, F); + } -// auto graph = std::make_shared(program, target); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + auto graph = std::make_shared(program, target); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// RunTest(target, graph, {"A", "B"}); -// } + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_0"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(C, D); -// auto G = net_builder.Add(F, E); -// } + RunTest(target, graph, {"A", "B", "C", "D", "E"}); +} -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); +TEST(CheckFusionAccuracyPass, General_Broadcast_Test_5) { + int h = 32, w = 32; + NetBuilder net_builder("Broadcast_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {w}, "A"); + auto B = net_builder.CreateInput(Float(32), {w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); + auto F = net_builder.Add(A, B); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(E, F); + } -// auto graph = std::make_shared(program, target); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + auto graph = std::make_shared(program, target); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_0) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_0"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(C, D); -// auto G = net_builder.Add(F, E); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + RunTest(target, graph, {"A", "B", "C", "D", "E"}); +} -// auto graph = std::make_shared(program, target); +TEST(CheckFusionAccuracyPass, Reduce_Test_0) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.ReduceSum(C, {0}); + } -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + auto graph = std::make_shared(program, target); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_2"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(C, E); -// auto G = net_builder.Add(D, E); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// auto graph = std::make_shared(program, target); + RunTest(target, graph, {"A", "B"}); +} -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); +TEST(CheckFusionAccuracyPass, General_Reduce_Test_0) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_0"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + auto E = net_builder.ReduceSum(C, {0}); + auto F = net_builder.ReduceSum(C, {0}); + } -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + auto graph = std::make_shared(program, target); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_2) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_2"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.Add(C, E); -// auto G = net_builder.Add(D, E); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_4"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.Add(A, B); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// auto I = net_builder.Add(E, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E"}); -// } - -// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_4) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_4"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h, w}, "E"); -// auto F = net_builder.Add(A, B); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// auto I = net_builder.Add(E, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E"}); -// } - -// TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_5"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); -// auto F = net_builder.Add(A, B); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// auto I = net_builder.Add(E, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + RunTest(target, graph, {"A", "B"}); +} -// auto graph = std::make_shared(program, target); +TEST(CheckFusionAccuracyPass, Reduce_Test_1) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + auto E = net_builder.ReduceSum(C, {1}); + } -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E"}); -// } - -// TEST(CheckFusionAccuracyPass, General_Broadcast_Test_5) { -// int h = 32, w = 32; -// NetBuilder net_builder("Broadcast_Test_5"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.CreateInput(Float(32), {h * w, w}, "E"); -// auto F = net_builder.Add(A, B); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// auto I = net_builder.Add(E, F); -// } + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + auto graph = std::make_shared(program, target); -// auto graph = std::make_shared(program, target); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D", "E"}); -// } - -// TEST(CheckFusionAccuracyPass, Reduce_Test_0) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_0"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.ReduceSum(C, {0}); -// auto E = net_builder.ReduceSum(C, {0}); -// auto F = net_builder.ReduceSum(C, {0}); -// } + RunTest(target, graph, {"A", "B"}); +} -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); +TEST(CheckFusionAccuracyPass, General_Reduce_Test_1) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_1"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(C, {0}); + auto E = net_builder.ReduceSum(C, {1}); + } -// auto graph = std::make_shared(program, target); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + auto graph = std::make_shared(program, target); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// RunTest(target, graph, {"A", "B"}); -// } - -// TEST(CheckFusionAccuracyPass, General_Reduce_Test_0) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_0"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.ReduceSum(C, {0}); -// auto E = net_builder.ReduceSum(C, {0}); -// auto F = net_builder.ReduceSum(C, {0}); -// } + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + RunTest(target, graph, {"A", "B"}); +} -// auto graph = std::make_shared(program, target); +TEST(CheckFusionAccuracyPass, Reduce_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0}); + auto F = net_builder.ReduceSum(D, {1}); + auto G = net_builder.Add(C, E); + auto H = net_builder.Add(C, F); + } -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + auto graph = std::make_shared(program, target); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// RunTest(target, graph, {"A", "B"}); -// } + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// TEST(CheckFusionAccuracyPass, Reduce_Test_1) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_1"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.ReduceSum(C, {0}); -// auto E = net_builder.ReduceSum(C, {1}); -// } + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + RunTest(target, graph, {"A", "B", "C"}); +} -// auto graph = std::make_shared(program, target); +TEST(CheckFusionAccuracyPass, General_Reduce_Test_2) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_2"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.Add(A, B); + auto E = net_builder.ReduceSum(D, {0}); + auto F = net_builder.ReduceSum(D, {1}); + auto G = net_builder.Add(C, E); + auto H = net_builder.Add(C, F); + } -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + auto graph = std::make_shared(program, target); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// RunTest(target, graph, {"A", "B"}); -// } + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// TEST(CheckFusionAccuracyPass, General_Reduce_Test_1) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_1"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.ReduceSum(C, {0}); -// auto E = net_builder.ReduceSum(C, {1}); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// auto graph = std::make_shared(program, target); + RunTest(target, graph, {"A", "B", "C"}); +} -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); +TEST(CheckFusionAccuracyPass, Reduce_Test_3) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.ReduceSum(E, {0}); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + } -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + auto graph = std::make_shared(program, target); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// RunTest(target, graph, {"A", "B"}); -// } + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// TEST(CheckFusionAccuracyPass, Reduce_Test_2) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_2"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.Add(A, B); -// auto E = net_builder.ReduceSum(D, {0}); -// auto F = net_builder.ReduceSum(D, {1}); -// auto G = net_builder.Add(C, E); -// auto H = net_builder.Add(C, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// auto graph = std::make_shared(program, target); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + RunTest(target, graph, {"A", "B", "C", "D"}); +} -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); +TEST(CheckFusionAccuracyPass, General_Reduce_Test_3) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_3"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.ReduceSum(E, {0}); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + } -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + auto graph = std::make_shared(program, target); -// RunTest(target, graph, {"A", "B", "C"}); -// } + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// TEST(CheckFusionAccuracyPass, General_Reduce_Test_2) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_2"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.Add(A, B); -// auto E = net_builder.ReduceSum(D, {0}); -// auto F = net_builder.ReduceSum(D, {1}); -// auto G = net_builder.Add(C, E); -// auto H = net_builder.Add(C, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// auto graph = std::make_shared(program, target); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + RunTest(target, graph, {"A", "B", "C", "D"}); +} -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +TEST(CheckFusionAccuracyPass, Reduce_Test_4) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.ReduceSum(E, {0}); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(D, F); + } -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// RunTest(target, graph, {"A", "B", "C"}); -// } + auto graph = std::make_shared(program, target); -// TEST(CheckFusionAccuracyPass, Reduce_Test_3) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_3"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.ReduceSum(E, {0}); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, General_Reduce_Test_3) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_3"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.ReduceSum(E, {0}); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// } - -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); - -// auto graph = std::make_shared(program, target); - -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); - -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, Reduce_Test_4) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_4"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.ReduceSum(E, {0}); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// auto I = net_builder.Add(D, F); -// } + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// auto graph = std::make_shared(program, target); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + RunTest(target, graph, {"A", "B", "C", "D"}); +} -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); +TEST(CheckFusionAccuracyPass, General_Reduce_Test_4) { + int h = 32, w = 32; + NetBuilder net_builder("Reduce_Test_4"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.CreateInput(Float(32), {w}, "C"); + auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); + auto E = net_builder.Add(A, B); + auto F = net_builder.ReduceSum(E, {0}); + auto G = net_builder.Add(C, F); + auto H = net_builder.Add(D, F); + auto I = net_builder.Add(D, F); + } -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, General_Reduce_Test_4) { -// int h = 32, w = 32; -// NetBuilder net_builder("Reduce_Test_4"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.CreateInput(Float(32), {w}, "C"); -// auto D = net_builder.CreateInput(Float(32), {h, w}, "D"); -// auto E = net_builder.Add(A, B); -// auto F = net_builder.ReduceSum(E, {0}); -// auto G = net_builder.Add(C, F); -// auto H = net_builder.Add(D, F); -// auto I = net_builder.Add(D, F); -// } + auto graph = std::make_shared(program, target); -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// auto graph = std::make_shared(program, target); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + RunTest(target, graph, {"A", "B", "C", "D"}); +} -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); - -// RunTest(target, graph, {"A", "B", "C", "D"}); -// } - -// TEST(CheckFusionAccuracyPass, Reduce_Test_5) { -// int h = 128, w = 128; -// NetBuilder net_builder("Reduce_Test_5"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.ReduceSum(A, {1}); -// auto E = net_builder.ReduceSum(B, {1}); -// auto F = net_builder.ReduceSum(C, {1}); -// } +TEST(CheckFusionAccuracyPass, Reduce_Test_5) { + int h = 128, w = 128; + NetBuilder net_builder("Reduce_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(A, {1}); + auto E = net_builder.ReduceSum(B, {1}); + auto F = net_builder.ReduceSum(C, {1}); + } -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// auto graph = std::make_shared(program, target); + auto graph = std::make_shared(program, target); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// RunTest(target, graph, {"A", "B"}); -// } + RunTest(target, graph, {"A", "B"}); +} -// TEST(CheckFusionAccuracyPass, General_Reduce_Test_5) { -// int h = 128, w = 128; -// NetBuilder net_builder("Reduce_Test_5"); -// // create model -// { -// auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); -// auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); -// auto C = net_builder.Add(A, B); -// auto D = net_builder.ReduceSum(A, {1}); -// auto E = net_builder.ReduceSum(B, {1}); -// auto F = net_builder.ReduceSum(C, {1}); -// } +TEST(CheckFusionAccuracyPass, General_Reduce_Test_5) { + int h = 128, w = 128; + NetBuilder net_builder("Reduce_Test_5"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h, w}, "B"); + auto C = net_builder.Add(A, B); + auto D = net_builder.ReduceSum(A, {1}); + auto E = net_builder.ReduceSum(B, {1}); + auto F = net_builder.ReduceSum(C, {1}); + } -// auto program = net_builder.Build(); -// auto target = common::DefaultTarget(); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); -// auto graph = std::make_shared(program, target); + auto graph = std::make_shared(program, target); -// hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "GeneralFusionMergePass"}); -// int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); -// VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); -// VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); -// CHECK_EQ(graph->fusion_groups.size(), group_size_after); + CHECK_EQ(graph->fusion_groups.size(), group_size_after); -// RunTest(target, graph, {"A", "B"}); -// } + RunTest(target, graph, {"A", "B"}); +} } // namespace cinn::frontend diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 0633966a3f..fd583d55f6 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -138,6 +138,10 @@ class FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + // User can cache some group info in context by using this function. + // The group info can be any data and need to create by create_fn. + // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; + protected: FusePassCtx() = default; }; @@ -152,6 +156,8 @@ class LightwareFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; + protected: LightwareFusePassCtx() = default; }; @@ -173,9 +179,17 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } + // absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) { + // if (cache_data_.find(op_group) == cache_data_.end()) { + // cache_data_[op_group] = create_fn(op_group); + // } + // return &cache_data_[op_group]; + // } + const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } private: + // static std::unordered_map cache_data_; const FusionHelperBase* graph_group_fusion_helper_; OpGroupPtr group_; const std::function EnableFuse_; @@ -192,6 +206,8 @@ class InputFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; + protected: InputFusePassCtx() = default; }; @@ -214,7 +230,15 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } + // absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) { + // if (cache_data_.find(op_group) == cache_data_.end()) { + // cache_data_[op_group] = create_fn(op_group); + // } + // return &cache_data_[op_group]; + // } + private: + // static std::unordered_map cache_data_; const FusionHelperBase* graph_group_fusion_helper_; const OpGroupList& groups_; const std::function EnableFuse_; @@ -338,8 +362,38 @@ struct HorizontalFuseUtil { }; } + static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { + VLOG(1) << "####### GetMasterNode"; + size_t op_num = op_group.OpSize(); + for (size_t i = 0; i < op_num; ++i) { + VLOG(1) << "####### GetMasterNode GetOp : " << i; + api::OpNode node = op_group.GetOp(i); + if (node.kind() == OpPatternKind::kReduction) { + return node; + } + } + VLOG(1) << "####### return GetOp : 0"; + return op_group.GetOp(0); + } + static bool IsSameSize(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().AllOutputsSameSize(src, dst); + api::OpNode src_master_node = GetMasterNode(ctx, src); + api::OpNode dst_master_node = GetMasterNode(ctx, dst); + VLOG(1) << "#### GetMasterNode Finish"; + + const auto& output_var_0 = src_master_node.GetOutput(0).Shape(); + const auto& output_var_1 = dst_master_node.GetOutput(0).Shape(); + VLOG(1) << "##### output_var_0 " << output_var_0.size(); + VLOG(1) << "##### output_var_1 " << output_var_1.size(); + if (output_var_0 == output_var_1) { + return true; + } + + auto size_0 = std::accumulate(output_var_0.begin(), output_var_0.end(), 1, std::multiplies()); + auto size_1 = std::accumulate(output_var_1.begin(), output_var_1.end(), 1, std::multiplies()); + VLOG(1) << "##### size_0 " << size_0; + VLOG(1) << "##### size_1 " << size_1; + return size_0 == size_1; } static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { From 0245b9bc7c3ab0823acd371efac02b5d2d59fce8 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 25 Jun 2023 06:27:21 +0000 Subject: [PATCH 33/66] refactor HorizontalElementwiseFuseReduce --- cinn/hlir/pass/general_fusion_merge_pass.cc | 41 +++++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index fd583d55f6..1f869ff178 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -363,41 +363,66 @@ struct HorizontalFuseUtil { } static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - VLOG(1) << "####### GetMasterNode"; size_t op_num = op_group.OpSize(); for (size_t i = 0; i < op_num; ++i) { - VLOG(1) << "####### GetMasterNode GetOp : " << i; api::OpNode node = op_group.GetOp(i); if (node.kind() == OpPatternKind::kReduction) { return node; } } - VLOG(1) << "####### return GetOp : 0"; return op_group.GetOp(0); } static bool IsSameSize(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { api::OpNode src_master_node = GetMasterNode(ctx, src); api::OpNode dst_master_node = GetMasterNode(ctx, dst); - VLOG(1) << "#### GetMasterNode Finish"; const auto& output_var_0 = src_master_node.GetOutput(0).Shape(); const auto& output_var_1 = dst_master_node.GetOutput(0).Shape(); - VLOG(1) << "##### output_var_0 " << output_var_0.size(); - VLOG(1) << "##### output_var_1 " << output_var_1.size(); if (output_var_0 == output_var_1) { return true; } auto size_0 = std::accumulate(output_var_0.begin(), output_var_0.end(), 1, std::multiplies()); auto size_1 = std::accumulate(output_var_1.begin(), output_var_1.end(), 1, std::multiplies()); - VLOG(1) << "##### size_0 " << size_0; - VLOG(1) << "##### size_1 " << size_1; return size_0 == size_1; } static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); + // if same shape with horizontal relation + if (IsSameSize(ctx, src, dst)) { + return true; + } + + const OpGroupPtr* ele_group; + const OpGroupPtr* reduce_group; + + if (src.kind() == framework::kReduction) { + ele_group = &dst; + reduce_group = &src; + } else { + ele_group = &src; + reduce_group = &dst; + } + + shape_t ele_node_shape = GetMasterNode(ctx, *ele_group).GetOutput(0).Shape(); + int32_t size_ele = std::accumulate(ele_node_shape.begin(), ele_node_shape.end(), 1, std::multiplies()); + + size_t op_num = reduce_group->OpSize(); + for (size_t i = 0; i < op_num; ++i) { + api::OpNode node = reduce_group->GetOp(i); + if (node.kind() == OpPatternKind::kReduction) { + shape_t master_node_shape = node.GetOutput(0).Shape(); + int32_t size_master = + std::accumulate(master_node_shape.begin(), master_node_shape.end(), 1, std::multiplies()); + if (size_ele == size_master) { + return true; + } + } + } + + return false; } static bool ReduceFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { From f8bb7135f048906702c4963c635428b499a7bdf2 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 25 Jun 2023 08:58:28 +0000 Subject: [PATCH 34/66] update --- cinn/hlir/pass/general_fusion_merge_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index a47c37ce73..862126c18e 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -659,13 +659,13 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { }; struct LightwareFusePassComparator { - bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const{ return lhs->Benefit() > rhs->Benefit(); } }; struct InputFusePassComparator { - bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { return lhs->Benefit() > rhs->Benefit(); } }; From 929b1591112956ec3bc655923d3db1d92cfb5e8a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 25 Jun 2023 09:21:40 +0000 Subject: [PATCH 35/66] fix bug --- cinn/hlir/pass/general_fusion_merge_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index b409211e7f..cc103ea21f 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -116,7 +116,7 @@ class GraphGroupFuseHelper final : public FuseHelper { return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - for(auto iter = node.ProducerEnd(); iter != node.ProducerEnd(); ++iter) { + for(auto iter = node.ProducerBegin(); iter != node.ProducerEnd(); ++iter) { if (node == consumer && *iter == producer) { continue; } From fe122273ae374148e645692588cd9ea90229c887 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 26 Jun 2023 05:02:17 +0000 Subject: [PATCH 36/66] fix api headerfile not found error --- cinn/CMakeLists.txt | 1 + cinn/api/CMakeLists.txt | 1 + 2 files changed, 2 insertions(+) create mode 100644 cinn/api/CMakeLists.txt diff --git a/cinn/CMakeLists.txt b/cinn/CMakeLists.txt index 16c70714d7..36d7e4a516 100644 --- a/cinn/CMakeLists.txt +++ b/cinn/CMakeLists.txt @@ -2,6 +2,7 @@ if (WITH_TESTING) cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest) endif() +add_subdirectory(api) add_subdirectory(auto_schedule) add_subdirectory(common) add_subdirectory(utils) diff --git a/cinn/api/CMakeLists.txt b/cinn/api/CMakeLists.txt new file mode 100644 index 0000000000..9bcce6cab3 --- /dev/null +++ b/cinn/api/CMakeLists.txt @@ -0,0 +1 @@ +core_gather_headers() From cb3fa1f85c205dcd4f9bdc7e02b637d426f23e40 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 26 Jun 2023 06:14:25 +0000 Subject: [PATCH 37/66] update interface of op and tensor --- cinn/api/op_group.h | 137 +++++++++++++------- cinn/api/op_node.cc | 10 +- cinn/api/op_node.h | 67 +++++++++- cinn/api/tensor_node.cc | 14 +- cinn/api/tensor_node.h | 65 +++++++++- cinn/hlir/pass/general_fusion_merge_pass.cc | 76 +++++++---- 6 files changed, 266 insertions(+), 103 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 5bdad0765e..7e0191f400 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -24,76 +24,121 @@ namespace cinn { namespace api { +using Comparator = hlir::framework::Graph::Group::SharedGroupComparator; +using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; + class OpGroup { public: - OpGroup(const hlir::pass::FusionHelperBase* helper, const std::shared_ptr& group) : helper_(helper), group_(group) {} + OpGroup(const hlir::framework::Graph* graph, const std::shared_ptr& group) : graph_(graph), group_(group) {} OpGroup(const OpGroup& other) = default; - class iterator { + class OpNodeListView { public: - iterator(std::unordered_map, TensorInterfaceList>::iterator it, const hlir::pass::FusionHelperBase* helper) : iter_(it), helper_(helper) {} + explicit OpNodeListView(std::vector op_nodes, const cinn::hlir::framework::Graph* graph) : op_nodes_(std::move(op_nodes)), graph_(graph) {} + + class Iterator { + public: + Iterator(std::vector::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} + + Iterator& operator++() { + ++iter_; + return *this; + } - iterator& operator++() { - ++iter_; - return *this; - } + Iterator operator++(int) { + Iterator tmp = *this; + ++iter_; + return tmp; + } - iterator operator++(int) { - iterator tmp = *this; - ++iter_; - return tmp; - } + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } - OpGroup operator*() { - return OpGroup(helper_, iter_->first); - } + bool operator!=(const Iterator& other) const { + return !(*this == other); + } - bool operator==(const iterator& other) const { - return iter_ == other.iter_; - } + OpNode operator*() const { + return OpNode(graph_, *iter_); + } - bool operator!=(const iterator& other) const { - return !(*this == other); - } + private: + std::vector::const_iterator iter_; + const hlir::framework::Graph* graph_; + }; + size_t size() const { return op_nodes_.size(); } + + Iterator begin() { return Iterator(op_nodes_.begin(), graph_); } + + Iterator end() { return Iterator(op_nodes_.begin(), graph_); } private: - std::unordered_map, TensorInterfaceList>::iterator iter_; - const hlir::pass::FusionHelperBase* helper_; + std::vector op_nodes_; + const cinn::hlir::framework::Graph* graph_; }; - hlir::framework::OpPatternKind kind() const { return group_->kind(); } + class OpGroupListView { + public: + OpGroupListView(const std::unordered_map, TensorInterfaceList, Hasher, Comparator>& group_map, const hlir::framework::Graph* graph) : op_group_map_(group_map), graph_(graph) {} + class Iterator { + public: + Iterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} - size_t OpSize() const { - return group_->CollectNodes().size(); - } + Iterator& operator++() { + ++iter_; + return *this; + } - OpNode GetOp(size_t index) const { - return OpNode(helper_, group_->CollectNodes()[index]); - } + Iterator operator++(int) { + Iterator tmp = *this; + ++iter_; + return tmp; + } - size_t ProducerSize() const { - return group_->producer_groups().size(); - } + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } - size_t ConsumerSize() const { - return group_->consumer_groups().size(); - } + bool operator!=(const Iterator& other) const { + return !(*this == other); + } - iterator ProducerBegin() const { - return iterator(group_->mut_producer_groups()->begin(), helper_); - } + OpGroup operator*() const{ + return OpGroup(graph_, iter_->first); + } + + private: + std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator iter_; + const hlir::framework::Graph* graph_; + }; + + size_t size() const { return op_group_map_.size(); } + + Iterator begin() { return Iterator(op_group_map_.begin(), graph_); } + + Iterator end() { return Iterator(op_group_map_.begin(), graph_); } + + private: + const std::unordered_map, TensorInterfaceList, Hasher, Comparator>& op_group_map_; + const cinn::hlir::framework::Graph* graph_; + }; + + + + hlir::framework::OpPatternKind kind() const { return group_->kind(); } - iterator ProducerEnd() const { - return iterator(group_->mut_producer_groups()->end(), helper_); + OpNodeListView Ops() const { + return OpNodeListView(group_->CollectNodes(), graph_); } - iterator ConsumerBegin() const { - return iterator(group_->mut_consumer_groups()->begin(), helper_); + OpGroupListView Producers() const { + return OpGroupListView(group_->producer_groups(), graph_); } - iterator ConsumerEnd() const { - return iterator(group_->mut_consumer_groups()->end(), helper_); + OpGroupListView Consumers() const { + return OpGroupListView(group_->consumer_groups(), graph_); } std::shared_ptr GetGroup() const { @@ -109,7 +154,7 @@ class OpGroup { } private: - const hlir::pass::FusionHelperBase* helper_; + const hlir::framework::Graph* graph_; const std::shared_ptr group_; }; diff --git a/cinn/api/op_node.cc b/cinn/api/op_node.cc index 6265143e04..54c51b4d84 100644 --- a/cinn/api/op_node.cc +++ b/cinn/api/op_node.cc @@ -17,14 +17,12 @@ namespace cinn { namespace api { -TensorNode OpNode::GetInput(size_t i) const { - auto edges = node_->inlinks_in_order(); - return TensorNode(helper_, edges[i]->source()->safe_as()); +TensorNode OpNode::InputTensorListView::operator[](size_t index) const { + return TensorNode(graph_, edges_[index]->source()->safe_as()); } -TensorNode OpNode::GetOutput(size_t i) const { - auto edges = node_->outlinks_in_order(); - return TensorNode(helper_, edges[i]->sink()->safe_as()); +TensorNode OpNode::OutputTensorListView::operator[](size_t index) const { + return TensorNode(graph_, edges_[index]->sink()->safe_as()); } } // namespace api diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index bc471e4ba1..a5bf9fa58d 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -15,10 +15,10 @@ #pragma once #include -#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/graph.h" #include "cinn/hlir/pass/fusion_helper_base.h" #include "cinn/api/tensor_node.h" - +#include "cinn/hlir/framework/op.h" namespace cinn { namespace api { @@ -28,12 +28,58 @@ using Attribute = cinn::utils::Attribute; class OpNode { public: - OpNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::Node* node) : helper_(helper), node_(node) {} + OpNode(const hlir::framework::Graph* graph, const hlir::framework::Node* node) : graph_(graph), node_(node) { + input_edges_ = node->inlinks_in_order(); + output_edges_ = node->outlinks_in_order(); + } OpPatternKind kind () { - return helper_->GetOpKind(node_); + thread_local const static hlir::framework::OpValueType& op_pattern_dict = hlir::framework::Operator::GetAttrs("OpPattern"); + auto kind = op_pattern_dict[node_->op()]; + + if (kind == hlir::framework::kBroadcast) { + // As binary op was defined as broadcast, actually it should be element-wise. + if (node_->op()->name != "broadcast_to") { + return hlir::framework::kElementWise; + } + } + return kind; } + class InputTensorListView { + public: + InputTensorListView(const hlir::framework::Graph* graph, const std::vector>& edges) : graph_(graph), edges_(edges) {} + + InputTensorListView(const InputTensorListView& other) = delete; + + InputTensorListView(InputTensorListView&& other) = delete; + + size_t size() const { return edges_.size(); } + + TensorNode operator[](size_t index) const; + + private: + const hlir::framework::Graph* graph_; + const std::vector>& edges_; + }; + + class OutputTensorListView { + public: + OutputTensorListView(const hlir::framework::Graph* graph, const std::vector>& edges) : graph_(graph), edges_(edges) {} + + OutputTensorListView(const OutputTensorListView& other) = delete; + + OutputTensorListView(OutputTensorListView&& other) = delete; + + size_t size() const { return edges_.size(); } + + TensorNode operator[](size_t index) const; + + private: + const hlir::framework::Graph* graph_; + const std::vector>& edges_; + }; + size_t InputsSize() const { return node_->inlinks().size(); } @@ -42,9 +88,13 @@ class OpNode { return node_->outlinks().size(); } - TensorNode GetInput(size_t i) const; + InputTensorListView Inputs() const { + return InputTensorListView(graph_, input_edges_); + } - TensorNode GetOutput(size_t i) const; + OutputTensorListView Outputs() const { + return OutputTensorListView(graph_, output_edges_); + } template const T& GetAttr(const std::string& attr_name) const { @@ -56,8 +106,11 @@ class OpNode { return node_->attrs.attr_store.at(attr_name); } - const hlir::pass::FusionHelperBase* helper_; + const hlir::framework::Graph* graph_; const hlir::framework::Node* node_; + + std::vector> input_edges_; + std::vector> output_edges_; }; } // namespace api diff --git a/cinn/api/tensor_node.cc b/cinn/api/tensor_node.cc index 5f717d4d37..40674e33f1 100644 --- a/cinn/api/tensor_node.cc +++ b/cinn/api/tensor_node.cc @@ -20,20 +20,12 @@ namespace cinn { namespace api { OpNode TensorNode::Producer() const { - return OpNode(helper_, node_data_->source_node.get()); + return OpNode(graph_, node_data_->source_node.get()); } -OpNode TensorNode::Consumer(size_t index) const { - std::vector consumer_nodes; - for (auto& link : node_data_->outlinks()) { - auto consumer = link->sink()->safe_as(); - consumer_nodes.push_back(consumer); - } - return OpNode(helper_, consumer_nodes[index]); +OpNode TensorNode::ConsumerOpListView::Iterator::operator * () const{ + return OpNode(graph_, (*iter_)->sink()->safe_as()); } - - - } // namespace api } // namespace cinn \ No newline at end of file diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index ae2ffa6add..94b2fb998a 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -14,11 +14,10 @@ #pragma once -#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/graph.h" #include "cinn/utils/type_defs.h" #include "cinn/hlir/pass/fusion_helper_base.h" - namespace cinn { namespace api { @@ -28,24 +27,76 @@ using shape_t = utils::ShapeType; class TensorNode { public: - TensorNode(const hlir::pass::FusionHelperBase* helper, const hlir::framework::NodeData* node_data) : helper_(helper), node_data_(node_data) {} + TensorNode(const hlir::framework::Graph* graph, const hlir::framework::NodeData* node_data) : graph_(graph), node_data_(node_data) {} // Get the shape of tensor. const shape_t& Shape() const { - CHECK(helper_->shape_dict_.count(node_data_->id())) << "Can't find " << node_data_->id() << " 's shape!"; - return helper_->shape_dict_.at(node_data_->id()); + const auto& shape_dict = graph_->GetAttrs>("infershape"); + CHECK(shape_dict.count(node_data_->id())) << "Can't find " << node_data_->id() << " 's shape!"; + return shape_dict.at(node_data_->id()); } OpNode Producer() const; + class ConsumerOpListView { + public: + ConsumerOpListView(const std::set, common::GraphEdgeCompare>& edges, const hlir::framework::Graph* graph) : edges_(edges), graph_(graph) {} + + class Iterator { + public: + Iterator(std::set, common::GraphEdgeCompare>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} + + Iterator& operator++() { + ++iter_; + return *this; + } + + Iterator operator++(int) { + Iterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const Iterator& other) const { + return !(*this == other); + } + + OpNode operator*() const; + + private: + std::set, common::GraphEdgeCompare>::const_iterator iter_; + const hlir::framework::Graph* graph_; + }; + + size_t size() const { return edges_.size(); } + + Iterator begin() const { + return Iterator(this->edges_.begin(), graph_); + } + + Iterator end() const { + return Iterator(this->edges_.end(), graph_); + } + + private: + const std::set, common::GraphEdgeCompare>& edges_; + const hlir::framework::Graph* graph_; + }; + size_t ConsumerSize() const { return node_data_->outlinks().size(); } - OpNode Consumer(size_t index) const; + ConsumerOpListView Consumers() const { + return ConsumerOpListView(node_data_->outlinks(), graph_); + } private: - const hlir::pass::FusionHelperBase* helper_; + const hlir::framework::Graph* graph_; const hlir::framework::NodeData* node_data_; }; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index cc103ea21f..8667f462ee 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -116,7 +116,8 @@ class GraphGroupFuseHelper final : public FuseHelper { return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - for(auto iter = node.ProducerBegin(); iter != node.ProducerEnd(); ++iter) { + auto producer_groups = producer.Producers(); + for(auto iter = producer_groups.begin(); iter != producer_groups.end(); ++iter) { if (node == consumer && *iter == producer) { continue; } @@ -317,6 +318,25 @@ bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, dst.GetGroup()); } +// limit the group args number to less equal 512, as args stack size is 4K. +// static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { +// std::unordered_set args; +// for (auto& group : {first, second}) { +// for (auto node : group->input_nodes) { +// args.insert(node.first); +// } +// for (auto node : group->output_nodes) { +// args.insert(node); +// } +// } + +// if (args.size() > 512) { +// return false; +// } else { +// return true; +// } +// } + template struct HorizontalFuseUtil { using KindKeyT = std::pair; @@ -363,22 +383,22 @@ struct HorizontalFuseUtil { } static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - size_t op_num = op_group.OpSize(); - for (size_t i = 0; i < op_num; ++i) { - api::OpNode node = op_group.GetOp(i); + auto ops = op_group.Ops(); + for (auto iter = ops.begin(); iter != ops.end(); ++iter) { + api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { return node; } } - return op_group.GetOp(0); + return *ops.begin(); } static bool IsSameSize(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { api::OpNode src_master_node = GetMasterNode(ctx, src); api::OpNode dst_master_node = GetMasterNode(ctx, dst); - const auto& output_var_0 = src_master_node.GetOutput(0).Shape(); - const auto& output_var_1 = dst_master_node.GetOutput(0).Shape(); + const auto& output_var_0 = src_master_node.Outputs()[0].Shape(); + const auto& output_var_1 = dst_master_node.Outputs()[0].Shape(); if (output_var_0 == output_var_1) { return true; } @@ -389,14 +409,13 @@ struct HorizontalFuseUtil { } static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().HorizontalElementwiseFuseReduce(src, dst); // if same shape with horizontal relation if (IsSameSize(ctx, src, dst)) { return true; } - const OpGroupPtr* ele_group; - const OpGroupPtr* reduce_group; + const OpGroupPtr* ele_group = nullptr; + const OpGroupPtr* reduce_group = nullptr; if (src.kind() == framework::kReduction) { ele_group = &dst; @@ -406,14 +425,14 @@ struct HorizontalFuseUtil { reduce_group = &dst; } - shape_t ele_node_shape = GetMasterNode(ctx, *ele_group).GetOutput(0).Shape(); + shape_t ele_node_shape = GetMasterNode(ctx, *ele_group).Outputs()[0].Shape(); int32_t size_ele = std::accumulate(ele_node_shape.begin(), ele_node_shape.end(), 1, std::multiplies()); - size_t op_num = reduce_group->OpSize(); - for (size_t i = 0; i < op_num; ++i) { - api::OpNode node = reduce_group->GetOp(i); + auto ops = reduce_group->Ops(); + for (auto iter = ops.begin(); iter!= ops.end(); ++iter) { + api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { - shape_t master_node_shape = node.GetOutput(0).Shape(); + shape_t master_node_shape = node.Outputs()[0].Shape(); int32_t size_master = std::accumulate(master_node_shape.begin(), master_node_shape.end(), 1, std::multiplies()); if (size_ele == size_master) { @@ -483,6 +502,7 @@ class DefaultInputFusePass final : public InputFusePass { return; } } + VLOG(1) << "DefaultInputFusePass Finish"; } }; @@ -525,7 +545,8 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + auto consumer_groups = producer.Consumers(); + for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -575,7 +596,8 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + auto consumer_groups = producer.Consumers(); + for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -701,7 +723,8 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - for(auto iter = producer.ConsumerBegin(); iter!= producer.ConsumerEnd(); ++iter) { + auto consumer_groups = producer.Consumers(); + for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } return consumers; @@ -826,7 +849,7 @@ class FusionPassRegistrar final : public Registrar { // code generation. class GeneralFusionMergePassHelper : public FusionHelperBase { public: - GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { + GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph), graph_(graph) { fusion_groups_ = graph->fusion_groups; // init fusion relation. InitFusionRelation(); @@ -1020,7 +1043,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.ConsumerSize() <= 1) { + if (producer.Consumers().size() <= 1) { return; } const auto& fuse_passes = GetHorizontalFusePasses(); @@ -1037,7 +1060,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this->graph_, producer), EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -1089,7 +1112,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { OpGroupList consumer_groups; consumer_groups.reserve(consumers.size()); for(auto& consumer : consumers) { - consumer_groups.push_back(api::OpGroup(this, consumer)); + consumer_groups.push_back(api::OpGroup(this->graph_, consumer)); } GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); EnableFusedInputGroups(&fuse_ctx); @@ -1410,7 +1433,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagVerticalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.ConsumerSize() == 0) { + if (producer.Consumers().size() == 0) { return; } const auto& fuse_passes = GetVerticalFusePasses(); @@ -1427,7 +1450,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.push_back(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this->graph_, producer), EnableFuse); TagVerticalGroups(&fuse_ctx); return tagged_sets; }; @@ -1657,7 +1680,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.ConsumerSize() <= 1) { + if (producer.Consumers().size() <= 1) { return; } const auto& fuse_passes = GetRecomputeFusePasses(); @@ -1674,7 +1697,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this->graph_, producer), EnableFuse); TagRecomputeGroups(&fuse_ctx); return tagged_sets; }; @@ -2091,6 +2114,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } + const Graph* graph_; GroupList fusion_groups_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; From 96a0e035aaa29d964228edb33511413d40436fbb Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 26 Jun 2023 06:26:09 +0000 Subject: [PATCH 38/66] refine code --- cinn/api/op_group.h | 15 +++++++++++---- cinn/api/op_node.cc | 4 ++-- cinn/api/op_node.h | 6 +++--- cinn/api/tensor_node.cc | 4 ++-- cinn/api/tensor_node.h | 7 +++++-- cinn/hlir/pass/general_fusion_merge_pass.cc | 8 ++++---- 6 files changed, 27 insertions(+), 17 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 7e0191f400..c56e91b944 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -29,7 +29,7 @@ using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; class OpGroup { public: - OpGroup(const hlir::framework::Graph* graph, const std::shared_ptr& group) : graph_(graph), group_(group) {} + OpGroup(const std::shared_ptr& group, const hlir::framework::Graph* graph) : group_(group), graph_(graph) {} OpGroup(const OpGroup& other) = default; @@ -37,6 +37,9 @@ class OpGroup { public: explicit OpNodeListView(std::vector op_nodes, const cinn::hlir::framework::Graph* graph) : op_nodes_(std::move(op_nodes)), graph_(graph) {} + OpNodeListView(const OpNodeListView& other) = delete; + OpNodeListView(OpNodeListView&& other) = delete; + class Iterator { public: Iterator(std::vector::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} @@ -61,7 +64,7 @@ class OpGroup { } OpNode operator*() const { - return OpNode(graph_, *iter_); + return OpNode(*iter_, graph_); } private: @@ -82,6 +85,10 @@ class OpGroup { class OpGroupListView { public: OpGroupListView(const std::unordered_map, TensorInterfaceList, Hasher, Comparator>& group_map, const hlir::framework::Graph* graph) : op_group_map_(group_map), graph_(graph) {} + + OpGroupListView(const OpGroupListView& other) = delete; + OpGroupListView(OpGroupListView&& other) = delete; + class Iterator { public: Iterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} @@ -106,7 +113,7 @@ class OpGroup { } OpGroup operator*() const{ - return OpGroup(graph_, iter_->first); + return OpGroup(iter_->first, graph_); } private: @@ -154,8 +161,8 @@ class OpGroup { } private: - const hlir::framework::Graph* graph_; const std::shared_ptr group_; + const hlir::framework::Graph* graph_; }; } // namespace api diff --git a/cinn/api/op_node.cc b/cinn/api/op_node.cc index 54c51b4d84..2444243016 100644 --- a/cinn/api/op_node.cc +++ b/cinn/api/op_node.cc @@ -18,11 +18,11 @@ namespace cinn { namespace api { TensorNode OpNode::InputTensorListView::operator[](size_t index) const { - return TensorNode(graph_, edges_[index]->source()->safe_as()); + return TensorNode(edges_[index]->source()->safe_as(), graph_); } TensorNode OpNode::OutputTensorListView::operator[](size_t index) const { - return TensorNode(graph_, edges_[index]->sink()->safe_as()); + return TensorNode(edges_[index]->sink()->safe_as(), graph_); } } // namespace api diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index a5bf9fa58d..c9221936d6 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -28,12 +28,12 @@ using Attribute = cinn::utils::Attribute; class OpNode { public: - OpNode(const hlir::framework::Graph* graph, const hlir::framework::Node* node) : graph_(graph), node_(node) { + OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) : node_(node), graph_(graph) { input_edges_ = node->inlinks_in_order(); output_edges_ = node->outlinks_in_order(); } - OpPatternKind kind () { + OpPatternKind kind () const { thread_local const static hlir::framework::OpValueType& op_pattern_dict = hlir::framework::Operator::GetAttrs("OpPattern"); auto kind = op_pattern_dict[node_->op()]; @@ -106,8 +106,8 @@ class OpNode { return node_->attrs.attr_store.at(attr_name); } - const hlir::framework::Graph* graph_; const hlir::framework::Node* node_; + const hlir::framework::Graph* graph_; std::vector> input_edges_; std::vector> output_edges_; diff --git a/cinn/api/tensor_node.cc b/cinn/api/tensor_node.cc index 40674e33f1..833756139e 100644 --- a/cinn/api/tensor_node.cc +++ b/cinn/api/tensor_node.cc @@ -20,11 +20,11 @@ namespace cinn { namespace api { OpNode TensorNode::Producer() const { - return OpNode(graph_, node_data_->source_node.get()); + return OpNode(node_data_->source_node.get(), graph_); } OpNode TensorNode::ConsumerOpListView::Iterator::operator * () const{ - return OpNode(graph_, (*iter_)->sink()->safe_as()); + return OpNode((*iter_)->sink()->safe_as(), graph_); } } // namespace api diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index 94b2fb998a..f7c315532e 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -27,7 +27,7 @@ using shape_t = utils::ShapeType; class TensorNode { public: - TensorNode(const hlir::framework::Graph* graph, const hlir::framework::NodeData* node_data) : graph_(graph), node_data_(node_data) {} + TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph) {} // Get the shape of tensor. const shape_t& Shape() const { @@ -42,6 +42,9 @@ class TensorNode { public: ConsumerOpListView(const std::set, common::GraphEdgeCompare>& edges, const hlir::framework::Graph* graph) : edges_(edges), graph_(graph) {} + ConsumerOpListView(const ConsumerOpListView& other) = delete; + ConsumerOpListView(ConsumerOpListView&& other) = delete; + class Iterator { public: Iterator(std::set, common::GraphEdgeCompare>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} @@ -96,8 +99,8 @@ class TensorNode { } private: - const hlir::framework::Graph* graph_; const hlir::framework::NodeData* node_data_; + const hlir::framework::Graph* graph_; }; } // namespace api diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 8667f462ee..03cb265f73 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -1060,7 +1060,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this->graph_, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer, this->graph_), EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -1112,7 +1112,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { OpGroupList consumer_groups; consumer_groups.reserve(consumers.size()); for(auto& consumer : consumers) { - consumer_groups.push_back(api::OpGroup(this->graph_, consumer)); + consumer_groups.push_back(api::OpGroup(consumer, this->graph_)); } GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); EnableFusedInputGroups(&fuse_ctx); @@ -1450,7 +1450,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.push_back(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this->graph_, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer, this->graph_), EnableFuse); TagVerticalGroups(&fuse_ctx); return tagged_sets; }; @@ -1697,7 +1697,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(this->graph_, producer), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer, this->graph_), EnableFuse); TagRecomputeGroups(&fuse_ctx); return tagged_sets; }; From 2efd7c3955d7178b5f5818b566a3c52365b9f246 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 26 Jun 2023 07:36:42 +0000 Subject: [PATCH 39/66] add Shape class --- cinn/api/op_group.h | 6 ++- cinn/api/op_node.h | 3 +- cinn/api/tensor_node.h | 60 ++++++++++++++++++--- cinn/hlir/pass/general_fusion_merge_pass.cc | 20 +++---- 4 files changed, 66 insertions(+), 23 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index c56e91b944..0e00647f90 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -40,6 +40,8 @@ class OpGroup { OpNodeListView(const OpNodeListView& other) = delete; OpNodeListView(OpNodeListView&& other) = delete; + OpNodeListView& operator=(const OpNodeListView& other) = delete; + class Iterator { public: Iterator(std::vector::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} @@ -89,6 +91,8 @@ class OpGroup { OpGroupListView(const OpGroupListView& other) = delete; OpGroupListView(OpGroupListView&& other) = delete; + OpGroupListView& operator=(const OpGroupListView& other) = delete; + class Iterator { public: Iterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} @@ -136,7 +140,7 @@ class OpGroup { hlir::framework::OpPatternKind kind() const { return group_->kind(); } - OpNodeListView Ops() const { + OpNodeListView ops() const { return OpNodeListView(group_->CollectNodes(), graph_); } diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index c9221936d6..efd9791deb 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -51,9 +51,10 @@ class OpNode { InputTensorListView(const hlir::framework::Graph* graph, const std::vector>& edges) : graph_(graph), edges_(edges) {} InputTensorListView(const InputTensorListView& other) = delete; - InputTensorListView(InputTensorListView&& other) = delete; + InputTensorListView& operator=(const InputTensorListView& other) = delete; + size_t size() const { return edges_.size(); } TensorNode operator[](size_t index) const; diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index f7c315532e..4e8822286c 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -14,26 +14,69 @@ #pragma once +#include + #include "cinn/hlir/framework/graph.h" #include "cinn/utils/type_defs.h" #include "cinn/hlir/pass/fusion_helper_base.h" +#include "cinn/utils/small_vector.h" namespace cinn { namespace api { -class OpNode; - using shape_t = utils::ShapeType; -class TensorNode { +class OpNode; + +class Shape final { public: - TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph) {} + explicit Shape(const utils::ShapeType& shape) : shape_(shape) {} - // Get the shape of tensor. - const shape_t& Shape() const { + Shape(const Shape& other) = delete; + Shape(Shape&& other) = delete; + + Shape& operator=(const Shape& other) = delete; + + bool operator == (const Shape& other) const { + if (shape_.size() != other.shape_.size()) { + return false; + } + return std::equal(shape_.begin(), shape_.end(), other.shape_.begin()); + } + + const size_t& operator[] (size_t index) const { + return shape_[index]; + } + + size_t at(size_t index) const { + return shape_.at(index); + } + + size_t size() const { + return shape_.size(); + } + + // Returns the total number of elements in the shape. + size_t numel() const { + return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies());; + } + + private: + const shape_t& shape_; +}; + + +class TensorNode final { + public: + TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph) { const auto& shape_dict = graph_->GetAttrs>("infershape"); CHECK(shape_dict.count(node_data_->id())) << "Can't find " << node_data_->id() << " 's shape!"; - return shape_dict.at(node_data_->id()); + shape_.reset(new Shape(shape_dict.find(node_data_->id())->second)); + } + + // Get the shape of tensor. + const Shape& shape() const { + return *shape_; } OpNode Producer() const; @@ -45,6 +88,8 @@ class TensorNode { ConsumerOpListView(const ConsumerOpListView& other) = delete; ConsumerOpListView(ConsumerOpListView&& other) = delete; + ConsumerOpListView& operator=(const ConsumerOpListView& other) = delete; + class Iterator { public: Iterator(std::set, common::GraphEdgeCompare>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} @@ -101,6 +146,7 @@ class TensorNode { private: const hlir::framework::NodeData* node_data_; const hlir::framework::Graph* graph_; + std::unique_ptr shape_; }; } // namespace api diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 03cb265f73..bb4cfe1a7c 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -383,7 +383,7 @@ struct HorizontalFuseUtil { } static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - auto ops = op_group.Ops(); + auto ops = op_group.ops(); for (auto iter = ops.begin(); iter != ops.end(); ++iter) { api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { @@ -397,14 +397,9 @@ struct HorizontalFuseUtil { api::OpNode src_master_node = GetMasterNode(ctx, src); api::OpNode dst_master_node = GetMasterNode(ctx, dst); - const auto& output_var_0 = src_master_node.Outputs()[0].Shape(); - const auto& output_var_1 = dst_master_node.Outputs()[0].Shape(); - if (output_var_0 == output_var_1) { - return true; - } + auto size_0 = src_master_node.Outputs()[0].shape().numel(); + auto size_1 = dst_master_node.Outputs()[0].shape().numel(); - auto size_0 = std::accumulate(output_var_0.begin(), output_var_0.end(), 1, std::multiplies()); - auto size_1 = std::accumulate(output_var_1.begin(), output_var_1.end(), 1, std::multiplies()); return size_0 == size_1; } @@ -425,16 +420,13 @@ struct HorizontalFuseUtil { reduce_group = &dst; } - shape_t ele_node_shape = GetMasterNode(ctx, *ele_group).Outputs()[0].Shape(); - int32_t size_ele = std::accumulate(ele_node_shape.begin(), ele_node_shape.end(), 1, std::multiplies()); + size_t size_ele = GetMasterNode(ctx, *ele_group).Outputs()[0].shape().numel(); - auto ops = reduce_group->Ops(); + auto ops = reduce_group->ops(); for (auto iter = ops.begin(); iter!= ops.end(); ++iter) { api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { - shape_t master_node_shape = node.Outputs()[0].Shape(); - int32_t size_master = - std::accumulate(master_node_shape.begin(), master_node_shape.end(), 1, std::multiplies()); + size_t size_master = node.Outputs()[0].shape().numel(); if (size_ele == size_master) { return true; } From 6d30239143b7fb343f3b63db3dd778fabce62937 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 26 Jun 2023 08:28:31 +0000 Subject: [PATCH 40/66] refine view interface --- cinn/api/op_group.h | 38 ++++++++++++++------- cinn/api/op_node.h | 38 ++++++++------------- cinn/api/tensor_node.cc | 2 +- cinn/api/tensor_node.h | 18 +++++----- cinn/hlir/pass/general_fusion_merge_pass.cc | 34 +++++++++--------- 5 files changed, 66 insertions(+), 64 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 0e00647f90..28e5332156 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -29,9 +29,16 @@ using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; class OpGroup { public: - OpGroup(const std::shared_ptr& group, const hlir::framework::Graph* graph) : group_(group), graph_(graph) {} + OpGroup(const std::shared_ptr& group, const hlir::framework::Graph* graph) + : group_(group), graph_(graph), + producers_(group_->producer_groups(), graph_), + consumers_(group_->consumer_groups(), graph_), + ops_(group_->CollectNodes(), graph_) {} - OpGroup(const OpGroup& other) = default; + OpGroup(const OpGroup& other) : group_(other.group_), graph_(other.graph_), + producers_(group_->producer_groups(), graph_), + consumers_(group_->consumer_groups(), graph_), + ops_(group_->CollectNodes(), graph_) {} class OpNodeListView { public: @@ -76,11 +83,11 @@ class OpGroup { size_t size() const { return op_nodes_.size(); } - Iterator begin() { return Iterator(op_nodes_.begin(), graph_); } + Iterator begin() const { return Iterator(op_nodes_.begin(), graph_); } - Iterator end() { return Iterator(op_nodes_.begin(), graph_); } + Iterator end() const { return Iterator(op_nodes_.begin(), graph_); } private: - std::vector op_nodes_; + const std::vector op_nodes_; const cinn::hlir::framework::Graph* graph_; }; @@ -127,9 +134,9 @@ class OpGroup { size_t size() const { return op_group_map_.size(); } - Iterator begin() { return Iterator(op_group_map_.begin(), graph_); } + Iterator begin() const { return Iterator(op_group_map_.begin(), graph_); } - Iterator end() { return Iterator(op_group_map_.begin(), graph_); } + Iterator end() const { return Iterator(op_group_map_.begin(), graph_); } private: const std::unordered_map, TensorInterfaceList, Hasher, Comparator>& op_group_map_; @@ -140,16 +147,16 @@ class OpGroup { hlir::framework::OpPatternKind kind() const { return group_->kind(); } - OpNodeListView ops() const { - return OpNodeListView(group_->CollectNodes(), graph_); + const OpNodeListView& ops() const { + return ops_; } - OpGroupListView Producers() const { - return OpGroupListView(group_->producer_groups(), graph_); + const OpGroupListView& producers() const { + return producers_; } - OpGroupListView Consumers() const { - return OpGroupListView(group_->consumer_groups(), graph_); + const OpGroupListView &consumers() const { + return consumers_; } std::shared_ptr GetGroup() const { @@ -167,6 +174,11 @@ class OpGroup { private: const std::shared_ptr group_; const hlir::framework::Graph* graph_; + + const OpGroupListView producers_; + const OpGroupListView consumers_; + + const OpNodeListView ops_; }; } // namespace api diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index efd9791deb..655d238bba 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -28,10 +28,9 @@ using Attribute = cinn::utils::Attribute; class OpNode { public: - OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) : node_(node), graph_(graph) { - input_edges_ = node->inlinks_in_order(); - output_edges_ = node->outlinks_in_order(); - } + OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) : node_(node), graph_(graph), input_tensors_(node->inlinks_in_order(), graph_), output_tensors_(node->outlinks_in_order(), graph_) {} + + OpNode(const OpNode& other) : node_(other.node_), graph_(other.graph_), input_tensors_(node_->inlinks_in_order(), graph_), output_tensors_(node_->outlinks_in_order(), graph_) {} OpPatternKind kind () const { thread_local const static hlir::framework::OpValueType& op_pattern_dict = hlir::framework::Operator::GetAttrs("OpPattern"); @@ -48,7 +47,7 @@ class OpNode { class InputTensorListView { public: - InputTensorListView(const hlir::framework::Graph* graph, const std::vector>& edges) : graph_(graph), edges_(edges) {} + InputTensorListView(const std::vector>& edges, const hlir::framework::Graph* graph) : edges_(edges), graph_(graph) {} InputTensorListView(const InputTensorListView& other) = delete; InputTensorListView(InputTensorListView&& other) = delete; @@ -60,41 +59,34 @@ class OpNode { TensorNode operator[](size_t index) const; private: + std::vector> edges_; const hlir::framework::Graph* graph_; - const std::vector>& edges_; }; class OutputTensorListView { public: - OutputTensorListView(const hlir::framework::Graph* graph, const std::vector>& edges) : graph_(graph), edges_(edges) {} + OutputTensorListView(const std::vector>& edges, const hlir::framework::Graph* graph) : edges_(edges), graph_(graph) {} OutputTensorListView(const OutputTensorListView& other) = delete; - OutputTensorListView(OutputTensorListView&& other) = delete; + OutputTensorListView& operator=(const OutputTensorListView& other) = delete; + size_t size() const { return edges_.size(); } TensorNode operator[](size_t index) const; private: + std::vector> edges_; const hlir::framework::Graph* graph_; - const std::vector>& edges_; }; - size_t InputsSize() const { - return node_->inlinks().size(); - } - - size_t OutputsSize() const { - return node_->outlinks().size(); - } - - InputTensorListView Inputs() const { - return InputTensorListView(graph_, input_edges_); + const InputTensorListView& inputs() const { + return input_tensors_; } - OutputTensorListView Outputs() const { - return OutputTensorListView(graph_, output_edges_); + const OutputTensorListView& outputs() const { + return output_tensors_; } template @@ -110,8 +102,8 @@ class OpNode { const hlir::framework::Node* node_; const hlir::framework::Graph* graph_; - std::vector> input_edges_; - std::vector> output_edges_; + const InputTensorListView input_tensors_; + const OutputTensorListView output_tensors_; }; } // namespace api diff --git a/cinn/api/tensor_node.cc b/cinn/api/tensor_node.cc index 833756139e..7a77009aee 100644 --- a/cinn/api/tensor_node.cc +++ b/cinn/api/tensor_node.cc @@ -19,7 +19,7 @@ namespace cinn { namespace api { -OpNode TensorNode::Producer() const { +OpNode TensorNode::producer() const { return OpNode(node_data_->source_node.get(), graph_); } diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index 4e8822286c..0be223e432 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -44,8 +44,8 @@ class Shape final { return std::equal(shape_.begin(), shape_.end(), other.shape_.begin()); } - const size_t& operator[] (size_t index) const { - return shape_[index]; + size_t operator[] (size_t index) const { + return shape_.at(index); } size_t at(size_t index) const { @@ -68,7 +68,7 @@ class Shape final { class TensorNode final { public: - TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph) { + TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph), consumers_(node_data_->outlinks(), graph_) { const auto& shape_dict = graph_->GetAttrs>("infershape"); CHECK(shape_dict.count(node_data_->id())) << "Can't find " << node_data_->id() << " 's shape!"; shape_.reset(new Shape(shape_dict.find(node_data_->id())->second)); @@ -79,7 +79,7 @@ class TensorNode final { return *shape_; } - OpNode Producer() const; + OpNode producer() const; class ConsumerOpListView { public: @@ -135,18 +135,16 @@ class TensorNode final { const hlir::framework::Graph* graph_; }; - size_t ConsumerSize() const { - return node_data_->outlinks().size(); - } - - ConsumerOpListView Consumers() const { - return ConsumerOpListView(node_data_->outlinks(), graph_); + const ConsumerOpListView& consumers() const { + return consumers_; } private: const hlir::framework::NodeData* node_data_; const hlir::framework::Graph* graph_; + std::unique_ptr shape_; + const ConsumerOpListView consumers_; }; } // namespace api diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index bb4cfe1a7c..d2a0f4e5d5 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -109,14 +109,14 @@ class GraphGroupFuseHelper final : public FuseHelper { private: bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { - const auto& MinDepth4Node = [&](OpGroupPtr node) { + const auto& MinDepth4Node = [&](const OpGroupPtr& node) { return node.GetGroup()->min_depth; }; - const auto& MaxDepth4Node = [&](OpGroupPtr node) { + const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { return node.GetGroup()->max_depth; }; - const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { - auto producer_groups = producer.Producers(); + const auto& VisitNextNodes = [&](const OpGroupPtr& node, const std::function& Visit) { + const auto& producer_groups = producer.producers(); for(auto iter = producer_groups.begin(); iter != producer_groups.end(); ++iter) { if (node == consumer && *iter == producer) { continue; @@ -192,7 +192,7 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { private: // static std::unordered_map cache_data_; const FusionHelperBase* graph_group_fusion_helper_; - OpGroupPtr group_; + const OpGroupPtr& group_; const std::function EnableFuse_; const std::unique_ptr fuse_helper_; }; @@ -383,7 +383,7 @@ struct HorizontalFuseUtil { } static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - auto ops = op_group.ops(); + const auto& ops = op_group.ops(); for (auto iter = ops.begin(); iter != ops.end(); ++iter) { api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { @@ -397,8 +397,8 @@ struct HorizontalFuseUtil { api::OpNode src_master_node = GetMasterNode(ctx, src); api::OpNode dst_master_node = GetMasterNode(ctx, dst); - auto size_0 = src_master_node.Outputs()[0].shape().numel(); - auto size_1 = dst_master_node.Outputs()[0].shape().numel(); + auto size_0 = src_master_node.outputs()[0].shape().numel(); + auto size_1 = dst_master_node.outputs()[0].shape().numel(); return size_0 == size_1; } @@ -420,13 +420,13 @@ struct HorizontalFuseUtil { reduce_group = &dst; } - size_t size_ele = GetMasterNode(ctx, *ele_group).Outputs()[0].shape().numel(); + size_t size_ele = GetMasterNode(ctx, *ele_group).outputs()[0].shape().numel(); - auto ops = reduce_group->ops(); + const auto& ops = reduce_group->ops(); for (auto iter = ops.begin(); iter!= ops.end(); ++iter) { api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { - size_t size_master = node.Outputs()[0].shape().numel(); + size_t size_master = node.outputs()[0].shape().numel(); if (size_ele == size_master) { return true; } @@ -537,7 +537,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - auto consumer_groups = producer.Consumers(); + const auto& consumer_groups = producer.consumers(); for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } @@ -588,7 +588,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - auto consumer_groups = producer.Consumers(); + const auto& consumer_groups = producer.consumers(); for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } @@ -715,7 +715,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - auto consumer_groups = producer.Consumers(); + const auto& consumer_groups = producer.consumers(); for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } @@ -1035,7 +1035,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.Consumers().size() <= 1) { + if (producer.consumers().size() <= 1) { return; } const auto& fuse_passes = GetHorizontalFusePasses(); @@ -1425,7 +1425,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagVerticalGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.Consumers().size() == 0) { + if (producer.consumers().size() == 0) { return; } const auto& fuse_passes = GetVerticalFusePasses(); @@ -1672,7 +1672,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { const auto& producer = ctx->PickOpGroup(); - if (producer.Consumers().size() <= 1) { + if (producer.consumers().size() <= 1) { return; } const auto& fuse_passes = GetRecomputeFusePasses(); From 495f0b41d36b20fcff5c81d4b5f47293541d1128 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 26 Jun 2023 12:55:50 +0000 Subject: [PATCH 41/66] modify shared_ptr of Group --- cinn/api/op_group.h | 103 +++++++++++--------- cinn/api/op_node.cc | 2 +- cinn/api/op_node.h | 21 ++++ cinn/api/tensor_node.cc | 2 +- cinn/api/tensor_node.h | 12 +-- cinn/hlir/pass/general_fusion_merge_pass.cc | 19 ++-- 6 files changed, 99 insertions(+), 60 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 28e5332156..053239e0ad 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -30,15 +30,9 @@ using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; class OpGroup { public: OpGroup(const std::shared_ptr& group, const hlir::framework::Graph* graph) - : group_(group), graph_(graph), - producers_(group_->producer_groups(), graph_), - consumers_(group_->consumer_groups(), graph_), - ops_(group_->CollectNodes(), graph_) {} + : group_(group), graph_(graph) {} - OpGroup(const OpGroup& other) : group_(other.group_), graph_(other.graph_), - producers_(group_->producer_groups(), graph_), - consumers_(group_->consumer_groups(), graph_), - ops_(group_->CollectNodes(), graph_) {} + OpGroup(const OpGroup& other) = default; class OpNodeListView { public: @@ -91,35 +85,26 @@ class OpGroup { const cinn::hlir::framework::Graph* graph_; }; - class OpGroupListView { - public: - OpGroupListView(const std::unordered_map, TensorInterfaceList, Hasher, Comparator>& group_map, const hlir::framework::Graph* graph) : op_group_map_(group_map), graph_(graph) {} - - OpGroupListView(const OpGroupListView& other) = delete; - OpGroupListView(OpGroupListView&& other) = delete; - - OpGroupListView& operator=(const OpGroupListView& other) = delete; - - class Iterator { + class OpGroupListIterator { public: - Iterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} + OpGroupListIterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} - Iterator& operator++() { + OpGroupListIterator& operator++() { ++iter_; return *this; } - Iterator operator++(int) { - Iterator tmp = *this; + OpGroupListIterator operator++(int) { + OpGroupListIterator tmp = *this; ++iter_; return tmp; } - bool operator==(const Iterator& other) const { + bool operator==(const OpGroupListIterator& other) const { return iter_ == other.iter_; } - bool operator!=(const Iterator& other) const { + bool operator!=(const OpGroupListIterator& other) const { return !(*this == other); } @@ -130,55 +115,83 @@ class OpGroup { private: std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator iter_; const hlir::framework::Graph* graph_; - }; + }; + + class ProducerOpGroupListView { + public: + ProducerOpGroupListView(const std::weak_ptr& group, const hlir::framework::Graph* graph) : group_(group), graph_(graph) {} + + ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete; + ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete; + + ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) = delete; - size_t size() const { return op_group_map_.size(); } + using const_iterator = OpGroupListIterator; - Iterator begin() const { return Iterator(op_group_map_.begin(), graph_); } + size_t size() const { return group_.lock()->producer_groups().size(); } - Iterator end() const { return Iterator(op_group_map_.begin(), graph_); } + const_iterator begin() const { return const_iterator(group_.lock()->producer_groups().begin(), graph_); } + + const_iterator end() const { return const_iterator(group_.lock()->producer_groups().begin(), graph_); } private: - const std::unordered_map, TensorInterfaceList, Hasher, Comparator>& op_group_map_; + const std::weak_ptr group_; const cinn::hlir::framework::Graph* graph_; }; + class ConsumerOpGroupListView { + public: + ConsumerOpGroupListView(const std::weak_ptr& group, const hlir::framework::Graph* graph) : group_(group), graph_(graph) {} + + ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete; + ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete; + + ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) = delete; + using const_iterator = OpGroupListIterator; - hlir::framework::OpPatternKind kind() const { return group_->kind(); } + size_t size() const { return group_.lock()->consumer_groups().size(); } - const OpNodeListView& ops() const { - return ops_; + const_iterator begin() const { return const_iterator(group_.lock()->consumer_groups().begin(), graph_); } + + const_iterator end() const { return const_iterator(group_.lock()->consumer_groups().begin(), graph_); } + + private: + const std::weak_ptr group_; + const cinn::hlir::framework::Graph* graph_; + }; + + + + hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); } + + OpNodeListView ops() const { + return OpNodeListView(group_.lock()->CollectNodes(), graph_); } - const OpGroupListView& producers() const { - return producers_; + ProducerOpGroupListView producers() const { + return ProducerOpGroupListView(group_, graph_); } - const OpGroupListView &consumers() const { - return consumers_; + ConsumerOpGroupListView consumers() const { + return ConsumerOpGroupListView(group_, graph_); } std::shared_ptr GetGroup() const { - return group_; + return group_.lock(); } bool operator == (const OpGroup& other) const { - return group_.get() == other.group_.get(); + return group_.lock().get() == other.group_.lock().get(); } bool operator < (const OpGroup& other) const { - return group_.get() < other.group_.get(); + return group_.lock().get() < other.group_.lock().get(); } private: - const std::shared_ptr group_; + const std::weak_ptr group_; const hlir::framework::Graph* graph_; - - const OpGroupListView producers_; - const OpGroupListView consumers_; - - const OpNodeListView ops_; }; } // namespace api diff --git a/cinn/api/op_node.cc b/cinn/api/op_node.cc index 2444243016..b54f8c9c43 100644 --- a/cinn/api/op_node.cc +++ b/cinn/api/op_node.cc @@ -26,4 +26,4 @@ TensorNode OpNode::OutputTensorListView::operator[](size_t index) const { } } // namespace api -} // namespace cinn \ No newline at end of file +} // namespace cinn diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index 655d238bba..7e254ca678 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -81,6 +81,14 @@ class OpNode { const hlir::framework::Graph* graph_; }; + bool operator == (const OpNode& other) const { + return node_ == other.node_; + } + + bool operator < (const OpNode& other) const { + return node_ < other.node_; + } + const InputTensorListView& inputs() const { return input_tensors_; } @@ -99,6 +107,8 @@ class OpNode { return node_->attrs.attr_store.at(attr_name); } + friend struct std::hash; + const hlir::framework::Node* node_; const hlir::framework::Graph* graph_; @@ -108,3 +118,14 @@ class OpNode { } // namespace api } // namespace cinn + +namespace std { + +template <> +struct hash { + size_t operator()(const cinn::api::OpNode& obj) const { + return std::hash()(reinterpret_cast(obj.node_)); + } +}; + +} // namespace std diff --git a/cinn/api/tensor_node.cc b/cinn/api/tensor_node.cc index 7a77009aee..66f2f0e321 100644 --- a/cinn/api/tensor_node.cc +++ b/cinn/api/tensor_node.cc @@ -28,4 +28,4 @@ OpNode TensorNode::ConsumerOpListView::Iterator::operator * () const{ } } // namespace api -} // namespace cinn \ No newline at end of file +} // namespace cinn diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index 0be223e432..5e43992d69 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -30,7 +30,7 @@ class OpNode; class Shape final { public: - explicit Shape(const utils::ShapeType& shape) : shape_(shape) {} + explicit Shape(const utils::ShapeType& shape) : shape_(shape.begin(), shape.end()) {} Shape(const Shape& other) = delete; Shape(Shape&& other) = delete; @@ -45,11 +45,11 @@ class Shape final { } size_t operator[] (size_t index) const { - return shape_.at(index); + return shape_[index]; } size_t at(size_t index) const { - return shape_.at(index); + return shape_[index]; } size_t size() const { @@ -62,7 +62,7 @@ class Shape final { } private: - const shape_t& shape_; + cinn::utils::SmallVector shape_; }; @@ -71,7 +71,7 @@ class TensorNode final { TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph), consumers_(node_data_->outlinks(), graph_) { const auto& shape_dict = graph_->GetAttrs>("infershape"); CHECK(shape_dict.count(node_data_->id())) << "Can't find " << node_data_->id() << " 's shape!"; - shape_.reset(new Shape(shape_dict.find(node_data_->id())->second)); + shape_ = std::make_shared(shape_dict.find(node_data_->id())->second); } // Get the shape of tensor. @@ -143,7 +143,7 @@ class TensorNode final { const hlir::framework::NodeData* node_data_; const hlir::framework::Graph* graph_; - std::unique_ptr shape_; + std::shared_ptr shape_; const ConsumerOpListView consumers_; }; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index d2a0f4e5d5..a1534e9460 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -116,7 +116,7 @@ class GraphGroupFuseHelper final : public FuseHelper { return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](const OpGroupPtr& node, const std::function& Visit) { - const auto& producer_groups = producer.producers(); + const auto producer_groups = producer.producers(); for(auto iter = producer_groups.begin(); iter != producer_groups.end(); ++iter) { if (node == consumer && *iter == producer) { continue; @@ -318,7 +318,12 @@ bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, dst.GetGroup()); } -// limit the group args number to less equal 512, as args stack size is 4K. +// static std::vector GetInputOps(const OpGroupPtr& op_group) { +// std::unordered_set ops_set(op_group.ops().begin(), op_group.ops().end()); + +// } + +// // limit the group args number to less equal 512, as args stack size is 4K. // static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { // std::unordered_set args; // for (auto& group : {first, second}) { @@ -383,7 +388,7 @@ struct HorizontalFuseUtil { } static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - const auto& ops = op_group.ops(); + const auto ops = op_group.ops(); for (auto iter = ops.begin(); iter != ops.end(); ++iter) { api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { @@ -422,7 +427,7 @@ struct HorizontalFuseUtil { size_t size_ele = GetMasterNode(ctx, *ele_group).outputs()[0].shape().numel(); - const auto& ops = reduce_group->ops(); + const auto ops = reduce_group->ops(); for (auto iter = ops.begin(); iter!= ops.end(); ++iter) { api::OpNode node = *iter; if (node.kind() == OpPatternKind::kReduction) { @@ -537,7 +542,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - const auto& consumer_groups = producer.consumers(); + const auto consumer_groups = producer.consumers(); for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } @@ -588,7 +593,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - const auto& consumer_groups = producer.consumers(); + const auto consumer_groups = producer.consumers(); for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } @@ -715,7 +720,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - const auto& consumer_groups = producer.consumers(); + const auto consumer_groups = producer.consumers(); for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { consumers.push_back(*iter); } From 46ace80862282e5025fd0ba13fc6a5f39f03b0c0 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 27 Jun 2023 02:28:52 +0000 Subject: [PATCH 42/66] fix-accuracy-test-bug --- cinn/hlir/pass/general_fusion_merge_pass.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 862126c18e..07849d5277 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -1192,6 +1192,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. (*gconsumer->mut_producer_groups())[fused_group] += {}; } + fused_group->mut_consumer_groups()->erase(fused_group); + fused_group->mut_producer_groups()->erase(fused_group); // belongs group consumer->belong_groups.insert(fused_group); From 007d67f309254080627f3e2daf3122bed0fcba5d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 27 Jun 2023 13:13:43 +0000 Subject: [PATCH 43/66] fix bug --- cinn/api/op_group.h | 25 +- cinn/api/tensor_node.h | 5 + cinn/hlir/pass/general_fusion_merge_pass.cc | 259 +++++++++++++++++--- 3 files changed, 241 insertions(+), 48 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 053239e0ad..59101d44a8 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -43,26 +43,26 @@ class OpGroup { OpNodeListView& operator=(const OpNodeListView& other) = delete; - class Iterator { + class iterator { public: - Iterator(std::vector::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} + iterator(std::vector::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} - Iterator& operator++() { + iterator& operator++() { ++iter_; return *this; } - Iterator operator++(int) { - Iterator tmp = *this; + iterator operator++(int) { + iterator tmp = *this; ++iter_; return tmp; } - bool operator==(const Iterator& other) const { + bool operator==(const iterator& other) const { return iter_ == other.iter_; } - bool operator!=(const Iterator& other) const { + bool operator!=(const iterator& other) const { return !(*this == other); } @@ -77,9 +77,9 @@ class OpGroup { size_t size() const { return op_nodes_.size(); } - Iterator begin() const { return Iterator(op_nodes_.begin(), graph_); } + iterator begin() const { return iterator(op_nodes_.begin(), graph_); } - Iterator end() const { return Iterator(op_nodes_.begin(), graph_); } + iterator end() const { return iterator(op_nodes_.end(), graph_); } private: const std::vector op_nodes_; const cinn::hlir::framework::Graph* graph_; @@ -132,7 +132,7 @@ class OpGroup { const_iterator begin() const { return const_iterator(group_.lock()->producer_groups().begin(), graph_); } - const_iterator end() const { return const_iterator(group_.lock()->producer_groups().begin(), graph_); } + const_iterator end() const { return const_iterator(group_.lock()->producer_groups().end(), graph_); } private: const std::weak_ptr group_; @@ -154,13 +154,16 @@ class OpGroup { const_iterator begin() const { return const_iterator(group_.lock()->consumer_groups().begin(), graph_); } - const_iterator end() const { return const_iterator(group_.lock()->consumer_groups().begin(), graph_); } + const_iterator end() const { return const_iterator(group_.lock()->consumer_groups().end(), graph_); } private: const std::weak_ptr group_; const cinn::hlir::framework::Graph* graph_; }; + const std::string& group_id() const { + return group_.lock()->group_id; + } hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); } diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index 5e43992d69..8e315b7e7d 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -79,6 +79,11 @@ class TensorNode final { return *shape_; } + // Input data has no producer. + bool HasProducer() const { + return node_data_->source_node.get() != nullptr; + } + OpNode producer() const; class ConsumerOpListView { diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index a1534e9460..641ba6b52c 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -116,12 +116,11 @@ class GraphGroupFuseHelper final : public FuseHelper { return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](const OpGroupPtr& node, const std::function& Visit) { - const auto producer_groups = producer.producers(); - for(auto iter = producer_groups.begin(); iter != producer_groups.end(); ++iter) { - if (node == consumer && *iter == producer) { + for(const auto& node_producer : node.producers()) { + if (node == consumer && node_producer == producer) { continue; } - Visit(*iter); + Visit(node_producer); } }; common::IsReachablePredicator is_reachable(MinDepth4Node, MaxDepth4Node, VisitNextNodes); @@ -318,29 +317,218 @@ bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, dst.GetGroup()); } -// static std::vector GetInputOps(const OpGroupPtr& op_group) { -// std::unordered_set ops_set(op_group.ops().begin(), op_group.ops().end()); - -// } - -// // limit the group args number to less equal 512, as args stack size is 4K. -// static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { -// std::unordered_set args; -// for (auto& group : {first, second}) { -// for (auto node : group->input_nodes) { -// args.insert(node.first); -// } -// for (auto node : group->output_nodes) { -// args.insert(node); -// } -// } - -// if (args.size() > 512) { -// return false; -// } else { -// return true; -// } -// } +static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { + const auto& ops = op_group.ops(); + std::unordered_set ops_set; + for (const auto& op : ops) { + ops_set.insert(op); + } + std::unordered_set input_ops; + for (const auto& op : ops) { + const auto& input_tensors = op.inputs(); + for (size_t i = 0; i < input_tensors.size(); ++i) { + if(input_tensors[i].HasProducer()) { + api::OpNode producer = input_tensors[i].producer(); + if (ops_set.find(producer) == ops_set.end()) { + input_ops.insert(producer); + } + } + } + } + return input_ops; +} + +static std::unordered_set GetOutputOps(const OpGroupPtr& op_group) { + auto ops = op_group.ops(); + std::unordered_set ops_set; + for (const auto& op : ops) { + ops_set.insert(op); + } + std::unordered_set output_ops; + for (const auto& op : ops) { + const auto& output_tensors = op.outputs(); + for (size_t i = 0; i < output_tensors.size(); ++i) { + const auto& consumers = output_tensors[i].consumers(); + for (const auto& consumer : consumers) { + if (ops_set.find(consumer) == ops_set.end()) { + output_ops.insert(consumer); + break; + } + } + } + } + return output_ops; +} + +// limit the group args number to less equal 512, as args stack size is 4K. +static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { + std::unordered_set args; + for (auto& group : {first, second}) { + for (const auto& node : GetInputOps(group)) { + args.insert(node); + } + for (const auto& node : GetOutputOps(group)) { + args.insert(node); + } + } + + if (args.size() > 512) { + return false; + } else { + return true; + } +} + +bool WithoutLastDimInReduce(const api::Shape& inshape, const std::vector& axes) { + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int sum_last_axes = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + sum_last_axes *= inshape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +static int GetSharedSize(const api::OpNode& op_node) { + const auto& producers = op_node.inputs(); + CHECK_GT(producers.size(), 0); + const auto& inshape =producers[0].shape(); + const auto& axes = op_node.GetAttr>("dim"); + if (WithoutLastDimInReduce(inshape, axes)) { + int lane = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + lane = inshape[idx]; + } + int max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + if (lane > max_num_threads / 2) { + return 0; + } + int index = axes.size() - 1; + for (; index >= 0; --index) { + if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (lane > max_num_threads / 2) { + break; + } + } + // if lane > (max_num_threads / 2),the loop break from lane > max_num_threads / 2. + int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; + if (lane <= max_num_threads) { + return lane * sizeof(float); + } else { + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > ((max_num_threads / 2) / tail); --idx) { + if (prefix % idx == 0) { + return idx * tail * sizeof(float); + } + } + int num = max_num_threads / tail; + return num * tail * sizeof(float); + } + } + return 0; +} + +static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& second) { + if (!limit_args(first, second)) { + return false; + } + std::unique_ptr reducer_0 = nullptr; + for (const auto& op : first.ops()) { + if (op.kind() == OpPatternKind::kReduction) { + reducer_0.reset(new api::OpNode(op)); + break; + } + } + CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); + + std::unique_ptr reducer_1 = nullptr; + for (const auto& op : second.ops()) { + if (op.kind() == OpPatternKind::kReduction) { + reducer_1.reset(new api::OpNode(op)); + break; + } + } + CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id(); + + // check reduce has same input shape and output shape + const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape(); + const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); + + const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); + const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); + + auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); + auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); + + for (auto& dim : reducer_0_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_0_reduce_dim.size() - 1; + } + } + + for (auto& dim : reducer_1_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_1_reduce_dim.size() - 1; + } + } + + // check shape is same + if (reducer_0_input_shape == reducer_1_input_shape && reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + for (const auto& node : fusion_group.ops()) { + if (node.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(*reducer_0); + } + } + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && + WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && + reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + for (const auto& node : fusion_group.ops()) { + if (node.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(*reducer_0); + } + } + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + return false; +} template struct HorizontalFuseUtil { @@ -442,7 +630,7 @@ struct HorizontalFuseUtil { } static bool ReduceFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseReduce(src, dst); + return CanReduceFuseReduce(src, dst); } }; @@ -542,9 +730,8 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - const auto consumer_groups = producer.consumers(); - for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { - consumers.push_back(*iter); + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); } return consumers; }(); @@ -593,9 +780,8 @@ class DefaultVerticalFusePass final : public VerticalFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - const auto consumer_groups = producer.consumers(); - for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { - consumers.push_back(*iter); + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); } return consumers; }(); @@ -720,9 +906,8 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; - const auto consumer_groups = producer.consumers(); - for(auto iter = consumer_groups.begin(); iter != consumer_groups.end(); ++iter) { - consumers.push_back(*iter); + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); } return consumers; }(); From d50f1bece97abf6427b22f02d20af4e920cc409c Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 27 Jun 2023 13:52:18 +0000 Subject: [PATCH 44/66] add graph point into group --- cinn/api/op_group.h | 30 +++++++++------------ cinn/hlir/framework/graph.h | 7 +++++ cinn/hlir/pass/general_fusion_merge_pass.cc | 15 +++++------ cinn/hlir/pass/op_fusion_pass.cc | 2 +- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 59101d44a8..12b7733b8d 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -29,8 +29,8 @@ using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; class OpGroup { public: - OpGroup(const std::shared_ptr& group, const hlir::framework::Graph* graph) - : group_(group), graph_(graph) {} + OpGroup(const std::shared_ptr& group) + : group_(group) {} OpGroup(const OpGroup& other) = default; @@ -87,7 +87,7 @@ class OpGroup { class OpGroupListIterator { public: - OpGroupListIterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} + OpGroupListIterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it) : iter_(it) {} OpGroupListIterator& operator++() { ++iter_; @@ -109,17 +109,16 @@ class OpGroup { } OpGroup operator*() const{ - return OpGroup(iter_->first, graph_); + return OpGroup(iter_->first); } private: std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator iter_; - const hlir::framework::Graph* graph_; }; class ProducerOpGroupListView { public: - ProducerOpGroupListView(const std::weak_ptr& group, const hlir::framework::Graph* graph) : group_(group), graph_(graph) {} + ProducerOpGroupListView(const std::weak_ptr& group) : group_(group) {} ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete; ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete; @@ -130,18 +129,17 @@ class OpGroup { size_t size() const { return group_.lock()->producer_groups().size(); } - const_iterator begin() const { return const_iterator(group_.lock()->producer_groups().begin(), graph_); } + const_iterator begin() const { return const_iterator(group_.lock()->producer_groups().begin()); } - const_iterator end() const { return const_iterator(group_.lock()->producer_groups().end(), graph_); } + const_iterator end() const { return const_iterator(group_.lock()->producer_groups().end()); } private: const std::weak_ptr group_; - const cinn::hlir::framework::Graph* graph_; }; class ConsumerOpGroupListView { public: - ConsumerOpGroupListView(const std::weak_ptr& group, const hlir::framework::Graph* graph) : group_(group), graph_(graph) {} + ConsumerOpGroupListView(const std::weak_ptr& group) : group_(group) {} ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete; ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete; @@ -152,13 +150,12 @@ class OpGroup { size_t size() const { return group_.lock()->consumer_groups().size(); } - const_iterator begin() const { return const_iterator(group_.lock()->consumer_groups().begin(), graph_); } + const_iterator begin() const { return const_iterator(group_.lock()->consumer_groups().begin()); } - const_iterator end() const { return const_iterator(group_.lock()->consumer_groups().end(), graph_); } + const_iterator end() const { return const_iterator(group_.lock()->consumer_groups().end()); } private: const std::weak_ptr group_; - const cinn::hlir::framework::Graph* graph_; }; const std::string& group_id() const { @@ -169,15 +166,15 @@ class OpGroup { hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); } OpNodeListView ops() const { - return OpNodeListView(group_.lock()->CollectNodes(), graph_); + return OpNodeListView(group_.lock()->CollectNodes(), group_.lock()->graph_); } ProducerOpGroupListView producers() const { - return ProducerOpGroupListView(group_, graph_); + return ProducerOpGroupListView(group_); } ConsumerOpGroupListView consumers() const { - return ConsumerOpGroupListView(group_, graph_); + return ConsumerOpGroupListView(group_); } std::shared_ptr GetGroup() const { @@ -194,7 +191,6 @@ class OpGroup { private: const std::weak_ptr group_; - const hlir::framework::Graph* graph_; }; } // namespace api diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 73396ac888..7dca7e42ae 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -61,6 +61,13 @@ class Graph : public cinn::common::Graph { std::vector> groups; struct Group { + Group() = default; + + Group(const Graph* graph) : graph_(graph) {} + + // The graph that group belongs to. + const Graph* graph_ = nullptr; + // distance to last group. int depth{0}; int max_depth{0}; diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 641ba6b52c..c167d7de50 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -43,7 +43,6 @@ using Comparator = Graph::Group::SharedGroupComparator; using Hasher = Graph::Group::SharedGroupHasher; using OpGroupPtr = api::OpGroup; -// using OpGroupPtr = api::OpGroup; using OpGroupList = std::vector; using ConditionFunction = std::function; @@ -1242,7 +1241,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer, this->graph_), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer), EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_sets; }; @@ -1294,7 +1293,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { OpGroupList consumer_groups; consumer_groups.reserve(consumers.size()); for(auto& consumer : consumers) { - consumer_groups.push_back(api::OpGroup(consumer, this->graph_)); + consumer_groups.push_back(api::OpGroup(consumer)); } GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); EnableFusedInputGroups(&fuse_ctx); @@ -1390,7 +1389,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { void HorizontalFuse(const GroupList& consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group - auto fused_group = std::make_shared(); + auto fused_group = std::make_shared(graph_); // As recompute exist which may case sub-group used by more than one time. std::vector repeat_sub_groups; std::unordered_set sub_group_set; @@ -1632,7 +1631,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.push_back(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer, this->graph_), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer), EnableFuse); TagVerticalGroups(&fuse_ctx); return tagged_sets; }; @@ -1666,7 +1665,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); for (auto& consumer : fusionable_consumers) { - auto fused_group = std::make_shared(); + auto fused_group = std::make_shared(graph_); // update depth using consumer depth. fused_group->max_depth = std::max(producer->max_depth, consumer->max_depth); fused_group->min_depth = std::min(producer->min_depth, consumer->min_depth); @@ -1879,7 +1878,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::make_pair(first, second)); }; - GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer, this->graph_), EnableFuse); + GraphGroupLightwareFusePassCtx fuse_ctx(this, api::OpGroup(producer), EnableFuse); TagRecomputeGroups(&fuse_ctx); return tagged_sets; }; @@ -2162,7 +2161,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // init the postion of groups in fusion groups. for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto group = fusion_groups_[idx]; - auto belong_group = std::make_shared(); + auto belong_group = std::make_shared(graph_); // copy from group. belong_group->max_depth = group->depth; belong_group->min_depth = group->depth; diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 759ca24d25..9e8abb812d 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -48,7 +48,7 @@ class OpFusionPassHelper : public FusionHelperBase { auto node = graph_node->safe_as(); if (node) { nodes_.push_back(node); - auto group = std::make_shared(); + auto group = std::make_shared(graph); // init group group->nodes.push_back(node); group->nodes_set.insert(node); From 0a9b7b2e3e844fdfac5b228a7e384e2902731c4f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 27 Jun 2023 15:28:54 +0000 Subject: [PATCH 45/66] add shape.h file --- cinn/api/op_interface.h | 55 ---------------------------------- cinn/api/shape.h | 65 +++++++++++++++++++++++++++++++++++++++++ cinn/api/tensor_node.h | 39 +------------------------ 3 files changed, 66 insertions(+), 93 deletions(-) delete mode 100644 cinn/api/op_interface.h create mode 100644 cinn/api/shape.h diff --git a/cinn/api/op_interface.h b/cinn/api/op_interface.h deleted file mode 100644 index b6c2abb69c..0000000000 --- a/cinn/api/op_interface.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "cinn/api/tensor_interface.h" -#include "cinn/utils/type_defs.h" -#include "cinn/hlir/framework/op.h" - -namespace cinn { -namespace api { - -using OpPatternKind = cinn::hlir::framework::OpPatternKind; -using Attribute = cinn::utils::Attribute; - -class OpInterface { - public: - virtual OpPatternKind kind () = 0; - - virtual size_t InputsSize() const = 0; - virtual TensorInterface Inputs(size_t i) const = 0; - - virtual const TensorInterfaceList& Inputs() = 0; - virtual const TensorInterfaceList& Outputs() = 0; - - template - const T& GetAttr(const std::string& attr_name) const { - return absl::get(GetAttr(attr_name)); - } - - protected: - OpInterface() = default; - OpInterface(const OpInterface&) = delete; - OpInterface(OpInterface&&) = delete; - - virtual const Attribute& GetAttr(const std::string& attr_name) = 0; -}; - -using OpInterfacePtr = std::shared_ptr; - -} // namespace api -} // namespace cinn diff --git a/cinn/api/shape.h b/cinn/api/shape.h new file mode 100644 index 0000000000..7fd9be2f77 --- /dev/null +++ b/cinn/api/shape.h @@ -0,0 +1,65 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/hlir/framework/graph.h" +#include "cinn/utils/type_defs.h" +#include "cinn/hlir/pass/fusion_helper_base.h" +#include "cinn/utils/small_vector.h" + +namespace cinn { +namespace api { + +class Shape final { + public: + explicit Shape(const utils::ShapeType& shape) : shape_(shape.begin(), shape.end()) {} + + Shape(const Shape& other) = delete; + Shape(Shape&& other) = delete; + + Shape& operator=(const Shape& other) = delete; + + bool operator == (const Shape& other) const { + if (shape_.size() != other.shape_.size()) { + return false; + } + return std::equal(shape_.begin(), shape_.end(), other.shape_.begin()); + } + + size_t operator[] (size_t index) const { + return shape_[index]; + } + + size_t at(size_t index) const { + return shape_[index]; + } + + size_t size() const { + return shape_.size(); + } + + // Returns the total number of elements in the shape. + size_t numel() const { + return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies());; + } + + private: + cinn::utils::SmallVector shape_; +}; + +} // namespace api +} // namespace cinn diff --git a/cinn/api/tensor_node.h b/cinn/api/tensor_node.h index 8e315b7e7d..57ab6c5e31 100644 --- a/cinn/api/tensor_node.h +++ b/cinn/api/tensor_node.h @@ -20,6 +20,7 @@ #include "cinn/utils/type_defs.h" #include "cinn/hlir/pass/fusion_helper_base.h" #include "cinn/utils/small_vector.h" +#include "cinn/api/shape.h" namespace cinn { namespace api { @@ -28,44 +29,6 @@ using shape_t = utils::ShapeType; class OpNode; -class Shape final { - public: - explicit Shape(const utils::ShapeType& shape) : shape_(shape.begin(), shape.end()) {} - - Shape(const Shape& other) = delete; - Shape(Shape&& other) = delete; - - Shape& operator=(const Shape& other) = delete; - - bool operator == (const Shape& other) const { - if (shape_.size() != other.shape_.size()) { - return false; - } - return std::equal(shape_.begin(), shape_.end(), other.shape_.begin()); - } - - size_t operator[] (size_t index) const { - return shape_[index]; - } - - size_t at(size_t index) const { - return shape_[index]; - } - - size_t size() const { - return shape_.size(); - } - - // Returns the total number of elements in the shape. - size_t numel() const { - return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies());; - } - - private: - cinn::utils::SmallVector shape_; -}; - - class TensorNode final { public: TensorNode(const hlir::framework::NodeData* node_data, const hlir::framework::Graph* graph) : node_data_(node_data), graph_(graph), consumers_(node_data_->outlinks(), graph_) { From 8e16c502dc80dc7e281c25469ad8e47989ad707b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 28 Jun 2023 04:09:58 +0000 Subject: [PATCH 46/66] revert producer and consumer group from map to set in group --- cinn/api/op_group.h | 6 +- cinn/hlir/framework/graph.h | 15 ++-- cinn/hlir/pass/fusion_merge_pass.cc | 78 ++++++--------------- cinn/hlir/pass/general_fusion_merge_pass.cc | 73 +++++-------------- cinn/hlir/pass/op_fusion_pass.cc | 15 ++-- 5 files changed, 55 insertions(+), 132 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 12b7733b8d..4a086488d8 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -87,7 +87,7 @@ class OpGroup { class OpGroupListIterator { public: - OpGroupListIterator(std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator it) : iter_(it) {} + OpGroupListIterator(std::unordered_set, Hasher, Comparator>::const_iterator it) : iter_(it) {} OpGroupListIterator& operator++() { ++iter_; @@ -109,11 +109,11 @@ class OpGroup { } OpGroup operator*() const{ - return OpGroup(iter_->first); + return OpGroup(*iter_); } private: - std::unordered_map, TensorInterfaceList, Hasher, Comparator>::const_iterator iter_; + std::unordered_set, Hasher, Comparator>::const_iterator iter_; }; class ProducerOpGroupListView { diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 7dca7e42ae..f85075c39a 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -22,7 +22,6 @@ #include #include "cinn/api/op_group_interface.h" -#include "cinn/api/tensor_interface_list.h" #include "cinn/common/graph_utils.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/node.h" @@ -114,7 +113,7 @@ class Graph : public cinn::common::Graph { std::unordered_set, SharedGroupHasher, SharedGroupComparator> CollectConsumerGroups() { std::unordered_set, SharedGroupHasher, SharedGroupComparator> groups; for (const auto& consumer_and_list : consumer_groups_) { - groups.insert(std::dynamic_pointer_cast(consumer_and_list.first)); + groups.insert(std::dynamic_pointer_cast(consumer_and_list)); } return groups; } @@ -145,19 +144,19 @@ class Graph : public cinn::common::Graph { std::string GetFuncName() { return "fn_" + group_id + unique_id; } public: - const std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>& producer_groups() const { + const std::unordered_set, SharedGroupHasher, SharedGroupComparator>& producer_groups() const { return producer_groups_; } - const std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>& consumer_groups() const { + const std::unordered_set, SharedGroupHasher, SharedGroupComparator>& consumer_groups() const { return consumer_groups_; } - std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>* mut_producer_groups() { + std::unordered_set, SharedGroupHasher, SharedGroupComparator>* mut_producer_groups() { return &producer_groups_; } - std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator>* mut_consumer_groups() { + std::unordered_set, SharedGroupHasher, SharedGroupComparator>* mut_consumer_groups() { return &consumer_groups_; } @@ -165,9 +164,9 @@ class Graph : public cinn::common::Graph { private: // input groups - std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator> producer_groups_; + std::unordered_set, SharedGroupHasher, SharedGroupComparator> producer_groups_; // output grous - std::unordered_map, TensorInterfaceList, SharedGroupHasher, SharedGroupComparator> consumer_groups_; + std::unordered_set, SharedGroupHasher, SharedGroupComparator> consumer_groups_; }; std::vector> fusion_groups; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index ffa1eeaf30..356ce6e867 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -64,13 +64,11 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& sub_group : group->fused_sub_groups) { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } - for (const auto& pair : group->producer_groups()) { - const auto& producer = pair.first; - VLOG(3) << " Producer -> " << std::dynamic_pointer_cast(producer)->group_id; + for (const auto& producer : group->producer_groups()) { + VLOG(3) << " Producer -> " << producer->group_id; } - for (const auto& pair : group->consumer_groups()) { - const auto& consumer = pair.first; - VLOG(3) << " Consumer -> " << std::dynamic_pointer_cast(consumer)->group_id; + for (const auto& consumer : group->consumer_groups()) { + VLOG(3) << " Consumer -> " << consumer->group_id; } } return fusion_groups_; @@ -155,8 +153,7 @@ class FusionMergePassHelper : public FusionHelperBase { } bool exist = false; - for (const auto& pair : group->producer_groups()) { - const auto& producer = std::dynamic_pointer_cast(pair.first); + for (const auto& producer : group->producer_groups()) { if (fusion_groups_set.count(producer)) { VLOG(4) << group->group_id << " " << producer->group_id; exist = true; @@ -321,22 +318,14 @@ class FusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (const auto& producer_and_list : consumer->producer_groups()) { - GroupPtr producer = std::dynamic_pointer_cast(producer_and_list.first); - (*fused_group->mut_producer_groups())[producer] += producer_and_list.second; + for (const auto& producer : consumer->producer_groups()) { // update producer's consumer producer->mut_consumer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*producer->mut_consumer_groups())[fused_group] += {}; } // consumer group - for (const auto& gconsumer_and_list : consumer->consumer_groups()) { - GroupPtr gconsumer = std::dynamic_pointer_cast(gconsumer_and_list.first); - (*fused_group->mut_consumer_groups())[gconsumer] += gconsumer_and_list.second; + for (const auto& gconsumer : consumer->consumer_groups()) { // update consumer's producer gconsumer->mut_producer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*gconsumer->mut_producer_groups())[fused_group] += {}; } // belongs group consumer->belong_groups.insert(fused_group); @@ -509,13 +498,9 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer groups - for (const auto& group_and_list : producer->producer_groups()) { - (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; - const auto& group = std::dynamic_pointer_cast(group_and_list.first); + for (const auto& group : producer->producer_groups()) { // update producer's producer's consumer group->mut_consumer_groups()->erase(producer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*group->mut_consumer_groups())[fused_group] += {}; } // sub groups @@ -561,25 +546,17 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group_and_list : consumer->producer_groups()) { - if (group_and_list.first.get() != producer.get()) { - (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; - const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); + for (const auto& group : consumer->producer_groups()) { + if (group.get() != producer.get()) { // update consumer's producer's consumer group->mut_consumer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*group->mut_consumer_groups())[fused_group] += {}; } } // consumer nodes - for (const auto& group_and_list : consumer->consumer_groups()) { - (*fused_group->mut_consumer_groups())[group_and_list.first] += group_and_list.second; - const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); + for (const auto& group : consumer->consumer_groups()) { // update consumer's consumer's producer group->mut_producer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*group->mut_producer_groups())[fused_group] += {}; } // sub group @@ -613,8 +590,7 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (const auto& consumer_and_list : producer->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + for (const auto& consumer : producer->consumer_groups()) { // if consumer is in fusionable. if (fusionable_consumers.count(consumer)) { if (consumer->input_nodes.count(node)) { @@ -640,16 +616,12 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer_and_list : producer->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + for (const auto& consumer : producer->consumer_groups()) { if (fusionable_consumers.count(consumer)) { continue; } - (*master_fuesd_group->mut_consumer_groups())[consumer_and_list.first] += consumer_and_list.second; // update consumer's producer consumer->mut_producer_groups()->erase(producer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*consumer->mut_producer_groups())[master_fuesd_group] += {}; } } @@ -765,11 +737,10 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups()) { - if (producer_and_list.first.get() == producer_g.get()) { + for (const auto& producer : candidate->producer_groups()) { + if (producer.get() == producer_g.get()) { continue; } - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); if (consumers.count(producer)) { return true; } @@ -793,11 +764,10 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer_and_list : candidate->producer_groups()) { - if (producer_and_list.first.get() == producer_g.get()) { + for (auto& producer : candidate->producer_groups()) { + if (producer.get() == producer_g.get()) { continue; } - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); if (producer->min_depth > check_upper_depth) { continue; } @@ -907,21 +877,15 @@ class FusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_set producers; + std::unordered_set consumers; - for (const auto& producer_and_list : group->producer_groups()) { - const auto& producer = producer_and_list.first; + for (const auto& producer : group->producer_groups()) { CHECK(producer->belong_groups.size()); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producers[*producer->belong_groups.begin()] += {}; } - for (auto& consumer_and_list : group->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + for (auto& consumer : group->consumer_groups()) { CHECK(consumer->belong_groups.size()); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumers[*consumer->belong_groups.begin()] += {}; } CHECK_EQ(group->producer_groups().size(), producers.size()); CHECK_EQ(group->consumer_groups().size(), consumers.size()); diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index c167d7de50..23f2507d89 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -15,7 +15,6 @@ #include #include -#include "cinn/api/op_group_interface.h" #include "cinn/api/op_group.h" #include "cinn/common/is_reachable_predicator.h" #include "cinn/common/macros.h" @@ -1048,12 +1047,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { for (auto& sub_group : group->fused_sub_groups) { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } - for (const auto& pair : group->producer_groups()) { - const auto& producer = std::dynamic_pointer_cast(pair.first); + for (const auto& producer : group->producer_groups()) { VLOG(3) << " Producer -> " << producer->group_id; } - for (const auto& pair : group->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(pair.first); + for (const auto& consumer : group->consumer_groups()) { VLOG(3) << " Consumer -> " << consumer->group_id; } } @@ -1189,8 +1186,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool exist = false; - for (const auto& pair : group->producer_groups()) { - const auto& producer = std::dynamic_pointer_cast(pair.first); + for (const auto& producer : group->producer_groups()) { if (fusion_groups_set.count(producer)) { VLOG(4) << group->group_id << " " << producer->group_id; exist = true; @@ -1458,22 +1454,14 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (const auto& producer_and_list : consumer->producer_groups()) { - GroupPtr producer = std::dynamic_pointer_cast(producer_and_list.first); - (*fused_group->mut_producer_groups())[producer] += producer_and_list.second; + for (const auto& producer : consumer->producer_groups()) { // update producer's consumer producer->mut_consumer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*producer->mut_consumer_groups())[fused_group] += {}; } // consumer group - for (const auto& gconsumer_and_list : consumer->consumer_groups()) { - GroupPtr gconsumer = std::dynamic_pointer_cast(gconsumer_and_list.first); - (*fused_group->mut_consumer_groups())[gconsumer] += gconsumer_and_list.second; + for (const auto& gconsumer : consumer->consumer_groups()) { // update consumer's producer gconsumer->mut_producer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*gconsumer->mut_producer_groups())[fused_group] += {}; } // belongs group consumer->belong_groups.insert(fused_group); @@ -1703,13 +1691,9 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // producer groups - for (const auto& group_and_list : producer->producer_groups()) { - (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; - const auto& group = std::dynamic_pointer_cast(group_and_list.first); + for (const auto& group : producer->producer_groups()) { // update producer's producer's consumer group->mut_consumer_groups()->erase(producer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*group->mut_consumer_groups())[fused_group] += {}; } // delete consumer group in producer @@ -1758,25 +1742,17 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group_and_list : consumer->producer_groups()) { - if (group_and_list.first.get() != producer.get()) { - (*fused_group->mut_producer_groups())[group_and_list.first] += group_and_list.second; - const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); + for (const auto& group : consumer->producer_groups()) { + if (group.get() != producer.get()) { // update consumer's producer's consumer group->mut_consumer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*group->mut_consumer_groups())[fused_group] += {}; } } // consumer nodes - for (const auto& group_and_list : consumer->consumer_groups()) { - (*fused_group->mut_consumer_groups())[group_and_list.first] += group_and_list.second; - const GroupPtr& group = std::dynamic_pointer_cast(group_and_list.first); + for (const auto& group : consumer->consumer_groups()) { // update consumer's consumer's producer group->mut_producer_groups()->erase(consumer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*group->mut_producer_groups())[fused_group] += {}; } // sub group @@ -1810,8 +1786,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (const auto& consumer_and_list : producer->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + for (const auto& consumer : producer->consumer_groups()) { // if consumer is in fusionable. if (fusionable_consumers.count(consumer)) { if (consumer->input_nodes.count(node)) { @@ -1837,16 +1812,12 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer_and_list : producer->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + for (const auto& consumer : producer->consumer_groups()) { if (fusionable_consumers.count(consumer)) { continue; } - (*master_fuesd_group->mut_consumer_groups())[consumer_and_list.first] += consumer_and_list.second; // update consumer's producer consumer->mut_producer_groups()->erase(producer); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*consumer->mut_producer_groups())[master_fuesd_group] += {}; } } @@ -2021,10 +1992,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { auto& candidate = candidates.front(); candidates.pop(); for (const auto& producer_and_list : candidate->producer_groups()) { - if (producer_and_list.first.get() == producer_g.get()) { + if (producer_and_list.get() == producer_g.get()) { continue; } - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + const auto& producer = std::dynamic_pointer_cast(producer_and_list); if (consumers.count(producer)) { return true; } @@ -2049,10 +2020,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { auto& candidate = candidates.front(); candidates.pop(); for (auto& producer_and_list : candidate->producer_groups()) { - if (producer_and_list.first.get() == producer_g.get()) { + if (producer_and_list.get() == producer_g.get()) { continue; } - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + const auto& producer = std::dynamic_pointer_cast(producer_and_list); if (producer->min_depth > check_upper_depth) { continue; } @@ -2182,21 +2153,15 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // update producer and consumer. for (auto& group : fusion_groups_) { - std::unordered_map producers; - std::unordered_map consumers; + std::unordered_set producers; + std::unordered_set consumers; - for (auto& producer_and_list : group->producer_groups()) { - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); + for (auto& producer : group->producer_groups()) { CHECK(producer->belong_groups.size()); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - producers[*producer->belong_groups.begin()] += {}; } - for (auto& consumer_and_list : group->consumer_groups()) { - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); + for (auto& consumer : group->consumer_groups()) { CHECK(consumer->belong_groups.size()); - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - consumers[*consumer->belong_groups.begin()] += {}; } CHECK_EQ(group->producer_groups().size(), producers.size()); CHECK_EQ(group->consumer_groups().size(), consumers.size()); diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 9e8abb812d..a7969aaa23 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -100,18 +100,13 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& consumer : fusion_groups) { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*consumer->mut_producer_groups())[producer] += {}; - // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. - (*producer->mut_consumer_groups())[consumer] += {}; } } // init group depth. for (auto& group : fusion_groups) { - for (const auto& consumer_and_list : group->consumer_groups()) { + for (const auto& consumer : group->consumer_groups()) { // update depth. - const auto& consumer = std::dynamic_pointer_cast(consumer_and_list.first); group->depth = std::max(group->depth, consumer->depth + 1); } } @@ -351,11 +346,11 @@ void OpFusionPassInternal(Graph* graph) { for (auto& group : graph->fusion_groups) { VLOG(3) << "Group Id : " << group->group_id; - for (const auto& producer_and_list : group->producer_groups()) { - VLOG(3) << " producer group -> " << std::dynamic_pointer_cast(producer_and_list.first)->group_id; + for (const auto& producer : group->producer_groups()) { + VLOG(3) << " producer group -> " << producer->group_id; } - for (const auto& consumer_and_list : group->consumer_groups()) { - VLOG(3) << " consumer group -> " << std::dynamic_pointer_cast(consumer_and_list.first)->group_id; + for (const auto& consumer : group->consumer_groups()) { + VLOG(3) << " consumer group -> " << consumer->group_id; } } VLOG(3) << "OpFusionPass Finish...!"; From 0274349158395bf81183a67d76c3c8c690cc5812 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 28 Jun 2023 08:17:28 +0000 Subject: [PATCH 47/66] replace ops with WalkOpNodes in op_graph --- cinn/api/op_group.h | 74 ++++++------------ cinn/api/op_node.h | 4 +- cinn/hlir/framework/graph.h | 14 ++++ cinn/hlir/pass/general_fusion_merge_pass.cc | 84 ++++++++++----------- 4 files changed, 78 insertions(+), 98 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 4a086488d8..40df305eac 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -34,57 +34,6 @@ class OpGroup { OpGroup(const OpGroup& other) = default; - class OpNodeListView { - public: - explicit OpNodeListView(std::vector op_nodes, const cinn::hlir::framework::Graph* graph) : op_nodes_(std::move(op_nodes)), graph_(graph) {} - - OpNodeListView(const OpNodeListView& other) = delete; - OpNodeListView(OpNodeListView&& other) = delete; - - OpNodeListView& operator=(const OpNodeListView& other) = delete; - - class iterator { - public: - iterator(std::vector::const_iterator it, const hlir::framework::Graph* graph) : iter_(it), graph_(graph) {} - - iterator& operator++() { - ++iter_; - return *this; - } - - iterator operator++(int) { - iterator tmp = *this; - ++iter_; - return tmp; - } - - bool operator==(const iterator& other) const { - return iter_ == other.iter_; - } - - bool operator!=(const iterator& other) const { - return !(*this == other); - } - - OpNode operator*() const { - return OpNode(*iter_, graph_); - } - - private: - std::vector::const_iterator iter_; - const hlir::framework::Graph* graph_; - }; - - size_t size() const { return op_nodes_.size(); } - - iterator begin() const { return iterator(op_nodes_.begin(), graph_); } - - iterator end() const { return iterator(op_nodes_.end(), graph_); } - private: - const std::vector op_nodes_; - const cinn::hlir::framework::Graph* graph_; - }; - class OpGroupListIterator { public: OpGroupListIterator(std::unordered_set, Hasher, Comparator>::const_iterator it) : iter_(it) {} @@ -165,8 +114,27 @@ class OpGroup { hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); } - OpNodeListView ops() const { - return OpNodeListView(group_.lock()->CollectNodes(), group_.lock()->graph_); + // The WalkOpNodes function is used to traverse the op_nodes in the group and execute + // the VisitOpNode function for each OpNode. This function is equivalent to for loop + // for op_nodes in graph. + // + // In order to avoid unnecessary memory copies, we use WalkOpNodes function instead of + // providing a function to get all op_nodes directly. + // + // Example: Get the all Reduction op_nodes in the group. + // OpGroup group = ...; + // std::set reduce_ op_set; + // // The lambda funtion of VisitOpNode to get reduction op_nodes. + // auto get_reduce_op = [&reduce_op_set](const api::OpNode& op){ + // if (op.kind() == OpPatternKind::kReduction) { + // reduce_op_set.insert(op); + // } + // }; + // group.WalkOpNodes(get_reduce_op); + void WalkOpNodes(const std::function& VisitOpNode) const { + group_.lock()->WalkNodes([&](const hlir::framework::Node* node){ + VisitOpNode(OpNode(node, group_.lock()->graph_)); + }); } ProducerOpGroupListView producers() const { diff --git a/cinn/api/op_node.h b/cinn/api/op_node.h index 7e254ca678..14ee82293c 100644 --- a/cinn/api/op_node.h +++ b/cinn/api/op_node.h @@ -28,7 +28,9 @@ using Attribute = cinn::utils::Attribute; class OpNode { public: - OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) : node_(node), graph_(graph), input_tensors_(node->inlinks_in_order(), graph_), output_tensors_(node->outlinks_in_order(), graph_) {} + OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) : node_(node), graph_(graph), input_tensors_(node->inlinks_in_order(), graph_), output_tensors_(node->outlinks_in_order(), graph_) { + VLOG(1) << "[OpNode] node: " << node->id(); + } OpNode(const OpNode& other) : node_(other.node_), graph_(other.graph_), input_tensors_(node_->inlinks_in_order(), graph_), output_tensors_(node_->outlinks_in_order(), graph_) {} diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index f85075c39a..56f7846de3 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -130,6 +130,20 @@ class Graph : public cinn::common::Graph { } } + void WalkNodes(const std::function& VisitNode) const { + if (fused_sub_groups.size()) { + for (auto& group : fused_sub_groups) { + for (const auto* node : group->nodes) { + VisitNode(node); + } + } + } else { + for (const auto* node : nodes) { + VisitNode(node); + } + } + } + std::unordered_set NodeSet() { std::unordered_set node_set; for (auto node : CollectNodes()) { diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 23f2507d89..f51fb70a54 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -316,13 +316,13 @@ bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, } static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { - const auto& ops = op_group.ops(); std::unordered_set ops_set; - for (const auto& op : ops) { - ops_set.insert(op); - } + op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node){ + ops_set.insert(op_node); + }); + std::unordered_set input_ops; - for (const auto& op : ops) { + op_group.WalkOpNodes([&](const api::OpNode& op){ const auto& input_tensors = op.inputs(); for (size_t i = 0; i < input_tensors.size(); ++i) { if(input_tensors[i].HasProducer()) { @@ -332,18 +332,17 @@ static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { } } } - } + }); return input_ops; } static std::unordered_set GetOutputOps(const OpGroupPtr& op_group) { - auto ops = op_group.ops(); std::unordered_set ops_set; - for (const auto& op : ops) { - ops_set.insert(op); - } + op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node){ + ops_set.insert(op_node); + }); std::unordered_set output_ops; - for (const auto& op : ops) { + op_group.WalkOpNodes([&](const api::OpNode& op){ const auto& output_tensors = op.outputs(); for (size_t i = 0; i < output_tensors.size(); ++i) { const auto& consumers = output_tensors[i].consumers(); @@ -354,7 +353,7 @@ static std::unordered_set GetOutputOps(const OpGroupPtr& op_group) } } } - } + }); return output_ops; } @@ -399,7 +398,7 @@ bool WithoutLastDimInReduce(const api::Shape& inshape, const std::vector& a static int GetSharedSize(const api::OpNode& op_node) { const auto& producers = op_node.inputs(); CHECK_GT(producers.size(), 0); - const auto& inshape =producers[0].shape(); + const auto& inshape = producers[0].shape(); const auto& axes = op_node.GetAttr>("dim"); if (WithoutLastDimInReduce(inshape, axes)) { int lane = 1; @@ -444,21 +443,20 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon return false; } std::unique_ptr reducer_0 = nullptr; - for (const auto& op : first.ops()) { - if (op.kind() == OpPatternKind::kReduction) { + first.WalkOpNodes([&](const api::OpNode& op){ + if (!reducer_0 && op.kind() == OpPatternKind::kReduction) { reducer_0.reset(new api::OpNode(op)); - break; } - } + }); CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); std::unique_ptr reducer_1 = nullptr; - for (const auto& op : second.ops()) { - if (op.kind() == OpPatternKind::kReduction) { + second.WalkOpNodes([&](const api::OpNode& op){ + if (!reducer_1 && op.kind() == OpPatternKind::kReduction) { reducer_1.reset(new api::OpNode(op)); - break; } - } + }); + CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id(); // check reduce has same input shape and output shape @@ -490,11 +488,11 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { - for (const auto& node : fusion_group.ops()) { - if (node.kind() == OpPatternKind::kReduction) { - shared_size += GetSharedSize(*reducer_0); + fusion_group.WalkOpNodes([&](const api::OpNode& op){ + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); } - } + }); } #define MAX_AVAILABLE_SHREAD 32 * 1024 @@ -510,11 +508,11 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { - for (const auto& node : fusion_group.ops()) { - if (node.kind() == OpPatternKind::kReduction) { - shared_size += GetSharedSize(*reducer_0); + fusion_group.WalkOpNodes([&](const api::OpNode& op){ + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); } - } + }); } #define MAX_AVAILABLE_SHREAD 32 * 1024 @@ -574,14 +572,13 @@ struct HorizontalFuseUtil { } static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - const auto ops = op_group.ops(); - for (auto iter = ops.begin(); iter != ops.end(); ++iter) { - api::OpNode node = *iter; - if (node.kind() == OpPatternKind::kReduction) { - return node; + std::vector master_nodes; + op_group.WalkOpNodes([&](const api::OpNode& op){ + if (master_nodes.empty() || op.kind() == OpPatternKind::kReduction) { + master_nodes.push_back(op); } - } - return *ops.begin(); + }); + return master_nodes.back(); } static bool IsSameSize(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { @@ -613,18 +610,17 @@ struct HorizontalFuseUtil { size_t size_ele = GetMasterNode(ctx, *ele_group).outputs()[0].shape().numel(); - const auto ops = reduce_group->ops(); - for (auto iter = ops.begin(); iter!= ops.end(); ++iter) { - api::OpNode node = *iter; - if (node.kind() == OpPatternKind::kReduction) { - size_t size_master = node.outputs()[0].shape().numel(); + bool can_fuse = false; + reduce_group->WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + size_t size_master = op.outputs()[0].shape().numel(); if (size_ele == size_master) { - return true; + can_fuse = true; } } - } + }); - return false; + return can_fuse; } static bool ReduceFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { From 692551714b31534d79c9c7815308f839b47c058f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 28 Jun 2023 09:02:36 +0000 Subject: [PATCH 48/66] delete unused file --- cinn/api/README.md | 0 cinn/api/fuse_pass_context.h | 36 ------------------ cinn/api/op_group_interface.h | 57 ----------------------------- cinn/api/tensor_interface.h | 38 ------------------- cinn/api/tensor_interface_list.h | 43 ---------------------- cinn/hlir/framework/graph.h | 4 -- cinn/hlir/pass/fusion_merge_pass.cc | 3 -- 7 files changed, 181 deletions(-) create mode 100644 cinn/api/README.md delete mode 100644 cinn/api/fuse_pass_context.h delete mode 100644 cinn/api/op_group_interface.h delete mode 100644 cinn/api/tensor_interface.h delete mode 100644 cinn/api/tensor_interface_list.h diff --git a/cinn/api/README.md b/cinn/api/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cinn/api/fuse_pass_context.h b/cinn/api/fuse_pass_context.h deleted file mode 100644 index ed702fabec..0000000000 --- a/cinn/api/fuse_pass_context.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "cinn/api/op_group_interface.h" - -namespace cinn { -namespace api { - -class FusePassContext { - public: - FusePassContext() = default; - - std::shared_ptr PickGroup(); - - void EnableRecompute(const OpGroupInterface& op_group); - - void EnableVerticalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group); - - void EnableHorizontalFuse(const OpGroupInterface& first_op_group, const OpGroupInterface& second_op_group); -}; - -} // namespace api -} // namespace cinn diff --git a/cinn/api/op_group_interface.h b/cinn/api/op_group_interface.h deleted file mode 100644 index 49c3cd5272..0000000000 --- a/cinn/api/op_group_interface.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "cinn/api/tensor_interface.h" -#include "cinn/api/tensor_interface_list.h" -#include "cinn/hlir/framework/op.h" - -namespace cinn { -namespace api { - -class OpGroupInterface { - public: - virtual hlir::framework::OpPatternKind kind() const = 0; - - // virtual const TensorInterfaceList& input_tensors() const = 0; - - // virtual const TensorInterfaceList& output_tensors() const = 0; - - // virtual const std::unordered_set> producers() const = 0; - - // virtual const std::unordered_set> consumers() const = 0; - - virtual const std::unordered_map, TensorInterfaceList>& producer_groups() const = 0; - - virtual const std::unordered_map, TensorInterfaceList>& consumer_groups() const = 0; - - const std::unordered_map, TensorInterfaceList>& producer2inputs() const { - return producer_groups(); - } - - const std::unordered_map, TensorInterfaceList>& consumer2outputs() const { - return consumer_groups(); - } - - protected: - OpGroupInterface() = default; -}; - -} // namespace api -} // namespace cinn diff --git a/cinn/api/tensor_interface.h b/cinn/api/tensor_interface.h deleted file mode 100644 index 2bca4a62df..0000000000 --- a/cinn/api/tensor_interface.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -namespace cinn { -namespace api { - -class ShapeInterface; - -class TensorInterface { - public: - // Get the shape of tensor. - virtual const ShapeInterface& shape() const = 0; - - protected: - TensorInterface() = default; - TensorInterface(const TensorInterface&) = delete; - TensorInterface(TensorInterface&&) = delete; -}; - -using TensorInterfacePtr = std::shared_ptr; - -} // namespace api -} // namespace cinn diff --git a/cinn/api/tensor_interface_list.h b/cinn/api/tensor_interface_list.h deleted file mode 100644 index 0a0a2121f3..0000000000 --- a/cinn/api/tensor_interface_list.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "cinn/api/tensor_interface.h" -#include "cinn/utils/small_vector.h" - -namespace cinn { -namespace api { - -class TensorInterfaceList : public cinn::utils::SmallVector { - public: - using cinn::utils::SmallVector::SmallVector; - - TensorInterfaceList& operator+=(const TensorInterfaceList& other) { - std::unordered_set tensor_set(this->begin(), this->end()); - for (const auto& tensor_if : other) { - if (tensor_set.find(tensor_if) == tensor_set.end()) { - this->push_back(tensor_if); - tensor_set.insert(tensor_if); - } - } - return *this; - } -}; - -} // namespace api -} // namespace cinn diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 56f7846de3..55caf788ff 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -21,7 +21,6 @@ #include #include -#include "cinn/api/op_group_interface.h" #include "cinn/common/graph_utils.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/node.h" @@ -30,9 +29,6 @@ namespace cinn { namespace hlir { namespace framework { -using OpGroupInterface = cinn::api::OpGroupInterface; -using TensorInterfaceList = cinn::api::TensorInterfaceList; - /** * \brief Symbolic computation graph. * This is the intermediate representation for optimization pass. diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 356ce6e867..b9c06fbe1b 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -35,9 +35,6 @@ using GroupList = std::vector; using Comparator = Graph::Group::SharedGroupComparator; using Hasher = Graph::Group::SharedGroupHasher; -using OpGroupPtr = std::shared_ptr; -using OpGroupList = std::vector; - using ConditionFunction = std::function; // Op Fusion Pass which performs Ops fusion, Ops are fused From ad177015c0139b26143e51ee7e237172248726a3 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 28 Jun 2023 09:13:03 +0000 Subject: [PATCH 49/66] fix some bugs --- cinn/hlir/framework/graph.cc | 1 + cinn/hlir/pass/fusion_merge_pass.cc | 34 +++++---- cinn/hlir/pass/general_fusion_merge_pass.cc | 81 ++++++++++++++------- 3 files changed, 75 insertions(+), 41 deletions(-) diff --git a/cinn/hlir/framework/graph.cc b/cinn/hlir/framework/graph.cc index 2d79f10781..02be3af1bb 100644 --- a/cinn/hlir/framework/graph.cc +++ b/cinn/hlir/framework/graph.cc @@ -284,6 +284,7 @@ void Graph::VisualizeGroupedGraph(const std::vector>& origin_ { // create base Directory viz_path_ = utils::StringFormat("%s/fusion_groups_%d/", FLAGS_cinn_fusion_groups_graphviz_dir.c_str(), viz_id); + VLOG(1) << "DEBUG Visualize directory id = " << viz_id; if (!MakeDirectory(viz_path_, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { LOG_IF(WARNING, viz_id == 0) << "Failed to make directory: \"" << viz_path_ << "\", the CINN subgraph's fusion group information will not print."; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 957447942d..4ca2128a32 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -76,10 +76,10 @@ class FusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; - while (DoHorizontalFusion()) { - } - while (DoVerticalFusion(/* recompute=*/false)) { - } + // while (DoHorizontalFusion()) { + // } + // while (DoVerticalFusion(/* recompute=*/false)) { + // } while (DoVerticalFusion(/* recompute=*/true)) { } } @@ -109,7 +109,7 @@ class FusionMergePassHelper : public FusionHelperBase { bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -401,6 +401,7 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set fuse_consumers_unsafe; std::unordered_set fuse_consumers; + VLOG(1) << "DEBUG VerticalFusion, begin check : " << producer->group_id; for (const auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse @@ -418,12 +419,12 @@ class FusionMergePassHelper : public FusionHelperBase { fuse_consumers_unsafe.insert(consumer); if (IsDependencySimplify(producer, consumer, consumers)) { - VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; + VLOG(1) << "DEBUG consumer " << consumer->group_id << " has loop"; continue; } if (IsDependency(producer, consumer, consumers)) { - VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; + VLOG(1) << "DEBUG consumer " << consumer->group_id << " has loop"; continue; } @@ -434,8 +435,10 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); if (fuse_consumers.size() == 0) { + VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " << producer->group_id; return false; } + VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << fuse_consumers_unsafe.size(); // if can_fuse_consumers == consumers // if producer op kind == kElementwise // if use recompute @@ -444,21 +447,22 @@ class FusionMergePassHelper : public FusionHelperBase { if (!recompute) { return false; } else { + VLOG(1) << "DEBUG begin recompute fuse group " << producer->group_id; RecomputeEleGraph(producer, fuse_consumers_unsafe); VerticalFuse(producer, fuse_consumers_unsafe); return true; } } - if (fuse_consumers.size()) { - SelectConsumerToFuse(producer, fuse_consumers); - } + // if (fuse_consumers.size()) { + // SelectConsumerToFuse(producer, fuse_consumers); + // } - // if fusionable consumers exist - if (fuse_consumers.size()) { - VerticalFuse(producer, fuse_consumers); - return true; - } + // // if fusionable consumers exist + // if (fuse_consumers.size()) { + // VerticalFuse(producer, fuse_consumers); + // return true; + // } return false; } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 07849d5277..052e123909 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -69,6 +69,8 @@ class FuseHelper { virtual bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0; + protected: FuseHelper() = default; }; @@ -98,11 +100,31 @@ class GraphGroupFuseHelper final : public FuseHelper { bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const override; + bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const override { + return IsReachableInDag(lhs, rhs) || IsReachableInDag(rhs, lhs); + } + bool DetectCycleIfFuse(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const override { return ReachableIfDirectEdgeIgnored(lhs, rhs) || ReachableIfDirectEdgeIgnored(rhs, lhs); } private: + bool IsReachableInDag(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { + const auto& MinDepth4Node = [&](OpGroupPtr node) { + return std::dynamic_pointer_cast(node)->min_depth; + }; + const auto& MaxDepth4Node = [&](OpGroupPtr node) { + return std::dynamic_pointer_cast(node)->max_depth; + }; + const auto& VisitNextNodes = [&](OpGroupPtr node, const std::function& Visit) { + for (const auto& pair : node->producer2inputs()) { + Visit(pair.first); + } + }; + common::IsReachablePredicator is_reachable(MinDepth4Node, MaxDepth4Node, VisitNextNodes); + return is_reachable(consumer, producer, [](OpGroupPtr) {}); + } + bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { return std::dynamic_pointer_cast(node)->min_depth; @@ -395,7 +417,7 @@ class DefaultInputFusePass final : public InputFusePass { const auto& src = consumers.at(i); for (int j = i + 1; j < consumers.size(); ++j) { const auto& dst = consumers.at(j); - if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { + if (ctx->fuse_helper().IsReachable(src, dst)) { continue; } if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, src, dst)) { @@ -459,7 +481,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { const auto& src = consumers.at(i); for (int j = i + 1; j < consumers.size(); ++j) { const auto& dst = consumers.at(j); - if (ctx->fuse_helper().DetectCycleIfFuse(src, dst)) { + if (ctx->fuse_helper().IsReachable(src, dst)) { continue; } if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, src, dst)) { @@ -506,13 +528,24 @@ class DefaultVerticalFusePass final : public VerticalFusePass { return; } + std::vector candidates; + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + break; + } + candidates.push_back(consumer); + } + if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + return; + } + for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); if (!DetectFusabilityByKind(ctx, producer, consumer)) { continue; } if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { - VLOG(4) << "Can't fuse because detect cycle"; continue; } ctx->EnableFuse(producer, consumer); @@ -619,7 +652,6 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { int Benefit() const override { return 100; } void operator()(LightwareFusePassCtx* ctx) const override { - VLOG(1) << "DefaultRecomputeFusePass"; const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; @@ -628,18 +660,22 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } return consumers; }(); - if (consumers.size() <= 1) { - return; - } + // Borrows unsafe_candidates and candidates concept from origin fusion_merge_pass + std::vector unsafe_candidates; std::vector candidates; for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); if (!DetectFusabilityByKind(ctx, producer, consumer)) { continue; } + unsafe_candidates.push_back(consumer); + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + continue; + } candidates.push_back(consumer); } - if (candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + + if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); } @@ -781,10 +817,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; - while (DoGeneralHorizontalFusion()) { - } - while (DoGeneralVerticalFusion()) { - } + // while (DoGeneralHorizontalFusion()) { + // } + // while (DoGeneralVerticalFusion()) { + // } while (DoGeneralRecomputeAndVerticalFusion()) { } } @@ -863,16 +899,17 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; } // do horizontal fusion. - updated |= GeneralRecomputeFuse(producer); - if (!updated) { - updated |= GeneralVerticalFuse(producer); - } + bool recompute_success = GeneralRecomputeFuse(producer); + // updated |= recompute_success; + // if (!recompute_success) { + // updated |= GeneralVerticalFuse(producer); + // } } // fuse input consumers @@ -1192,8 +1229,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // TODO: Do not add any TensorInterface into any TensorInterfaceList in this file which will be deprecated. (*gconsumer->mut_producer_groups())[fused_group] += {}; } - fused_group->mut_consumer_groups()->erase(fused_group); - fused_group->mut_producer_groups()->erase(fused_group); // belongs group consumer->belong_groups.insert(fused_group); @@ -1431,9 +1466,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { (*group->mut_consumer_groups())[fused_group] += {}; } - // delete consumer group in producer - producer->mut_consumer_groups()->erase(consumer); - // sub groups if (producer->fused_sub_groups.size()) { for (auto& group : producer->fused_sub_groups) { @@ -1579,10 +1611,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { - const auto& producer = ctx->PickOpGroup(); - if (producer->consumer2outputs().size() <= 1) { - return; - } const auto& fuse_passes = GetRecomputeFusePasses(); for (const auto& fuse_pass : fuse_passes) { (*fuse_pass)(ctx); @@ -1617,6 +1645,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool update = false; auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size() > 0) { + CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; RecomputeFuse(producer, consumer_groups); update = true; } From 5406478712385ba6b984a39766076171d9e6282c Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 28 Jun 2023 09:46:37 +0000 Subject: [PATCH 50/66] delele unused code --- cinn/hlir/pass/general_fusion_merge_pass.cc | 287 -------------------- 1 file changed, 287 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index f51fb70a54..d635f41fea 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -1027,8 +1027,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { public: GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph), graph_(graph) { fusion_groups_ = graph->fusion_groups; - // init fusion relation. - InitFusionRelation(); // init input to consumers. InitInputToConsumers(); // init fusion group index. @@ -1084,31 +1082,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return updated; } - bool DoVerticalFusion(bool recompute) { - VLOG(3) << "DoVerticalFusion...!"; - bool updated = false; - for (int idx = 0; idx < fusion_groups_.size(); ++idx) { - auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; - // if producer is sub group. - if (producer->belong_groups.size()) { - continue; - } - // do horizontal fusion. - if (!recompute) { - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); - } - updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); - } - // fuse input consumers - updated |= FuseInputToConsumers(); - - if (updated) { - UpdateFusionGroup(); - } - return updated; - } - bool DoGeneralVerticalFusion() { VLOG(3) << "DoGeneralVerticalFusion...!"; bool updated = false; @@ -1308,76 +1281,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return true; } - bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { - VLOG(3) << "HorizontalFusion...!"; - if (consumers.size() <= 1) { - return false; - } - - std::unordered_set candidates; - for (const auto& consumer : consumers) { - // relation - auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; - // check horizontal relation exist - if (!relation.horizontal_relation.size()) { - continue; - } - candidates.insert(consumer); - } - - std::vector fusionable_consumers; - for (auto& candidate : candidates) { - // check dependency - if (IsDependencySimplify(producer, candidate, candidates)) { - VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id << ", As it depency others!"; - continue; - } - - if (IsDependency(producer, candidate, candidates)) { - VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id << ", As it depency others!"; - continue; - } - - if (!fusionable_consumers.size()) { - fusionable_consumers.push_back({candidate}); - continue; - } - - // check each fusionable groups - bool fusionable = false; - auto& relation = fusion_relation_map_[candidate->op_pattern_kind]; - for (auto& groups : fusionable_consumers) { - auto& last = groups.back(); - if (!relation.horizontal_relation.count(last->op_pattern_kind)) { - continue; - } - - if (!relation.horizontal_relation[last->op_pattern_kind](this, candidate, last)) { - continue; - } - - groups.push_back(candidate); - fusionable = true; - break; - } - - // if can't fuse to othors Groups, new Groups. - if (!fusionable) { - fusionable_consumers.push_back({candidate}); - } - } - - bool updated = false; - for (auto& groups : fusionable_consumers) { - if (groups.size() > 1) { - updated = true; - HorizontalFuse(groups); - } - } - - return updated; - } - void HorizontalFuse(const GroupList& consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group @@ -1515,78 +1418,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { - VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); - auto& relation = fusion_relation_map_[producer->op_pattern_kind]; - // if producer can't fuse others - if (!relation.vertical_relation.size()) { - return false; - } - - std::unordered_set fuse_consumers_unsafe; - std::unordered_set fuse_consumers; - for (const auto& consumer : consumers) { - VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; - // if can't fuse - if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; - continue; - } - - // if condition function is false - if (!relation.vertical_relation[consumer->op_pattern_kind](this, producer, consumer)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; - continue; - } - - fuse_consumers_unsafe.insert(consumer); - - if (IsDependencySimplify(producer, consumer, consumers)) { - VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; - continue; - } - - if (IsDependency(producer, consumer, consumers)) { - VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; - continue; - } - - fuse_consumers.insert(consumer); - } - - VLOG(3) << "VerticalFusion, Number of fuse Consumers : " << fuse_consumers.size(); - VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); - - if (fuse_consumers.size() == 0) { - return false; - } - // if can_fuse_consumers == consumers - // if producer op kind == kElementwise - // if use recompute - if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() && - producer->op_pattern_kind == framework::kElementWise) { - if (!recompute) { - return false; - } else { - RecomputeEleGraph(producer, fuse_consumers_unsafe); - VerticalFuse(producer, fuse_consumers_unsafe); - return true; - } - } - - if (fuse_consumers.size()) { - SelectConsumerToFuse(producer, fuse_consumers); - } - - // if fusionable consumers exist - if (fuse_consumers.size()) { - VerticalFuse(producer, fuse_consumers); - return true; - } - - return false; - } - std::vector> RawVerticalFusePasses() const { return FusionPassMap::Instance().GetLightwareFusePassesByMode("VerticalFuse"); } @@ -2035,28 +1866,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return false; } - bool FuseInputToConsumers() { - VLOG(3) << "FuseInputToConsumers...!"; - auto updated = false; - UpdateInputToConsumers(); - GroupPtr producer(nullptr); - for (auto& input_consumers : input_to_consumers_) { - // if group set size == 1. - if (input_consumers.second.size() == 1) { - continue; - } - // do horizontal fusion. - auto st = HorizontalFusion(producer, input_consumers.second); - if (st) { - // fused consumers, update - UpdateInputToConsumers(); - } - updated |= st; - } - - return updated; - } - bool GeneralInputFuse() { VLOG(3) << "GeneralInputFuse...!"; auto updated = false; @@ -2166,106 +1975,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - void InitFusionRelation() { - VLOG(3) << "InitFusionRelation...!"; - // kElementWise - { - auto& relation = fusion_relation_map_[OpPatternKind::kElementWise]; - // horizontal - relation.horizontal_relation = {{framework::kElementWise, is_same_size}, - // element-wise and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // element-wise and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // element-wise and reduce op must be horizontal relation. - {OpPatternKind::kReduction, honrizontal_elementwise_fuse_reduce}}; - // vertical - relation.vertical_relation = {{OpPatternKind::kElementWise, is_same_size}, - // element-wise and broadcast can be vertical/horizontal relation. - {OpPatternKind::kBroadcast, elementwise_fuse_broadcast}, - // element-wise and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // element-wise and reduce can be vertical/horizontal relation. - {OpPatternKind::kReduction, elementwise_fuse_reduce}}; - } - // kBroadcast - { - auto& relation = fusion_relation_map_[OpPatternKind::kBroadcast]; - // horizontal - relation.horizontal_relation = {// broadcast and element-wise op must be horizontal relation. - {framework::kElementWise, is_same_size}, - // broadcast and broadcast op must be horizontal relation. - {framework::kBroadcast, is_same_size}, - // broadcast and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // broadcast and reduce op must be horizontal relation. - {OpPatternKind::kReduction, is_same_size}}; - // vertical - relation.vertical_relation = {// broadcast and element-wise op must be vertical relation. - {OpPatternKind::kElementWise, is_same_size}, - // broadcast and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // broadcast and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // broadcast and reduce must be vertical relation. - {OpPatternKind::kReduction, broadcast_fuse_reduce}}; - } - // kInjective - { - auto& relation = fusion_relation_map_[OpPatternKind::kInjective]; - // horizontal - relation.horizontal_relation = {// injective and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, is_same_size}, - // injective and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // injective and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // injective and reduce must be horizontal relation. - {OpPatternKind::kReduction, is_same_size}}; - // vertical - relation.vertical_relation = {// injective and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, is_same_size}, - // injective and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // injective and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // injective and reduce can be horizontal/vertical relation. - {OpPatternKind::kReduction, injective_horizontal_with_reduce}}; - } - // kReduction - { - auto& relation = fusion_relation_map_[OpPatternKind::kReduction]; - // horizontal - relation.horizontal_relation = {// reduce and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, honrizontal_elementwise_fuse_reduce}, - // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // reduce and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // reduce and reduce must be horizontal relation. - {OpPatternKind::kReduction, reduce_fuse_reduce}}; - // vertical - relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. - {OpPatternKind::kElementWise, reduce_fuse_elementwise}, - // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, - // reduce and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // reduce and reduce must be horizontal relation. - {OpPatternKind::kReduction, reduce_fuse_reduce}}; - } - } - const Graph* graph_; GroupList fusion_groups_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; - - struct Relation { - std::unordered_map vertical_relation; - std::unordered_map horizontal_relation; - }; - std::unordered_map fusion_relation_map_; }; void GeneralFusionMergePassInternal(Graph* graph) { From ca648f9f6c1e7c43daba06b95e836da498508932 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 28 Jun 2023 12:05:54 +0000 Subject: [PATCH 51/66] develop comment --- cinn/hlir/pass/general_fusion_merge_pass.cc | 115 ++++++-------------- 1 file changed, 34 insertions(+), 81 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index d635f41fea..4e29a8a0c9 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -107,14 +107,10 @@ class GraphGroupFuseHelper final : public FuseHelper { private: bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { - const auto& MinDepth4Node = [&](const OpGroupPtr& node) { - return node.GetGroup()->min_depth; - }; - const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { - return node.GetGroup()->max_depth; - }; + const auto& MinDepth4Node = [&](const OpGroupPtr& node) { return node.GetGroup()->min_depth; }; + const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { return node.GetGroup()->max_depth; }; const auto& VisitNextNodes = [&](const OpGroupPtr& node, const std::function& Visit) { - for(const auto& node_producer : node.producers()) { + for (const auto& node_producer : node.producers()) { if (node == consumer && node_producer == producer) { continue; } @@ -136,10 +132,6 @@ class FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; - // User can cache some group info in context by using this function. - // The group info can be any data and need to create by create_fn. - // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; - protected: FusePassCtx() = default; }; @@ -154,8 +146,6 @@ class LightwareFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; - // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; - protected: LightwareFusePassCtx() = default; }; @@ -177,17 +167,9 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } - // absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) { - // if (cache_data_.find(op_group) == cache_data_.end()) { - // cache_data_[op_group] = create_fn(op_group); - // } - // return &cache_data_[op_group]; - // } - const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } private: - // static std::unordered_map cache_data_; const FusionHelperBase* graph_group_fusion_helper_; const OpGroupPtr& group_; const std::function EnableFuse_; @@ -204,7 +186,8 @@ class InputFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; - // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; + // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; protected: InputFusePassCtx() = default; @@ -228,15 +211,7 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } - // absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) { - // if (cache_data_.find(op_group) == cache_data_.end()) { - // cache_data_[op_group] = create_fn(op_group); - // } - // return &cache_data_[op_group]; - // } - private: - // static std::unordered_map cache_data_; const FusionHelperBase* graph_group_fusion_helper_; const OpGroupList& groups_; const std::function EnableFuse_; @@ -245,87 +220,65 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { template bool GraphGroupFuseHelper::AllOutputsSameSize(const OpGroupPtr& first, const OpGroupPtr& second) const { - return is_same_size(&ctx_->graph_group_fusion_helper(), - first.GetGroup(), - second.GetGroup()); + return is_same_size(&ctx_->graph_group_fusion_helper(), first.GetGroup(), second.GetGroup()); } template bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return honrizontal_elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return elementwise_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return elementwise_fuse_broadcast(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::HorizontalWithInjective(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return horizontal_with_injective(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return horizontal_with_injective(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::ElementwiseFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return elementwise_fuse_reduce(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::BroadcastFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return broadcast_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return broadcast_fuse_reduce(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return injective_horizontal_with_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return injective_horizontal_with_reduce(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseElementwise(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return reduce_fuse_elementwise(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return reduce_fuse_elementwise(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseBroadcast(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return reduce_fuse_broadcast(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return reduce_fuse_broadcast(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } template bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const { - return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), - src.GetGroup(), - dst.GetGroup()); + return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { std::unordered_set ops_set; - op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node){ - ops_set.insert(op_node); - }); + op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); std::unordered_set input_ops; - op_group.WalkOpNodes([&](const api::OpNode& op){ + op_group.WalkOpNodes([&](const api::OpNode& op) { const auto& input_tensors = op.inputs(); for (size_t i = 0; i < input_tensors.size(); ++i) { - if(input_tensors[i].HasProducer()) { + if (input_tensors[i].HasProducer()) { api::OpNode producer = input_tensors[i].producer(); if (ops_set.find(producer) == ops_set.end()) { input_ops.insert(producer); @@ -338,11 +291,9 @@ static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { static std::unordered_set GetOutputOps(const OpGroupPtr& op_group) { std::unordered_set ops_set; - op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node){ - ops_set.insert(op_node); - }); + op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); std::unordered_set output_ops; - op_group.WalkOpNodes([&](const api::OpNode& op){ + op_group.WalkOpNodes([&](const api::OpNode& op) { const auto& output_tensors = op.outputs(); for (size_t i = 0; i < output_tensors.size(); ++i) { const auto& consumers = output_tensors[i].consumers(); @@ -443,7 +394,7 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon return false; } std::unique_ptr reducer_0 = nullptr; - first.WalkOpNodes([&](const api::OpNode& op){ + first.WalkOpNodes([&](const api::OpNode& op) { if (!reducer_0 && op.kind() == OpPatternKind::kReduction) { reducer_0.reset(new api::OpNode(op)); } @@ -451,7 +402,7 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); std::unique_ptr reducer_1 = nullptr; - second.WalkOpNodes([&](const api::OpNode& op){ + second.WalkOpNodes([&](const api::OpNode& op) { if (!reducer_1 && op.kind() == OpPatternKind::kReduction) { reducer_1.reset(new api::OpNode(op)); } @@ -461,10 +412,10 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon // check reduce has same input shape and output shape const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape(); - const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); + const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); - const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); + const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); @@ -488,7 +439,7 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { - fusion_group.WalkOpNodes([&](const api::OpNode& op){ + fusion_group.WalkOpNodes([&](const api::OpNode& op) { if (op.kind() == OpPatternKind::kReduction) { shared_size += GetSharedSize(op); } @@ -508,7 +459,7 @@ static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& secon reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { - fusion_group.WalkOpNodes([&](const api::OpNode& op){ + fusion_group.WalkOpNodes([&](const api::OpNode& op) { if (op.kind() == OpPatternKind::kReduction) { shared_size += GetSharedSize(op); } @@ -573,7 +524,7 @@ struct HorizontalFuseUtil { static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { std::vector master_nodes; - op_group.WalkOpNodes([&](const api::OpNode& op){ + op_group.WalkOpNodes([&](const api::OpNode& op) { if (master_nodes.empty() || op.kind() == OpPatternKind::kReduction) { master_nodes.push_back(op); } @@ -597,7 +548,7 @@ struct HorizontalFuseUtil { return true; } - const OpGroupPtr* ele_group = nullptr; + const OpGroupPtr* ele_group = nullptr; const OpGroupPtr* reduce_group = nullptr; if (src.kind() == framework::kReduction) { @@ -936,7 +887,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { }; struct LightwareFusePassComparator { - bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const{ + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { return lhs->Benefit() > rhs->Benefit(); } }; @@ -1257,7 +1208,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { }; OpGroupList consumer_groups; consumer_groups.reserve(consumers.size()); - for(auto& consumer : consumers) { + for (auto& consumer : consumers) { consumer_groups.push_back(api::OpGroup(consumer)); } GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, EnableFuse); @@ -1706,13 +1657,15 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { VerticalFuse(producer, fusionable_consumers); } - void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void RecomputeEleGraph(const GroupPtr& producer, + std::unordered_set& fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } } - void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void SelectConsumerToFuse(const GroupPtr& producer, + std::unordered_set& fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; From 602ba9dcef69422af05985ac484b1687de3320f2 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 28 Jun 2023 12:43:02 +0000 Subject: [PATCH 52/66] save workspace --- cinn/hlir/pass/fusion_merge_pass.cc | 22 +++--- cinn/hlir/pass/general_fusion_merge_pass.cc | 88 +++++++++++++++------ 2 files changed, 77 insertions(+), 33 deletions(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 4ca2128a32..a33a670119 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -78,10 +78,10 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "DoFusionMerge...!"; // while (DoHorizontalFusion()) { // } - // while (DoVerticalFusion(/* recompute=*/false)) { - // } - while (DoVerticalFusion(/* recompute=*/true)) { + while (DoVerticalFusion(/* recompute=*/false)) { } + // while (DoVerticalFusion(/* recompute=*/true)) { + // } } bool DoHorizontalFusion() { @@ -454,15 +454,15 @@ class FusionMergePassHelper : public FusionHelperBase { } } - // if (fuse_consumers.size()) { - // SelectConsumerToFuse(producer, fuse_consumers); - // } + if (fuse_consumers.size()) { + SelectConsumerToFuse(producer, fuse_consumers); + } - // // if fusionable consumers exist - // if (fuse_consumers.size()) { - // VerticalFuse(producer, fuse_consumers); - // return true; - // } + // if fusionable consumers exist + if (fuse_consumers.size()) { + VerticalFuse(producer, fuse_consumers); + return true; + } return false; } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 052e123909..eba56bf6df 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -670,11 +670,17 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } unsafe_candidates.push_back(consumer); if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + VLOG(1) << "DEBUG consumer " << std::dynamic_pointer_cast(consumer)->group_id << " has loop"; continue; } candidates.push_back(consumer); } + if (candidates.empty()) { + VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " << std::dynamic_pointer_cast(producer)->group_id; + } + VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << unsafe_candidates.size(); + if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); @@ -819,10 +825,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { VLOG(3) << "DoFusionMerge...!"; // while (DoGeneralHorizontalFusion()) { // } - // while (DoGeneralVerticalFusion()) { - // } - while (DoGeneralRecomputeAndVerticalFusion()) { + while (DoGeneralVerticalFusion()) { } + // while (DoGeneralRecomputeAndVerticalFusion()) { + // } } bool DoGeneralHorizontalFusion() { @@ -836,7 +842,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - updated |= GeneralHorizontalFuse(producer); + // updated |= GeneralHorizontalFuse(producer); + updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); } if (updated) { @@ -881,12 +888,14 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - updated |= GeneralHorizontalFuse(producer); + // updated |= GeneralHorizontalFuse(producer); + updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); updated |= GeneralVerticalFuse(producer); } // fuse input consumers - updated |= GeneralInputFuse(); + // updated |= GeneralInputFuse(); + updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -906,14 +915,15 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. bool recompute_success = GeneralRecomputeFuse(producer); - // updated |= recompute_success; - // if (!recompute_success) { - // updated |= GeneralVerticalFuse(producer); - // } + updated |= recompute_success; + if (!recompute_success) { + updated |= GeneralVerticalFuse(producer); + } } // fuse input consumers - updated |= GeneralInputFuse(); + // updated |= GeneralInputFuse(); + updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -1039,20 +1049,45 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } + std::unordered_set UpdateMutConsumers(const std::unordered_set& consumers) { + std::unordered_set updated_consumers; + for (auto& consumer : consumers) { + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(belong_group); + } else { + fused_groups.push(belong_group); + } + } + } + } + } + return updated_consumers; + } + bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; using OpGroupSets = std::set>; - const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { + const auto& GetFusableConsumerGroupSets = [&](const std::unordered_set& input_consumers) -> OpGroupSets { OpGroupSets tagged_sets; const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { tagged_sets.insert(std::set{first, second}); }; - GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse); + GraphGroupInputFusePassCtx fuse_ctx(this, input_consumers, EnableFuse); EnableFusedInputGroups(&fuse_ctx); return tagged_sets; }; - const auto& GetFusableConsumerGroupList = [&]() -> GroupList { - const auto& group_sets = GetFusableConsumerGroupSets(); + const auto& GetFusableConsumerGroupList = [&](const std::unordered_set& input_consumers) -> GroupList { + const auto& group_sets = GetFusableConsumerGroupSets(input_consumers); if (group_sets.empty()) { return GroupList{}; } @@ -1062,12 +1097,18 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } return ret; }; - const auto& groups = GetFusableConsumerGroupList(); - if (groups.size() <= 1) { - return false; + bool update = false; + std::unordered_set mut_consumers(consumers); + while (true) { + const auto& groups = GetFusableConsumerGroupList(mut_consumers); + if (groups.size() <= 1) { + return false; + } + HorizontalFuse(groups); + mut_consumers = UpdateMutConsumers(mut_consumers); + update = true; } - HorizontalFuse(groups); - return true; + return update; } bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { @@ -1619,6 +1660,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralRecomputeFuse(GroupPtr& producer) { VLOG(3) << "GeneralRecomputeFuse...!"; + VLOG(1) << "DEBUG VerticalFusion, begin check : " << producer->group_id; using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; @@ -1646,6 +1688,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size() > 0) { CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; + VLOG(1) << "DEBUG begin recompute fuse group " << producer->group_id; RecomputeFuse(producer, consumer_groups); update = true; } @@ -1848,11 +1891,12 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { continue; } // do input fusion. - while (CallGeneralInputFusePass(input_consumers.second)) { + auto st = CallGeneralInputFusePass(input_consumers.second); + if (st) { // fused consumers, update UpdateInputToConsumers(); - updated = true; } + updated |= st; } return updated; From c4c4a78674bf4ed4f99194a9a92cd5b2adae5486 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 28 Jun 2023 14:55:51 +0000 Subject: [PATCH 53/66] aligned --- cinn/hlir/pass/fusion_merge_pass.cc | 8 +- cinn/hlir/pass/general_fusion_merge_pass.cc | 175 +++++++++++++++----- 2 files changed, 135 insertions(+), 48 deletions(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index a33a670119..c5de3ccf33 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -76,12 +76,12 @@ class FusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; - // while (DoHorizontalFusion()) { - // } + while (DoHorizontalFusion()) { + } while (DoVerticalFusion(/* recompute=*/false)) { } - // while (DoVerticalFusion(/* recompute=*/true)) { - // } + while (DoVerticalFusion(/* recompute=*/true)) { + } } bool DoHorizontalFusion() { diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index eba56bf6df..90f242e3cc 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -169,6 +169,8 @@ class LightwareFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + virtual void EnableFuse(const OpGroupList& candidates) = 0; + protected: LightwareFusePassCtx() = default; }; @@ -184,18 +186,30 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { EnableFuse_(EnableFuse), fuse_helper_(new GraphGroupFuseHelper(this)) {} + GraphGroupLightwareFusePassCtx( + const FusionHelperBase* graph_group_fusion_helper, + const OpGroupPtr& group, + const std::function& EnableFuseList) + : graph_group_fusion_helper_(graph_group_fusion_helper), + group_(group), + EnableFuseList_(EnableFuseList), + fuse_helper_(new GraphGroupFuseHelper(this)) {} + const OpGroupPtr& PickOpGroup() const override { return group_; } const FuseHelper& fuse_helper() const override { return *fuse_helper_; } void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } + void EnableFuse(const OpGroupList& candidates) override { EnableFuseList_(candidates); } + const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } private: const FusionHelperBase* graph_group_fusion_helper_; const OpGroupPtr group_; const std::function EnableFuse_; + const std::function EnableFuseList_; const std::unique_ptr fuse_helper_; }; @@ -464,31 +478,99 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { int Benefit() const override { return 100; } + bool IsDependency(const OpGroupPtr& producer_g, + const OpGroupPtr& consumer, + const std::unordered_set& consumers) const { + std::queue candidates; + candidates.push(consumer); + + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (const auto& producer_and_list : candidate->producer_groups()) { + const auto& producer = producer_and_list.first; + if (producer == producer_g) { + continue; + } + + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; + } + void operator()(LightwareFusePassCtx* ctx) const override { VLOG(1) << "DefaultHorizontalFusePass"; const auto& producer = ctx->PickOpGroup(); - const OpGroupList consumers = [&]() { - OpGroupList consumers; + const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { + std::unordered_set consumers; for (const auto& pair : producer->consumer2outputs()) { - consumers.push_back(pair.first); + if (pair.first->kind() != framework::kElementWise && + pair.first->kind() != framework::kBroadcast && + pair.first->kind() != framework::kInjective && + pair.first->kind() != framework::kReduction) { + continue; + } + consumers.insert(pair.first); } return consumers; }(); - if (consumers.size() <= 1) { + if (consumer_candidates.size() <= 1) { return; } - for (int i = 0; i < consumers.size(); ++i) { - const auto& src = consumers.at(i); - for (int j = i + 1; j < consumers.size(); ++j) { - const auto& dst = consumers.at(j); - if (ctx->fuse_helper().IsReachable(src, dst)) { - continue; - } - if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, src, dst)) { + + std::vector fusionable_consumers; + for (auto& candidate : consumer_candidates) { + + // bool reachable = false; + // for (const auto& tmp: consumer_candidates) { + // if (tmp == candidate) { + // continue; + // } + // if (ctx->fuse_helper().IsReachable(candidate, tmp)) { + // reachable = true; + // break; + // } + // } + // if (reachable) { + // continue; + // } + + if (IsDependency(producer, candidate, consumer_candidates)) { + continue; + } + if (fusionable_consumers.empty()) { + fusionable_consumers.push_back({candidate}); + continue; + } + // check each fusionable groups + bool fusionable = false; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, candidate, last)) { continue; } - ctx->EnableFuse(src, dst); - return; + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + for (const auto& groups: fusionable_consumers) { + if (groups.size() > 1) { + ctx->EnableFuse(groups); } } } @@ -823,12 +905,12 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; - // while (DoGeneralHorizontalFusion()) { - // } + while (DoGeneralHorizontalFusion()) { + } while (DoGeneralVerticalFusion()) { } - // while (DoGeneralRecomputeAndVerticalFusion()) { - // } + while (DoGeneralRecomputeAndVerticalFusion()) { + } } bool DoGeneralHorizontalFusion() { @@ -842,8 +924,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - // updated |= GeneralHorizontalFuse(producer); - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + updated |= GeneralHorizontalFuse(producer); + // updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); } if (updated) { @@ -888,8 +970,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - // updated |= GeneralHorizontalFuse(producer); - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + updated |= GeneralHorizontalFuse(producer); + // updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); updated |= GeneralVerticalFuse(producer); } @@ -1001,36 +1083,41 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralHorizontalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralHorizontalFuse...!"; using OpGroupSets = std::set>; - const auto& GetFusableConsumerGroupSets = [&]() -> OpGroupSets { - OpGroupSets tagged_sets; - const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { - tagged_sets.insert(std::set{first, second}); + + const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { + std::vector tagged_lists; + const auto& EnableFuse = [&](const OpGroupList& candidates) { + tagged_lists.push_back(candidates); }; GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); - return tagged_sets; + return tagged_lists; }; - const auto& GetFusableConsumerGroupList = [&]() -> GroupList { - const auto& group_sets = GetFusableConsumerGroupSets(); - if (group_sets.empty()) { - return GroupList{}; - } - GroupList ret; - for (const auto& group : *group_sets.begin()) { - ret.push_back(std::dynamic_pointer_cast(group)); + const auto& GetFusableConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusableConsumerGroupLists(); + if (group_lists.empty()) { + return std::vector{}; + } + std::vector ret; + for (const auto& group_list : group_lists) { + GroupList tmp; + for (const auto& group: group_list) { + tmp.push_back(std::dynamic_pointer_cast(group)); + } + ret.push_back(tmp); } return ret; }; - bool update = false; - while (true) { - const auto& groups = GetFusableConsumerGroupList(); - if (groups.size() <= 1) { - break; - } - HorizontalFuse(groups); - update = true; + + const auto& group_lists = GetFusableConsumerGroupList(); + if (group_lists.empty()) { + return false; } - return update; + for (const auto& group_list: group_lists) { + HorizontalFuse(group_list); + } + + return true; } std::vector> RawInputFusePasses() const { From 35b40ba83d7a3488cc0849c3b4e13726c8651a4b Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Thu, 29 Jun 2023 03:14:05 +0000 Subject: [PATCH 54/66] fully aligned --- cinn/hlir/pass/general_fusion_merge_pass.cc | 173 ++++++++++++++------ 1 file changed, 125 insertions(+), 48 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 90f242e3cc..81622ab30f 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -223,6 +223,8 @@ class InputFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; + virtual void EnableFuse(const OpGroupList& candidates) = 0; + protected: InputFusePassCtx() = default; }; @@ -237,18 +239,29 @@ class GraphGroupInputFusePassCtx final : public InputFusePassCtx { EnableFuse_(EnableFuse), fuse_helper_(new GraphGroupFuseHelper(this)) {} + GraphGroupInputFusePassCtx(const FusionHelperBase* graph_group_fusion_helper, + const std::unordered_set& groups, + const std::function& EnableFuseList) + : graph_group_fusion_helper_(graph_group_fusion_helper), + groups_(groups), + EnableFuseList_(EnableFuseList), + fuse_helper_(new GraphGroupFuseHelper(this)) {} + const std::unordered_set& PickConsumersWithSameInputs() const override { return groups_; } const FuseHelper& fuse_helper() const override { return *fuse_helper_; } void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) override { EnableFuse_(first, second); } + void EnableFuse(const OpGroupList& candidates) override { EnableFuseList_(candidates); } + const FusionHelperBase& graph_group_fusion_helper() const { return *graph_group_fusion_helper_; } private: const FusionHelperBase* graph_group_fusion_helper_; const std::unordered_set& groups_; const std::function EnableFuse_; + const std::function EnableFuseList_; const std::unique_ptr fuse_helper_; }; @@ -414,31 +427,95 @@ class DefaultInputFusePass final : public InputFusePass { int Benefit() const override { return 100; } + bool IsDependency(const OpGroupPtr& consumer, + const std::unordered_set& consumers) const { + std::queue candidates; + candidates.push(consumer); + + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (const auto& producer_and_list : candidate->producer_groups()) { + const auto& producer = producer_and_list.first; + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; + } + void operator()(InputFusePassCtx* ctx) const override { VLOG(1) << "DefaultInputFusePass"; const auto& consumer_set = ctx->PickConsumersWithSameInputs(); - if (consumer_set.size() <= 1) { - return; - } - const OpGroupList consumers = [&]() { - OpGroupList ret; + + const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { + std::unordered_set consumers; for (const auto& consumer : consumer_set) { - ret.push_back(consumer); + if (consumer->kind() != framework::kElementWise && + consumer->kind() != framework::kBroadcast && + consumer->kind() != framework::kInjective && + consumer->kind() != framework::kReduction) { + continue; + } + consumers.insert(consumer); } - return ret; + return consumers; }(); - for (int i = 0; i < consumers.size(); ++i) { - const auto& src = consumers.at(i); - for (int j = i + 1; j < consumers.size(); ++j) { - const auto& dst = consumers.at(j); - if (ctx->fuse_helper().IsReachable(src, dst)) { - continue; - } - if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, src, dst)) { + if (consumer_candidates.size() <= 1) { + return; + } + + std::vector fusionable_consumers; + for (auto& candidate : consumer_candidates) { + + // bool reachable = false; + // for (const auto& tmp: consumer_candidates) { + // if (tmp == candidate) { + // continue; + // } + // if (ctx->fuse_helper().IsReachable(candidate, tmp)) { + // reachable = true; + // break; + // } + // } + // if (reachable) { + // continue; + // } + + if (IsDependency(candidate, consumer_candidates)) { + continue; + } + if (fusionable_consumers.empty()) { + fusionable_consumers.push_back({candidate}); + continue; + } + // check each fusionable groups + bool fusionable = false; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!HorizontalFuseUtil::DetectFusabilityByKind(ctx, candidate, last)) { continue; } - ctx->EnableFuse(src, dst); - return; + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + for (const auto& groups: fusionable_consumers) { + if (groups.size() > 1) { + ctx->EnableFuse(groups); } } } @@ -976,8 +1053,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // fuse input consumers - // updated |= GeneralInputFuse(); - updated |= FuseInputToConsumers(); + updated |= GeneralInputFuse(); + // updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -1004,8 +1081,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // fuse input consumers - // updated |= GeneralInputFuse(); - updated |= FuseInputToConsumers(); + updated |= GeneralInputFuse(); + // updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -1081,9 +1158,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool GeneralHorizontalFuse(const GroupPtr& producer) { - VLOG(3) << "GeneralHorizontalFuse...!"; - using OpGroupSets = std::set>; - + VLOG(3) << "GeneralHorizontalFuse...!"; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& EnableFuse = [&](const OpGroupList& candidates) { @@ -1164,38 +1239,40 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; using OpGroupSets = std::set>; - const auto& GetFusableConsumerGroupSets = [&](const std::unordered_set& input_consumers) -> OpGroupSets { - OpGroupSets tagged_sets; - const auto& EnableFuse = [&](const OpGroupPtr& first, const OpGroupPtr& second) { - tagged_sets.insert(std::set{first, second}); + const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { + std::vector tagged_lists; + const auto& EnableFuse = [&](const OpGroupList& candidates) { + tagged_lists.push_back(candidates); }; - GraphGroupInputFusePassCtx fuse_ctx(this, input_consumers, EnableFuse); + GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse); EnableFusedInputGroups(&fuse_ctx); - return tagged_sets; + return tagged_lists; }; - const auto& GetFusableConsumerGroupList = [&](const std::unordered_set& input_consumers) -> GroupList { - const auto& group_sets = GetFusableConsumerGroupSets(input_consumers); - if (group_sets.empty()) { - return GroupList{}; + const auto& GetFusableConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusableConsumerGroupLists(); + if (group_lists.empty()) { + return std::vector{}; } - GroupList ret; - for (const auto& group : *group_sets.begin()) { - ret.push_back(std::dynamic_pointer_cast(group)); + std::vector ret; + for (const auto& group_list : group_lists) { + GroupList tmp; + for (const auto& group: group_list) { + tmp.push_back(std::dynamic_pointer_cast(group)); + } + ret.push_back(tmp); } return ret; }; - bool update = false; - std::unordered_set mut_consumers(consumers); - while (true) { - const auto& groups = GetFusableConsumerGroupList(mut_consumers); - if (groups.size() <= 1) { - return false; - } - HorizontalFuse(groups); - mut_consumers = UpdateMutConsumers(mut_consumers); - update = true; + + const auto& group_lists = GetFusableConsumerGroupList(); + if (group_lists.empty()) { + return false; } - return update; + for (const auto& group_list: group_lists) { + HorizontalFuse(group_list); + } + + return true; } bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { From d4c5e9ca7cb2179d368223acc0a172246351415b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 29 Jun 2023 06:49:44 +0000 Subject: [PATCH 55/66] fix bug --- cinn/hlir/framework/graph.h | 14 +-- cinn/hlir/pass/fusion_merge_pass.cc | 85 +++++++++++-------- cinn/hlir/pass/general_fusion_merge_pass.cc | 94 ++++++++++++--------- cinn/hlir/pass/op_fusion_pass.cc | 4 +- 4 files changed, 112 insertions(+), 85 deletions(-) diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index 55caf788ff..742056e893 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -106,14 +106,6 @@ class Graph : public cinn::common::Graph { } }; - std::unordered_set, SharedGroupHasher, SharedGroupComparator> CollectConsumerGroups() { - std::unordered_set, SharedGroupHasher, SharedGroupComparator> groups; - for (const auto& consumer_and_list : consumer_groups_) { - groups.insert(std::dynamic_pointer_cast(consumer_and_list)); - } - return groups; - } - std::vector CollectNodes() { if (fused_sub_groups.size()) { std::vector tmp_nodes; @@ -154,11 +146,13 @@ class Graph : public cinn::common::Graph { std::string GetFuncName() { return "fn_" + group_id + unique_id; } public: - const std::unordered_set, SharedGroupHasher, SharedGroupComparator>& producer_groups() const { + const std::unordered_set, SharedGroupHasher, SharedGroupComparator>& producer_groups() + const { return producer_groups_; } - const std::unordered_set, SharedGroupHasher, SharedGroupComparator>& consumer_groups() const { + const std::unordered_set, SharedGroupHasher, SharedGroupComparator>& consumer_groups() + const { return consumer_groups_; } diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index b9c06fbe1b..f5e39d493d 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -29,12 +29,12 @@ using framework::shape_t; using common::GraphEdge; using common::GraphNode; -using GroupPtr = std::shared_ptr; -using GroupList = std::vector; - using Comparator = Graph::Group::SharedGroupComparator; using Hasher = Graph::Group::SharedGroupHasher; +using GroupPtr = std::shared_ptr; +using GroupList = std::vector; + using ConditionFunction = std::function; // Op Fusion Pass which performs Ops fusion, Ops are fused @@ -43,7 +43,7 @@ using ConditionFunction = std::functionfusion_groups; // init fusion relation. InitFusionRelation(); @@ -93,7 +93,7 @@ class FusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + updated |= HorizontalFusion(producer, producer->consumer_groups()); } if (updated) { @@ -114,9 +114,9 @@ class FusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. if (!recompute) { - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); + updated |= HorizontalFusion(producer, producer->consumer_groups()); } - updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); + updated |= VerticalFusion(producer, producer->consumer_groups(), recompute); } // fuse input consumers updated |= FuseInputToConsumers(); @@ -243,7 +243,7 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } - void HorizontalFuse(GroupList& consumers) { + void HorizontalFuse(const GroupList& consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group auto fused_group = std::make_shared(); @@ -315,14 +315,18 @@ class FusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (const auto& producer : consumer->producer_groups()) { + for (auto& producer : *consumer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(producer); // update producer's consumer producer->mut_consumer_groups()->erase(consumer); + producer->mut_consumer_groups()->insert(fused_group); } // consumer group - for (const auto& gconsumer : consumer->consumer_groups()) { + for (auto& gconsumer : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(gconsumer); // update consumer's producer gconsumer->mut_producer_groups()->erase(consumer); + gconsumer->mut_producer_groups()->insert(fused_group); } // belongs group consumer->belong_groups.insert(fused_group); @@ -380,7 +384,9 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { + bool VerticalFusion(const GroupPtr& producer, + const std::unordered_set& consumers, + bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; // if producer can't fuse others @@ -433,14 +439,14 @@ class FusionMergePassHelper : public FusionHelperBase { if (!recompute) { return false; } else { - RecomputeEleGraph(producer, fuse_consumers_unsafe); + RecomputeEleGraph(producer, &fuse_consumers_unsafe); VerticalFuse(producer, fuse_consumers_unsafe); return true; } } if (fuse_consumers.size()) { - SelectConsumerToFuse(producer, fuse_consumers); + SelectConsumerToFuse(producer, &fuse_consumers); } // if fusionable consumers exist @@ -452,7 +458,8 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void VerticalFuse(const GroupPtr& producer, + const std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); @@ -495,9 +502,11 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer groups - for (const auto& group : producer->producer_groups()) { + for (auto& group : *producer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(group); // update producer's producer's consumer group->mut_consumer_groups()->erase(producer); + group->mut_consumer_groups()->insert(fused_group); } // sub groups @@ -543,17 +552,20 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group : consumer->producer_groups()) { + for (auto& group : *consumer->mut_producer_groups()) { if (group.get() != producer.get()) { + fused_group->mut_producer_groups()->insert(group); // update consumer's producer's consumer group->mut_consumer_groups()->erase(consumer); + group->mut_consumer_groups()->insert(fused_group); } } - // consumer nodes - for (const auto& group : consumer->consumer_groups()) { + for (auto& group : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(group); // update consumer's consumer's producer group->mut_producer_groups()->erase(consumer); + group->mut_producer_groups()->insert(fused_group); } // sub group @@ -613,26 +625,30 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer : producer->consumer_groups()) { + for (auto& consumer : *producer->mut_consumer_groups()) { if (fusionable_consumers.count(consumer)) { continue; } + master_fuesd_group->mut_consumer_groups()->insert(consumer); // update consumer's producer consumer->mut_producer_groups()->erase(producer); + consumer->mut_producer_groups()->insert(master_fuesd_group); } } - void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void RecomputeEleGraph(const GroupPtr& producer, + std::unordered_set* fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } } - void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void SelectConsumerToFuse(const GroupPtr& producer, + std::unordered_set* fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { // if can be output node. if (is_same_shape(this, producer, consumer)) { candidates.insert(consumer); @@ -654,10 +670,10 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK_GE(producer->consumer_groups().size(), candidates.size()); if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { - producer->belong_groups.insert(*fusionable_consumers.begin()); + producer->belong_groups.insert(*fusionable_consumers->begin()); } - fusionable_consumers = candidates; + *fusionable_consumers = candidates; return; } // 1 to 1 fusion. @@ -667,7 +683,7 @@ class FusionMergePassHelper : public FusionHelperBase { if (FLAGS_enhance_vertical_fusion_with_recompute) { std::vector candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.push_back(consumer); continue; @@ -696,13 +712,13 @@ class FusionMergePassHelper : public FusionHelperBase { return lhs->op_pattern_kind < rhs->op_pattern_kind; }); - fusionable_consumers.clear(); + fusionable_consumers->clear(); if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); + fusionable_consumers->insert(*candidates.begin()); } } else { std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.insert(consumer); continue; @@ -717,9 +733,9 @@ class FusionMergePassHelper : public FusionHelperBase { } } - fusionable_consumers.clear(); + fusionable_consumers->clear(); if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); + fusionable_consumers->insert(*candidates.begin()); } } } @@ -879,10 +895,12 @@ class FusionMergePassHelper : public FusionHelperBase { for (const auto& producer : group->producer_groups()) { CHECK(producer->belong_groups.size()); + producers.insert(*producer->belong_groups.begin()); } - for (auto& consumer : group->consumer_groups()) { + for (auto& consumer : *group->mut_consumer_groups()) { CHECK(consumer->belong_groups.size()); + consumers.insert(*consumer->belong_groups.begin()); } CHECK_EQ(group->producer_groups().size(), producers.size()); CHECK_EQ(group->consumer_groups().size(), consumers.size()); @@ -982,7 +1000,7 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList fusion_groups_; - std::unordered_map fusion_groups_index_; + std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; struct Relation { @@ -1009,7 +1027,8 @@ void FusionMergePassInternal(Graph* graph) { CINN_REGISTER_HELPER(FusionMergePass) { CINN_REGISTER_PASS(FusionMergePass) .describe( - "Fusion Merge Pass which performs Fusion-Ops fusion, Producer Fusion-Ops are fused into Consumer Fusion-Ops " + "Fusion Merge Pass which performs Fusion-Ops fusion, Producer " + "Fusion-Ops are fused into Consumer Fusion-Ops " "with certain conditions.") .set_change_structure(false) .set_body(cinn::hlir::pass::FusionMergePassInternal); diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 4e29a8a0c9..a7f5d7f7f4 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// Copyright (c) 2023 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -186,8 +186,8 @@ class InputFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; - // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, const std::function& create_fn) = 0; + // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, + // const std::function& create_fn) = 0; protected: InputFusePassCtx() = default; @@ -370,7 +370,8 @@ static int GetSharedSize(const api::OpNode& op_node) { break; } } - // if lane > (max_num_threads / 2),the loop break from lane > max_num_threads / 2. + // if lane > (max_num_threads / 2),the loop break from lane > + // max_num_threads / 2. int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; if (lane <= max_num_threads) { return lane * sizeof(float); @@ -597,7 +598,7 @@ class InputFusePass : public FusePass { virtual void operator()(InputFusePassCtx* ctx) const = 0; - virtual const std::string FuseMode() const override final { return "InputFuse"; } + const std::string FuseMode() const final { return "InputFuse"; } virtual int Benefit() const = 0; @@ -656,7 +657,7 @@ class HorizontalFusePass : public LightwareFusePass { virtual void operator()(LightwareFusePassCtx* ctx) const = 0; - virtual const std::string FuseMode() const override final { return "HorizontalFuse"; } + const std::string FuseMode() const final { return "HorizontalFuse"; } virtual int Benefit() const = 0; @@ -706,7 +707,7 @@ class VerticalFusePass : public LightwareFusePass { virtual void operator()(LightwareFusePassCtx* ctx) const = 0; - virtual const std::string FuseMode() const override final { return "VerticalFuse"; } + const std::string FuseMode() const final { return "VerticalFuse"; } virtual int Benefit() const = 0; @@ -832,7 +833,7 @@ class RecomputeFusePass : public LightwareFusePass { virtual void operator()(LightwareFusePassCtx* ctx) const = 0; - virtual const std::string FuseMode() const override final { return "RecomputeFuse"; } + const std::string FuseMode() const final { return "RecomputeFuse"; } virtual int Benefit() const = 0; @@ -921,7 +922,8 @@ class FusionPassMap { // fuse_mode: HorizontalFuse, VerticalFuse, RecomputeFuse std::vector> GetLightwareFusePassesByMode(const std::string& fuse_mode) const { CHECK(fuse_mode == "HorizontalFuse" || fuse_mode == "VerticalFuse" || fuse_mode == "RecomputeFuse") - << "fuse_mode only supports HorizontalFuse, VerticalFuse and RecomputeFuse. Please check your input modes = " + << "fuse_mode only supports HorizontalFuse, VerticalFuse and " + "RecomputeFuse. Please check your input modes = " << fuse_mode; std::set, LightwareFusePassComparator> candidate_passes; for (const auto iter : map_) { @@ -976,7 +978,7 @@ class FusionPassRegistrar final : public Registrar { // code generation. class GeneralFusionMergePassHelper : public FusionHelperBase { public: - GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph), graph_(graph) { + explicit GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph), graph_(graph) { fusion_groups_ = graph->fusion_groups; // init input to consumers. InitInputToConsumers(); @@ -1304,14 +1306,18 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (const auto& producer : consumer->producer_groups()) { + for (auto& producer : *consumer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(producer); // update producer's consumer producer->mut_consumer_groups()->erase(consumer); + producer->mut_consumer_groups()->insert(fused_group); } // consumer group - for (const auto& gconsumer : consumer->consumer_groups()) { + for (auto& gconsumer : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(gconsumer); // update consumer's producer gconsumer->mut_producer_groups()->erase(consumer); + gconsumer->mut_producer_groups()->insert(fused_group); } // belongs group consumer->belong_groups.insert(fused_group); @@ -1389,7 +1395,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - bool GeneralVerticalFuse(GroupPtr& producer) { + bool GeneralVerticalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralVerticalFuse...!"; using GroupSets = std::vector>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { @@ -1417,7 +1423,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool update = false; auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size()) { - SelectConsumerToFuse(producer, consumer_groups); + SelectConsumerToFuse(producer, &consumer_groups); } if (consumer_groups.size() > 0) { VerticalFuse(producer, consumer_groups); @@ -1426,7 +1432,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return update; } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void VerticalFuse(const GroupPtr& producer, + const std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); @@ -1469,9 +1476,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // producer groups - for (const auto& group : producer->producer_groups()) { + for (auto& group : *producer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(group); // update producer's producer's consumer group->mut_consumer_groups()->erase(producer); + group->mut_consumer_groups()->insert(fused_group); } // delete consumer group in producer @@ -1520,17 +1529,21 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (const auto& group : consumer->producer_groups()) { + for (auto& group : *consumer->mut_producer_groups()) { if (group.get() != producer.get()) { + fused_group->mut_producer_groups()->insert(group); // update consumer's producer's consumer group->mut_consumer_groups()->erase(consumer); + group->mut_consumer_groups()->insert(fused_group); } } // consumer nodes - for (const auto& group : consumer->consumer_groups()) { + for (auto& group : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(group); // update consumer's consumer's producer group->mut_producer_groups()->erase(consumer); + group->mut_producer_groups()->insert(fused_group); } // sub group @@ -1590,12 +1603,14 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (const auto& consumer : producer->consumer_groups()) { + for (auto& consumer : *producer->mut_consumer_groups()) { if (fusionable_consumers.count(consumer)) { continue; } + master_fuesd_group->mut_consumer_groups()->insert(consumer); // update consumer's producer consumer->mut_producer_groups()->erase(producer); + consumer->mut_producer_groups()->insert(master_fuesd_group); } } @@ -1619,7 +1634,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - bool GeneralRecomputeFuse(GroupPtr& producer) { + bool GeneralRecomputeFuse(const GroupPtr& producer) { VLOG(3) << "GeneralRecomputeFuse...!"; using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { @@ -1653,23 +1668,17 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return update; } - void RecomputeFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void RecomputeFuse(const GroupPtr& producer, + const std::unordered_set& fusionable_consumers) { VerticalFuse(producer, fusionable_consumers); } - void RecomputeEleGraph(const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind != framework::kElementWise) { - SelectConsumerToFuse(producer, fusionable_consumers); - } - } - void SelectConsumerToFuse(const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + std::unordered_set* fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { // if can be output node. if (is_same_shape(this, producer, consumer)) { candidates.insert(consumer); @@ -1691,10 +1700,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { CHECK_GE(producer->consumer_groups().size(), candidates.size()); if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { - producer->belong_groups.insert(*fusionable_consumers.begin()); + producer->belong_groups.insert(*fusionable_consumers->begin()); } - fusionable_consumers = candidates; + *fusionable_consumers = candidates; return; } // 1 to 1 fusion. @@ -1704,7 +1713,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { if (FLAGS_enhance_vertical_fusion_with_recompute) { std::vector candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.push_back(consumer); continue; @@ -1733,13 +1742,13 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return lhs->op_pattern_kind < rhs->op_pattern_kind; }); - fusionable_consumers.clear(); + fusionable_consumers->clear(); if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); + fusionable_consumers->insert(*candidates.begin()); } } else { std::vector candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.push_back(consumer); continue; @@ -1754,9 +1763,9 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - fusionable_consumers.clear(); + fusionable_consumers->clear(); if (candidates.size()) { - fusionable_consumers.insert(candidates.front()); + fusionable_consumers->insert(candidates.front()); } } } @@ -1914,12 +1923,14 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::unordered_set producers; std::unordered_set consumers; - for (auto& producer : group->producer_groups()) { + for (const auto& producer : group->producer_groups()) { CHECK(producer->belong_groups.size()); + producers.insert(*producer->belong_groups.begin()); } - for (auto& consumer : group->consumer_groups()) { + for (auto& consumer : *group->mut_consumer_groups()) { CHECK(consumer->belong_groups.size()); + consumers.insert(*consumer->belong_groups.begin()); } CHECK_EQ(group->producer_groups().size(), producers.size()); CHECK_EQ(group->consumer_groups().size(), consumers.size()); @@ -1951,7 +1962,8 @@ void GeneralFusionMergePassInternal(Graph* graph) { CINN_REGISTER_HELPER(GeneralFusionMergePass) { CINN_REGISTER_PASS(GeneralFusionMergePass) .describe( - "Fusion Merge Pass which performs Fusion-Ops fusion, Producer Fusion-Ops are fused into Consumer Fusion-Ops " + "Fusion Merge Pass which performs Fusion-Ops fusion, Producer " + "Fusion-Ops are fused into Consumer Fusion-Ops " "with certain conditions.") .set_change_structure(false) .set_body(cinn::hlir::pass::GeneralFusionMergePassInternal); diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index a7969aaa23..aff427cdb1 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -100,6 +100,8 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& consumer : fusion_groups) { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; + consumer->mut_producer_groups()->insert(producer); + producer->mut_consumer_groups()->insert(consumer); } } @@ -107,7 +109,7 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& group : fusion_groups) { for (const auto& consumer : group->consumer_groups()) { // update depth. - group->depth = std::max(group->depth, consumer->depth + 1); + group->depth = std::max(group->depth, consumer->depth + 1); } } From 87e9a47113c1b850a05c2f8ae06f8f9139b57944 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 29 Jun 2023 07:01:08 +0000 Subject: [PATCH 56/66] polish code --- cinn/api/op_group.h | 101 ++++++++------------ cinn/hlir/pass/general_fusion_merge_pass.cc | 3 - 2 files changed, 40 insertions(+), 64 deletions(-) diff --git a/cinn/api/op_group.h b/cinn/api/op_group.h index 40df305eac..2924559c54 100644 --- a/cinn/api/op_group.h +++ b/cinn/api/op_group.h @@ -17,7 +17,6 @@ #include #include "cinn/api/op_node.h" - #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/pass/fusion_helper_base.h" @@ -29,40 +28,35 @@ using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; class OpGroup { public: - OpGroup(const std::shared_ptr& group) - : group_(group) {} + OpGroup(const std::shared_ptr& group) : group_(group) {} OpGroup(const OpGroup& other) = default; class OpGroupListIterator { - public: - OpGroupListIterator(std::unordered_set, Hasher, Comparator>::const_iterator it) : iter_(it) {} - - OpGroupListIterator& operator++() { - ++iter_; - return *this; - } - - OpGroupListIterator operator++(int) { - OpGroupListIterator tmp = *this; - ++iter_; - return tmp; - } - - bool operator==(const OpGroupListIterator& other) const { - return iter_ == other.iter_; - } - - bool operator!=(const OpGroupListIterator& other) const { - return !(*this == other); - } - - OpGroup operator*() const{ - return OpGroup(*iter_); - } - - private: - std::unordered_set, Hasher, Comparator>::const_iterator iter_; + public: + OpGroupListIterator( + std::unordered_set, Hasher, Comparator>::const_iterator it) + : iter_(it) {} + + OpGroupListIterator& operator++() { + ++iter_; + return *this; + } + + OpGroupListIterator operator++(int) { + OpGroupListIterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const OpGroupListIterator& other) const { return iter_ == other.iter_; } + + bool operator!=(const OpGroupListIterator& other) const { return !(*this == other); } + + OpGroup operator*() const { return OpGroup(*iter_); } + + private: + std::unordered_set, Hasher, Comparator>::const_iterator iter_; }; class ProducerOpGroupListView { @@ -70,7 +64,7 @@ class OpGroup { ProducerOpGroupListView(const std::weak_ptr& group) : group_(group) {} ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete; - ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete; + ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete; ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) = delete; @@ -91,7 +85,7 @@ class OpGroup { ConsumerOpGroupListView(const std::weak_ptr& group) : group_(group) {} ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete; - ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete; + ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete; ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) = delete; @@ -107,10 +101,7 @@ class OpGroup { const std::weak_ptr group_; }; - const std::string& group_id() const { - return group_.lock()->group_id; - } - + const std::string& group_id() const { return group_.lock()->group_id; } hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); } @@ -123,39 +114,27 @@ class OpGroup { // // Example: Get the all Reduction op_nodes in the group. // OpGroup group = ...; - // std::set reduce_ op_set; - // // The lambda funtion of VisitOpNode to get reduction op_nodes. - // auto get_reduce_op = [&reduce_op_set](const api::OpNode& op){ + // std::set reduce_op_set; + // group.WalkOpNodes([&reduce_op_set](const api::OpNode& op){ + // // The lambda funtion of VisitOpNode to get reduction op_nodes. // if (op.kind() == OpPatternKind::kReduction) { // reduce_op_set.insert(op); // } - // }; - // group.WalkOpNodes(get_reduce_op); + // }); void WalkOpNodes(const std::function& VisitOpNode) const { - group_.lock()->WalkNodes([&](const hlir::framework::Node* node){ - VisitOpNode(OpNode(node, group_.lock()->graph_)); - }); + group_.lock()->WalkNodes( + [&](const hlir::framework::Node* node) { VisitOpNode(OpNode(node, group_.lock()->graph_)); }); } - ProducerOpGroupListView producers() const { - return ProducerOpGroupListView(group_); - } + ProducerOpGroupListView producers() const { return ProducerOpGroupListView(group_); } - ConsumerOpGroupListView consumers() const { - return ConsumerOpGroupListView(group_); - } + ConsumerOpGroupListView consumers() const { return ConsumerOpGroupListView(group_); } - std::shared_ptr GetGroup() const { - return group_.lock(); - } + std::shared_ptr GetGroup() const { return group_.lock(); } - bool operator == (const OpGroup& other) const { - return group_.lock().get() == other.group_.lock().get(); - } + bool operator==(const OpGroup& other) const { return group_.lock().get() == other.group_.lock().get(); } - bool operator < (const OpGroup& other) const { - return group_.lock().get() < other.group_.lock().get(); - } + bool operator<(const OpGroup& other) const { return group_.lock().get() < other.group_.lock().get(); } private: const std::weak_ptr group_; @@ -173,4 +152,4 @@ struct hash { } }; -} // namespace std +} // namespace std diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index a7f5d7f7f4..685bd0cddd 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -186,9 +186,6 @@ class InputFusePassCtx : public FusePassCtx { virtual void EnableFuse(const OpGroupPtr& first, const OpGroupPtr& second) = 0; - // virtual absl::any* FindOrCreateCachedGroupInfo(const OpGroupPtr& op_group, - // const std::function& create_fn) = 0; - protected: InputFusePassCtx() = default; }; From 6dc511dec12423667b5c7be2bda53391b37e8579 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 30 Jun 2023 02:23:34 +0000 Subject: [PATCH 57/66] fully aligned and add debug message --- cinn/hlir/pass/fusion_merge_pass.cc | 22 +++++++++--- cinn/hlir/pass/general_fusion_merge_pass.cc | 38 +++++++++++++-------- cinn/lang/lower_impl.cc | 14 ++++---- 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index c5de3ccf33..dab40d9743 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -76,6 +76,10 @@ class FusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; + } while (DoHorizontalFusion()) { } while (DoVerticalFusion(/* recompute=*/false)) { @@ -85,11 +89,11 @@ class FusionMergePassHelper : public FusionHelperBase { } bool DoHorizontalFusion() { - VLOG(3) << "DoHorizontalFusion...!"; + VLOG(1) << "****** DEBUG DoGeneralHorizontalFusion...! ********"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + VLOG(1) << "Fusion Producer Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -105,6 +109,11 @@ class FusionMergePassHelper : public FusionHelperBase { } bool DoVerticalFusion(bool recompute) { + if (recompute) { + VLOG(1) << "****** DEBUG DoGeneralRecomputeAndVerticalFusion...! ********"; + } else { + VLOG(1) << "****** DEBUG DoGeneralVerticalFusion...! ********"; + } VLOG(3) << "DoVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { @@ -239,6 +248,7 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& groups : fusionable_consumers) { if (groups.size() > 1) { updated = true; + VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id; HorizontalFuse(groups); } } @@ -256,8 +266,9 @@ class FusionMergePassHelper : public FusionHelperBase { // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. + VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (auto& consumer : consumers) { - VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; + VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!"; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -439,6 +450,8 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << fuse_consumers_unsafe.size(); + VLOG(1) << "DEBUG fuse_consumers.size() = " << fuse_consumers.size(); + VLOG(1) << "DEBUG producer->consumer_groups().size() = " << producer->consumer_groups().size(); // if can_fuse_consumers == consumers // if producer op kind == kElementwise // if use recompute @@ -447,8 +460,8 @@ class FusionMergePassHelper : public FusionHelperBase { if (!recompute) { return false; } else { - VLOG(1) << "DEBUG begin recompute fuse group " << producer->group_id; RecomputeEleGraph(producer, fuse_consumers_unsafe); + VLOG(1) << "DEBUG recompute fuse group " << producer->group_id; VerticalFuse(producer, fuse_consumers_unsafe); return true; } @@ -460,6 +473,7 @@ class FusionMergePassHelper : public FusionHelperBase { // if fusionable consumers exist if (fuse_consumers.size()) { + VLOG(1) << "DEBUG Vertical fuse group " << producer->group_id; VerticalFuse(producer, fuse_consumers); return true; } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 81622ab30f..897492d6b8 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -586,6 +586,11 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { void operator()(LightwareFusePassCtx* ctx) const override { VLOG(1) << "DefaultHorizontalFusePass"; const auto& producer = ctx->PickOpGroup(); + + + + + const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& pair : producer->consumer2outputs()) { @@ -647,6 +652,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { for (const auto& groups: fusionable_consumers) { if (groups.size() > 1) { + VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size(); ctx->EnableFuse(groups); } } @@ -981,6 +987,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { + VLOG(1) << "****** DEBUG Input Groups...! ********"; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; + } VLOG(3) << "DoFusionMerge...!"; while (DoGeneralHorizontalFusion()) { } @@ -991,18 +1002,17 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool DoGeneralHorizontalFusion() { - VLOG(3) << "DoGeneralHorizontalFusion...!"; + VLOG(1) << "****** DEBUG DoGeneralHorizontalFusion...! ********"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; } // do horizontal fusion. updated |= GeneralHorizontalFuse(producer); - // updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); } if (updated) { @@ -1037,24 +1047,22 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool DoGeneralVerticalFusion() { - VLOG(3) << "DoGeneralVerticalFusion...!"; + VLOG(1) << "****** DEBUG DoGeneralVerticalFusion...! ********"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; + VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; } // do horizontal fusion. updated |= GeneralHorizontalFuse(producer); - // updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); updated |= GeneralVerticalFuse(producer); } // fuse input consumers updated |= GeneralInputFuse(); - // updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -1063,7 +1071,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool DoGeneralRecomputeAndVerticalFusion() { - VLOG(3) << "DoGeneralRecomputeAndVerticalFusion...!"; + VLOG(1) << "****** DEBUG DoGeneralRecomputeAndVerticalFusion...! ********"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; @@ -1082,7 +1090,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // fuse input consumers updated |= GeneralInputFuse(); - // updated |= FuseInputToConsumers(); if (updated) { UpdateFusionGroup(); @@ -1158,7 +1165,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool GeneralHorizontalFuse(const GroupPtr& producer) { - VLOG(3) << "GeneralHorizontalFuse...!"; + VLOG(1) << "DEBUG Horizontal, begin check : " << producer->group_id; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& EnableFuse = [&](const OpGroupList& candidates) { @@ -1189,6 +1196,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return false; } for (const auto& group_list: group_lists) { + VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id; HorizontalFuse(group_list); } @@ -1355,8 +1363,9 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. + VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (const auto& consumer : consumers) { - VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; + VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!"; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -1583,7 +1592,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool GeneralVerticalFuse(GroupPtr& producer) { - VLOG(3) << "GeneralVerticalFuse...!"; + VLOG(1) << "DEBUG Vertical, begin check : " << producer->group_id; using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; @@ -1613,6 +1622,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { SelectConsumerToFuse(producer, consumer_groups); } if (consumer_groups.size() > 0) { + VLOG(1) << "DEBUG Vertical fuse group " << producer->group_id; VerticalFuse(producer, consumer_groups); update = true; } @@ -1824,7 +1834,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralRecomputeFuse(GroupPtr& producer) { VLOG(3) << "GeneralRecomputeFuse...!"; - VLOG(1) << "DEBUG VerticalFusion, begin check : " << producer->group_id; + VLOG(1) << "DEBUG Recompute, begin check : " << producer->group_id; using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; @@ -1852,7 +1862,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size() > 0) { CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; - VLOG(1) << "DEBUG begin recompute fuse group " << producer->group_id; + VLOG(1) << "DEBUG recompute fuse group " << producer->group_id; RecomputeFuse(producer, consumer_groups); update = true; } diff --git a/cinn/lang/lower_impl.cc b/cinn/lang/lower_impl.cc index e839fc8ef0..d1be794b12 100644 --- a/cinn/lang/lower_impl.cc +++ b/cinn/lang/lower_impl.cc @@ -77,12 +77,12 @@ Expr LowerGroup(const poly::ScheduleGroup& group, BindBuffer(stage_map); std::vector stages; for (auto& node : group.nodes) { - VLOG(1) << "In LowerGroup, node id is: " << node->id(); + VLOG(2) << "In LowerGroup, node id is: " << node->id(); if (node->stage->has_expression()) { stages.push_back(node->stage); - VLOG(1) << "stage expr " << node->stage->expr(); + VLOG(2) << "stage expr " << node->stage->expr(); } else { - VLOG(1) << "stage expression is null: " << node->stage->domain(); + VLOG(2) << "stage expression is null: " << node->stage->domain(); } } @@ -104,7 +104,7 @@ Expr LowerGroup(const poly::ScheduleGroup& group, // now we get a workable expression, but the statement are something like `B(((16 * po0) + po1), po2)`, we need to // transform this to some realworld statement in CINN. - VLOG(1) << "ast to expr: \n" << e << std::endl; + VLOG(2) << "ast to expr: \n" << e << std::endl; // replace isl call to the corresponding CINN statement, we need to replace the axis at the same time. for (auto& statement : tuple_to_expr) { @@ -345,7 +345,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) for (auto& tensor : tensor_args_) { auto* tensor_node = tensor.As(); bool is_output = teller.IsWrite(tensor->name); - VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + VLOG(2) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; // avoid duplicate if (!tensor_node->buffer.defined()) continue; @@ -772,7 +772,7 @@ LowerImpl::LowerImpl(const std::string& fn_name, compu_graph_ = CreateCompGraph(tensors, stages, false /*inline_hide*/); - VLOG(1) << "compute_graph:\n" << compu_graph_->Visualize(); + VLOG(2) << "compute_graph:\n" << compu_graph_->Visualize(); } // Todo: Here insert auto syncthreads() @haoze @@ -782,7 +782,7 @@ LowerImpl::LowerImpl(const std::string& fn_name, tensors.insert(std::end(tensors), temp_tensor_args_.begin(), temp_tensor_args_.end()); compu_graph_ = CreateCompGraph(tensors, stages, true /*inline_hide*/); - VLOG(1) << "Computation Graph:\n" << compu_graph_->Visualize(); + VLOG(2) << "Computation Graph:\n" << compu_graph_->Visualize(); } } From a40aab723837144c42c22987dfe10da96160b1bc Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 30 Jun 2023 07:34:32 +0000 Subject: [PATCH 58/66] Add trick for BERT --- cinn/hlir/pass/general_fusion_merge_pass.cc | 43 +++++++++++---------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 897492d6b8..16319b6529 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -457,13 +457,12 @@ class DefaultInputFusePass final : public InputFusePass { const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& consumer : consumer_set) { - if (consumer->kind() != framework::kElementWise && - consumer->kind() != framework::kBroadcast && - consumer->kind() != framework::kInjective && - consumer->kind() != framework::kReduction) { - continue; + if (consumer->kind() == framework::kElementWise || + consumer->kind() == framework::kBroadcast || + consumer->kind() == framework::kInjective || + consumer->kind() == framework::kReduction) { + consumers.insert(consumer); } - consumers.insert(consumer); } return consumers; }(); @@ -586,21 +585,15 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { void operator()(LightwareFusePassCtx* ctx) const override { VLOG(1) << "DefaultHorizontalFusePass"; const auto& producer = ctx->PickOpGroup(); - - - - - const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& pair : producer->consumer2outputs()) { - if (pair.first->kind() != framework::kElementWise && - pair.first->kind() != framework::kBroadcast && - pair.first->kind() != framework::kInjective && - pair.first->kind() != framework::kReduction) { - continue; + if (pair.first->kind() == framework::kElementWise || + pair.first->kind() == framework::kBroadcast || + pair.first->kind() == framework::kInjective || + pair.first->kind() == framework::kReduction) { + consumers.insert(pair.first); } - consumers.insert(pair.first); } return consumers; }(); @@ -1353,7 +1346,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return updated; } - void HorizontalFuse(const GroupList& consumers) { + void HorizontalFuse(const GroupList& const_consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group auto fused_group = std::make_shared(); @@ -1362,10 +1355,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::unordered_set sub_group_set; // find the first consumer. GroupPtr first_consumer(nullptr); + // Trick for BERT + GroupList consumers = const_consumers; + if (consumers.size() == 2) { + if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") { + auto tmp = consumers[0]; + consumers[0] = consumers[1]; + consumers[1] = tmp; + } + } // fuse all group into fusion group. VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (const auto& consumer : consumers) { - VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!"; + VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" << " Pattern kind = " << consumer->op_pattern_kind; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -1469,6 +1471,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } + VLOG(1) << "DEBUG consumers.back() kind : " << static_cast((consumers.back())->op_pattern_kind); if (static_cast(framework::kReduction) > static_cast((consumers.back())->op_pattern_kind)) { auto consumer = consumers.back(); @@ -1485,7 +1488,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } if (master_node) { - VLOG(3) << "Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id; + VLOG(1) << "DEBUG Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id; fused_group->master_nodes.insert(master_node); break; } From 57557d5fdd7897752a2162c56712047ff8142d45 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 30 Jun 2023 08:09:57 +0000 Subject: [PATCH 59/66] update readme --- cinn/api/README.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/cinn/api/README.md b/cinn/api/README.md index e69de29bb2..b4a8775d25 100644 --- a/cinn/api/README.md +++ b/cinn/api/README.md @@ -0,0 +1,44 @@ +The classes in this directory are the interface of group fusion pass, you can use these apis to build the stragey for group fusion. + +The Class and APIs are following: + +`OpGroup` : A set of op nodes, which will pass to cinn backend for generating kernel code. Two groups can fuse togather according to the rule of merging written in the passes. + +`OpNode` : Map the op in the program. + +`TensorNode` : Map the tensor in the program. + +`Shape` : The shape infomation of tensor + +`FusePassCtx` : The context is the parameter for the pass, it hold the data all you need in the pass. + +`FuseHelper` : We provide some util methods such as `DetectCycleIfFuse` in fuse_helper to simplify development of pass. + +| Class | method | description | +| :--: | :--: | :--: | +| `OpGroup` | kind()| Get the Kind of group | +| | producers()| Get producer groups of current group | +| | consumers() | Get consumer groups of current group | +| | WalkOpNodes(const std::function& VisitOpNode) | Visit the op_nodes in the group and execute the VisitOpNode function for each OpNode | +| | | | +| `OpNode` | kind() | Get the Kind of op_node | +| | inputs() | Get input tensors of op_node | +| | outputs() | Get output tensors of op_node | +| | GetAttr(const std::string& attr_name) | Get attribute of op_node by attr name | +| | | | +| `TensorNode` | shape() | Get shape of tensor | +| | producer() | Get the producer op_node of tensor | +| | consumers() | Get the consumer op_nodes of tensor | +| | | | +| `Shape` | numel() | Get total number of elements in the shape | +| | other methods are same with std::vector | | +| | | | +| `LightwareFusePassCtx` | PickOpGroup() | Get the current group in the pass context | +| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather | +| | fuse_helper() | Get the fuse_helper provided by pass context | +| | | | +| `InputFusePassCtx` | PickConsumersWithSameInputs() | Get all consumer groups for input tensors of graph | +| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather | +| | fuse_helper() | Get the fuse_helper provided by pass context | +| | | | +| `FuseHelper` | DetectCycleIfFuse(const OpGroup& first, const OpGroup& second) | Whether there is cycle in graph after fusing two groups | From 172ba95186f516da98a21456313ecc4b5909d1eb Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 30 Jun 2023 09:57:52 +0000 Subject: [PATCH 60/66] Remove IsDependency Trick --- cinn/hlir/pass/general_fusion_merge_pass.cc | 172 ++++++-------------- 1 file changed, 51 insertions(+), 121 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 16319b6529..014e261bae 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -67,9 +67,12 @@ class FuseHelper { virtual bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; + virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0; + virtual bool DetectCycleIfFuse(const OpGroupPtr& src, const OpGroupPtr& dst) const = 0; - virtual bool IsReachable(const OpGroupPtr& lhs, const OpGroupPtr& rhs) const = 0; + virtual bool IsConsumerSetsReachable(const OpGroupPtr& group, + const std::unordered_set& consumers) const = 0; protected: FuseHelper() = default; @@ -108,6 +111,19 @@ class GraphGroupFuseHelper final : public FuseHelper { return ReachableIfDirectEdgeIgnored(lhs, rhs) || ReachableIfDirectEdgeIgnored(rhs, lhs); } + bool IsConsumerSetsReachable(const OpGroupPtr& group, + const std::unordered_set& consumers) const override { + for (const auto& consumer : consumers) { + if (group == consumer) { + continue; + } + if (IsReachableInDag(consumer, group)) { + return true; + } + } + return false; + } + private: bool IsReachableInDag(const OpGroupPtr& producer, const OpGroupPtr& consumer) const { const auto& MinDepth4Node = [&](OpGroupPtr node) { @@ -186,10 +202,9 @@ class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { EnableFuse_(EnableFuse), fuse_helper_(new GraphGroupFuseHelper(this)) {} - GraphGroupLightwareFusePassCtx( - const FusionHelperBase* graph_group_fusion_helper, - const OpGroupPtr& group, - const std::function& EnableFuseList) + GraphGroupLightwareFusePassCtx(const FusionHelperBase* graph_group_fusion_helper, + const OpGroupPtr& group, + const std::function& EnableFuseList) : graph_group_fusion_helper_(graph_group_fusion_helper), group_(group), EnableFuseList_(EnableFuseList), @@ -427,29 +442,6 @@ class DefaultInputFusePass final : public InputFusePass { int Benefit() const override { return 100; } - bool IsDependency(const OpGroupPtr& consumer, - const std::unordered_set& consumers) const { - std::queue candidates; - candidates.push(consumer); - - std::unordered_set visited_set; - while (!candidates.empty()) { - auto& candidate = candidates.front(); - candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups()) { - const auto& producer = producer_and_list.first; - if (consumers.count(producer)) { - return true; - } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); - } - } - } - return false; - } - void operator()(InputFusePassCtx* ctx) const override { VLOG(1) << "DefaultInputFusePass"; const auto& consumer_set = ctx->PickConsumersWithSameInputs(); @@ -457,11 +449,9 @@ class DefaultInputFusePass final : public InputFusePass { const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& consumer : consumer_set) { - if (consumer->kind() == framework::kElementWise || - consumer->kind() == framework::kBroadcast || - consumer->kind() == framework::kInjective || - consumer->kind() == framework::kReduction) { - consumers.insert(consumer); + if (consumer->kind() == framework::kElementWise || consumer->kind() == framework::kBroadcast || + consumer->kind() == framework::kInjective || consumer->kind() == framework::kReduction) { + consumers.insert(consumer); } } return consumers; @@ -472,22 +462,7 @@ class DefaultInputFusePass final : public InputFusePass { std::vector fusionable_consumers; for (auto& candidate : consumer_candidates) { - - // bool reachable = false; - // for (const auto& tmp: consumer_candidates) { - // if (tmp == candidate) { - // continue; - // } - // if (ctx->fuse_helper().IsReachable(candidate, tmp)) { - // reachable = true; - // break; - // } - // } - // if (reachable) { - // continue; - // } - - if (IsDependency(candidate, consumer_candidates)) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, consumer_candidates)) { continue; } if (fusionable_consumers.empty()) { @@ -511,8 +486,8 @@ class DefaultInputFusePass final : public InputFusePass { fusionable_consumers.push_back({candidate}); } } - - for (const auto& groups: fusionable_consumers) { + + for (const auto& groups : fusionable_consumers) { if (groups.size() > 1) { ctx->EnableFuse(groups); } @@ -554,45 +529,15 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { int Benefit() const override { return 100; } - bool IsDependency(const OpGroupPtr& producer_g, - const OpGroupPtr& consumer, - const std::unordered_set& consumers) const { - std::queue candidates; - candidates.push(consumer); - - std::unordered_set visited_set; - while (!candidates.empty()) { - auto& candidate = candidates.front(); - candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups()) { - const auto& producer = producer_and_list.first; - if (producer == producer_g) { - continue; - } - - if (consumers.count(producer)) { - return true; - } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); - } - } - } - return false; - } - void operator()(LightwareFusePassCtx* ctx) const override { VLOG(1) << "DefaultHorizontalFusePass"; - const auto& producer = ctx->PickOpGroup(); + const auto& producer = ctx->PickOpGroup(); const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; for (const auto& pair : producer->consumer2outputs()) { - if (pair.first->kind() == framework::kElementWise || - pair.first->kind() == framework::kBroadcast || - pair.first->kind() == framework::kInjective || - pair.first->kind() == framework::kReduction) { - consumers.insert(pair.first); + if (pair.first->kind() == framework::kElementWise || pair.first->kind() == framework::kBroadcast || + pair.first->kind() == framework::kInjective || pair.first->kind() == framework::kReduction) { + consumers.insert(pair.first); } } return consumers; @@ -603,22 +548,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { std::vector fusionable_consumers; for (auto& candidate : consumer_candidates) { - - // bool reachable = false; - // for (const auto& tmp: consumer_candidates) { - // if (tmp == candidate) { - // continue; - // } - // if (ctx->fuse_helper().IsReachable(candidate, tmp)) { - // reachable = true; - // break; - // } - // } - // if (reachable) { - // continue; - // } - - if (IsDependency(producer, candidate, consumer_candidates)) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, consumer_candidates)) { continue; } if (fusionable_consumers.empty()) { @@ -642,8 +572,8 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { fusionable_consumers.push_back({candidate}); } } - - for (const auto& groups: fusionable_consumers) { + + for (const auto& groups : fusionable_consumers) { if (groups.size() > 1) { VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size(); ctx->EnableFuse(groups); @@ -835,11 +765,13 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } if (candidates.empty()) { - VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " << std::dynamic_pointer_cast(producer)->group_id; + VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " + << std::dynamic_pointer_cast(producer)->group_id; } VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << unsafe_candidates.size(); - if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { + if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && + producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { ctx->EnableFuse(producer, consumer); } @@ -859,7 +791,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { }; struct LightwareFusePassComparator { - bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const{ + bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { return lhs->Benefit() > rhs->Benefit(); } }; @@ -1161,9 +1093,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { VLOG(1) << "DEBUG Horizontal, begin check : " << producer->group_id; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; - const auto& EnableFuse = [&](const OpGroupList& candidates) { - tagged_lists.push_back(candidates); - }; + const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); }; GraphGroupLightwareFusePassCtx fuse_ctx(this, producer, EnableFuse); EnableFusedHorizontalGroups(&fuse_ctx); return tagged_lists; @@ -1176,19 +1106,19 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::vector ret; for (const auto& group_list : group_lists) { GroupList tmp; - for (const auto& group: group_list) { + for (const auto& group : group_list) { tmp.push_back(std::dynamic_pointer_cast(group)); } ret.push_back(tmp); } return ret; }; - + const auto& group_lists = GetFusableConsumerGroupList(); if (group_lists.empty()) { return false; } - for (const auto& group_list: group_lists) { + for (const auto& group_list : group_lists) { VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id; HorizontalFuse(group_list); } @@ -1239,12 +1169,10 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; - using OpGroupSets = std::set>; + using OpGroupSets = std::set>; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; - const auto& EnableFuse = [&](const OpGroupList& candidates) { - tagged_lists.push_back(candidates); - }; + const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); }; GraphGroupInputFusePassCtx fuse_ctx(this, consumers, EnableFuse); EnableFusedInputGroups(&fuse_ctx); return tagged_lists; @@ -1257,7 +1185,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::vector ret; for (const auto& group_list : group_lists) { GroupList tmp; - for (const auto& group: group_list) { + for (const auto& group : group_list) { tmp.push_back(std::dynamic_pointer_cast(group)); } ret.push_back(tmp); @@ -1269,7 +1197,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { if (group_lists.empty()) { return false; } - for (const auto& group_list: group_lists) { + for (const auto& group_list : group_lists) { HorizontalFuse(group_list); } @@ -1359,7 +1287,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { GroupList consumers = const_consumers; if (consumers.size() == 2) { if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") { - auto tmp = consumers[0]; + auto tmp = consumers[0]; consumers[0] = consumers[1]; consumers[1] = tmp; } @@ -1367,7 +1295,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // fuse all group into fusion group. VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (const auto& consumer : consumers) { - VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" << " Pattern kind = " << consumer->op_pattern_kind; + VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" + << " Pattern kind = " << consumer->op_pattern_kind; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -1864,7 +1793,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool update = false; auto consumer_groups = GetFusableConsumerGroupSet(); if (consumer_groups.size() > 0) { - CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; + CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) + << "Recompute requires fuse all consumers!"; VLOG(1) << "DEBUG recompute fuse group " << producer->group_id; RecomputeFuse(producer, consumer_groups); update = true; From 91f67a6721e6ff4240b32d6067741579d1112a2c Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 3 Jul 2023 08:11:46 +0000 Subject: [PATCH 61/66] Remove trick to HorizontalFusePass --- cinn/hlir/pass/general_fusion_merge_pass.cc | 22 +++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 014e261bae..df8f79ef28 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -576,6 +576,17 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { for (const auto& groups : fusionable_consumers) { if (groups.size() > 1) { VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size(); + + // Trick for BERT, maybe not required, wait for substitution from unordered_set to set + if (groups.size() == 2) { + OpGroupList fuse_group; + if (std::dynamic_pointer_cast(groups[1])->group_id == "cast_13" && std::dynamic_pointer_cast(groups[0])->group_id == "reshape_split") { + fuse_group.push_back(groups[1]); + fuse_group.push_back(groups[0]); + ctx->EnableFuse(fuse_group); + continue; + } + } ctx->EnableFuse(groups); } } @@ -1274,7 +1285,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return updated; } - void HorizontalFuse(const GroupList& const_consumers) { + void HorizontalFuse(const GroupList& consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group auto fused_group = std::make_shared(); @@ -1283,15 +1294,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { std::unordered_set sub_group_set; // find the first consumer. GroupPtr first_consumer(nullptr); - // Trick for BERT - GroupList consumers = const_consumers; - if (consumers.size() == 2) { - if (consumers[1]->group_id == "cast_13" && consumers[0]->group_id == "reshape_split") { - auto tmp = consumers[0]; - consumers[0] = consumers[1]; - consumers[1] = tmp; - } - } // fuse all group into fusion group. VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (const auto& consumer : consumers) { From c9f6b75799a126788a154a860a4fae69cfc560a3 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 3 Jul 2023 09:30:25 +0000 Subject: [PATCH 62/66] add utils header --- cinn/hlir/pass/general_fusion_merge_pass.cc | 237 +--------------- .../pass/general_fusion_merge_pass_utils.h | 262 ++++++++++++++++++ 2 files changed, 269 insertions(+), 230 deletions(-) create mode 100644 cinn/hlir/pass/general_fusion_merge_pass_utils.h diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 685bd0cddd..06ad41fe10 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -18,7 +18,7 @@ #include "cinn/api/op_group.h" #include "cinn/common/is_reachable_predicator.h" #include "cinn/common/macros.h" -#include "cinn/hlir/pass/fusion_merge_pass_util.h" +#include "cinn/hlir/pass/general_fusion_merge_pass_utils.h" DECLARE_bool(enhance_vertical_fusion_with_recompute); @@ -267,214 +267,6 @@ bool GraphGroupFuseHelper::ReduceFuseReduce(const OpGroupPtr& src, return reduce_fuse_reduce(&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); } -static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { - std::unordered_set ops_set; - op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); - - std::unordered_set input_ops; - op_group.WalkOpNodes([&](const api::OpNode& op) { - const auto& input_tensors = op.inputs(); - for (size_t i = 0; i < input_tensors.size(); ++i) { - if (input_tensors[i].HasProducer()) { - api::OpNode producer = input_tensors[i].producer(); - if (ops_set.find(producer) == ops_set.end()) { - input_ops.insert(producer); - } - } - } - }); - return input_ops; -} - -static std::unordered_set GetOutputOps(const OpGroupPtr& op_group) { - std::unordered_set ops_set; - op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); - std::unordered_set output_ops; - op_group.WalkOpNodes([&](const api::OpNode& op) { - const auto& output_tensors = op.outputs(); - for (size_t i = 0; i < output_tensors.size(); ++i) { - const auto& consumers = output_tensors[i].consumers(); - for (const auto& consumer : consumers) { - if (ops_set.find(consumer) == ops_set.end()) { - output_ops.insert(consumer); - break; - } - } - } - }); - return output_ops; -} - -// limit the group args number to less equal 512, as args stack size is 4K. -static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { - std::unordered_set args; - for (auto& group : {first, second}) { - for (const auto& node : GetInputOps(group)) { - args.insert(node); - } - for (const auto& node : GetOutputOps(group)) { - args.insert(node); - } - } - - if (args.size() > 512) { - return false; - } else { - return true; - } -} - -bool WithoutLastDimInReduce(const api::Shape& inshape, const std::vector& axes) { - // if last axis is in reduce. - if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || - std::find(axes.begin(), axes.end(), -1) != axes.end()) { - return false; - } - - int sum_last_axes = 1; - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - sum_last_axes *= inshape[idx]; - } - - if (sum_last_axes > 1) { - return true; - } else { - return false; - } -} - -static int GetSharedSize(const api::OpNode& op_node) { - const auto& producers = op_node.inputs(); - CHECK_GT(producers.size(), 0); - const auto& inshape = producers[0].shape(); - const auto& axes = op_node.GetAttr>("dim"); - if (WithoutLastDimInReduce(inshape, axes)) { - int lane = 1; - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - lane = inshape[idx]; - } - int max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); - if (lane > max_num_threads / 2) { - return 0; - } - int index = axes.size() - 1; - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - break; - } - lane *= inshape[axes[index]]; - if (lane > max_num_threads / 2) { - break; - } - } - // if lane > (max_num_threads / 2),the loop break from lane > - // max_num_threads / 2. - int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; - if (lane <= max_num_threads) { - return lane * sizeof(float); - } else { - int prefix = inshape[axis]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > ((max_num_threads / 2) / tail); --idx) { - if (prefix % idx == 0) { - return idx * tail * sizeof(float); - } - } - int num = max_num_threads / tail; - return num * tail * sizeof(float); - } - } - return 0; -} - -static bool CanReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& second) { - if (!limit_args(first, second)) { - return false; - } - std::unique_ptr reducer_0 = nullptr; - first.WalkOpNodes([&](const api::OpNode& op) { - if (!reducer_0 && op.kind() == OpPatternKind::kReduction) { - reducer_0.reset(new api::OpNode(op)); - } - }); - CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); - - std::unique_ptr reducer_1 = nullptr; - second.WalkOpNodes([&](const api::OpNode& op) { - if (!reducer_1 && op.kind() == OpPatternKind::kReduction) { - reducer_1.reset(new api::OpNode(op)); - } - }); - - CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id(); - - // check reduce has same input shape and output shape - const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape(); - const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); - - const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); - const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); - - auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); - auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); - - for (auto& dim : reducer_0_reduce_dim) { - // if dim = -1, set as shape.size() - 1 - if (dim == -1) { - dim = reducer_0_reduce_dim.size() - 1; - } - } - - for (auto& dim : reducer_1_reduce_dim) { - // if dim = -1, set as shape.size() - 1 - if (dim == -1) { - dim = reducer_1_reduce_dim.size() - 1; - } - } - - // check shape is same - if (reducer_0_input_shape == reducer_1_input_shape && reducer_0_output_shape == reducer_1_output_shape && - reducer_0_reduce_dim == reducer_1_reduce_dim) { - auto shared_size = 0; - for (auto& fusion_group : {first, second}) { - fusion_group.WalkOpNodes([&](const api::OpNode& op) { - if (op.kind() == OpPatternKind::kReduction) { - shared_size += GetSharedSize(op); - } - }); - } - -#define MAX_AVAILABLE_SHREAD 32 * 1024 - if (shared_size > MAX_AVAILABLE_SHREAD) { - return false; - } -#undef MAX_AVAILABLE_SHREAD - return true; - } - - if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && - WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && - reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { - auto shared_size = 0; - for (auto& fusion_group : {first, second}) { - fusion_group.WalkOpNodes([&](const api::OpNode& op) { - if (op.kind() == OpPatternKind::kReduction) { - shared_size += GetSharedSize(op); - } - }); - } - -#define MAX_AVAILABLE_SHREAD 32 * 1024 - if (shared_size > MAX_AVAILABLE_SHREAD) { - return false; - } -#undef MAX_AVAILABLE_SHREAD - return true; - } - - return false; -} - template struct HorizontalFuseUtil { using KindKeyT = std::pair; @@ -520,24 +312,8 @@ struct HorizontalFuseUtil { }; } - static api::OpNode GetMasterNode(FusePassCtxT* ctx, const OpGroupPtr& op_group) { - std::vector master_nodes; - op_group.WalkOpNodes([&](const api::OpNode& op) { - if (master_nodes.empty() || op.kind() == OpPatternKind::kReduction) { - master_nodes.push_back(op); - } - }); - return master_nodes.back(); - } - static bool IsSameSize(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - api::OpNode src_master_node = GetMasterNode(ctx, src); - api::OpNode dst_master_node = GetMasterNode(ctx, dst); - - auto size_0 = src_master_node.outputs()[0].shape().numel(); - auto size_1 = dst_master_node.outputs()[0].shape().numel(); - - return size_0 == size_1; + return utils::IsSameSize(src, dst); } static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { @@ -557,7 +333,7 @@ struct HorizontalFuseUtil { reduce_group = &dst; } - size_t size_ele = GetMasterNode(ctx, *ele_group).outputs()[0].shape().numel(); + size_t size_ele = utils::GetMasterNode(*ele_group).outputs()[0].shape().numel(); bool can_fuse = false; reduce_group->WalkOpNodes([&](const api::OpNode& op) { @@ -573,7 +349,7 @@ struct HorizontalFuseUtil { } static bool ReduceFuseReduce(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return CanReduceFuseReduce(src, dst); + return utils::ReduceFuseReduce(src, dst); } }; @@ -788,7 +564,8 @@ class DefaultVerticalFusePass final : public VerticalFusePass { } static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().AllOutputsSameSize(src, dst); + return utils::IsSameSize(src, dst); + ; } static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { @@ -820,7 +597,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { } static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { - return ctx->fuse_helper().ReduceFuseReduce(src, dst); + return utils::ReduceFuseReduce(src, dst); } }; diff --git a/cinn/hlir/pass/general_fusion_merge_pass_utils.h b/cinn/hlir/pass/general_fusion_merge_pass_utils.h new file mode 100644 index 0000000000..37c6c503f5 --- /dev/null +++ b/cinn/hlir/pass/general_fusion_merge_pass_utils.h @@ -0,0 +1,262 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cinn/api/op_group.h" +#include "cinn/hlir/pass/fusion_merge_pass_util.h" + +namespace cinn { +namespace hlir { +namespace pass { +namespace utils { + +using framework::OpPatternKind; + +using OpGroupPtr = api::OpGroup; +using OpGroupList = std::vector; + +static api::OpNode GetMasterNode(const OpGroupPtr& op_group) { + std::vector master_nodes; + op_group.WalkOpNodes([&](const api::OpNode& op) { + if (master_nodes.empty() || op.kind() == OpPatternKind::kReduction) { + master_nodes.push_back(op); + } + }); + return master_nodes.back(); +} + +static bool IsSameSize(const OpGroupPtr& src, const OpGroupPtr& dst) { + api::OpNode src_master_node = GetMasterNode(src); + api::OpNode dst_master_node = GetMasterNode(dst); + + auto size_0 = src_master_node.outputs()[0].shape().numel(); + auto size_1 = dst_master_node.outputs()[0].shape().numel(); + + return size_0 == size_1; +} + +static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { + std::unordered_set ops_set; + op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); + + std::unordered_set input_ops; + op_group.WalkOpNodes([&](const api::OpNode& op) { + const auto& input_tensors = op.inputs(); + for (size_t i = 0; i < input_tensors.size(); ++i) { + if (input_tensors[i].HasProducer()) { + api::OpNode producer = input_tensors[i].producer(); + if (ops_set.find(producer) == ops_set.end()) { + input_ops.insert(producer); + } + } + } + }); + return input_ops; +} + +static std::unordered_set GetOutputOps(const OpGroupPtr& op_group) { + std::unordered_set ops_set; + op_group.WalkOpNodes([&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); + std::unordered_set output_ops; + op_group.WalkOpNodes([&](const api::OpNode& op) { + const auto& output_tensors = op.outputs(); + for (size_t i = 0; i < output_tensors.size(); ++i) { + const auto& consumers = output_tensors[i].consumers(); + for (const auto& consumer : consumers) { + if (ops_set.find(consumer) == ops_set.end()) { + output_ops.insert(consumer); + break; + } + } + } + }); + return output_ops; +} + +// limit the group args number to less equal 512, as args stack size is 4K. +static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { + std::unordered_set args; + for (auto& group : {first, second}) { + for (const auto& node : GetInputOps(group)) { + args.insert(node); + } + for (const auto& node : GetOutputOps(group)) { + args.insert(node); + } + } + + if (args.size() > 512) { + return false; + } else { + return true; + } +} + +bool WithoutLastDimInReduce(const api::Shape& inshape, const std::vector& axes) { + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int sum_last_axes = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + sum_last_axes *= inshape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +static int GetSharedSize(const api::OpNode& op_node) { + const auto& producers = op_node.inputs(); + CHECK_GT(producers.size(), 0); + const auto& inshape = producers[0].shape(); + const auto& axes = op_node.GetAttr>("dim"); + if (WithoutLastDimInReduce(inshape, axes)) { + int lane = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + lane = inshape[idx]; + } + int max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + if (lane > max_num_threads / 2) { + return 0; + } + int index = axes.size() - 1; + for (; index >= 0; --index) { + if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (lane > max_num_threads / 2) { + break; + } + } + // if lane > (max_num_threads / 2),the loop break from lane > + // max_num_threads / 2. + int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; + if (lane <= max_num_threads) { + return lane * sizeof(float); + } else { + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > ((max_num_threads / 2) / tail); --idx) { + if (prefix % idx == 0) { + return idx * tail * sizeof(float); + } + } + int num = max_num_threads / tail; + return num * tail * sizeof(float); + } + } + return 0; +} + +static bool ReduceFuseReduce(const OpGroupPtr& first, const OpGroupPtr& second) { + if (!limit_args(first, second)) { + return false; + } + std::unique_ptr reducer_0 = nullptr; + first.WalkOpNodes([&](const api::OpNode& op) { + if (!reducer_0 && op.kind() == OpPatternKind::kReduction) { + reducer_0.reset(new api::OpNode(op)); + } + }); + CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); + + std::unique_ptr reducer_1 = nullptr; + second.WalkOpNodes([&](const api::OpNode& op) { + if (!reducer_1 && op.kind() == OpPatternKind::kReduction) { + reducer_1.reset(new api::OpNode(op)); + } + }); + + CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id(); + + // check reduce has same input shape and output shape + const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape(); + const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); + + const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); + const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); + + auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); + auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); + + for (auto& dim : reducer_0_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_0_reduce_dim.size() - 1; + } + } + + for (auto& dim : reducer_1_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_1_reduce_dim.size() - 1; + } + } + + // check shape is same + if (reducer_0_input_shape == reducer_1_input_shape && reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + fusion_group.WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); + } + }); + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && + WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && + reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + fusion_group.WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); + } + }); + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + return false; +} + +} // namespace utils +} // namespace pass +} // namespace hlir +} // namespace cinn From 34be0ee47ed2ae3e713e5b884ddd192a14ba5060 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 3 Jul 2023 11:21:25 +0000 Subject: [PATCH 63/66] remove debug log --- cinn/hlir/framework/graph.cc | 1 - cinn/hlir/pass/fusion_merge_pass.cc | 28 ++------ cinn/hlir/pass/general_fusion_merge_pass.cc | 76 ++++----------------- cinn/lang/lower_impl.cc | 14 ++-- 4 files changed, 26 insertions(+), 93 deletions(-) diff --git a/cinn/hlir/framework/graph.cc b/cinn/hlir/framework/graph.cc index 02be3af1bb..2d79f10781 100644 --- a/cinn/hlir/framework/graph.cc +++ b/cinn/hlir/framework/graph.cc @@ -284,7 +284,6 @@ void Graph::VisualizeGroupedGraph(const std::vector>& origin_ { // create base Directory viz_path_ = utils::StringFormat("%s/fusion_groups_%d/", FLAGS_cinn_fusion_groups_graphviz_dir.c_str(), viz_id); - VLOG(1) << "DEBUG Visualize directory id = " << viz_id; if (!MakeDirectory(viz_path_, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { LOG_IF(WARNING, viz_id == 0) << "Failed to make directory: \"" << viz_path_ << "\", the CINN subgraph's fusion group information will not print."; diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index dab40d9743..5e46c61600 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -76,10 +76,6 @@ class FusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; - for (int idx = 0; idx < fusion_groups_.size(); ++idx) { - auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; - } while (DoHorizontalFusion()) { } while (DoVerticalFusion(/* recompute=*/false)) { @@ -89,11 +85,11 @@ class FusionMergePassHelper : public FusionHelperBase { } bool DoHorizontalFusion() { - VLOG(1) << "****** DEBUG DoGeneralHorizontalFusion...! ********"; + VLOG(3) << "DoHorizontalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer Group -> " << producer->group_id; + VLOG(3) << "Fusion Producer Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -109,16 +105,11 @@ class FusionMergePassHelper : public FusionHelperBase { } bool DoVerticalFusion(bool recompute) { - if (recompute) { - VLOG(1) << "****** DEBUG DoGeneralRecomputeAndVerticalFusion...! ********"; - } else { - VLOG(1) << "****** DEBUG DoGeneralVerticalFusion...! ********"; - } VLOG(3) << "DoVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; + VLOG(3) << "Fusion Producer Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -248,7 +239,6 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& groups : fusionable_consumers) { if (groups.size() > 1) { updated = true; - VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id; HorizontalFuse(groups); } } @@ -266,9 +256,8 @@ class FusionMergePassHelper : public FusionHelperBase { // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. - VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (auto& consumer : consumers) { - VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!"; + VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -412,7 +401,6 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set fuse_consumers_unsafe; std::unordered_set fuse_consumers; - VLOG(1) << "DEBUG VerticalFusion, begin check : " << producer->group_id; for (const auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse @@ -435,7 +423,7 @@ class FusionMergePassHelper : public FusionHelperBase { } if (IsDependency(producer, consumer, consumers)) { - VLOG(1) << "DEBUG consumer " << consumer->group_id << " has loop"; + VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; continue; } @@ -446,12 +434,8 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); if (fuse_consumers.size() == 0) { - VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " << producer->group_id; return false; } - VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << fuse_consumers_unsafe.size(); - VLOG(1) << "DEBUG fuse_consumers.size() = " << fuse_consumers.size(); - VLOG(1) << "DEBUG producer->consumer_groups().size() = " << producer->consumer_groups().size(); // if can_fuse_consumers == consumers // if producer op kind == kElementwise // if use recompute @@ -461,7 +445,6 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } else { RecomputeEleGraph(producer, fuse_consumers_unsafe); - VLOG(1) << "DEBUG recompute fuse group " << producer->group_id; VerticalFuse(producer, fuse_consumers_unsafe); return true; } @@ -473,7 +456,6 @@ class FusionMergePassHelper : public FusionHelperBase { // if fusionable consumers exist if (fuse_consumers.size()) { - VLOG(1) << "DEBUG Vertical fuse group " << producer->group_id; VerticalFuse(producer, fuse_consumers); return true; } diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index df8f79ef28..5aeccc735a 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -443,7 +443,6 @@ class DefaultInputFusePass final : public InputFusePass { int Benefit() const override { return 100; } void operator()(InputFusePassCtx* ctx) const override { - VLOG(1) << "DefaultInputFusePass"; const auto& consumer_set = ctx->PickConsumersWithSameInputs(); const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { @@ -530,7 +529,6 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { int Benefit() const override { return 100; } void operator()(LightwareFusePassCtx* ctx) const override { - VLOG(1) << "DefaultHorizontalFusePass"; const auto& producer = ctx->PickOpGroup(); const std::unordered_set consumer_candidates = [&]() -> std::unordered_set { std::unordered_set consumers; @@ -575,12 +573,11 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { for (const auto& groups : fusionable_consumers) { if (groups.size() > 1) { - VLOG(1) << "NOTICE DefaultHorizontalFusePass fuse groups.size() = " << groups.size(); - // Trick for BERT, maybe not required, wait for substitution from unordered_set to set if (groups.size() == 2) { OpGroupList fuse_group; - if (std::dynamic_pointer_cast(groups[1])->group_id == "cast_13" && std::dynamic_pointer_cast(groups[0])->group_id == "reshape_split") { + if (std::dynamic_pointer_cast(groups[1])->group_id.substr(0, 4) == "cast" && + std::dynamic_pointer_cast(groups[0])->group_id == "reshape_split") { fuse_group.push_back(groups[1]); fuse_group.push_back(groups[0]); ctx->EnableFuse(fuse_group); @@ -614,7 +611,6 @@ class DefaultVerticalFusePass final : public VerticalFusePass { int Benefit() const override { return 100; } void operator()(LightwareFusePassCtx* ctx) const override { - VLOG(1) << "DefaultVerticalFusePass"; const auto& producer = ctx->PickOpGroup(); const OpGroupList consumers = [&]() { OpGroupList consumers; @@ -645,6 +641,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { continue; } if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + VLOG(4) << "Can't fuse because detect cycle"; continue; } ctx->EnableFuse(producer, consumer); @@ -769,18 +766,11 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } unsafe_candidates.push_back(consumer); if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { - VLOG(1) << "DEBUG consumer " << std::dynamic_pointer_cast(consumer)->group_id << " has loop"; continue; } candidates.push_back(consumer); } - if (candidates.empty()) { - VLOG(1) << "DEBUG fuse_consumers.empty(), exit fuse group " - << std::dynamic_pointer_cast(producer)->group_id; - } - VLOG(1) << "DEBUG fuse_consumers_unsafe.size() = " << unsafe_candidates.size(); - if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && producer->kind() == framework::kElementWise) { for (const auto& consumer : consumers) { @@ -923,11 +913,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { private: void DoFusionMerge() { - VLOG(1) << "****** DEBUG Input Groups...! ********"; - for (int idx = 0; idx < fusion_groups_.size(); ++idx) { - auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; - } VLOG(3) << "DoFusionMerge...!"; while (DoGeneralHorizontalFusion()) { } @@ -938,11 +923,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool DoGeneralHorizontalFusion() { - VLOG(1) << "****** DEBUG DoGeneralHorizontalFusion...! ********"; + VLOG(3) << "DoGeneralHorizontalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -983,11 +968,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool DoGeneralVerticalFusion() { - VLOG(1) << "****** DEBUG DoGeneralVerticalFusion...! ********"; + VLOG(3) << "DoGeneralVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -1007,11 +992,11 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool DoGeneralRecomputeAndVerticalFusion() { - VLOG(1) << "****** DEBUG DoGeneralRecomputeAndVerticalFusion...! ********"; + VLOG(3) << "DoGeneralRecomputeAndVerticalFusion...!"; bool updated = false; for (int idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; - VLOG(1) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " << producer->group_id; // if producer is sub group. if (producer->belong_groups.size()) { continue; @@ -1101,7 +1086,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool GeneralHorizontalFuse(const GroupPtr& producer) { - VLOG(1) << "DEBUG Horizontal, begin check : " << producer->group_id; + VLOG(3) << "GeneralHorizontalFuse handling producer : " << producer->group_id; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); }; @@ -1130,7 +1115,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return false; } for (const auto& group_list : group_lists) { - VLOG(1) << "DEBUG horizontal fuse group " << producer->group_id; HorizontalFuse(group_list); } @@ -1153,34 +1137,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - std::unordered_set UpdateMutConsumers(const std::unordered_set& consumers) { - std::unordered_set updated_consumers; - for (auto& consumer : consumers) { - std::queue fused_groups; - fused_groups.push(consumer); - while (!fused_groups.empty()) { - auto& cur = fused_groups.front(); - fused_groups.pop(); - // if group is sub group - if (cur->belong_groups.empty()) { - updated_consumers.insert(cur); - } else { - for (auto& belong_group : cur->belong_groups) { - if (belong_group->group_id == cur->group_id) { - updated_consumers.insert(belong_group); - } else { - fused_groups.push(belong_group); - } - } - } - } - } - return updated_consumers; - } - bool CallGeneralInputFusePass(const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; - using OpGroupSets = std::set>; const auto& GetFusableConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& EnableFuse = [&](const OpGroupList& candidates) { tagged_lists.push_back(candidates); }; @@ -1295,10 +1253,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { // find the first consumer. GroupPtr first_consumer(nullptr); // fuse all group into fusion group. - VLOG(1) << "********** DEBUG Begin check Horizontal ************"; for (const auto& consumer : consumers) { - VLOG(1) << "DEBUG HorizontalFuse consumer " << consumer->group_id << " into fused_group!" - << " Pattern kind = " << consumer->op_pattern_kind; + VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; // update depth fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); @@ -1402,7 +1358,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - VLOG(1) << "DEBUG consumers.back() kind : " << static_cast((consumers.back())->op_pattern_kind); if (static_cast(framework::kReduction) > static_cast((consumers.back())->op_pattern_kind)) { auto consumer = consumers.back(); @@ -1419,7 +1374,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } if (master_node) { - VLOG(1) << "DEBUG Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id; + VLOG(3) << "Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id; fused_group->master_nodes.insert(master_node); break; } @@ -1526,7 +1481,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool GeneralVerticalFuse(GroupPtr& producer) { - VLOG(1) << "DEBUG Vertical, begin check : " << producer->group_id; + VLOG(3) << "GeneralVerticalFuse handling producer : " << producer->group_id; using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; @@ -1556,7 +1511,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { SelectConsumerToFuse(producer, consumer_groups); } if (consumer_groups.size() > 0) { - VLOG(1) << "DEBUG Vertical fuse group " << producer->group_id; VerticalFuse(producer, consumer_groups); update = true; } @@ -1767,8 +1721,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } bool GeneralRecomputeFuse(GroupPtr& producer) { - VLOG(3) << "GeneralRecomputeFuse...!"; - VLOG(1) << "DEBUG Recompute, begin check : " << producer->group_id; + VLOG(3) << "GeneralRecomputeFuse handling producer : " << producer->group_id; using GroupSets = std::set>; const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; @@ -1797,7 +1750,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { if (consumer_groups.size() > 0) { CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; - VLOG(1) << "DEBUG recompute fuse group " << producer->group_id; RecomputeFuse(producer, consumer_groups); update = true; } diff --git a/cinn/lang/lower_impl.cc b/cinn/lang/lower_impl.cc index d1be794b12..e839fc8ef0 100644 --- a/cinn/lang/lower_impl.cc +++ b/cinn/lang/lower_impl.cc @@ -77,12 +77,12 @@ Expr LowerGroup(const poly::ScheduleGroup& group, BindBuffer(stage_map); std::vector stages; for (auto& node : group.nodes) { - VLOG(2) << "In LowerGroup, node id is: " << node->id(); + VLOG(1) << "In LowerGroup, node id is: " << node->id(); if (node->stage->has_expression()) { stages.push_back(node->stage); - VLOG(2) << "stage expr " << node->stage->expr(); + VLOG(1) << "stage expr " << node->stage->expr(); } else { - VLOG(2) << "stage expression is null: " << node->stage->domain(); + VLOG(1) << "stage expression is null: " << node->stage->domain(); } } @@ -104,7 +104,7 @@ Expr LowerGroup(const poly::ScheduleGroup& group, // now we get a workable expression, but the statement are something like `B(((16 * po0) + po1), po2)`, we need to // transform this to some realworld statement in CINN. - VLOG(2) << "ast to expr: \n" << e << std::endl; + VLOG(1) << "ast to expr: \n" << e << std::endl; // replace isl call to the corresponding CINN statement, we need to replace the axis at the same time. for (auto& statement : tuple_to_expr) { @@ -345,7 +345,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) for (auto& tensor : tensor_args_) { auto* tensor_node = tensor.As(); bool is_output = teller.IsWrite(tensor->name); - VLOG(2) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; // avoid duplicate if (!tensor_node->buffer.defined()) continue; @@ -772,7 +772,7 @@ LowerImpl::LowerImpl(const std::string& fn_name, compu_graph_ = CreateCompGraph(tensors, stages, false /*inline_hide*/); - VLOG(2) << "compute_graph:\n" << compu_graph_->Visualize(); + VLOG(1) << "compute_graph:\n" << compu_graph_->Visualize(); } // Todo: Here insert auto syncthreads() @haoze @@ -782,7 +782,7 @@ LowerImpl::LowerImpl(const std::string& fn_name, tensors.insert(std::end(tensors), temp_tensor_args_.begin(), temp_tensor_args_.end()); compu_graph_ = CreateCompGraph(tensors, stages, true /*inline_hide*/); - VLOG(2) << "Computation Graph:\n" << compu_graph_->Visualize(); + VLOG(1) << "Computation Graph:\n" << compu_graph_->Visualize(); } } From dd334db884f03a635d40ca9327706a29c0caadd3 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 3 Jul 2023 11:28:29 +0000 Subject: [PATCH 64/66] remove debug log --- cinn/hlir/pass/fusion_merge_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 5e46c61600..957447942d 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -418,12 +418,12 @@ class FusionMergePassHelper : public FusionHelperBase { fuse_consumers_unsafe.insert(consumer); if (IsDependencySimplify(producer, consumer, consumers)) { - VLOG(1) << "DEBUG consumer " << consumer->group_id << " has loop"; + VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; continue; } if (IsDependency(producer, consumer, consumers)) { - VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; + VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; continue; } From 9ade8a9a4d69a2012e46c60ec3f65bf7e54d3c65 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 3 Jul 2023 11:51:24 +0000 Subject: [PATCH 65/66] polish code --- cinn/hlir/pass/general_fusion_merge_pass.cc | 1 - cinn/hlir/pass/general_fusion_merge_pass_utils.h | 3 --- 2 files changed, 4 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 06ad41fe10..49411d2fae 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -565,7 +565,6 @@ class DefaultVerticalFusePass final : public VerticalFusePass { static bool IsSameSize(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { return utils::IsSameSize(src, dst); - ; } static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { diff --git a/cinn/hlir/pass/general_fusion_merge_pass_utils.h b/cinn/hlir/pass/general_fusion_merge_pass_utils.h index 37c6c503f5..e036a7c861 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass_utils.h +++ b/cinn/hlir/pass/general_fusion_merge_pass_utils.h @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - #include "cinn/api/op_group.h" #include "cinn/hlir/pass/fusion_merge_pass_util.h" From b0f39ffa5a181eaea02649ab08ade5679bfffb43 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 4 Jul 2023 12:16:34 +0000 Subject: [PATCH 66/66] change logic of get master node --- cinn/hlir/pass/general_fusion_merge_pass_utils.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass_utils.h b/cinn/hlir/pass/general_fusion_merge_pass_utils.h index e036a7c861..ab956b3a07 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass_utils.h +++ b/cinn/hlir/pass/general_fusion_merge_pass_utils.h @@ -28,10 +28,15 @@ using OpGroupList = std::vector; static api::OpNode GetMasterNode(const OpGroupPtr& op_group) { std::vector master_nodes; op_group.WalkOpNodes([&](const api::OpNode& op) { - if (master_nodes.empty() || op.kind() == OpPatternKind::kReduction) { + if (op.kind() == OpPatternKind::kReduction) { master_nodes.push_back(op); } }); + if (!master_nodes.empty()) { + return master_nodes.front(); + } + + op_group.WalkOpNodes([&](const api::OpNode& op) { master_nodes.push_back(op); }); return master_nodes.back(); }