From 27363193d29f224d6285c1949d5932afeccc140b Mon Sep 17 00:00:00 2001 From: shaojian Date: Thu, 8 Aug 2024 19:24:22 +0800 Subject: [PATCH] repo-sync-2024-08-08T19:24:12+0800 --- README.md | 42 +++++++++++++++++--------- WORKSPACE | 7 ++--- bazel/repositories.bzl | 22 ++++++-------- ic_impl/algo/lr/lr_handler.cc | 9 +++--- ic_impl/env/ss-lr-env-alice.sh | 3 +- ic_impl/env/ss-lr-env-bob.sh | 3 +- ic_impl/protocol_family/ss/BUILD.bazel | 1 + ic_impl/protocol_family/ss/ss.cc | 21 ++++++++++--- ic_impl/protocol_family/ss/ss.h | 5 +-- 9 files changed, 68 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index f47c59c..8f0af2a 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ interconnection-impl 引用了 spu 仓库代码,需要根据[ spu 构建前提 然后执行以下构建指令: ```shell -bazel build ic_impl/ic_main +bazel build -c opt ic_impl/ic_main ``` ## 运行 ECDH-PSI @@ -19,13 +19,13 @@ bazel build ic_impl/ic_main 本地同时执行以下两条指令: ```shell -bazel run ic_impl/ic_main -- -rank=0 -algo=ECDH_PSI -protocol_families=ECC \ +bazel run -c opt ic_impl/ic_main -- -rank=0 -algo=ECDH_PSI -protocol_families=ECC \ -in_path ic_impl/data/psi_1.csv -field_names id -out_path /tmp/p1.out \ -parties=127.0.0.1:9530,127.0.0.1:9531 ``` ```shell -bazel run ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \ +bazel run -c opt ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \ -in_path ic_impl/data/psi_2.csv -field_names id -out_path /tmp/p2.out \ -parties=127.0.0.1:9530,127.0.0.1:9531 ``` @@ -38,7 +38,7 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \ 程序运行需要关闭握手过程: ```shell -bazel run ic_impl/ic_main -- -disable_handshake=1 +bazel run -c opt ic_impl/ic_main -- -disable_handshake=1 ``` ECDH-PSI 算法配置的环境变量如下表所示。环境变量设置可参考 [ecdh-psi-env-alice.sh](./ic_impl/env/ecdh-psi-env-alice.sh) 和 [ecdh-psi-env-bob.sh](./ic_impl/env/ecdh-psi-env-bob.sh) @@ -62,30 +62,43 @@ ECDH-PSI 算法配置的环境变量如下表所示。环境变量设置可参 ### 启动 Beaver 服务 -运行 SS-LR 之前,需要先启动 Beaver 服务。Beaver 服务的代码位于 SPU 仓库中,需要将 SPU 代码克隆到本地,然后编译并启动 Beaver -服务: +运行 SS-LR 之前,需要先启动 Beaver 服务。Beaver 服务的代码位于 SPU 仓库中,需要将 SPU 代码克隆到本地并编译: ```shell git clone git@github.com:secretflow/spu.git -cd spu && bazel run libspu/mpc/semi2k/beaver/ttp_server:beaver_server_main -- -port=9449 +cd spu && bazel build -c opt libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server_main +``` + +然后生成 Beaver 服务的公钥和私钥: + +``` +bazel-bin/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main -gen_key=true +``` + +最后启动 Beaver 服务,将上一步生成的私钥通过命令行参数传递给 Beaver 服务: + +``` +bazel-bin/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main -port=9449 -server_private_key=LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JR0lBZ0VBTUJRR0NDcUJITTlWQVlJdEJnZ3FnUnpQVlFHQ0xRUnRNR3NDQVFFRUlJVnRVS1JEalVERFptZ3cKL0xUd0dYUmZXVFM5MStTSEhqODAwNnc2SUUxNW9VUURRZ0FFdER5RHNLM0RQN3YyWmdEdjZYNVQySnMzdGtmNQpPYXVBUEdXTHErTlhuMW1HYkd5N3pIZEVaa0FvNERDSGZyRmVuRWFCckxXMFZxUUtUY3QxUzJUYXpnPT0KLS0tLS1FTkQgUFJJVkFURSBLRVktLS0tLQo= ``` ### 命令行传参 -启动 Beaver 服务后,本地同时执行以下两条指令: +启动 Beaver 服务后,本地同时执行以下两条指令,命令行参数包括上述 Beaver 服务生成的公钥: ```shell -bazel run ic_impl/ic_main -- -rank=0 -algo=SS_LR -protocol_families=SS \ +bazel run -c opt ic_impl/ic_main -- -rank=0 -algo=SS_LR -protocol_families=SS \ -dataset=ic_impl/data/perfect_logit_a.csv -has_label=true \ -use_ttp=true -ttp_server_host=127.0.0.1:9449 \ - -parties=127.0.0.1:9530,127.0.0.1:9531 + -parties=127.0.0.1:9530,127.0.0.1:9531 -ttp_asym_crypto_schema=sm2 \ + -ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg== ``` ```shell -bazel run ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \ +bazel run -c opt ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \ -dataset=ic_impl/data/perfect_logit_b.csv -has_label=false \ -use_ttp=true -ttp_server_host=127.0.0.1:9449 \ - -parties=127.0.0.1:9530,127.0.0.1:9531 + -parties=127.0.0.1:9530,127.0.0.1:9531 -ttp_asym_crypto_schema=sm2 \ + -ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg== ``` ### 环境变量传参 @@ -96,7 +109,7 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \ 程序运行需要关闭握手过程: ```shell -bazel run ic_impl/ic_main -- -disable_handshake=1 +bazel run -c opt ic_impl/ic_main -- -disable_handshake=1 ``` SS-LR 算法配置的环境变量如下表所示。环境变量设置可参考 [ss-lr-env-alice.sh](./ic_impl/env/ss-lr-env-alice.sh) 和 [ss-lr-env-bob.sh](./ic_impl/env/ss-lr-env-bob.sh) @@ -121,7 +134,8 @@ SS-LR 算法配置的环境变量如下表所示。环境变量设置可参考 [ | runtime.component.parameter.shard_serialize_format | raw | serialization format used for communicating secret shares | | runtime.component.parameter.use_ttp | true | whether to use beaver service | | runtime.component.parameter.ttp_server_host | ip:port | remote ip:port or load-balance uri of beaver service | -| runtime.component.parameter.ttp_session_id | interconnection-root | session id of beaver service | +| runtime.component.parameter.ttp_asym_crypto_schema | sm2 | asym_crypto_schema of beaver service | +| runtime.component.parameter.ttp_public_key | | public key of beaver service | | runtime.component.parameter.ttp_adjust_rank | 0 | which rank do adjust rpc call to beaver service | | system.storage.host.url | file://path/to/root | root path of input/output file | | runtime.component.input.train_data | {"namespace":"data","name":"perfect_logit_a.csv"} | relative path and name of input file | diff --git a/WORKSPACE b/WORKSPACE index 64cedb8..ca92819 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -20,14 +20,13 @@ protocol_deps() load("//bazel:repositories.bzl", "ic_impl_deps") ic_impl_deps() -load("@psi//bazel:repositories.bzl", "psi_deps") - -psi_deps() - # spu load("@spulib//bazel:repositories.bzl", "spu_deps") spu_deps() +load("@psi//bazel:repositories.bzl", "psi_deps") +psi_deps() + # yacl load("@yacl//bazel:repositories.bzl", "yacl_deps") yacl_deps() diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index ad86f63..8399dcc 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -14,12 +14,11 @@ load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") SECRETFLOW_GIT = "https://github.com/secretflow" -SPU_COMMIT_ID = "8bf3c97da503f1cffd1292c8e365ecbc30675400" - -PSI_COMMIT_ID = "330623e9eb42d92ef9701d352af947b70c2c9e7c" +SPU_COMMIT_ID = "74c6d54e5bff9e5d35ab5a73dda98cd9ccf0bfc8" IC_COMMIT_ID = "30e4220b7444d0bb077a9040f1b428632124e31a" @@ -35,13 +34,6 @@ def ic_impl_deps(): remote = "{}/{}.git".format(SECRETFLOW_GIT, SPU_REPOSITORY), ) - maybe( - git_repository, - name = "psi", - commit = PSI_COMMIT_ID, - remote = "{}/psi.git".format(SECRETFLOW_GIT), - ) - def protocol_deps(): maybe( git_repository, @@ -52,8 +44,12 @@ def protocol_deps(): def _com_github_nlohmann_json(): maybe( - git_repository, + http_archive, name = "com_github_nlohmann_json", - commit = "5d2754306d67d1e654a1a34e1d2e74439a9d53b3", - remote = "git@github.com:nlohmann/json.git", + sha256 = "0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406", + strip_prefix = "json-3.11.3", + type = "tar.gz", + urls = [ + "https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz", + ], ) diff --git a/ic_impl/algo/lr/lr_handler.cc b/ic_impl/algo/lr/lr_handler.cc index 4fe639f..1f66c04 100644 --- a/ic_impl/algo/lr/lr_handler.cc +++ b/ic_impl/algo/lr/lr_handler.cc @@ -286,7 +286,6 @@ bool LrHandler::ProcessHandshakeResponse(const HandshakeResponseV2& response) { YACL_ENFORCE(ss_param.triple_config().version() == ctx_->ttp_config.ttp_server_version); ctx_->ttp_config.ttp_server_host = ss_param.triple_config().server_host(); - ctx_->ttp_config.ttp_session_id = ss_param.triple_config().session_id(); ctx_->ttp_config.ttp_adjust_rank = ss_param.triple_config().adjust_rank(); return true; @@ -642,8 +641,6 @@ HandshakeResponseV2 LrHandler::BuildHandshakeResponse() { ctx_->ttp_config.ttp_server_host); ss_param.mutable_triple_config()->set_version( ctx_->ttp_config.ttp_server_version); - ss_param.mutable_triple_config()->set_session_id( - ctx_->ttp_config.ttp_session_id); ss_param.mutable_triple_config()->set_adjust_rank( ctx_->ttp_config.ttp_adjust_rank); response.add_protocol_family_params()->PackFrom(ss_param); @@ -755,10 +752,12 @@ std::unique_ptr LrHandler::MakeSpuContext() { config.set_beaver_type(spu::RuntimeConfig_BeaverType_TrustedThirdParty); config.mutable_ttp_beaver_config()->set_server_host( ctx_->ttp_config.ttp_server_host); + config.mutable_ttp_beaver_config()->set_asym_crypto_schema( + ctx_->ttp_config.ttp_asym_crypto_schema); + config.mutable_ttp_beaver_config()->set_server_public_key( + ctx_->ttp_config.ttp_public_key); config.mutable_ttp_beaver_config()->set_adjust_rank( ctx_->ttp_config.ttp_adjust_rank); - config.mutable_ttp_beaver_config()->set_session_id( - ctx_->ttp_config.ttp_session_id); } else { config.set_beaver_type(spu::RuntimeConfig_BeaverType_TrustedFirstParty); } diff --git a/ic_impl/env/ss-lr-env-alice.sh b/ic_impl/env/ss-lr-env-alice.sh index 6633ee0..9180549 100644 --- a/ic_impl/env/ss-lr-env-alice.sh +++ b/ic_impl/env/ss-lr-env-alice.sh @@ -22,7 +22,8 @@ exec env \ "runtime.component.parameter.shard_serialize_format=raw" \ "runtime.component.parameter.use_ttp=true" \ "runtime.component.parameter.ttp_server_host=127.0.0.1:9449" \ -"runtime.component.parameter.ttp_session_id=interconnection-root-1" \ +"runtime.component.parameter.ttp_asym_crypto_schema=sm2" \ +"runtime.component.parameter.ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==" \ "runtime.component.parameter.ttp_adjust_rank=0" \ "runtime.component.parameter.label_owner=host.0" \ 'runtime.component.parameter.feature_nums={"host.0":10, "guest.0":10}' \ diff --git a/ic_impl/env/ss-lr-env-bob.sh b/ic_impl/env/ss-lr-env-bob.sh index d66b0e2..14bb7ec 100644 --- a/ic_impl/env/ss-lr-env-bob.sh +++ b/ic_impl/env/ss-lr-env-bob.sh @@ -22,7 +22,8 @@ exec env \ "runtime.component.parameter.shard_serialize_format=raw" \ "runtime.component.parameter.use_ttp=true" \ "runtime.component.parameter.ttp_server_host=127.0.0.1:9449" \ -"runtime.component.parameter.ttp_session_id=interconnection-root-1" \ +"runtime.component.parameter.ttp_asym_crypto_schema=sm2" \ +"runtime.component.parameter.ttp_public_key=LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUZvd0ZBWUlLb0VjejFVQmdpMEdDQ3FCSE05VkFZSXRBMElBQkxROGc3Q3R3eis3OW1ZQTcrbCtVOWliTjdaSAorVG1yZ0R4bGk2dmpWNTlaaG14c3U4eDNSR1pBS09Bd2gzNnhYcHhHZ2F5MXRGYWtDazNMZFV0azJzND0KLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==" \ "runtime.component.parameter.ttp_adjust_rank=0" \ "runtime.component.parameter.label_owner=host.0" \ 'runtime.component.parameter.feature_nums={"host.0":10, "guest.0":10}' \ diff --git a/ic_impl/protocol_family/ss/BUILD.bazel b/ic_impl/protocol_family/ss/BUILD.bazel index 6fb2251..c8f2866 100644 --- a/ic_impl/protocol_family/ss/BUILD.bazel +++ b/ic_impl/protocol_family/ss/BUILD.bazel @@ -23,6 +23,7 @@ cc_library( deps = [ "//ic_impl:util", "//ic_impl:handshake_cc_proto", + "@com_github_brpc_brpc//:brpc", "@com_github_gflags_gflags//:gflags", ] ) diff --git a/ic_impl/protocol_family/ss/ss.cc b/ic_impl/protocol_family/ss/ss.cc index e991406..e33dc89 100644 --- a/ic_impl/protocol_family/ss/ss.cc +++ b/ic_impl/protocol_family/ss/ss.cc @@ -14,6 +14,7 @@ #include "ic_impl/protocol_family/ss/ss.h" +#include "butil/base64.h" #include "gflags/gflags.h" #include "ic_impl/util.h" @@ -33,8 +34,9 @@ DEFINE_bool(use_ttp, false, "whether use trusted third party's beaver service"); DEFINE_string( ttp_server_host, "127.0.0.1:9449", "trustedThirdParty beaver server's remote ip:port or load-balance uri"); -DEFINE_string(ttp_session_id, "interconnection-root", - "trustedThirdParty beaver server's session id"); +DEFINE_string(ttp_asym_crypto_schema, "sm2", + "asym_crypto_schema: support [\"SM2\"]"); +DEFINE_string(ttp_public_key, "", "TTP public key"); DEFINE_int32(ttp_adjust_rank, 0, "which rank do adjust rpc call"); namespace ic_impl::protocol_family::ss { @@ -77,8 +79,16 @@ std::string SuggestedTtpServerHost() { return util::GetParamEnv("ttp_server_host", FLAGS_ttp_server_host); } -std::string SuggestedTtpSessionId() { - return util::GetParamEnv("ttp_session_id", FLAGS_ttp_session_id); +std::string SuggestedTtpAsymCryptoSchema() { + return util::GetParamEnv("ttp_asym_crypto_schema", + FLAGS_ttp_asym_crypto_schema); +} + +std::string SuggestedTtpPublicKey() { + std::string ret; + auto pk = util::GetParamEnv("ttp_public_key", FLAGS_ttp_public_key); + YACL_ENFORCE(butil::Base64Decode(pk, &ret)); + return ret; } int32_t SuggestedTtpAdjustRank() { @@ -102,7 +112,8 @@ TrustedThirdPartyConfig SuggestedTtpConfig() { TrustedThirdPartyConfig ttp_config; ttp_config.use_ttp = SuggestedUseTtp(); ttp_config.ttp_server_host = SuggestedTtpServerHost(); - ttp_config.ttp_session_id = SuggestedTtpSessionId(); + ttp_config.ttp_asym_crypto_schema = SuggestedTtpAsymCryptoSchema(); + ttp_config.ttp_public_key = SuggestedTtpPublicKey(); ttp_config.ttp_adjust_rank = SuggestedTtpAdjustRank(); return ttp_config; diff --git a/ic_impl/protocol_family/ss/ss.h b/ic_impl/protocol_family/ss/ss.h index 68bed14..3fc1de0 100644 --- a/ic_impl/protocol_family/ss/ss.h +++ b/ic_impl/protocol_family/ss/ss.h @@ -29,8 +29,9 @@ struct SsProtocolParam { struct TrustedThirdPartyConfig { bool use_ttp; std::string ttp_server_host; - int32_t ttp_server_version = 1; - std::string ttp_session_id; + int32_t ttp_server_version = 2; + std::string ttp_asym_crypto_schema; + std::string ttp_public_key; int32_t ttp_adjust_rank; };