From fa84a1cf4ecd3ab889958a2dd80b98e1f027a241 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:09:19 +0800 Subject: [PATCH] Repo Sync (#631) --- CHANGELOG.md | 1 + .../compiler/passes/hlo_legalize_to_pphlo.cc | 1 + .../passes/map_stablehlo_to_pphlo_op.h | 1 + libspu/device/api.cc | 5 +- libspu/device/pphlo/pphlo_executor.cc | 1 + libspu/device/pphlo/pphlo_verifier.cc | 1 + libspu/device/pphlo/pphlo_verifier.h | 1 + libspu/dialect/pphlo/ops.td | 14 ++ libspu/kernel/hal/polymorphic.cc | 58 +++++- libspu/kernel/hal/polymorphic_test.cc | 21 +- libspu/kernel/hlo/basic_unary.cc | 40 ++++ libspu/kernel/hlo/basic_unary.h | 1 + libspu/kernel/hlo/basic_unary_test.cc | 1 + libspu/mpc/cheetah/arith/BUILD.bazel | 1 + libspu/mpc/cheetah/arith/cheetah_dot.cc | 179 +++++++++++++----- libspu/mpc/cheetah/arith/cheetah_dot_test.cc | 1 + libspu/mpc/cheetah/arith/matmat_prot.cc | 84 +++++++- libspu/mpc/cheetah/arith/matmat_prot.h | 28 ++- libspu/mpc/cheetah/arithmetic.cc | 20 ++ libspu/mpc/cheetah/arithmetic.h | 13 +- libspu/mpc/cheetah/protocol.cc | 1 + libspu/mpc/cheetah/rlwe/packlwes.cc | 41 ++-- libspu/mpc/cheetah/rlwe/packlwes.h | 8 +- libspu/mpc/cheetah/rlwe/packlwes_test.cc | 8 +- spu/tests/jnp_testbase.py | 10 +- 25 files changed, 419 insertions(+), 121 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b20a818..3c9a5e8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - [Feature] Add minimax approximation for log - [Feature] Support jax.lax.top_k +- [Feature] Support round to nearest even - [Improvement] Default log approximation to minmax - [Improvement] Improve median performance diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index d44c1e6a..381cbe15 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -1331,6 +1331,7 @@ struct HloLegalizeToPPHlo HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, + HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, diff --git a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h index e5398c5e..122f5f06 100644 --- a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h +++ b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h @@ -77,6 +77,7 @@ MAP_HLO_TO_PPHLO(RemOp) MAP_HLO_TO_PPHLO(ReshapeOp) MAP_HLO_TO_PPHLO(ReverseOp) MAP_HLO_TO_PPHLO(RoundOp) +MAP_HLO_TO_PPHLO(RoundNearestEvenOp) MAP_HLO_TO_PPHLO(RngOp) MAP_HLO_TO_PPHLO(SelectOp) MAP_HLO_TO_PPHLO(ShiftLeftOp) diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 8a2fd7cf..64f939bd 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -219,8 +219,9 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, } // print link statistics - SPDLOG_INFO("Link details: total send bytes {}, send actions {}", - comm_stats.send_bytes, comm_stats.send_actions); + SPDLOG_INFO( + "Link details: total send bytes {}, recv bytes {}, send actions {}", + comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions); } void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 40d8b2fd..8bf9ecde 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -264,6 +264,7 @@ STANDARD_UNARY_OP_EXEC_IMPL(NotOp, Not) STANDARD_UNARY_OP_EXEC_IMPL(RsqrtOp, Rsqrt) STANDARD_UNARY_OP_EXEC_IMPL(SqrtOp, Sqrt) STANDARD_UNARY_OP_EXEC_IMPL(RoundOp, Round_AFZ) +STANDARD_UNARY_OP_EXEC_IMPL(RoundNearestEvenOp, Round_RNTE) STANDARD_UNARY_OP_EXEC_IMPL(SineOp, Sine) STANDARD_UNARY_OP_EXEC_IMPL(CosineOp, Cosine) diff --git a/libspu/device/pphlo/pphlo_verifier.cc b/libspu/device/pphlo/pphlo_verifier.cc index b3fb164a..dd4190d7 100644 --- a/libspu/device/pphlo/pphlo_verifier.cc +++ b/libspu/device/pphlo/pphlo_verifier.cc @@ -268,6 +268,7 @@ UNARY_VERIFIER(ExpOp, evalExponentialOp) UNARY_VERIFIER(RsqrtOp, evalRsqrtOp) UNARY_VERIFIER(SqrtOp, evalSqrtOp) UNARY_VERIFIER(RoundOp, evalRoundOp) +UNARY_VERIFIER(RoundNearestEvenOp, evalRoundNearestEvenOp) UNARY_VERIFIER(SignOp, evalSignOp) UNARY_VERIFIER(Log1pOp, evalLog1pOp) UNARY_VERIFIER(Expm1Op, evalExpm1Op) diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 3c53a199..712b0412 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -59,6 +59,7 @@ class PPHloVerifier { VERIFY_DECL(SignOp) VERIFY_DECL(SqrtOp) VERIFY_DECL(RoundOp) + VERIFY_DECL(RoundNearestEvenOp) // Simple binary VERIFY_DECL(AddOp) diff --git a/libspu/dialect/pphlo/ops.td b/libspu/dialect/pphlo/ops.td index 58ba702a..812736b4 100644 --- a/libspu/dialect/pphlo/ops.td +++ b/libspu/dialect/pphlo/ops.td @@ -274,6 +274,20 @@ def PPHLO_RoundOp }]; } +def PPHLO_RoundNearestEvenOp: PPHLO_UnaryElementwiseOpWithTypeInfer<"round_nearest_even", + [SameOperandsAndResultType], PPHLO_FpTensor> { + let summary = "RoundNearestEven operation"; + let description = [{ + Performs element-wise rounding towards the nearest integer, breaking ties + towards the even integer, on the `operand` tensor and produces a `result` + tensor. + + Ref: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even + ``` + }]; +} + def PPHLO_RsqrtOp : PPHLO_UnaryElementwiseOpWithTypeInfer<"rsqrt", [SameOperandsAndResultType], PPHLO_FpTensor> { let summary = "Reciprocal of square-root operator"; diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index f63dacbe..53eea1c7 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -330,11 +330,59 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) { Value power(SPUContext* ctx, const Value& x, const Value& y) { SPU_TRACE_HAL_DISP(ctx, x, y); - if (x.isInt() || y.isInt()) { - auto x_f = dtype_cast(ctx, x, DT_F32); - auto y_f = dtype_cast(ctx, y, DT_F32); - auto ret = power(ctx, x_f, y_f); - return ret; + if (x.isInt()) { + // ref: + // https://github.com/openxla/stablehlo/blob/main/stablehlo/reference/Element.cpp#L912 + // Although there are some "strange" semantics in stablehlo, we still follow + // them yet: + // 1. when x is int, then the return value must be int type. + // 2. if x is int, then y must be int + // 3. if x is int and y<0, then + // a. when |x|!=1, then always return 0; + // b. when |x|=1, then y=|y|; + // + // However, for jax.numpy.power, it behaves differently: + // 1. if any x or y is float, then both x and y will be upcast to float. + // 2. if both x and y are int, then y must be non-negative. + SPU_ENFORCE(y.isInt(), "when base is int, then y must be int."); + auto k0 = _constant(ctx, 0, x.shape()); + auto k1 = _constant(ctx, 1, x.shape()); + const auto bit_width = SizeOf(ctx->getField()) * 8; + + auto y_b = _prefer_b(ctx, y); + auto msb_y = _rshift(ctx, y_b, bit_width - 1); + auto x_abs1 = _equal(ctx, abs(ctx, x), k1); + + auto ret = _constant(ctx, 1, x.shape()); + // To compute ret = x^y, + // although y has `bit_width` bits, we only consider `y_bits` bits here. + // The reason are two folds (recall that both x and y are int): + // 1. if |x|>1, then `ret` will OVERFLOW/UNDERFLOW if y>63 (e.g. FM64), + // which means the valid bits of y can't exceed `log(bit_width - 1)` . + // 2. if |x|=1: + // a). x=1, then we always get `ret`=1; + // b). x=-1, then the sign of `ret` is decided on the LSB of y; + // So we can "truncate" y to `y_bits` bits safely. + const size_t y_bits = Log2Ceil(bit_width - 1); + + auto base = x; + // TODO: do this in parallel + // To compute x^y, it is necessary to compute all x^(2^idx), we use base + // (init as `x`) to store it, update base to base*base till last + // iteration, and multiply all these numbers according to y_{idx}. + // e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5 + for (size_t idx = 0; idx < y_bits; idx++) { + // x^(2^idx) * y_{idx} + auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1); + ret = _mul(ctx, cur_pow, ret); + if (idx < y_bits - 1) { + base = _mul(ctx, base, base); + } + } + + // when x=-1 and y<0, we can still get a correct result + return _mux(ctx, _and(ctx, msb_y, _not(ctx, x_abs1)), k0, ret) + .setDtype(x.dtype()); } if (x.isPublic() && y.isPublic()) { return f_pow_p(ctx, x, y); diff --git a/libspu/kernel/hal/polymorphic_test.cc b/libspu/kernel/hal/polymorphic_test.cc index e4a7014e..fc64d9cc 100644 --- a/libspu/kernel/hal/polymorphic_test.cc +++ b/libspu/kernel/hal/polymorphic_test.cc @@ -406,7 +406,11 @@ TYPED_TEST(MathTest, Pow) { using LHS_VT = typename std::tuple_element<1, TypeParam>::type; using RHS_DT = typename std::tuple_element<2, TypeParam>::type; using RHS_VT = typename std::tuple_element<3, TypeParam>::type; - // using RES_DT = typename std::tuple_element<4, TypeParam>::type; + using RES_DT = typename std::tuple_element<4, TypeParam>::type; + + if constexpr (!std::is_same_v) { + return; + } // GIVEN xt::xarray x; @@ -414,10 +418,10 @@ TYPED_TEST(MathTest, Pow) { { // random test x = test::xt_random({5, 6}, 0, 100); - y = test::xt_random({5, 6}, -2, 2); + y = test::xt_random({5, 6}, 0, 2); // WHAT - auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); + auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); // THEN auto expected = xt::pow(x, y); @@ -429,14 +433,17 @@ TYPED_TEST(MathTest, Pow) { { // some fixed corner case - x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0}; - y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2}; + x = {-1, -1, -1, -1, -3, 1, -3, 0, 1, 1, 5, 0, 3, 2, -2}; + y = {1, 0, -3, -4, -3, -3, 3, 0, 0, 2, 5, 2, -3, -1, -1}; // WHAT - auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); + auto z = test::evalBinaryOp(LHS_VT(), RHS_VT(), power, x, y); // THEN - auto expected = xt::pow(x, y); + // when x is int and x=-3, y=-3, we should get 0. + // when x is int and x=3, y=-3, we should get 0. + xt::xarray expected = xt::pow(x, y); + EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl << y << std::endl << expected << std::endl diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index f819cf0c..2b7d8d64 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -17,6 +17,7 @@ #include "libspu/kernel/hal/complex.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" +#include "libspu/kernel/hal/ring.h" #include "libspu/kernel/hal/type_cast.h" namespace spu::kernel::hlo { @@ -101,4 +102,43 @@ spu::Value Round_AFZ(SPUContext *ctx, const spu::Value &in) { return hal::dtype_cast(ctx, hal::dtype_cast(ctx, round, DT_I64), in.dtype()); } +spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) { + // RNTE: Round to nearest, ties to even + // let x^' = *****a.b##### be origin fxp number + // x = *****a.bc ( c = reduce_or(#####) ), y = *****a + // then ret = y + comp (comp = 0 or 1), where + // 1) if b=0, then comp=0 + // 2) if b=1, c=1, then comp=1 + // 3) if b=1, c=0, a=1, then comp=1 + // 4) if b=1, c=0, a=0, then comp=0 + // so comp = b && (c || a) + SPU_ENFORCE(!in.isComplex()); + SPU_ENFORCE(in.isFxp(), "Round only supports fxp"); + const auto fxp_bits = ctx->getFxpBits(); + const auto k1 = hal::_constant(ctx, 1U, in.shape()); + + auto x_prime = hal::_prefer_b(ctx, in); + auto y = hal::floor(ctx, x_prime); + + auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1); + auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1); + + std::vector cs; + cs.reserve(fxp_bits - 1); + for (size_t idx = 0; idx < fxp_bits - 1; idx++) { + auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1); + cs.push_back(std::move(x_)); + } + auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) { + return hal::_or(ctx, a, b); + }); + auto comp = hal::_and(ctx, b, hal::_or(ctx, c, a)); + // set nbits to improve b2a + if (comp.storage_type().isa()) { + const_cast(comp.storage_type()).as()->setNbits(1); + } + + return hal::add(ctx, y, comp.setDtype(DT_I64)).setDtype(in.dtype()); +} + } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/basic_unary.h b/libspu/kernel/hlo/basic_unary.h index 46578ee7..c5770cd7 100644 --- a/libspu/kernel/hlo/basic_unary.h +++ b/libspu/kernel/hlo/basic_unary.h @@ -45,6 +45,7 @@ SIMPLE_UNARY_KERNEL_DECL(Sign) SIMPLE_UNARY_KERNEL_DECL(Round_AFZ) SIMPLE_UNARY_KERNEL_DECL(Real) SIMPLE_UNARY_KERNEL_DECL(Imag) +SIMPLE_UNARY_KERNEL_DECL(Round_RNTE) #undef SIMPLE_UNARY_KERNEL_DECL diff --git a/libspu/kernel/hlo/basic_unary_test.cc b/libspu/kernel/hlo/basic_unary_test.cc index 7e4d2e0a..e0ebe466 100644 --- a/libspu/kernel/hlo/basic_unary_test.cc +++ b/libspu/kernel/hlo/basic_unary_test.cc @@ -58,6 +58,7 @@ UNARY_EMPTY_TEST(Rsqrt) UNARY_EMPTY_TEST(Sqrt) UNARY_EMPTY_TEST(Sign) UNARY_EMPTY_TEST(Round_AFZ) +UNARY_EMPTY_TEST(Round_RNTE) INSTANTIATE_TEST_SUITE_P( UnaryTestInstances, UnaryTest, diff --git a/libspu/mpc/cheetah/arith/BUILD.bazel b/libspu/mpc/cheetah/arith/BUILD.bazel index 9842fad0..c2a77950 100644 --- a/libspu/mpc/cheetah/arith/BUILD.bazel +++ b/libspu/mpc/cheetah/arith/BUILD.bazel @@ -49,6 +49,7 @@ spu_cc_library( hdrs = ["matmat_prot.h"], deps = [ ":arith_comm", + "//libspu/mpc/cheetah/rlwe:lwe", ], ) diff --git a/libspu/mpc/cheetah/arith/cheetah_dot.cc b/libspu/mpc/cheetah/arith/cheetah_dot.cc index 67c0a62d..92c40e46 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot.cc @@ -36,6 +36,7 @@ #include "libspu/mpc/cheetah/arith/common.h" #include "libspu/mpc/cheetah/arith/matmat_prot.h" +#include "libspu/mpc/cheetah/rlwe/lwe_ct.h" #include "libspu/mpc/cheetah/rlwe/modswitch_helper.h" #include "libspu/mpc/cheetah/rlwe/packlwes.h" #include "libspu/mpc/cheetah/rlwe/utils.h" @@ -43,13 +44,26 @@ namespace spu::mpc::cheetah { +enum class CipherPackingType { + lwes, + rlwes, + none, +}; + +static std::string ToString(CipherPackingType type) { + switch (type) { + case CipherPackingType::rlwes: + return "interleave"; + case CipherPackingType::lwes: + return "pack_lwes"; + default: + case CipherPackingType::none: + return "none"; + } +} + struct CheetahDot::Impl : public EnableCPRNG { public: - enum class CipherPackingType { - rlwes, - none, - }; - const bool kUseModDownOptimization = true; static constexpr size_t kCtAsyncParallel = 16; @@ -97,7 +111,14 @@ struct CheetahDot::Impl : public EnableCPRNG { yacl::link::Context *conn, size_t bytes_recv); + bool IsPackingEnabled(uint32_t ring_bitlen) const { + return DecideSEALParameters(ring_bitlen).use_special_prime(); + } + seal::EncryptionParameters DecideSEALParameters(uint32_t ring_bitlen) const { + auto scheme_type = seal::scheme_type::ckks; + auto parms = seal::EncryptionParameters(scheme_type); + size_t poly_deg; std::vector modulus_bits; // NOTE(lwj): we need Q=sum(modulus_bits) > 2*k for multiplying two @@ -112,20 +133,19 @@ struct CheetahDot::Impl : public EnableCPRNG { poly_deg = 4096; // ~ 64 + 32 bit modulus_bits = {59, 37}; + parms.set_use_special_prime(false); } else if (ring_bitlen <= 64) { poly_deg = 8192; // ~ 128 + 32 bit - modulus_bits = {59, 55, 49}; + modulus_bits = {59, 55, 49, 49}; + parms.set_use_special_prime(true); } else { poly_deg = 16384; // ~ 256 + 30 bit - modulus_bits = {59, 59, 59, 59, 49}; + modulus_bits = {59, 59, 59, 59, 49, 49}; + parms.set_use_special_prime(true); } - auto scheme_type = seal::scheme_type::ckks; - auto parms = seal::EncryptionParameters(scheme_type); - - parms.set_use_special_prime(false); parms.set_poly_modulus_degree(poly_deg); parms.set_coeff_modulus(seal::CoeffModulus::Create(poly_deg, modulus_bits)); return parms; @@ -183,7 +203,6 @@ struct CheetahDot::Impl : public EnableCPRNG { // field_bitlen -> functor mapping std::unordered_map> seal_cntxts_; - std::unordered_map galoi_cntxts_; std::unordered_map> secret_keys_; std::unordered_map> peer_pub_keys_; std::unordered_map> @@ -196,7 +215,7 @@ struct CheetahDot::Impl : public EnableCPRNG { }; void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) { - if (galoi_cntxts_.find(field_bitlen) != galoi_cntxts_.end()) { + if (peer_galois_keys_.find(field_bitlen) != peer_galois_keys_.end()) { return; } auto kv = seal_cntxts_.find(field_bitlen); @@ -205,13 +224,8 @@ void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) { const auto &this_rlwe_sk = *secret_keys_.find(field_bitlen)->second; seal::GaloisKeys gk; - auto gk_parms = this_context.key_context_data()->parms(); - gk_parms.set_use_special_prime(true); - - seal::SEALContext gk_context(gk_parms, true, seal::sec_level_type::none); - GenerateGaloisKeyForPacking(gk_context, this_rlwe_sk, + GenerateGaloisKeyForPacking(this_context, this_rlwe_sk, /*seed*/ true, &gk); - galoi_cntxts_.emplace(field_bitlen, gk_context); auto gk_buf = EncodeSEALObject(gk); int nxt_rank = lctx_->NextRank(); @@ -259,8 +273,8 @@ void CheetahDot::Impl::LazyInit(size_t field_bitlen, bool need_galois_keys) { DecodeSEALObject(recv_pk, *this_context, peer_public_key.get()); } - auto modulus = parms.coeff_modulus(); - size_t ecd_modulus_sze = modulus.size(); + auto modulus = this_context->first_context_data()->parms().coeff_modulus(); + size_t enc_modulus_sze = modulus.size(); parms.set_coeff_modulus(modulus); seal::SEALContext ecd_ms_context(parms, false, seal::sec_level_type::none); @@ -273,6 +287,7 @@ void CheetahDot::Impl::LazyInit(size_t field_bitlen, bool need_galois_keys) { } size_t dcd_modulus_sze = modulus.size(); + parms.set_use_special_prime(false); parms.set_coeff_modulus(modulus); seal::SEALContext dcd_ms_context(parms, false, seal::sec_level_type::none); @@ -294,7 +309,7 @@ void CheetahDot::Impl::LazyInit(size_t field_bitlen, bool need_galois_keys) { if (lctx_->Rank() == 0) { SPDLOG_INFO( "CheetahDot uses {}@{} modulus {} degree for {} bit ring (packing={})", - ecd_modulus_sze, dcd_modulus_sze, parms.poly_modulus_degree(), + enc_modulus_sze, dcd_modulus_sze, parms.poly_modulus_degree(), field_bitlen, need_galois_keys ? "enabled" : "disabled"); } } @@ -375,21 +390,20 @@ NdArrayRef CheetahDot::Impl::doDotOLEReceiverSendStep( size_t out_n = ct_array_to_pack.size(); size_t num_ct_response = 0; - double pack_time = 0.0; + yacl::ElapsedTimer pack_timer; if (cptype == CipherPackingType::none) { SPU_ENFORCE(batch_size == 1, "impossible dispatch here"); num_ct_response = out_n; - } else { - yacl::ElapsedTimer _timer; - const auto &this_galois_context = galoi_cntxts_.find(field_bitlen)->second; + } else if (cptype == CipherPackingType::rlwes) { + // BumbleBee's interleave const auto &this_galois_key = *(peer_galois_keys_.find(field_bitlen)->second); const size_t gap = subshape[1]; const size_t pack_stride = gap; - PackingHelper pack_helper(gap, this_galois_key, this_galois_context, - this_context); + PackingHelper pack_helper(gap, this_ecd_msh.coeff_modulus_size(), + this_galois_key, this_context); for (size_t i = 0; i < out_n; i += pack_stride) { size_t this_batch = std::min(out_n - i, pack_stride); @@ -398,10 +412,48 @@ NdArrayRef CheetahDot::Impl::doDotOLEReceiverSendStep( ct_array_to_pack.subspan(i, this_batch), ct_array_to_pack[packed_idx]); } - pack_time = _timer.CountMs(); num_ct_response = CeilDiv(out_n, pack_stride); + } else { + // Chen Hao et al's PackLWEs + SPU_ENFORCE(batch_size == 1, "not implemented yet"); + const auto &this_galois_key = + *(peer_galois_keys_.find(field_bitlen)->second); + + // Drop some modulus before packing + yacl::parallel_for( + 0, ct_array_to_pack.size(), [&](int64_t bgn, int64_t end) { + for (int64_t i = bgn; i < end; ++i) { + InvNttInplace(ct_array_to_pack[i], this_context); + ModulusSwtichInplace(ct_array_to_pack[i], + this_dcd_msh.coeff_modulus_size(), + this_context); + } + }); + + // Extract Phantom LWECt from RLWEs + size_t out_numel = meta.dims[0] * meta.dims[2]; + size_t aligned_onumel = absl::bit_ceil(out_numel); + std::vector _lwes(aligned_onumel); + auto lwes = absl::MakeSpan(_lwes); + matmat_prot.WrapPhantomLWEs(meta, ct_array_to_pack, + lwes.subspan(0, out_numel)); + + size_t pack_stride = lwes[0].poly_modulus_degree(); + std::vector packed_lwes(CeilDiv(lwes.size(), pack_stride)); + + // NOTE: we can not modify the ct_array_to_pack inplace because it wraps the + // PhantomLWECt + PackLWEs(lwes, this_galois_key, this_context, absl::MakeSpan(packed_lwes)); + + num_ct_response = packed_lwes.size(); + + // copy back + for (size_t i = 0; i < num_ct_response; ++i) { + ct_array_to_pack[i] = packed_lwes[i]; + } } + double pack_time = pack_timer.CountMs(); // 4. Random masking to conver HE to AShr std::vector rnd_polys(num_ct_response); @@ -420,20 +472,21 @@ NdArrayRef CheetahDot::Impl::doDotOLEReceiverSendStep( } bytes_sent = conn->GetStats()->sent_bytes - bytes_sent; - if (conn->Rank() == 0) { - SPDLOG_INFO( - "{}@{}x{}x{} => {}x{}x{} Recv {} MiB, Response {} MiB Pack {} ms", - batch_size, meta.dims[0], meta.dims[1], meta.dims[2], subshape[0], - subshape[1], subshape[2], - std::roundf(bytes_recv / 1024. / 1024. * 1000) / 1000., - std::roundf(bytes_sent / 1024. / 1024. * 1000) / 1000., - std::roundf(pack_time * 1000) / 1000.); - } + SPDLOG_INFO( + "{}@{}x{}x{} => {}x{}x{} Recv {} MiB, Response {} MiB Pack {} ms ({})", + batch_size, meta.dims[0], meta.dims[1], meta.dims[2], subshape[0], + subshape[1], subshape[2], + std::roundf(bytes_recv / 1024. / 1024. * 1000) / 1000., + std::roundf(bytes_sent / 1024. / 1024. * 1000) / 1000., + std::roundf(pack_time * 1000) / 1000., ToString(cptype)); switch (cptype) { case CipherPackingType::none: return matmat_prot.ParseResult( field, meta, absl::MakeConstSpan(rnd_polys), this_dcd_msh); + case CipherPackingType::lwes: + return matmat_prot.ParsePackLWEsResult( + field, meta, absl::MakeConstSpan(rnd_polys), this_dcd_msh); case CipherPackingType::rlwes: default: return matmat_prot.ParseBatchPackedResult(field, batch_size, meta, @@ -539,6 +592,9 @@ NdArrayRef CheetahDot::Impl::doDotOLESenderRecvStep(FieldType field, case CipherPackingType::none: return matmat_prot.ParseResult( field, meta, absl::MakeConstSpan(result_poly), this_dcd_msh); + case CipherPackingType::lwes: + return matmat_prot.ParsePackLWEsResult( + field, meta, absl::MakeConstSpan(result_poly), this_dcd_msh); case CipherPackingType::rlwes: default: return matmat_prot.ParseBatchPackedResult( @@ -583,8 +639,9 @@ NdArrayRef CheetahDot::Impl::BatchDotOLE(const NdArrayRef &prv_mat, SPU_ENFORCE_EQ(prv_mat.numel(), dim4[0] * dim4[2] * dim4[3]); } - if (eltype.template as()->field() == FM32) { - // FM32 not supportting Packing. + auto field = eltype.template as()->field(); + if (not IsPackingEnabled(8 * SizeOf(field))) { + // Packing is not supported. // Thus just call multiple DotOLEs Shape3D dim3 = {dim4[1], dim4[2], dim4[3]}; int64_t out_numel = dim4[1] * dim4[3]; @@ -616,7 +673,6 @@ NdArrayRef CheetahDot::Impl::doBatchDotOLE(const NdArrayRef &prv_mat, bool is_self_lhs) { auto eltype = prv_mat.eltype(); auto field = eltype.template as()->field(); - SPU_ENFORCE(field != FM32, "Not support BatchDotOLE for FM32"); const size_t field_bitlen = SizeOf(field) * 8; size_t poly_deg = DecideSEALParameters(field_bitlen).poly_modulus_degree(); @@ -691,9 +747,11 @@ NdArrayRef CheetahDot::Impl::doDotOLE(const NdArrayRef &prv_mat, MatMatProtocol::Meta meta = {.dims = dim3}; // No cipher packing for small HE - CipherPackingType cptype = (field == FM32 || disable_pack_) - ? CipherPackingType::none - : CipherPackingType::rlwes; + CipherPackingType cptype = + (IsPackingEnabled(8 * SizeOf(field)) and not disable_pack_) + ? CipherPackingType::rlwes + : CipherPackingType::none; + Shape3D subshape; size_t blk[3]; if (cptype != CipherPackingType::none) { @@ -702,9 +760,18 @@ NdArrayRef CheetahDot::Impl::doDotOLE(const NdArrayRef &prv_mat, for (int i : {0, 1, 2}) { blk[i] = CeilDiv(meta.dims[i], subshape[i]); } - // If there is only 1 resultant RLWE; then we just skip any packing - cptype = blk[0] * blk[2] <= 1 ? CipherPackingType::none - : CipherPackingType::rlwes; + + if (blk[0] * blk[2] <= 1) { + // If there is only 1 resultant RLWE; + // then we just skip any packing + cptype = CipherPackingType::none; + } else { + // dynamic packing type + double pack_rlwes_cost = subshape[1]; + double pack_lwes_cost = meta.dims[0] * meta.dims[2]; + cptype = pack_rlwes_cost < pack_lwes_cost ? CipherPackingType::rlwes + : CipherPackingType::lwes; + } } LazyInit(field_bitlen, cptype != CipherPackingType::none); @@ -726,10 +793,17 @@ NdArrayRef CheetahDot::Impl::doDotOLE(const NdArrayRef &prv_mat, doDotOLESenderSendStep(prv_mat, dim3, is_self_lhs, cptype, conn); size_t num_ct_to_recv = 0; - if (cptype == CipherPackingType::rlwes) { - num_ct_to_recv = CeilDiv(blk[0] * blk[2], subshape[1]); - } else { - num_ct_to_recv = blk[0] * blk[2]; + switch (cptype) { + case CipherPackingType::rlwes: + num_ct_to_recv = CeilDiv(blk[0] * blk[2], subshape[1]); + break; + case CipherPackingType::lwes: + num_ct_to_recv = CeilDiv(meta.dims[0] * meta.dims[2], poly_deg); + break; + default: + case CipherPackingType::none: + num_ct_to_recv = blk[0] * blk[2]; + break; } return doDotOLESenderRecvStep(field, /*batch*/ 1, meta, num_ct_to_recv, @@ -760,7 +834,8 @@ CheetahDot::~CheetahDot() = default; void CheetahDot::LazyInitKeys(FieldType field) { SPU_ENFORCE(impl_ != nullptr); - return impl_->LazyInit(SizeOf(field) * 8, /*create_galois*/ true); + return impl_->LazyInit(SizeOf(field) * 8, + impl_->IsPackingEnabled(SizeOf(field) * 8)); } NdArrayRef CheetahDot::DotOLE(const NdArrayRef &inp, yacl::link::Context *conn, diff --git a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc index 7d083411..8f0e5ac1 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc @@ -31,6 +31,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(Shape3D{8, 7, 5}, Shape3D{57, 30, 1}, Shape3D{30, 57, 1}, Shape3D{18, 8, 41}, Shape3D{500, 13, 25}, + Shape3D{1, 20480, 768}, Shape3D{18, 768, 78})), [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}x{}", std::get<0>(std::get<1>(p.param)), diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index 8e49b1b0..a56e522f 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -25,6 +25,7 @@ #include "yacl/utils/platform_utils.h" #include "libspu/mpc/cheetah/arith/vector_encoder.h" +#include "libspu/mpc/cheetah/rlwe/lwe_ct.h" #include "libspu/mpc/cheetah/rlwe/utils.h" #include "libspu/mpc/utils/ring_ops.h" @@ -521,6 +522,87 @@ NdArrayRef MatMatProtocol::ParseResult( return ParseResult(field, meta, ans_poly, msh_); } +NdArrayRef MatMatProtocol::ParsePackLWEsResult( + FieldType field, const Meta& meta, absl::Span ans_poly, + const ModulusSwitchHelper& msh) const { + const size_t packing_width = poly_deg_; + const size_t num_total_vals = meta.dims[0] * meta.dims[2]; + SPU_ENFORCE_EQ(ans_poly.size(), CeilDiv(num_total_vals, packing_width)); + + std::vector decoded_vectors(ans_poly.size()); + for (size_t i = 0; i < ans_poly.size(); ++i) { + decoded_vectors[i] = + msh.ModulusDownRNS(field, {(int64_t)packing_width}, + {ans_poly[i].data(), ans_poly[i].coeff_count()}); + } + + NdArrayRef matmat = ring_zeros(field, {meta.dims[0] * meta.dims[2]}); + + DISPATCH_ALL_FIELDS(field, "pack_lwes_results", [&]() { + NdArrayView xmatmat(matmat); + + for (size_t i = 0; i < ans_poly.size(); ++i) { + NdArrayView decoded_vec(decoded_vectors[i]); + auto coeff_bgn = std::min(i * poly_deg_, num_total_vals); + auto coeff_end = std::min(coeff_bgn + poly_deg_, num_total_vals); + + size_t num_coeff = coeff_end - coeff_bgn; + size_t aligned_num_coeff = absl::bit_ceil(num_coeff); + size_t packed_gap = packing_width / aligned_num_coeff; + + for (size_t j = coeff_bgn; j < coeff_end; ++j) { + size_t row = j / meta.dims[2]; + size_t col = j % meta.dims[2]; + xmatmat[row * meta.dims[2] + col] = + decoded_vec[(j - coeff_bgn) * packed_gap]; + } + } + }); + + return matmat.reshape({meta.dims[0], meta.dims[2]}); +} + +void MatMatProtocol::WrapPhantomLWEs(const Meta& meta, + absl::Span rlwes, + absl::Span lwes) const { + auto subdims = GetSubMatShape(meta); + size_t num_rlwes = GetOutSize(meta, subdims); + size_t num_lwes = meta.dims[0] * meta.dims[2]; + SPU_ENFORCE_EQ(rlwes.size(), num_rlwes, "expected {} got {}", num_rlwes, + rlwes.size()); + SPU_ENFORCE_EQ(lwes.size(), num_lwes, "expected {} got {}", num_lwes, + lwes.size()); + + ResultIndexer ans_indexer(subdims); + Shape2D out_blks = {CeilDiv(meta.dims[0], subdims[0]), + CeilDiv(meta.dims[2], subdims[2])}; + + for (int64_t rblk = 0; rblk < out_blks[0]; ++rblk) { + for (int64_t cblk = 0; cblk < out_blks[1]; ++cblk) { + const int64_t row_start = rblk * subdims[0]; + const int64_t row_end = std::min(row_start + subdims[0], meta.dims[0]); + const int64_t row_ext = row_end - row_start; + + const int64_t col_start = cblk * subdims[2]; + const int64_t col_end = std::min(col_start + subdims[2], meta.dims[2]); + const int64_t col_ext = col_end - col_start; + + const int64_t rlwe_idx = rblk * out_blks[1] + cblk; + const RLWECt& this_rlwe = rlwes.at(rlwe_idx); + SPU_ENFORCE(not this_rlwe.is_ntt_form()); + + for (int64_t r = 0; r < row_ext; ++r) { + int64_t lwes_row = row_start + r; + for (int64_t c = 0; c < col_ext; ++c) { + int64_t lwes_col = col_start + c; + int64_t lwe_idx = lwes_row * meta.dims[2] + lwes_col; + lwes[lwe_idx].WrapIt(this_rlwe, ans_indexer.get(r, c)); + } + } + } + } +} + void MatMatProtocol::ExtractLWEsInplace(const Meta& meta, absl::Span out) const { auto subdims = GetSubMatShape(meta); @@ -648,4 +730,4 @@ void MatMatProtocol::Compute(absl::Span lhs_mat, DoCompute(lhs_mat, rhs_mat, meta, out_mat); } -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arith/matmat_prot.h b/libspu/mpc/cheetah/arith/matmat_prot.h index 38e6874e..4dce4aa8 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.h +++ b/libspu/mpc/cheetah/arith/matmat_prot.h @@ -61,28 +61,42 @@ class MatMatProtocol { bool IsValidMeta(const Meta& meta) const; + // Parse the polynomails into matmul result without any packing NdArrayRef ParseResult(FieldType field, const Meta& meta, absl::Span ans_poly) const; + // Parse the polynomails into matmul result without any packing + // Use the specified ModulusSwitchHelper. NdArrayRef ParseResult(FieldType field, const Meta& meta, absl::Span ans_poly, const ModulusSwitchHelper& msh) const; - // Coefficients via Packed Batched MatMul + // Parse the polynomails into matmul result which are packed using BumbleBee's + // interleave packing. + // output shape dims[0] x dims[2] + NdArrayRef ParsePackedResult(FieldType field, const Meta& meta, + absl::Span ans_poly, + const ModulusSwitchHelper& msh) const; + + // Parse the polynomails into matmul result which are packed using BumbleBee's + // interleave packing in a batch mode. // output shape batch_size x dims[0] x dims[2] NdArrayRef ParseBatchPackedResult(FieldType field, size_t batch_size, const Meta& meta, absl::Span polys, const ModulusSwitchHelper& msh) const; - // Coefficients via Packed MatMul - // output shape dims[0] x dims[2] - NdArrayRef ParsePackedResult(FieldType field, const Meta& meta, - absl::Span ans_poly, - const ModulusSwitchHelper& msh) const; + // Parse the polynomails into matmul result which are packed using Chen hao's + // PackLWEs packing. + NdArrayRef ParsePackLWEsResult(FieldType field, const Meta& meta, + absl::Span ans_poly, + const ModulusSwitchHelper& msh) const; void ExtractLWEsInplace(const Meta& meta, absl::Span rlwe) const; + void WrapPhantomLWEs(const Meta& meta, absl::Span rlwes, + absl::Span lwes) const; + // LHS_mat * RHS_mat // LHS = RLWECt, RHS = RLWEPt (when LHS is smaller) // LHS = RLWEPt, RHS = RLWECt (when RHS is smaller) @@ -125,4 +139,4 @@ class MatMatProtocol { std::unique_ptr vencoder_{nullptr}; }; -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index db898411..a59f8c53 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -249,6 +249,26 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, return ring_add(x0y1, ring_add(x1y0, ring_mul(x, y))).as(x.eltype()); } +NdArrayRef MatMulVVS::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + auto out_type = makeType(ctx->sctx()->getField()); + if (0 == x.numel() || 0 == y.numel()) { + return NdArrayRef(out_type, {x.shape()[0], y.shape()[1]}); + } + auto* comm = ctx->getState(); + auto* dot_prot = ctx->getState()->get(); + + const int self_rank = comm->getRank(); + auto lhs_owner = x.eltype().as()->owner(); + + const Shape3D dim3 = {x.shape()[0], x.shape()[1], y.shape()[1]}; + if (self_rank == lhs_owner) { + return dot_prot->DotOLE(x, dim3, /*is_lhs*/ true).as(out_type); + } else { + return dot_prot->DotOLE(y, dim3, /*is_lhs*/ false).as(out_type); + } +} + // A is (M, K); B is (K, N) NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index 3af3713f..21c8a004 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -166,6 +166,17 @@ class MatMulAV : public MatmulKernel { const NdArrayRef& y) const override; }; +class MatMulVVS : public MatmulKernel { + public: + static constexpr char kBindName[] = "mmul_vvs"; + + Kind kind() const override { return Kind::Dynamic; } + // LHS: m x k + // RHS: k x n + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + class MatMulAA : public MatmulKernel { public: static constexpr char kBindName[] = "mmul_aa"; @@ -275,4 +286,4 @@ class LessPA : public BinaryKernel { const NdArrayRef& y) const override; }; -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/protocol.cc b/libspu/mpc/cheetah/protocol.cc index 995b7f6e..ce4f90af 100644 --- a/libspu/mpc/cheetah/protocol.cc +++ b/libspu/mpc/cheetah/protocol.cc @@ -68,6 +68,7 @@ void regCheetahProtocol(SPUContext* ctx, cheetah::MulAP, cheetah::MulAA, cheetah::MulA1B, // cheetah::EqualAA, cheetah::EqualAP, // cheetah::MatMulAP, cheetah::MatMulAA, cheetah::MatMulAV, // + cheetah::MatMulVVS, // cheetah::LShiftA, cheetah::ARShiftB, cheetah::LShiftB, cheetah::RShiftB, // cheetah::BitrevB, // diff --git a/libspu/mpc/cheetah/rlwe/packlwes.cc b/libspu/mpc/cheetah/rlwe/packlwes.cc index 85602278..8e486bbf 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes.cc @@ -36,25 +36,21 @@ namespace spu::mpc::cheetah { static void NegacyclicRightShiftInplace(RLWECt &ct, size_t shift, const seal::SEALContext &context); -PackingHelper::PackingHelper(size_t gap, const seal::GaloisKeys &galois_keys, - const seal::SEALContext &gk_context, +PackingHelper::PackingHelper(size_t gap, size_t num_modulus_for_packing, + const seal::GaloisKeys &galois_keys, const seal::SEALContext &context) : gap_(gap), + num_modulus_for_packing_(num_modulus_for_packing), galois_keys_(galois_keys), - gk_context_(gk_context), context_(context) { - SPU_ENFORCE(gk_context_.parameters_set()); - SPU_ENFORCE(seal::is_metadata_valid_for(galois_keys, gk_context)); + SPU_ENFORCE(seal::is_metadata_valid_for(galois_keys, context)); SPU_ENFORCE(context_.parameters_set()); SPU_ENFORCE(gap > 0 && absl::has_single_bit(gap), "invalid gap={}", gap); + SPU_ENFORCE(num_modulus_for_packing_ > 0); + SPU_ENFORCE(num_modulus_for_packing_ <= + context_.first_context_data()->parms().coeff_modulus().size()); - // NOTE(lwj): dirty hack on SEAL's parms_id - if (context.key_parms_id() != gk_context_.key_parms_id()) { - SPU_ENFORCE_GT(context_.first_context_data()->chain_index(), - gk_context_.first_context_data()->chain_index()); - } - - auto n = gk_context.key_context_data()->parms().poly_modulus_degree(); + auto n = context.key_context_data()->parms().poly_modulus_degree(); size_t ks_level = absl::bit_width(gap) - 1; for (size_t i = 0; i < ks_level; ++i) { @@ -124,23 +120,18 @@ void PackingHelper::doPackingRLWEs(absl::Span rlwes, SPU_ENFORCE(num_ct > 0 && num_ct <= (int)gap_, fmt::format("invalid #rlwes = {} for gap = {}", num_ct, gap_)); - size_t modulus_for_keyswitch = - gk_context_.first_context_data()->chain_index() + 1; - yacl::parallel_for(0, num_ct, [&](int64_t bgn, int64_t end) { for (int64_t i = bgn; i < end; ++i) { InvNttInplace(rlwes[i], context_, true); // multiply gap^{-1} mod Q MultiplyFixedScalarInplace(rlwes[i]); // drop some modulus aiming a lighter KeySwitch - ModulusSwtichInplace(rlwes[i], modulus_for_keyswitch, context_); - // change pid to galois_context for KS - rlwes[i].parms_id() = gk_context_.first_parms_id(); + ModulusSwtichInplace(rlwes[i], num_modulus_for_packing_, context_); } }); // FFT-like method to merge RLWEs into one RLWE. - seal::Evaluator evaluator(gk_context_); + seal::Evaluator evaluator(context_); const int64_t logn = absl::bit_width(gap_) - 1; for (int64_t k = logn; k >= 1; --k) { int64_t h = 1 << (k - 1); @@ -158,7 +149,7 @@ void PackingHelper::doPackingRLWEs(absl::Span rlwes, continue; } - NegacyclicRightShiftInplace(ct_odd, h, gk_context_); + NegacyclicRightShiftInplace(ct_odd, h, context_); if (!is_even_empty) { seal::Ciphertext tmp = ct_even; @@ -184,14 +175,6 @@ void PackingHelper::doPackingRLWEs(absl::Span rlwes, SPU_ENFORCE(rlwes[0].size() > 0, fmt::format("all empty RLWEs are invalid")); out = rlwes[0]; - - out.parms_id() = [&]() -> seal::parms_id_type { - auto cntxt = context_.first_context_data(); - while ((cntxt->chain_index() + 1) > modulus_for_keyswitch) { - cntxt = cntxt->next_context_data(); - } - return cntxt->parms_id(); - }(); } void GenerateGaloisKeyForPacking(const seal::SEALContext &context, @@ -458,4 +441,4 @@ size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, return out_sze; } -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/rlwe/packlwes.h b/libspu/mpc/cheetah/rlwe/packlwes.h index c9337cfb..d5a1a6db 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes.h +++ b/libspu/mpc/cheetah/rlwe/packlwes.h @@ -26,8 +26,8 @@ void GenerateGaloisKeyForPacking(const seal::SEALContext &context, GaloisKeys *out); class PackingHelper { public: - PackingHelper(size_t gap, const seal::GaloisKeys &galois_keys, - const seal::SEALContext &gk_context, + PackingHelper(size_t gap, size_t num_modulus_for_packing, + const seal::GaloisKeys &galois_keys, const seal::SEALContext &context); // require ct_array.size() == gap @@ -39,8 +39,8 @@ class PackingHelper { void doPackingRLWEs(absl::Span rlwes, RLWECt &out) const; size_t gap_; + size_t num_modulus_for_packing_; const seal::GaloisKeys &galois_keys_; - const seal::SEALContext &gk_context_; const seal::SEALContext &context_; std::vector inv_gap_; @@ -53,4 +53,4 @@ size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, const seal::SEALContext &context, absl::Span rlwes); -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/rlwe/packlwes_test.cc b/libspu/mpc/cheetah/rlwe/packlwes_test.cc index 296c5a17..201f8f66 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes_test.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes_test.cc @@ -4,14 +4,13 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "libspu/mpc/cheetah/rlwe/packlwes.h" #include @@ -120,7 +119,8 @@ TEST_P(PackLWEsTest, PackRLWEs) { InvNttInplace(rlwes[i], *N_context_); } - PackingHelper ph(num_rlwes, *galois_, *N_context_, *N_context_); + PackingHelper ph(num_rlwes, N_ms_helper_->coeff_modulus_size(), *galois_, + *N_context_); RLWECt packed; ph.PackingWithModulusDrop(absl::MakeSpan(rlwes), packed); @@ -371,4 +371,4 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, out->parms_id() = msh_->parms_id(); out->scale() = 1.; } -} // namespace spu::mpc::cheetah::test \ No newline at end of file +} // namespace spu::mpc::cheetah::test diff --git a/spu/tests/jnp_testbase.py b/spu/tests/jnp_testbase.py index 662ddcbc..0fef91f4 100644 --- a/spu/tests/jnp_testbase.py +++ b/spu/tests/jnp_testbase.py @@ -320,15 +320,7 @@ def post(x): REC("remainder", 2, number_dtypes, all_shapes, jtu.rand_small), REC("mod", 2, number_dtypes, all_shapes, jtu.rand_nonzero), REC("modf", 1, number_dtypes, all_shapes, rand_default), - REC( - "rint", - 1, - float_dtypes, - all_shapes, - rand_default, - Status.SysError, - "stablehlo.round_nearest_even", - ), # FIXME: stablehlo.round_nearest_even + REC("rint", 1, float_dtypes, all_shapes, rand_default), REC("sign", 1, number_dtypes, all_shapes, jtu.rand_default), REC( "copysign", 2, number_dtypes, all_shapes, rand_default, Status.SysError, "shift"