Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:hlo] Templetize HloEvaluator on a hot path to remove std::function overheads #23217

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions xla/hlo/evaluator/hlo_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,13 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault {
bool use_fast_path_reduce_ = true;

private:
template <typename ReturnT, typename NativeT>
template <typename ReturnT, typename NativeT, typename UnaryOp>
static absl::StatusOr<Literal> ElementWiseUnaryOpImpl(
const HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const HloInstruction* instruction, UnaryOp&& unary_op,
const Literal& operand_literal) {
static_assert(std::is_invocable_r_v<ReturnT, UnaryOp, NativeT>,
"Invalid UnaryOp signature");

const Shape& shape = instruction->shape();
const auto* operand = instruction->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape()));
Expand Down
81 changes: 46 additions & 35 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,29 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
public:
explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {}

// The following higher-order functions convert a function with ElementwiseT
// to a function with ReturnT.
std::function<ReturnT(ReturnT)> ConvertUnaryFunction(
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
// Converts a UnaryOp from a unary function on ElementwiseT to a unary
// function on ReturnT types.
template <typename UnaryOp>
auto ConvertUnaryFunction(const UnaryOp& unary_op) {
return [&unary_op](ReturnT arg) {
return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg)));
};
}
std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction(
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {

// Converts a BinaryOp from a binary function on ElementwiseT to a binary
// function on ReturnT types.
template <typename BinaryOp>
auto ConvertBinaryFunction(const BinaryOp& binary_op) {
return [&binary_op](ReturnT arg1, ReturnT arg2) {
return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1),
static_cast<ElementwiseT>(arg2)));
};
}
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction(
const std::function<ElementwiseT(ElementwiseT, ElementwiseT,
ElementwiseT)>& ternary_op) {

// Converts a TernaryOp from a ternary function on ElementwiseT to a ternary
// function on ReturnT types.
template <typename TernaryOp>
auto ConvertTernaryFunction(const TernaryOp& ternary_op) {
return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) {
return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1),
static_cast<ElementwiseT>(arg2),
Expand Down Expand Up @@ -748,10 +752,9 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
}
return std::min(high, std::max(value, low));
};
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[clamp],
ElementwiseTernaryOp(clamp,
std::move(ConvertTernaryFunction(clamp_op))));
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
(ElementwiseTernaryOp<ReturnT, ReturnT, ReturnT>(
clamp, ConvertTernaryFunction(clamp_op))));
return absl::OkStatus();
}
return UnsupportedTypeError(clamp);
Expand All @@ -760,15 +763,12 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
absl::Status HandleSelect(const HloInstruction* select) override {
CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
CHECK(select->shape().IsArray());
std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
[](bool pred, ReturnT on_true, ReturnT on_false) {
if (pred) {
return on_true;
}
return on_false;
};
auto select_op = [](bool pred, ReturnT on_true, ReturnT on_false) {
return pred ? on_true : on_false;
};
TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
ElementwiseTernaryOp(select, std::move(select_op)));
(ElementwiseTernaryOp<bool, ReturnT, ReturnT>(
select, std::move(select_op))));
return absl::OkStatus();
}

Expand Down Expand Up @@ -1646,23 +1646,29 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
}

private:
absl::StatusOr<Literal> ElementWiseUnaryOp(
const HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
template <typename UnaryOp>
absl::StatusOr<Literal> ElementWiseUnaryOp(const HloInstruction* instruction,
UnaryOp&& unary_op) {
static_assert(std::is_invocable_r_v<ElementwiseT, UnaryOp, ElementwiseT>,
"Invalid UnaryOp signature");

const Literal& operand_literal =
parent_->GetEvaluatedLiteralFor(instruction->operand(0));
TF_ASSIGN_OR_RETURN(
auto result_literal,
(HloEvaluator::ElementWiseUnaryOpImpl<ReturnT, ReturnT>(
instruction, ConvertUnaryFunction(unary_op), operand_literal)));

return std::move(result_literal);
return result_literal;
}

absl::StatusOr<Literal> ElementWiseBinaryOp(
const HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
template <typename BinaryOp>
absl::StatusOr<Literal> ElementWiseBinaryOp(const HloInstruction* instruction,
BinaryOp&& binary_op) {
static_assert(std::is_invocable_r_v<ElementwiseT, BinaryOp, ElementwiseT,
ElementwiseT>,
"Invalid BinaryOp signature");

const auto& shape = instruction->shape();
const auto* lhs = instruction->operand(0);
const auto* rhs = instruction->operand(1);
Expand Down Expand Up @@ -1694,13 +1700,18 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
rhs_literal.Get<ReturnT>(multi_index));
}));
}
return std::move(result);

return result;
}

template <typename LhsType, typename RhsType, typename EhsType>
template <typename LhsType, typename RhsType, typename EhsType,
typename TernaryOp>
absl::StatusOr<Literal> ElementwiseTernaryOp(
const HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const HloInstruction* instruction, TernaryOp&& ternary_op) {
static_assert(
std::is_invocable_r_v<ReturnT, TernaryOp, LhsType, RhsType, EhsType>,
"Invalid TernaryOp signature");

const auto& shape = instruction->shape();
const auto* lhs = instruction->operand(0);
const auto* rhs = instruction->operand(1);
Expand Down Expand Up @@ -1739,7 +1750,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
}));
}

return std::move(result);
return result;
}

template <typename NativeT>
Expand Down
Loading