Skip to content

Commit

Permalink
Repo sync (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Mar 20, 2024
1 parent 7c7f863 commit 82a8bd6
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 141 deletions.
6 changes: 6 additions & 0 deletions libspu/kernel/hal/fxp_cleartext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,10 @@ Value f_erf_p(SPUContext* ctx, const Value& in) {
return applyFloatingPointFn(ctx, in, [](float x) { return std::erf(x); });
}

Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);
return applyFloatingPointFn(ctx, x, y,
[](float a, float b) { return std::pow(a, b); });
}

} // namespace spu::kernel::hal
2 changes: 2 additions & 0 deletions libspu/kernel/hal/fxp_cleartext.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ Value f_cosine_p(SPUContext* ctx, const Value& in);

Value f_erf_p(SPUContext* ctx, const Value& in);

Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y);

} // namespace spu::kernel::hal
28 changes: 25 additions & 3 deletions libspu/kernel/hal/polymorphic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "libspu/core/trace.h"
#include "libspu/kernel/hal/fxp_approx.h"
#include "libspu/kernel/hal/fxp_base.h"
#include "libspu/kernel/hal/fxp_cleartext.h"
#include "libspu/kernel/hal/integer.h"
#include "libspu/kernel/hal/ring.h" // for fast fxp x int
#include "libspu/kernel/hal/type_cast.h"
Expand Down Expand Up @@ -329,15 +330,36 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) {
Value power(SPUContext* ctx, const Value& x, const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);

if (x.isInt() && y.isInt()) {
if (x.isInt() || y.isInt()) {
auto x_f = dtype_cast(ctx, x, DT_F32);
auto y_f = dtype_cast(ctx, y, DT_F32);
auto ret = power(ctx, x_f, y_f);
return dtype_cast(ctx, ret, x.dtype());
return ret;
}
if (x.isPublic() && y.isPublic()) {
return f_pow_p(ctx, x, y);
}

auto msb = _msb(ctx, x);
auto msb_a = _prefer_a(ctx, msb);
auto x_abs = _mux(ctx, msb_a, _negate(ctx, x), x).setDtype(x.dtype());

// if x=0 is public, then log(x) get -inf, the wrong output will be got after
// multiplying y. So we force x to be secret, then computing log(x) leads to
// a small negative numbers, so exp(y*log(x))=0.
auto x_s = x.isPublic() ? hal::seal(ctx, x_abs) : x_abs;
// x^y = e^(y*ln(x))
return exp(ctx, mul(ctx, y, log(ctx, x)));
// the precision is highly dependent on the precision of exp and log, so we
// choose the most precise methods here.
auto val = detail::exp_pade(ctx, mul(ctx, y, detail::log_minmax(ctx, x_s)));

// the final sign is decided on both sign of x and the parity of y
// when x<0 and y is odd, e.g. (-2)^3 = -8
auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()),
_constant(ctx, 1, y.shape()));
auto sign = _and(ctx, msb, odd);

return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype());
}

Value idiv(SPUContext* ctx, const Value& x, const Value& y) {
Expand Down
46 changes: 31 additions & 15 deletions libspu/kernel/hal/polymorphic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,26 +406,42 @@ TYPED_TEST(MathTest, Pow) {
using LHS_VT = typename std::tuple_element<1, TypeParam>::type;
using RHS_DT = typename std::tuple_element<2, TypeParam>::type;
using RHS_VT = typename std::tuple_element<3, TypeParam>::type;
using RES_DT = typename std::tuple_element<4, TypeParam>::type;
// using RES_DT = typename std::tuple_element<4, TypeParam>::type;

if constexpr (!std::is_same_v<LHS_DT, RHS_DT> ||
!std::is_same_v<LHS_VT, RHS_VT> || std::is_integral_v<RHS_DT>) {
return;
// GIVEN
xt::xarray<LHS_DT> x;
xt::xarray<RHS_DT> y;
{
// random test
x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
y = test::xt_random<RHS_DT>({5, 6}, -2, 2);

// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
<< z << std::endl;
}

// GIVEN
const xt::xarray<LHS_DT> x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
const xt::xarray<RHS_DT> y = test::xt_random<RHS_DT>({5, 6}, 0, 2);
{
// some fixed corner case
x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0};
y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2};

// WHAT
auto z = test::evalBinaryOp<RES_DT>(LHS_VT(), RHS_VT(), power, x, y);
// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
<< z << std::endl;
// THEN
auto expected = xt::pow(x, y);
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
<< z << std::endl;
}
}

using MathUnaryTestTypes = ::testing::Types<
Expand Down
163 changes: 41 additions & 122 deletions libspu/mpc/cheetah/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace spu::mpc::cheetah {
namespace {
// Return num_workers for the given size of jobs
size_t InitOTState(KernelEvalContext* ctx, size_t njobs) {
constexpr size_t kMinWorkSize = 5000;
constexpr size_t kMinWorkSize = 2048;
if (njobs == 0) {
return 0;
}
Expand Down Expand Up @@ -139,86 +139,44 @@ std::array<NdArrayRef, 3> CheetahMulState::TakeCachedBeaver(FieldType field,

NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
OTUnaryFunc func) {
Shape shape = x.shape();
const Shape& shape = x.shape();
SPU_ENFORCE(shape.numel() > 0);
// (lazy) init OT
int64_t numel = x.numel();
int64_t nworker = InitOTState(ctx, numel);
int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker);

int64_t slicing_dim = -1;
int64_t slice_numel = 1;
for (int64_t dim = shape.size() - 1; dim >= 0; dim--) {
slice_numel *= shape[dim];
if (slice_numel > workload) {
slice_numel /= shape[dim];
slicing_dim = dim;
break;
}
}

// get the slice num in the left outer dimensions
int64_t num_slice = 1;
for (int64_t dim = 0; dim < slicing_dim; dim++) {
num_slice *= shape[dim];
}

int64_t slice_stride = (workload + slice_numel - 1) / slice_numel;
if (slice_stride == 1) {
return func(x, ctx->getState<CheetahOTState>()->get(0));
}

int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0);

// initialize slice indices
Index start_indices(shape.size());
Index end_indices(shape.begin(), shape.end());
end_indices[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
end_indices[dim] = 1;
if (shape.ndim() != 1) {
// TiledDispatchOTFunc over flatten input
return TiledDispatchOTFunc(ctx, x.reshape({numel}), func)
.reshape(x.shape());
}

SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker);
nworker = num_slice * num_slice_dim;

std::vector<NdArrayRef> outs(nworker);
std::vector<std::future<void>> futures;

Index sidx = start_indices;
Index eidx = end_indices;
for (int64_t wi = 0; wi < nworker; ++wi) {
auto slice_input = x.slice(sidx, eidx, {});
int64_t slice_end = 0;
for (int64_t wi = 0; wi + 1 < nworker; ++wi) {
int64_t slice_bgn = wi * workload;
slice_end = std::min(numel, slice_bgn + workload);
auto slice_input = x.slice({slice_bgn}, {slice_end}, {});
futures.emplace_back(std::async(
[&](int64_t idx, const NdArrayRef& input) {
auto ot_instance = ctx->getState<CheetahOTState>()->get(idx);
outs[idx] = func(input, ot_instance);
},
wi, slice_input));

// update indices
if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
// carray out
sidx[slicing_dim] = 0;
eidx[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
sidx[dim] = (sidx[dim] + 1) % shape[dim];
eidx[dim] = eidx[dim] % shape[dim] + 1;
if (eidx[dim] != 1) {
break;
}
}
} else {
sidx[slicing_dim] += slice_stride;
eidx[slicing_dim] += slice_stride;
eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]);
}
}

auto slice_input = x.slice({slice_end}, {numel}, {1});
auto ot_instance = ctx->getState<CheetahOTState>()->get(nworker - 1);
outs[nworker - 1] = func(slice_input, ot_instance);

for (auto&& f : futures) {
f.get();
}

NdArrayRef out(x.eltype(), x.shape());
NdArrayRef out(outs[0].eltype(), x.shape());
int64_t offset = 0;

for (auto& out_slice : outs) {
Expand All @@ -232,89 +190,50 @@ NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,

NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
const NdArrayRef& y, OTBinaryFunc func) {
Shape shape = x.shape();
SPU_ENFORCE_EQ(x.shape(), y.shape());
const Shape& shape = x.shape();
SPU_ENFORCE(shape.numel() > 0);
SPU_ENFORCE_EQ(shape, y.shape());
// (lazy) init OT
int64_t numel = x.numel();
int64_t nworker = InitOTState(ctx, numel);
int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker);

int64_t slicing_dim = -1;
int64_t slice_numel = 1;
for (int64_t dim = shape.size() - 1; dim >= 0; dim--) {
slice_numel *= shape[dim];
if (slice_numel > workload) {
slice_numel /= shape[dim];
slicing_dim = dim;
break;
}
if (shape.ndim() != 1) {
// TiledDispatchOTFunc over flatten input
return TiledDispatchOTFunc(ctx, x.reshape({numel}), y.reshape({numel}),
func)
.reshape(x.shape());
}

// get the slice num in the left outer dimensions
int64_t num_slice = 1;
for (int64_t dim = 0; dim < slicing_dim; dim++) {
num_slice *= shape[dim];
}

int64_t slice_stride = (workload + slice_numel - 1) / slice_numel;
if (slice_stride == 1) {
return func(x, y, ctx->getState<CheetahOTState>()->get(0));
}

int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0);

// initialize slice indices
Index start_indices(shape.size());
Index end_indices(shape.begin(), shape.end());
end_indices[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
end_indices[dim] = 1;
}

SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker);
nworker = num_slice * num_slice_dim;

std::vector<NdArrayRef> outs(nworker);
std::vector<std::future<void>> futures;

Index sidx = start_indices;
Index eidx = end_indices;
for (int64_t wi = 0; wi < nworker; ++wi) {
auto x_slice = x.slice(sidx, eidx, {});
auto y_slice = y.slice(sidx, eidx, {});

int64_t slice_end = 0;
for (int64_t wi = 0; wi + 1 < nworker; ++wi) {
int64_t slice_bgn = wi * workload;
slice_end = std::min(numel, slice_bgn + workload);
auto x_slice = x.slice({slice_bgn}, {slice_end}, {1});
auto y_slice = y.slice({slice_bgn}, {slice_end}, {1});
futures.emplace_back(std::async(
[&](int64_t idx, const NdArrayRef& input0, const NdArrayRef& input1) {
[&](int64_t idx, const NdArrayRef& inp0, const NdArrayRef& inp1) {
auto ot_instance = ctx->getState<CheetahOTState>()->get(idx);
outs[idx] = func(input0, input1, ot_instance);
outs[idx] = func(inp0, inp1, ot_instance);
},
wi, x_slice, y_slice));

// update indices
if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
// carray out
sidx[slicing_dim] = 0;
eidx[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
sidx[dim] = (sidx[dim] + 1) % shape[dim];
eidx[dim] = eidx[dim] % shape[dim] + 1;
if (eidx[dim] != 1) {
break;
}
}
} else {
sidx[slicing_dim] += slice_stride;
eidx[slicing_dim] += slice_stride;
eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]);
}
}

auto x_slice = x.slice({slice_end}, {numel}, {});
auto y_slice = y.slice({slice_end}, {numel}, {});
auto ot_instance = ctx->getState<CheetahOTState>()->get(nworker - 1);
outs[nworker - 1] = func(x_slice, y_slice, ot_instance);

for (auto&& f : futures) {
f.get();
}

NdArrayRef out(x.eltype(), x.shape());
NdArrayRef out(outs[0].eltype(), x.shape());
int64_t offset = 0;

for (auto& out_slice : outs) {
std::memcpy(out.data<std::byte>() + offset, out_slice.data(),
out_slice.numel() * out.elsize());
Expand Down
4 changes: 3 additions & 1 deletion libspu/mpc/cheetah/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/rlwe/utils.h"

#include "libspu/spu.pb.h"

namespace spu::mpc::cheetah {

using OTUnaryFunc = std::function<NdArrayRef(
Expand Down Expand Up @@ -101,7 +103,7 @@ class CheetahOTState : public State {

mutable std::mutex lock_;

static constexpr size_t kMaxOTParallel = 24;
static constexpr size_t kMaxOTParallel = 48;

size_t maximum_instances_ = 0;
std::vector<ProtPtr> basic_ot_prot_;
Expand Down

0 comments on commit 82a8bd6

Please sign in to comment.