diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index eb2d934e0f..3664092061 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -50,8 +50,10 @@ namespace ir { class ScheduleImpl { public: ScheduleImpl() = default; - explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false) - : module_expr_(module_expr), debug_flag_(debug_flag) {} + explicit ScheduleImpl(const ModuleExpr& module_expr, + bool debug_flag = false, + ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank) + : module_expr_(module_expr), debug_flag_(debug_flag), err_msg_level_(err_msg_level) {} explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {} //! Set the debug flag. @@ -114,8 +116,32 @@ class ScheduleImpl { ModuleExpr module_expr_; bool debug_flag_{false}; + ScheduleErrorMessageLevel err_msg_level_; }; +/** \brief A macro that guards the beginning of each implementation of schedule */ +#define CINN_IR_SCHEDULE_BEGIN() try { +/** + * \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential errors and error + * message printing + * \param primitive A string representing the kind of schedule primitive + * \param err_msg_level A ScheduleErrorMessageLevel enum, level of error message printing + */ +#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \ + } \ + catch (const IRScheduleErrorHandler& err_hanlder) { \ + switch (err_msg_level) { \ + case ScheduleErrorMessageLevel::kDetailed: \ + throw std::runtime_error(err_hanlder.FormatErrorMessage(primitive)); \ + case ScheduleErrorMessageLevel::kGenearl: \ + throw std::runtime_error(err_hanlder.GeneralErrorMessage()); \ + case ScheduleErrorMessageLevel::kBlank: \ + throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \ + default: \ + throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \ + } \ + } + std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& factors) { CHECK(loop.As()) << "Expr param of Split must be For node! Please check."; auto* for_node = loop.As(); @@ -126,8 +152,10 @@ std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " << tot_extent << ") to (" << cinn::utils::Join(factors, ", ") << ") at loop:\n" << loop; - - auto processed_factors = ValidateFactors(factors, tot_extent); + std::vector processed_factors; + CINN_IR_SCHEDULE_BEGIN(); + processed_factors = ValidateFactors(factors, tot_extent); + CINN_IR_SCHEDULE_END("split", this->err_msg_level_); int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, std::multiplies()); std::vector new_loop_vars; Expr substitute_value(0); @@ -1971,8 +1999,11 @@ Expr ScheduleImpl::SampleCategorical(utils::LinearRandomEngine::StateType* rand_ IRSchedule::IRSchedule() {} -IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag) { - impl_ = std::make_unique(module_expr, debug_flag); +IRSchedule::IRSchedule(const ModuleExpr& module_expr, + utils::LinearRandomEngine::StateType rand_seed, + bool debug_flag, + ScheduleErrorMessageLevel err_msg_level) { + impl_ = std::make_unique(module_expr, debug_flag, err_msg_level); this->InitSeed(rand_seed); } diff --git a/cinn/ir/ir_schedule.h b/cinn/ir/ir_schedule.h index 6b7b252a57..2bcc41ebcb 100644 --- a/cinn/ir/ir_schedule.h +++ b/cinn/ir/ir_schedule.h @@ -22,6 +22,7 @@ #include "cinn/ir/ir.h" #include "cinn/ir/ir_base.h" #include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_schedule_error.h" #include "cinn/ir/schedule_desc.h" #include "cinn/ir/tensor.h" #include "cinn/utils/random_engine.h" @@ -67,7 +68,8 @@ class IRSchedule { IRSchedule(); explicit IRSchedule(const ModuleExpr& modexpr, utils::LinearRandomEngine::StateType rand_seed = -1, - bool debug_flag = false); + bool debug_flag = false, + ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank); IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1); IRSchedule(const IRSchedule& other); IRSchedule& operator=(const IRSchedule& src); diff --git a/cinn/ir/ir_schedule_error.cc b/cinn/ir/ir_schedule_error.cc index 4d3f4e3524..c54433e17b 100644 --- a/cinn/ir/ir_schedule_error.cc +++ b/cinn/ir/ir_schedule_error.cc @@ -17,11 +17,11 @@ namespace cinn { namespace ir { -std::string IRScheduleErrorHandler::FormatErrorMessage(const std::string &primitive) { +std::string IRScheduleErrorHandler::FormatErrorMessage(const std::string &primitive) const { std::ostringstream os; std::string err_msg = DetailedErrorMessage(); - os << "[IRScheduleError] An error occurred in the schedue primitive <" << primitive << ">. " << std::endl; + os << "[IRScheduleError] An error occurred in the scheduel primitive <" << primitive << ">. " << std::endl; os << "Error info: " << err_msg; return os.str(); } diff --git a/cinn/ir/ir_schedule_error.h b/cinn/ir/ir_schedule_error.h index 273dd08f7c..1cdc29efc0 100644 --- a/cinn/ir/ir_schedule_error.h +++ b/cinn/ir/ir_schedule_error.h @@ -36,7 +36,7 @@ enum class ScheduleErrorMessageLevel : int32_t { }; /** - * This handler is to deal with the errors happens in in the current Scheduling. + * This handler is dealing with the errors happen in in the current Scheduling. */ class IRScheduleErrorHandler : public std::runtime_error { public: @@ -50,17 +50,17 @@ class IRScheduleErrorHandler : public std::runtime_error { /** * \brief Returns a detailed error message corresponding to the kDetailed error level. */ - std::string FormatErrorMessage(const std::string &primitive); + std::string FormatErrorMessage(const std::string &primitive) const; /** - * \brief Returns a detailed error message corresponding to the kDetailed error level. + * \brief Returns a short error message corresponding to the kGeneral error level. */ - virtual std::string DetailedErrorMessage() const = 0; + virtual std::string GeneralErrorMessage() const = 0; /** - * \brief Returns a short error message corresponding to the kGeneral error level. + * \brief Returns a detailed error message corresponding to the kDetailed error level. */ - virtual std::string GeneralErrorMessage() const = 0; + virtual std::string DetailedErrorMessage() const = 0; }; } // namespace ir diff --git a/cinn/ir/ir_schedule_util.cc b/cinn/ir/ir_schedule_util.cc index 054e05dee0..af1b1adf01 100644 --- a/cinn/ir/ir_schedule_util.cc +++ b/cinn/ir/ir_schedule_util.cc @@ -29,6 +29,7 @@ #include "cinn/ir/ir.h" #include "cinn/ir/ir_operators.h" #include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule_error.h" #include "cinn/ir/ir_visitor.h" #include "cinn/lang/compute.h" #include "cinn/optim/ir_copy.h" @@ -196,14 +197,66 @@ void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vect } std::vector ValidateFactors(const std::vector& factors, int total_extent) { + class NegativeFactorErrorHandler : public IRScheduleErrorHandler { + public: + explicit NegativeFactorErrorHandler(int64_t factor, size_t idx) : factor_(factor), idx_(idx) {} + + std::string GeneralErrorMessage() const final { + return "[IRScheduleError]: The params in factors of Split should be positive. However, some " + "factor is zero or negative."; + } + + std::string DetailedErrorMessage() const final { + std::ostringstream os; + os << "The params in factors of Split should be positive. However, the factor at position " << idx_ << " is " + << factor_; + return os.str(); + } + + private: + int64_t factor_; + size_t idx_; + }; + + class InferFactorErrorHandler : public IRScheduleErrorHandler { + public: + std::string GeneralErrorMessage() const final { + return "[IRScheduleError]: The params in factors of Split should not be less than -1 or have more than one -1!"; + } + + std::string DetailedErrorMessage() const final { + std::ostringstream os; + os << "The params in factors of Split should not be less than -1 or have more than one -1!"; + return os.str(); + } + }; + + class FactorProductErrorHandler : public IRScheduleErrorHandler { + public: + std::string GeneralErrorMessage() const final { + return "[IRScheduleError]: In Split, the factors' product should be not larger than or equal to original loop's " + "extent!"; + } + + std::string DetailedErrorMessage() const final { + std::ostringstream os; + os << "In Split, the factors' product should be not larger than or equal to original loop's extent!"; + return os.str(); + } + }; + CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; bool has_minus_one = false; int product = 1; + int idx = -1; for (auto& i : factors) { - CHECK(i != 0) << "The params in factors of Split should not be 0! Please check."; - CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check."; - if (i == -1) { - CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check."; + idx++; + if (i == 0 || i < -1) { + throw NegativeFactorErrorHandler(i, idx); + } else if (i == -1) { + if (has_minus_one) { + throw InferFactorErrorHandler(); + } has_minus_one = true; } else { product *= i; @@ -211,12 +264,14 @@ std::vector ValidateFactors(const std::vector& factors, int total_exte } std::vector validated_factors = factors; if (!has_minus_one) { - CHECK_GE(product, total_extent) - << "In Split, the factors' product should be equal to original loop's extent! Please check."; + if (product < total_extent) { + throw FactorProductErrorHandler(); + } return validated_factors; } else { - CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= " - "original loop's extent! Please check."; + if (product > total_extent) { + throw FactorProductErrorHandler(); + } int minus_one_candidate = (int)ceil((double)total_extent / (double)product); for (int i = 0; i < validated_factors.size(); ++i) { if (validated_factors[i] == -1) {