Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for Shamir-based Protocol #880

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .vscode/launch.json
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please delete this file

Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件不需要进主线吧?

// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug",
"program": "${workspaceFolder}/bazel-bin/libspu/mpc/shamir/protocol_test",
"args": ["--gtest_filter=Shamir/ArithmeticTest.A2P/FM32x3"],
"cwd": "${workspaceFolder}"
}
]
}
5 changes: 5 additions & 0 deletions libspu/core/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ void populateRuntimeConfig(RuntimeConfig& cfg) {
cfg.set_sigmoid_mode(RuntimeConfig::SIGMOID_REAL);
}

// shamir threshold
if (cfg.protocol() == ProtocolKind::SHAMIR && cfg.sss_threshold() == 0) {
SPU_THROW("shamir secret sharing threshold must be set");
}

// MPC related configurations
// trunc_allow_msb_error // by pass.
}
Expand Down
136 changes: 136 additions & 0 deletions libspu/core/encoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,140 @@ void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits,
});
}

NdArrayRef encodeToGfmp(const PtBufferView& bv, FieldType field,
size_t fxp_bits, DataType* out_dtype) {
const PtType pt_type = bv.pt_type;
const size_t numel = bv.shape.numel();
NdArrayRef dst(makeType<GfmpTy>(field), bv.shape);
const auto* dst_ty = dst.eltype().as<GfmpTy>();

if (out_dtype != nullptr) {
*out_dtype = getEncodeType(pt_type);
}

if (pt_type == PT_F32 || pt_type == PT_F64 || pt_type == PT_F16) {
DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() {
DISPATCH_ALL_FIELDS(field, [&]() {
using Float = ScalarT;

using U = ring2k_t;
using S = std::make_signed_t<ring2k_t>;
const auto p = static_cast<U>(dst_ty->p());
const S max_positve = p >> 1;
auto min_negetive = -max_positve;

// We have a Mersenne prime like p = 2^k -1, then encode integer in
// range [-2^(k-1)-1,2^(k-1)-1] to [0, 2^k -2]

const S kScale = S(1) << fxp_bits;
const auto kFlpUpper =
static_cast<Float>(static_cast<double>(max_positve) / kScale);
const auto kFlpLower =
static_cast<Float>(static_cast<double>(min_negetive) / kScale);

auto _dst = NdArrayView<S>(dst);

pforeach(0, numel, [&](int64_t idx) {
auto src_value = bv.get<Float>(idx);
S dst_val;
if (std::isnan(src_value)) {
// see numpy.nan_to_num
// note(jint) I dont know why nan could be
// encoded as zero..
dst_val = 0;
} else if (src_value >= kFlpUpper) {
dst_val = max_positve;
} else if (src_value <= kFlpLower) {
dst_val = min_negetive;
} else {
dst_val = src_value * kScale;
}
dst_val = dst_val >= 0 ? dst_val : dst_val + p;
_dst[idx] = static_cast<U>(dst_val);
});
});
});

return dst;
} else {
// handle integer & boolean
DISPATCH_INT_PT_TYPES(pt_type, [&]() {
DISPATCH_ALL_FIELDS(field, [&]() {
using Integer = ScalarT;
SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(Integer),
"integer encoding failed, ring={} could not represent {}",
field, pt_type);
using U = ring2k_t;
using S = std::make_signed_t<ring2k_t>;
const auto p = static_cast<U>(dst_ty->p());
const S max_positve = p >> 1;
auto min_negetive = -max_positve;

// We have a Mersenne prime like p = 2^k -1, then encode integer in
// range [-2^(k-1)-1, 2^(k-1)-1] to [0, 2^k - 2]

auto _dst = NdArrayView<U>(dst);
pforeach(0, numel, [&](int64_t idx) {
// the cast is safe for all valid inputs
auto src_value = static_cast<S>(bv.get<Integer>(idx));
src_value = std::clamp<S>(src_value, min_negetive, max_positve);
src_value = src_value >= 0 ? src_value : src_value + p;
_dst[idx] = static_cast<U>(src_value);
});
});
});

return dst;
}

SPU_THROW("should not be here");
}

void decodeFromGfmp(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits,
PtBufferView* out_bv, PtType* out_pt_type) {
const Type& src_type = src.eltype();
SPU_ENFORCE(src_type.isa<GfmpTy>(), "should be gfmp type but got={}",
src_type);
const FieldType field = src_type.as<Ring2k>()->field();
const PtType pt_type = getDecodeType(in_dtype);
const size_t numel = src.numel();

if (out_pt_type != nullptr) {
*out_pt_type = pt_type;
}

DISPATCH_ALL_FIELDS(field, [&]() {
DISPATCH_ALL_PT_TYPES(pt_type, [&]() {
using U = ring2k_t;
using S = std::make_signed_t<ring2k_t>;
const auto p = static_cast<U>(src_type.as<GfmpTy>()->p());
const auto max_positve = p >> 1;

auto _src = NdArrayView<U>(src);

if (in_dtype == DT_I1) {
pforeach(0, numel, [&](int64_t idx) {
bool value = !((_src[idx] & 0x1) == 0);
out_bv->set<bool>(idx, value);
});
} else if (in_dtype == DT_F32 || in_dtype == DT_F64 ||
in_dtype == DT_F16) {
const S kScale = S(1) << fxp_bits;
pforeach(0, numel, [&](int64_t idx) {
S dst_val = _src[idx] > max_positve ? _src[idx] - p : _src[idx];
auto value =
static_cast<ScalarT>(static_cast<double>(dst_val) / kScale);
out_bv->set<ScalarT>(idx, value);
});
} else {
pforeach(0, numel, [&](int64_t idx) {
S dst_val = _src[idx] > max_positve ? _src[idx] - p : _src[idx];
auto value = static_cast<ScalarT>(dst_val);
out_bv->set<ScalarT>(idx, value);
});
}
});
});
}

} // namespace spu
6 changes: 6 additions & 0 deletions libspu/core/encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,10 @@ NdArrayRef encodeToRing(const PtBufferView& src, FieldType field,
void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits,
PtBufferView* out_bv, PtType* out_pt_type = nullptr);

NdArrayRef encodeToGfmp(const PtBufferView& src, FieldType field,
size_t fxp_bits, DataType* out_dtype = nullptr);

void decodeFromGfmp(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits,
PtBufferView* out_bv, PtType* out_pt_type = nullptr);

} // namespace spu
Loading