Skip to content

Commit

Permalink
repo-sync-2024-09-20T11:43:39+0800 (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Sep 20, 2024
1 parent aad59d2 commit 69dadb2
Show file tree
Hide file tree
Showing 55 changed files with 867 additions and 444 deletions.
6 changes: 3 additions & 3 deletions docs/development/add_protocols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ member function and a member variable of an Object, respectively.
// register customized kernels
template <typename KernelT>
void regKernel() {
regKernel(KernelT::kBindName, std::make_unique<KernelT>());
regKernel(KernelT::kBindName(), std::make_unique<KernelT>());
}

template <typename KernelT>
Expand All @@ -116,7 +116,7 @@ member function and a member variable of an Object, respectively.
// add customized states
template <typename StateT, typename... Args>
void addState(Args&&... args) {
addState(StateT::kBindName,
addState(StateT::kBindName(),
std::make_unique<StateT>(std::forward<Args>(args)...));
}
...
Expand Down Expand Up @@ -205,7 +205,7 @@ As a result, the ABY3 developer can directly register these kernels through the
class AndPP : public BinaryKernel {
public:
// kernel name for dynamic binding
static constexpr char kBindName[] = "and_pp";
static constexpr const char* kBindName() { return "and_pp"; }

// define cost model
ce::CExpr latency() const override { return ce::Const(0); }
Expand Down
9 changes: 5 additions & 4 deletions libspu/core/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Object final {

template <typename KernelT>
void regKernel() {
regKernel(KernelT::kBindName, std::make_unique<KernelT>());
regKernel(KernelT::kBindName(), std::make_unique<KernelT>());
}

template <typename KernelT, typename OtherKernelT, typename... MoreKernelT>
Expand All @@ -137,14 +137,15 @@ class Object final {

template <typename StateT, typename... Args>
void addState(Args&&... args) {
addState(StateT::kBindName,
addState(StateT::kBindName(),
std::make_unique<StateT>(std::forward<Args>(args)...));
}

template <typename StateT>
StateT* getState() {
const auto& itr = states_.find(StateT::kBindName);
SPU_ENFORCE(itr != states_.end(), "state={} not found", StateT::kBindName);
const auto& itr = states_.find(StateT::kBindName());
SPU_ENFORCE(itr != states_.end(), "state={} not found",
StateT::kBindName());
return dynamic_cast<StateT*>(itr->second.get());
}

Expand Down
1 change: 0 additions & 1 deletion libspu/kernel/hal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ spu_cc_library(
hdrs = ["utils.h"],
deps = [
":constants",
":polymorphic",
":ring",
":shape_ops",
"//libspu/core:prelude",
Expand Down
12 changes: 0 additions & 12 deletions libspu/kernel/hal/fxp_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@
namespace spu::kernel::hal {
namespace detail {

Value EvaluatePolynomial(SPUContext* ctx, const Value& x,
absl::Span<const float> coefficients) {
auto poly = constant(ctx, coefficients[0], x.dtype(), x.shape());

for (size_t i = 1; i < coefficients.size(); ++i) {
auto c = constant(ctx, coefficients[i], x.dtype(), x.shape());
poly = f_mul(ctx, poly, x);
poly = f_add(ctx, poly, c);
}
return poly;
}

Value log_minmax_normalized(SPUContext* ctx, const Value& x) {
static std::array<float, 9> kLogCoefficient{
0.0, 0.9999964239, -0.4998741238, 0.3317990258, -0.2407338084,
Expand Down
39 changes: 23 additions & 16 deletions libspu/kernel/hal/fxp_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,34 @@ Value polynomial(SPUContext* ctx, const Value& x,
SPU_ENFORCE(x.isFxp());
SPU_ENFORCE(!coeffs.empty());

if (coeffs.size() == 1U) {
if (coeffs.size() == 1U || x.numel() == 0) {
return coeffs[0];
}
Value x_pow = constant(ctx, 1.0F, x.dtype(), x.shape());
Value res = _mul(ctx, x_pow, coeffs[0]);
// Use a parallel circuit to calculate x, x^2, x^3, ..., x^n.
// The general log(n) algorithm
// algorithm:
// Step 0. x
// Step 1. x, x2
// Step 2. x, x2, x3, x4
// ...
std::vector<spu::Value> x_prefix(1, x);
size_t degree = coeffs.size() - 1;
for (int64_t i = 0; i < Log2Ceil(degree); ++i) {
size_t x_size = std::min(x_prefix.size(), degree - x_prefix.size());
std::vector<spu::Value> x_pow(x_size, x_prefix.back());
// TODO: this can be further optimized to use sign hint
vmap(x_prefix.begin(), x_prefix.begin() + x_size, x_pow.begin(),
x_pow.end(), std::back_inserter(x_prefix),
[ctx, sign_x](const Value& a, const Value& b) {
return f_mul(ctx, a, b, sign_x);
});
}

Value res = _mul(ctx, constant(ctx, 1.0F, x.dtype(), x.shape()), coeffs[0]);

const auto fbits = ctx->getFxpBits();
for (size_t i = 1; i < coeffs.size(); i++) {
if ((i & 1) == 0U) {
// x^{even order} is always positive
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, SignType::Positive);
} else {
if (i > 1) {
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, sign_x);
} else {
// i=1, then save a _trunc
x_pow = x;
}
}
res = _add(ctx, res, _mul(ctx, x_pow, coeffs[i]));
res = _add(ctx, res, _mul(ctx, x_prefix[i - 1], coeffs[i]));
}

return _trunc(ctx, res, fbits, sign_ret).setDtype(x.dtype());
Expand Down Expand Up @@ -93,7 +101,6 @@ Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) {
}

namespace {

Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx,
const Value& b_abs,
const Value& factor) {
Expand Down
2 changes: 1 addition & 1 deletion libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ std::vector<spu::Value> permute(SPUContext *ctx,
for (auto const &input : inputs) {
auto transposed = hal::transpose(ctx, input, perm);
auto reshaped = hal::reshape(ctx, transposed, {N, W});
inputs2d.push_back(reshaped);
inputs2d.push_back(std::move(reshaped));
}

// Call permute1d for each dim to permute.
Expand Down
117 changes: 77 additions & 40 deletions libspu/kernel/hal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@

#include "libspu/core/context.h"
#include "libspu/core/value.h"
#include "libspu/core/vectorize.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/ring.h"
#include "libspu/kernel/hal/shape_ops.h"

namespace spu::kernel::hal {

//////////////////////////////////////////////////////////////////////////////
// Shape utils
//////////////////////////////////////////////////////////////////////////////

/// the squeeze function, i.e., removes dimensions of size 1 from the shape of
/// a tensor.
// @param in, the input
// @param dim, the dimension to be squeezed
Value squeeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

/// the unsqueeze function, i.e., expands a tensor with a length 1 axis
/// inserted at index axis.
// @param in, the input
// @param dim, the dimension to be unsqueezed
Value unsqueeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

// This is SPU's version of JAX's associative_scan
// See:
// https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html
Expand All @@ -32,61 +48,82 @@ namespace spu::kernel::hal {
// for the detailed algorithm explanation
//
// fn: an associative binary Function
// in: a 1-d tensor
// in: a tensor, scan the last axis
template <typename Fn>
spu::Value associative_scan(Fn&& fn, SPUContext* ctx, const Value& in) {
SPU_ENFORCE(in.shape().ndim() == 1U, "input should be 1d");
const auto numel = in.numel();
if (numel < 2) {
SPU_ENFORCE(in.shape().ndim() >= 1U, "input should not be scalar");
// First reshape to 2D {M, N} tensor, scan each N elements
const Shape shape = in.shape();
const auto N = shape.back();
// in case some empty tensors
if (N < 2 || shape.numel() == 0) {
return in;
}
const auto M = shape.numel() / N;
spu::Value in_2d = hal::reshape(ctx, in, {M, N});

// merge consecutive even/odd index elements
auto reduced_elems = fn(ctx, hal::slice(ctx, in, {0}, {numel - 1}, {2}),
hal::slice(ctx, in, {1}, {numel}, {2}));
// process half elements recursively and get odd index elements
auto odd_elems = associative_scan(fn, ctx, reduced_elems);
spu::Value odd_elems;
std::vector<spu::Value> odd_vec;
std::vector<spu::Value> even_vec;
{
for (int64_t i = 0; i < M; ++i) {
odd_vec.push_back(hal::slice(ctx, in_2d, {i, 0}, {i + 1, N - 1}, {1, 2}));
even_vec.push_back(hal::slice(ctx, in_2d, {i, 1}, {i + 1, N}, {1, 2}));
}
std::vector<spu::Value> reduced_elems_vec;
vmap(odd_vec.begin(), odd_vec.end(), even_vec.begin(), even_vec.end(),
std::back_inserter(reduced_elems_vec),
[&](const spu::Value& odd, const spu::Value& even) {
return fn(ctx, odd, even);
});

auto concat_reduced_elems = hal::concatenate(ctx, reduced_elems_vec, 0);

// process half elements recursively and get odd index elements
odd_elems = associative_scan(fn, ctx, concat_reduced_elems);
}

// get even index elements
odd_vec.clear();
even_vec.clear();
spu::Value even_elems;
if (numel % 2 == 0) {
even_elems =
fn(ctx, hal::slice(ctx, odd_elems, {0}, {odd_elems.numel() - 1}, {1}),
hal::slice(ctx, in, {2}, {numel}, {2}));
} else {
even_elems = fn(ctx, odd_elems, hal::slice(ctx, in, {2}, {numel}, {2}));
{
std::vector<spu::Value> even_elems_vec;
for (int64_t i = 0; i < M; ++i) {
if (N % 2 == 0) {
odd_vec.push_back(hal::slice(ctx, odd_elems, {i, 0},
{i + 1, odd_elems.shape().back() - 1},
{1, 1}));
} else {
odd_vec.push_back(hal::slice(ctx, odd_elems, {i, 0},
{i + 1, odd_elems.shape().back()}, {}));
}
even_vec.push_back(hal::slice(ctx, in_2d, {i, 2}, {i + 1, N}, {1, 2}));
}
vmap(odd_vec.begin(), odd_vec.end(), even_vec.begin(), even_vec.end(),
std::back_inserter(even_elems_vec),
[&](const spu::Value& odd, const spu::Value& even) {
return fn(ctx, odd, even);
});

even_elems = hal::concatenate(ctx, even_elems_vec, 0);
}
// concat the 0th element
auto final_even_elems =
hal::concatenate(ctx, {hal::slice(ctx, in, {0}, {1}), even_elems}, 0);
auto final_even_elems = hal::concatenate(
ctx, {hal::slice(ctx, in_2d, {0, 0}, {M, 1}), even_elems}, 1);

// concat even and odd elems interleavely
auto zero = hal::constant(ctx, 0U, in.dtype(), {1});
auto pad_even =
hal::pad(ctx, final_even_elems, zero, {0},
{final_even_elems.numel() == odd_elems.numel() ? 1 : 0}, {1});
auto pad_odd =
hal::pad(ctx, odd_elems, zero, {1},
{final_even_elems.numel() == odd_elems.numel() ? 0 : 1}, {1});
auto pad_even = hal::pad(
ctx, final_even_elems, zero, {0, 0},
{0, final_even_elems.numel() == odd_elems.numel() ? 1 : 0}, {0, 1});
auto pad_odd = hal::pad(
ctx, odd_elems, zero, {0, 1},
{0, final_even_elems.numel() == odd_elems.numel() ? 0 : 1}, {0, 1});

auto ret = hal::_add(ctx, pad_even, pad_odd).setDtype(in.dtype());
return ret;
return hal::reshape(ctx, ret, in.shape());
}

//////////////////////////////////////////////////////////////////////////////
// Shape utils
//////////////////////////////////////////////////////////////////////////////

/// the squeeze function, i.e., removes dimensions of size 1 from the shape of a
/// tensor.
// @param in, the input
// @param dim, the dimension to be squeezed
Value squeeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

/// the unsqueeze function, i.e., expands a tensor with a length 1 axis inserted
/// at index axis.
// @param in, the input
// @param dim, the dimension to be unsqueezed
Value unsqueeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

} // namespace spu::kernel::hal
58 changes: 57 additions & 1 deletion libspu/kernel/hal/utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace spu::kernel::hal {
namespace {

TEST(UtilsTest, associative_scan) {
TEST(UtilsTest, associative_scan_1d) {
SPUContext ctx = test::makeSPUContext();

{
Expand Down Expand Up @@ -82,6 +82,62 @@ TEST(UtilsTest, associative_scan) {
}
}

TEST(UtilsTest, associative_scan_2d) {
SPUContext ctx = test::makeSPUContext();

{
const xt::xarray<int32_t> x = {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}};
const xt::xarray<int32_t> prefix_sum = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}};
Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::add, &ctx, a);
auto ret = dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_sum == ret) << x << std::endl
<< prefix_sum << std::endl
<< ret;
}

{
const xt::xarray<int32_t> x = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}};
const xt::xarray<int32_t> prefix_prod = {{1, 2, 6, 24, 120},
{1, 2, 6, 24, 120}};
Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::mul, &ctx, a);
auto ret = dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_prod == ret) << x << std::endl
<< prefix_prod << std::endl
<< ret;
}

{
const xt::xarray<bool> x = {{true, true, true, false, true, false},
{true, true, true, false, true, false}};
const xt::xarray<bool> prefix_and = {
{true, true, true, false, false, false},
{true, true, true, false, false, false}};

Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::bitwise_and, &ctx, a);
auto ret = dump_public_as<bool>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_and == ret) << x << std::endl
<< prefix_and << std::endl
<< ret;
}

{
const xt::xarray<bool> x = {{true, true, true, false, true, false},
{true, true, true, false, true, false}};
const xt::xarray<bool> prefix_or = {{true, true, true, true, true, true},
{true, true, true, true, true, true}};

Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::bitwise_or, &ctx, a);
auto ret = dump_public_as<bool>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_or == ret) << x << std::endl
<< prefix_or << std::endl
<< ret;
}
}

TEST(UtilsTest, Squeeze) {
// GIVEN
xt::xarray<int32_t> x = xt::ones<int32_t>({2, 1, 2, 1, 2});
Expand Down
Loading

0 comments on commit 69dadb2

Please sign in to comment.