Skip to content

Commit

Permalink
Repo Sync (#631)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Mar 29, 2024
1 parent 16d67e9 commit fa84a1c
Show file tree
Hide file tree
Showing 25 changed files with 419 additions and 121 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- [Feature] Add minimax approximation for log
- [Feature] Support jax.lax.top_k
- [Feature] Support round to nearest even
- [Improvement] Default log approximation to minmax
- [Improvement] Improve median performance

Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/passes/hlo_legalize_to_pphlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,7 @@ struct HloLegalizeToPPHlo
HloToPPHloOpConverter<stablehlo::ReturnOp>,
HloToPPHloOpConverter<stablehlo::RngOp>,
HloToPPHloOpConverter<stablehlo::RoundOp>,
HloToPPHloOpConverter<stablehlo::RoundNearestEvenOp>,
HloToPPHloOpConverter<stablehlo::RsqrtOp>,
HloToPPHloOpConverter<stablehlo::SineOp>,
HloToPPHloOpConverter<stablehlo::SelectOp>,
Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/passes/map_stablehlo_to_pphlo_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ MAP_HLO_TO_PPHLO(RemOp)
MAP_HLO_TO_PPHLO(ReshapeOp)
MAP_HLO_TO_PPHLO(ReverseOp)
MAP_HLO_TO_PPHLO(RoundOp)
MAP_HLO_TO_PPHLO(RoundNearestEvenOp)
MAP_HLO_TO_PPHLO(RngOp)
MAP_HLO_TO_PPHLO(SelectOp)
MAP_HLO_TO_PPHLO(ShiftLeftOp)
Expand Down
5 changes: 3 additions & 2 deletions libspu/device/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
}

// print link statistics
SPDLOG_INFO("Link details: total send bytes {}, send actions {}",
comm_stats.send_bytes, comm_stats.send_actions);
SPDLOG_INFO(
"Link details: total send bytes {}, recv bytes {}, send actions {}",
comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions);
}

void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) {
Expand Down
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ STANDARD_UNARY_OP_EXEC_IMPL(NotOp, Not)
STANDARD_UNARY_OP_EXEC_IMPL(RsqrtOp, Rsqrt)
STANDARD_UNARY_OP_EXEC_IMPL(SqrtOp, Sqrt)
STANDARD_UNARY_OP_EXEC_IMPL(RoundOp, Round_AFZ)
STANDARD_UNARY_OP_EXEC_IMPL(RoundNearestEvenOp, Round_RNTE)
STANDARD_UNARY_OP_EXEC_IMPL(SineOp, Sine)
STANDARD_UNARY_OP_EXEC_IMPL(CosineOp, Cosine)

Expand Down
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ UNARY_VERIFIER(ExpOp, evalExponentialOp)
UNARY_VERIFIER(RsqrtOp, evalRsqrtOp)
UNARY_VERIFIER(SqrtOp, evalSqrtOp)
UNARY_VERIFIER(RoundOp, evalRoundOp)
UNARY_VERIFIER(RoundNearestEvenOp, evalRoundNearestEvenOp)
UNARY_VERIFIER(SignOp, evalSignOp)
UNARY_VERIFIER(Log1pOp, evalLog1pOp)
UNARY_VERIFIER(Expm1Op, evalExpm1Op)
Expand Down
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class PPHloVerifier {
VERIFY_DECL(SignOp)
VERIFY_DECL(SqrtOp)
VERIFY_DECL(RoundOp)
VERIFY_DECL(RoundNearestEvenOp)

// Simple binary
VERIFY_DECL(AddOp)
Expand Down
14 changes: 14 additions & 0 deletions libspu/dialect/pphlo/ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def PPHLO_RoundOp
}];
}

def PPHLO_RoundNearestEvenOp: PPHLO_UnaryElementwiseOpWithTypeInfer<"round_nearest_even",
[SameOperandsAndResultType], PPHLO_FpTensor> {
let summary = "RoundNearestEven operation";
let description = [{
Performs element-wise rounding towards the nearest integer, breaking ties
towards the even integer, on the `operand` tensor and produces a `result`
tensor.

Ref:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even
```
}];
}

def PPHLO_RsqrtOp
: PPHLO_UnaryElementwiseOpWithTypeInfer<"rsqrt", [SameOperandsAndResultType], PPHLO_FpTensor> {
let summary = "Reciprocal of square-root operator";
Expand Down
58 changes: 53 additions & 5 deletions libspu/kernel/hal/polymorphic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,59 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) {
Value power(SPUContext* ctx, const Value& x, const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);

if (x.isInt() || y.isInt()) {
auto x_f = dtype_cast(ctx, x, DT_F32);
auto y_f = dtype_cast(ctx, y, DT_F32);
auto ret = power(ctx, x_f, y_f);
return ret;
if (x.isInt()) {
// ref:
// https://github.com/openxla/stablehlo/blob/main/stablehlo/reference/Element.cpp#L912
// Although there are some "strange" semantics in stablehlo, we still follow
// them yet:
// 1. when x is int, then the return value must be int type.
// 2. if x is int, then y must be int
// 3. if x is int and y<0, then
// a. when |x|!=1, then always return 0;
// b. when |x|=1, then y=|y|;
//
// However, for jax.numpy.power, it behaves differently:
// 1. if any x or y is float, then both x and y will be upcast to float.
// 2. if both x and y are int, then y must be non-negative.
SPU_ENFORCE(y.isInt(), "when base is int, then y must be int.");
auto k0 = _constant(ctx, 0, x.shape());
auto k1 = _constant(ctx, 1, x.shape());
const auto bit_width = SizeOf(ctx->getField()) * 8;

auto y_b = _prefer_b(ctx, y);
auto msb_y = _rshift(ctx, y_b, bit_width - 1);
auto x_abs1 = _equal(ctx, abs(ctx, x), k1);

auto ret = _constant(ctx, 1, x.shape());
// To compute ret = x^y,
// although y has `bit_width` bits, we only consider `y_bits` bits here.
// The reason are two folds (recall that both x and y are int):
// 1. if |x|>1, then `ret` will OVERFLOW/UNDERFLOW if y>63 (e.g. FM64),
// which means the valid bits of y can't exceed `log(bit_width - 1)` .
// 2. if |x|=1:
// a). x=1, then we always get `ret`=1;
// b). x=-1, then the sign of `ret` is decided on the LSB of y;
// So we can "truncate" y to `y_bits` bits safely.
const size_t y_bits = Log2Ceil(bit_width - 1);

auto base = x;
// TODO: do this in parallel
// To compute x^y, it is necessary to compute all x^(2^idx), we use base
// (init as `x`) to store it, update base to base*base till last
// iteration, and multiply all these numbers according to y_{idx}.
// e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5
for (size_t idx = 0; idx < y_bits; idx++) {
// x^(2^idx) * y_{idx}
auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1);
ret = _mul(ctx, cur_pow, ret);
if (idx < y_bits - 1) {
base = _mul(ctx, base, base);
}
}

// when x=-1 and y<0, we can still get a correct result
return _mux(ctx, _and(ctx, msb_y, _not(ctx, x_abs1)), k0, ret)
.setDtype(x.dtype());
}
if (x.isPublic() && y.isPublic()) {
return f_pow_p(ctx, x, y);
Expand Down
21 changes: 14 additions & 7 deletions libspu/kernel/hal/polymorphic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,22 @@ TYPED_TEST(MathTest, Pow) {
using LHS_VT = typename std::tuple_element<1, TypeParam>::type;
using RHS_DT = typename std::tuple_element<2, TypeParam>::type;
using RHS_VT = typename std::tuple_element<3, TypeParam>::type;
// using RES_DT = typename std::tuple_element<4, TypeParam>::type;
using RES_DT = typename std::tuple_element<4, TypeParam>::type;

if constexpr (!std::is_same_v<LHS_DT, RHS_DT>) {
return;
}

// GIVEN
xt::xarray<LHS_DT> x;
xt::xarray<RHS_DT> y;
{
// random test
x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
y = test::xt_random<RHS_DT>({5, 6}, -2, 2);
y = test::xt_random<RHS_DT>({5, 6}, 0, 2);

// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);
auto z = test::evalBinaryOp<RHS_DT>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
Expand All @@ -429,14 +433,17 @@ TYPED_TEST(MathTest, Pow) {

{
// some fixed corner case
x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0};
y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2};
x = {-1, -1, -1, -1, -3, 1, -3, 0, 1, 1, 5, 0, 3, 2, -2};
y = {1, 0, -3, -4, -3, -3, 3, 0, 0, 2, 5, 2, -3, -1, -1};

// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);
auto z = test::evalBinaryOp<RES_DT>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
// when x is int and x=-3, y=-3, we should get 0.
// when x is int and x=3, y=-3, we should get 0.
xt::xarray<RES_DT> expected = xt::pow(x, y);

EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
Expand Down
40 changes: 40 additions & 0 deletions libspu/kernel/hlo/basic_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "libspu/kernel/hal/complex.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/ring.h"
#include "libspu/kernel/hal/type_cast.h"

namespace spu::kernel::hlo {
Expand Down Expand Up @@ -101,4 +102,43 @@ spu::Value Round_AFZ(SPUContext *ctx, const spu::Value &in) {
return hal::dtype_cast(ctx, hal::dtype_cast(ctx, round, DT_I64), in.dtype());
}

spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) {
// RNTE: Round to nearest, ties to even
// let x^' = *****a.b##### be origin fxp number
// x = *****a.bc ( c = reduce_or(#####) ), y = *****a
// then ret = y + comp (comp = 0 or 1), where
// 1) if b=0, then comp=0
// 2) if b=1, c=1, then comp=1
// 3) if b=1, c=0, a=1, then comp=1
// 4) if b=1, c=0, a=0, then comp=0
// so comp = b && (c || a)
SPU_ENFORCE(!in.isComplex());
SPU_ENFORCE(in.isFxp(), "Round only supports fxp");
const auto fxp_bits = ctx->getFxpBits();
const auto k1 = hal::_constant(ctx, 1U, in.shape());

auto x_prime = hal::_prefer_b(ctx, in);
auto y = hal::floor(ctx, x_prime);

auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1);
auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1);

std::vector<Value> cs;
cs.reserve(fxp_bits - 1);
for (size_t idx = 0; idx < fxp_bits - 1; idx++) {
auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1);
cs.push_back(std::move(x_));
}
auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) {
return hal::_or(ctx, a, b);
});
auto comp = hal::_and(ctx, b, hal::_or(ctx, c, a));
// set nbits to improve b2a
if (comp.storage_type().isa<BShare>()) {
const_cast<Type &>(comp.storage_type()).as<BShare>()->setNbits(1);
}

return hal::add(ctx, y, comp.setDtype(DT_I64)).setDtype(in.dtype());
}

} // namespace spu::kernel::hlo
1 change: 1 addition & 0 deletions libspu/kernel/hlo/basic_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ SIMPLE_UNARY_KERNEL_DECL(Sign)
SIMPLE_UNARY_KERNEL_DECL(Round_AFZ)
SIMPLE_UNARY_KERNEL_DECL(Real)
SIMPLE_UNARY_KERNEL_DECL(Imag)
SIMPLE_UNARY_KERNEL_DECL(Round_RNTE)

#undef SIMPLE_UNARY_KERNEL_DECL

Expand Down
1 change: 1 addition & 0 deletions libspu/kernel/hlo/basic_unary_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ UNARY_EMPTY_TEST(Rsqrt)
UNARY_EMPTY_TEST(Sqrt)
UNARY_EMPTY_TEST(Sign)
UNARY_EMPTY_TEST(Round_AFZ)
UNARY_EMPTY_TEST(Round_RNTE)

INSTANTIATE_TEST_SUITE_P(
UnaryTestInstances, UnaryTest,
Expand Down
1 change: 1 addition & 0 deletions libspu/mpc/cheetah/arith/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ spu_cc_library(
hdrs = ["matmat_prot.h"],
deps = [
":arith_comm",
"//libspu/mpc/cheetah/rlwe:lwe",
],
)

Expand Down
Loading

0 comments on commit fa84a1c

Please sign in to comment.