diff --git a/xla/hlo/evaluator/hlo_evaluator.h b/xla/hlo/evaluator/hlo_evaluator.h index 6bfd4d7aa75c8..03e0f64bed9da 100644 --- a/xla/hlo/evaluator/hlo_evaluator.h +++ b/xla/hlo/evaluator/hlo_evaluator.h @@ -462,11 +462,13 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { bool use_fast_path_reduce_ = true; private: - template + template static absl::StatusOr ElementWiseUnaryOpImpl( - const HloInstruction* instruction, - const std::function& unary_op, + const HloInstruction* instruction, UnaryOp&& unary_op, const Literal& operand_literal) { + static_assert(std::is_invocable_r_v, + "Invalid UnaryOp signature"); + const Shape& shape = instruction->shape(); const auto* operand = instruction->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index e063a90f6d951..3b0cbb9038eaa 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -141,25 +141,29 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { public: explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} - // The following higher-order functions convert a function with ElementwiseT - // to a function with ReturnT. - std::function ConvertUnaryFunction( - const std::function& unary_op) { + // Converts a UnaryOp from a unary function on ElementwiseT to a unary + // function on ReturnT types. + template + auto ConvertUnaryFunction(const UnaryOp& unary_op) { return [&unary_op](ReturnT arg) { return static_cast(unary_op(static_cast(arg))); }; } - std::function ConvertBinaryFunction( - const std::function& - binary_op) { + + // Converts a BinaryOp from a binary function on ElementwiseT to a binary + // function on ReturnT types. + template + auto ConvertBinaryFunction(const BinaryOp& binary_op) { return [&binary_op](ReturnT arg1, ReturnT arg2) { return static_cast(binary_op(static_cast(arg1), static_cast(arg2))); }; } - std::function ConvertTernaryFunction( - const std::function& ternary_op) { + + // Converts a TernaryOp from a ternary function on ElementwiseT to a ternary + // function on ReturnT types. + template + auto ConvertTernaryFunction(const TernaryOp& ternary_op) { return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { return static_cast(ternary_op(static_cast(arg1), static_cast(arg2), @@ -748,10 +752,9 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } return std::min(high, std::max(value, low)); }; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clamp], - ElementwiseTernaryOp(clamp, - std::move(ConvertTernaryFunction(clamp_op)))); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], + (ElementwiseTernaryOp( + clamp, ConvertTernaryFunction(clamp_op)))); return absl::OkStatus(); } return UnsupportedTypeError(clamp); @@ -760,15 +763,12 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { absl::Status HandleSelect(const HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); CHECK(select->shape().IsArray()); - std::function select_op = - [](bool pred, ReturnT on_true, ReturnT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; + auto select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { + return pred ? on_true : on_false; + }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementwiseTernaryOp(select, std::move(select_op))); + (ElementwiseTernaryOp( + select, std::move(select_op)))); return absl::OkStatus(); } @@ -1646,9 +1646,12 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } private: - absl::StatusOr ElementWiseUnaryOp( - const HloInstruction* instruction, - const std::function& unary_op) { + template + absl::StatusOr ElementWiseUnaryOp(const HloInstruction* instruction, + UnaryOp&& unary_op) { + static_assert(std::is_invocable_r_v, + "Invalid UnaryOp signature"); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(instruction->operand(0)); TF_ASSIGN_OR_RETURN( @@ -1656,13 +1659,16 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { (HloEvaluator::ElementWiseUnaryOpImpl( instruction, ConvertUnaryFunction(unary_op), operand_literal))); - return std::move(result_literal); + return result_literal; } - absl::StatusOr ElementWiseBinaryOp( - const HloInstruction* instruction, - const std::function& - binary_op) { + template + absl::StatusOr ElementWiseBinaryOp(const HloInstruction* instruction, + BinaryOp&& binary_op) { + static_assert(std::is_invocable_r_v, + "Invalid BinaryOp signature"); + const auto& shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); @@ -1694,13 +1700,18 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { rhs_literal.Get(multi_index)); })); } - return std::move(result); + + return result; } - template + template absl::StatusOr ElementwiseTernaryOp( - const HloInstruction* instruction, - const std::function& ternary_op) { + const HloInstruction* instruction, TernaryOp&& ternary_op) { + static_assert( + std::is_invocable_r_v, + "Invalid TernaryOp signature"); + const auto& shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); @@ -1739,7 +1750,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { })); } - return std::move(result); + return result; } template