Skip to content

Commit

Permalink
Merge pull request #14 from shaojian-ant/main
Browse files Browse the repository at this point in the history
repo-sync-2024-08-08T19:24:12+0800
  • Loading branch information
shaojian-ant authored Aug 8, 2024
2 parents 3a359da + 2736319 commit 4dc46f7
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 45 deletions.
42 changes: 28 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ interconnection-impl 引用了 spu 仓库代码,需要根据[ spu 构建前提
然后执行以下构建指令:

```shell
bazel build ic_impl/ic_main
bazel build -c opt ic_impl/ic_main
```

## 运行 ECDH-PSI
Expand All @@ -19,13 +19,13 @@ bazel build ic_impl/ic_main
本地同时执行以下两条指令:

```shell
bazel run ic_impl/ic_main -- -rank=0 -algo=ECDH_PSI -protocol_families=ECC \
bazel run -c opt ic_impl/ic_main -- -rank=0 -algo=ECDH_PSI -protocol_families=ECC \
-in_path ic_impl/data/psi_1.csv -field_names id -out_path /tmp/p1.out \
-parties=127.0.0.1:9530,127.0.0.1:9531
```

```shell
bazel run ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \
bazel run -c opt ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \
-in_path ic_impl/data/psi_2.csv -field_names id -out_path /tmp/p2.out \
-parties=127.0.0.1:9530,127.0.0.1:9531
```
Expand All @@ -38,7 +38,7 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \

程序运行需要关闭握手过程:
```shell
bazel run ic_impl/ic_main -- -disable_handshake=1
bazel run -c opt ic_impl/ic_main -- -disable_handshake=1
```

ECDH-PSI 算法配置的环境变量如下表所示。环境变量设置可参考 [ecdh-psi-env-alice.sh](./ic_impl/env/ecdh-psi-env-alice.sh)[ecdh-psi-env-bob.sh](./ic_impl/env/ecdh-psi-env-bob.sh)
Expand All @@ -62,30 +62,43 @@ ECDH-PSI 算法配置的环境变量如下表所示。环境变量设置可参

### 启动 Beaver 服务

运行 SS-LR 之前,需要先启动 Beaver 服务。Beaver 服务的代码位于 SPU 仓库中,需要将 SPU 代码克隆到本地,然后编译并启动 Beaver
服务:
运行 SS-LR 之前,需要先启动 Beaver 服务。Beaver 服务的代码位于 SPU 仓库中,需要将 SPU 代码克隆到本地并编译:

```shell
git clone [email protected]:secretflow/spu.git
cd spu && bazel run libspu/mpc/semi2k/beaver/ttp_server:beaver_server_main -- -port=9449
cd spu && bazel build -c opt libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server_main
```

然后生成 Beaver 服务的公钥和私钥:

```
bazel-bin/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main -gen_key=true
```

最后启动 Beaver 服务,将上一步生成的私钥通过命令行参数传递给 Beaver 服务:

```
bazel-bin/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main -port=9449 -server_private_key=LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JR0lBZ0VBTUJRR0NDcUJITTlWQVlJdEJnZ3FnUnpQVlFHQ0xRUnRNR3NDQVFFRUlJVnRVS1JEalVERFptZ3cKL0xUd0dYUmZXVFM5MStTSEhqODAwNnc2SUUxNW9VUURRZ0FFdER5RHNLM0RQN3YyWmdEdjZYNVQySnMzdGtmNQpPYXVBUEdXTHErTlhuMW1HYkd5N3pIZEVaa0FvNERDSGZyRmVuRWFCckxXMFZxUUtUY3QxUzJUYXpnPT0KLS0tLS1FTkQgUFJJVkFURSBLRVktLS0tLQo=
```

### 命令行传参

启动 Beaver 服务后,本地同时执行以下两条指令:
启动 Beaver 服务后,本地同时执行以下两条指令,命令行参数包括上述 Beaver 服务生成的公钥

```shell
bazel run ic_impl/ic_main -- -rank=0 -algo=SS_LR -protocol_families=SS \
bazel run -c opt ic_impl/ic_main -- -rank=0 -algo=SS_LR -protocol_families=SS \
-dataset=ic_impl/data/perfect_logit_a.csv -has_label=true \
-use_ttp=true -ttp_server_host=127.0.0.1:9449 \
-parties=127.0.0.1:9530,127.0.0.1:9531
-parties=127.0.0.1:9530,127.0.0.1:9531 -ttp_asym_crypto_schema=sm2 \
-ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==
```

```shell
bazel run ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \
bazel run -c opt ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \
-dataset=ic_impl/data/perfect_logit_b.csv -has_label=false \
-use_ttp=true -ttp_server_host=127.0.0.1:9449 \
-parties=127.0.0.1:9530,127.0.0.1:9531
-parties=127.0.0.1:9530,127.0.0.1:9531 -ttp_asym_crypto_schema=sm2 \
-ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==
```

### 环境变量传参
Expand All @@ -96,7 +109,7 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \

程序运行需要关闭握手过程:
```shell
bazel run ic_impl/ic_main -- -disable_handshake=1
bazel run -c opt ic_impl/ic_main -- -disable_handshake=1
```

SS-LR 算法配置的环境变量如下表所示。环境变量设置可参考 [ss-lr-env-alice.sh](./ic_impl/env/ss-lr-env-alice.sh)[ss-lr-env-bob.sh](./ic_impl/env/ss-lr-env-bob.sh)
Expand All @@ -121,7 +134,8 @@ SS-LR 算法配置的环境变量如下表所示。环境变量设置可参考 [
| runtime.component.parameter.shard_serialize_format | raw | serialization format used for communicating secret shares |
| runtime.component.parameter.use_ttp | true | whether to use beaver service |
| runtime.component.parameter.ttp_server_host | ip:port | remote ip:port or load-balance uri of beaver service |
| runtime.component.parameter.ttp_session_id | interconnection-root | session id of beaver service |
| runtime.component.parameter.ttp_asym_crypto_schema | sm2 | asym_crypto_schema of beaver service |
| runtime.component.parameter.ttp_public_key | | public key of beaver service |
| runtime.component.parameter.ttp_adjust_rank | 0 | which rank do adjust rpc call to beaver service |
| system.storage.host.url | file://path/to/root | root path of input/output file |
| runtime.component.input.train_data | {"namespace":"data","name":"perfect_logit_a.csv"} | relative path and name of input file |
Expand Down
7 changes: 3 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ protocol_deps()
load("//bazel:repositories.bzl", "ic_impl_deps")
ic_impl_deps()

load("@psi//bazel:repositories.bzl", "psi_deps")

psi_deps()

# spu
load("@spulib//bazel:repositories.bzl", "spu_deps")
spu_deps()

load("@psi//bazel:repositories.bzl", "psi_deps")
psi_deps()

# yacl
load("@yacl//bazel:repositories.bzl", "yacl_deps")
yacl_deps()
Expand Down
22 changes: 9 additions & 13 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

SECRETFLOW_GIT = "https://github.com/secretflow"

SPU_COMMIT_ID = "8bf3c97da503f1cffd1292c8e365ecbc30675400"

PSI_COMMIT_ID = "330623e9eb42d92ef9701d352af947b70c2c9e7c"
SPU_COMMIT_ID = "74c6d54e5bff9e5d35ab5a73dda98cd9ccf0bfc8"

IC_COMMIT_ID = "30e4220b7444d0bb077a9040f1b428632124e31a"

Expand All @@ -35,13 +34,6 @@ def ic_impl_deps():
remote = "{}/{}.git".format(SECRETFLOW_GIT, SPU_REPOSITORY),
)

maybe(
git_repository,
name = "psi",
commit = PSI_COMMIT_ID,
remote = "{}/psi.git".format(SECRETFLOW_GIT),
)

def protocol_deps():
maybe(
git_repository,
Expand All @@ -52,8 +44,12 @@ def protocol_deps():

def _com_github_nlohmann_json():
maybe(
git_repository,
http_archive,
name = "com_github_nlohmann_json",
commit = "5d2754306d67d1e654a1a34e1d2e74439a9d53b3",
remote = "[email protected]:nlohmann/json.git",
sha256 = "0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406",
strip_prefix = "json-3.11.3",
type = "tar.gz",
urls = [
"https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz",
],
)
9 changes: 4 additions & 5 deletions ic_impl/algo/lr/lr_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ bool LrHandler::ProcessHandshakeResponse(const HandshakeResponseV2& response) {
YACL_ENFORCE(ss_param.triple_config().version() ==
ctx_->ttp_config.ttp_server_version);
ctx_->ttp_config.ttp_server_host = ss_param.triple_config().server_host();
ctx_->ttp_config.ttp_session_id = ss_param.triple_config().session_id();
ctx_->ttp_config.ttp_adjust_rank = ss_param.triple_config().adjust_rank();

return true;
Expand Down Expand Up @@ -642,8 +641,6 @@ HandshakeResponseV2 LrHandler::BuildHandshakeResponse() {
ctx_->ttp_config.ttp_server_host);
ss_param.mutable_triple_config()->set_version(
ctx_->ttp_config.ttp_server_version);
ss_param.mutable_triple_config()->set_session_id(
ctx_->ttp_config.ttp_session_id);
ss_param.mutable_triple_config()->set_adjust_rank(
ctx_->ttp_config.ttp_adjust_rank);
response.add_protocol_family_params()->PackFrom(ss_param);
Expand Down Expand Up @@ -755,10 +752,12 @@ std::unique_ptr<spu::SPUContext> LrHandler::MakeSpuContext() {
config.set_beaver_type(spu::RuntimeConfig_BeaverType_TrustedThirdParty);
config.mutable_ttp_beaver_config()->set_server_host(
ctx_->ttp_config.ttp_server_host);
config.mutable_ttp_beaver_config()->set_asym_crypto_schema(
ctx_->ttp_config.ttp_asym_crypto_schema);
config.mutable_ttp_beaver_config()->set_server_public_key(
ctx_->ttp_config.ttp_public_key);
config.mutable_ttp_beaver_config()->set_adjust_rank(
ctx_->ttp_config.ttp_adjust_rank);
config.mutable_ttp_beaver_config()->set_session_id(
ctx_->ttp_config.ttp_session_id);
} else {
config.set_beaver_type(spu::RuntimeConfig_BeaverType_TrustedFirstParty);
}
Expand Down
3 changes: 2 additions & 1 deletion ic_impl/env/ss-lr-env-alice.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ exec env \
"runtime.component.parameter.shard_serialize_format=raw" \
"runtime.component.parameter.use_ttp=true" \
"runtime.component.parameter.ttp_server_host=127.0.0.1:9449" \
"runtime.component.parameter.ttp_session_id=interconnection-root-1" \
"runtime.component.parameter.ttp_asym_crypto_schema=sm2" \
"runtime.component.parameter.ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==" \
"runtime.component.parameter.ttp_adjust_rank=0" \
"runtime.component.parameter.label_owner=host.0" \
'runtime.component.parameter.feature_nums={"host.0":10, "guest.0":10}' \
Expand Down
3 changes: 2 additions & 1 deletion ic_impl/env/ss-lr-env-bob.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ exec env \
"runtime.component.parameter.shard_serialize_format=raw" \
"runtime.component.parameter.use_ttp=true" \
"runtime.component.parameter.ttp_server_host=127.0.0.1:9449" \
"runtime.component.parameter.ttp_session_id=interconnection-root-1" \
"runtime.component.parameter.ttp_asym_crypto_schema=sm2" \
"runtime.component.parameter.ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==" \
"runtime.component.parameter.ttp_adjust_rank=0" \
"runtime.component.parameter.label_owner=host.0" \
'runtime.component.parameter.feature_nums={"host.0":10, "guest.0":10}' \
Expand Down
1 change: 1 addition & 0 deletions ic_impl/protocol_family/ss/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cc_library(
deps = [
"//ic_impl:util",
"//ic_impl:handshake_cc_proto",
"@com_github_brpc_brpc//:brpc",
"@com_github_gflags_gflags//:gflags",
]
)
21 changes: 16 additions & 5 deletions ic_impl/protocol_family/ss/ss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "ic_impl/protocol_family/ss/ss.h"

#include "butil/base64.h"
#include "gflags/gflags.h"

#include "ic_impl/util.h"
Expand All @@ -33,8 +34,9 @@ DEFINE_bool(use_ttp, false, "whether use trusted third party's beaver service");
DEFINE_string(
ttp_server_host, "127.0.0.1:9449",
"trustedThirdParty beaver server's remote ip:port or load-balance uri");
DEFINE_string(ttp_session_id, "interconnection-root",
"trustedThirdParty beaver server's session id");
DEFINE_string(ttp_asym_crypto_schema, "sm2",
"asym_crypto_schema: support [\"SM2\"]");
DEFINE_string(ttp_public_key, "", "TTP public key");
DEFINE_int32(ttp_adjust_rank, 0, "which rank do adjust rpc call");

namespace ic_impl::protocol_family::ss {
Expand Down Expand Up @@ -77,8 +79,16 @@ std::string SuggestedTtpServerHost() {
return util::GetParamEnv("ttp_server_host", FLAGS_ttp_server_host);
}

std::string SuggestedTtpSessionId() {
return util::GetParamEnv("ttp_session_id", FLAGS_ttp_session_id);
std::string SuggestedTtpAsymCryptoSchema() {
return util::GetParamEnv("ttp_asym_crypto_schema",
FLAGS_ttp_asym_crypto_schema);
}

std::string SuggestedTtpPublicKey() {
std::string ret;
auto pk = util::GetParamEnv("ttp_public_key", FLAGS_ttp_public_key);
YACL_ENFORCE(butil::Base64Decode(pk, &ret));
return ret;
}

int32_t SuggestedTtpAdjustRank() {
Expand All @@ -102,7 +112,8 @@ TrustedThirdPartyConfig SuggestedTtpConfig() {
TrustedThirdPartyConfig ttp_config;
ttp_config.use_ttp = SuggestedUseTtp();
ttp_config.ttp_server_host = SuggestedTtpServerHost();
ttp_config.ttp_session_id = SuggestedTtpSessionId();
ttp_config.ttp_asym_crypto_schema = SuggestedTtpAsymCryptoSchema();
ttp_config.ttp_public_key = SuggestedTtpPublicKey();
ttp_config.ttp_adjust_rank = SuggestedTtpAdjustRank();

return ttp_config;
Expand Down
5 changes: 3 additions & 2 deletions ic_impl/protocol_family/ss/ss.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ struct SsProtocolParam {
struct TrustedThirdPartyConfig {
bool use_ttp;
std::string ttp_server_host;
int32_t ttp_server_version = 1;
std::string ttp_session_id;
int32_t ttp_server_version = 2;
std::string ttp_asym_crypto_schema;
std::string ttp_public_key;
int32_t ttp_adjust_rank;
};

Expand Down

0 comments on commit 4dc46f7

Please sign in to comment.