Skip to content

Commit

Permalink
Repo sync (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Jan 24, 2024
1 parent 8e99b4c commit 42e6278
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
2 changes: 1 addition & 1 deletion spu/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pybind_extension(
":version_script.lds",
"@psi//psi/pir",
"@psi//psi/psi:bucket_psi",
"@psi//psi/psi:factory",
"@psi//psi/psi:launch",
"@psi//psi/psi:memory_psi",
"@yacl//yacl/link",
],
Expand Down
23 changes: 17 additions & 6 deletions spu/libpsi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include "psi/pir/pir.h"
#include "psi/psi/bucket_psi.h"
#include "psi/psi/factory.h"
#include "psi/psi/launch.h"
#include "psi/psi/memory_psi.h"
#include "psi/psi/utils/progress.h"

Expand Down Expand Up @@ -84,13 +84,24 @@ void BindLibs(py::module& m) {
psi::v2::PsiConfig psi_config;
YACL_ENFORCE(psi_config.ParseFromString(config_pb));

std::unique_ptr<psi::AbstractPsiParty> psi_party =
psi::createPsiParty(psi_config, lctx);
auto report = psi_party->Run();
auto report = psi::RunPsi(psi_config, lctx);
return report.SerializeAsString();
},
py::arg("psi_config"), py::arg("link_context") = nullptr,
"Run PSI with v2 API.", NO_GIL);
py::arg("psi_config"), py::arg("link_context"), "Run PSI with v2 API.",
NO_GIL);

m.def(
"ub_psi",
[](const std::string& config_pb,
const std::shared_ptr<yacl::link::Context>& lctx) -> py::bytes {
psi::v2::UbPsiConfig ub_psi_config;
YACL_ENFORCE(ub_psi_config.ParseFromString(config_pb));

auto report = psi::RunUbPsi(ub_psi_config, lctx);
return report.SerializeAsString();
},
py::arg("ub_psi_config"), py::arg("link_context") = nullptr,
"Run UB PSI with v2 API.", NO_GIL);

m.def(
"pir_setup",
Expand Down
42 changes: 38 additions & 4 deletions spu/psi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,20 @@
PsiResultReport,
PsiType,
)
from .psi_v2_pb2 import PsiConfig
from .psi_v2_pb2 import (
DebugOptions,
EcdhConfig,
IoConfig,
IoType,
KkrtConfig,
Protocol,
ProtocolConfig,
PsiConfig,
RecoveryConfig,
Role,
Rr22Config,
UbPsiConfig,
)


def mem_psi(
Expand Down Expand Up @@ -92,10 +105,11 @@ def gen_cache_for_2pc_ub_psi(config: BucketPsiConfig) -> PsiResultReport:

def psi_v2(
config: PsiConfig,
link: Context = None,
) -> PsiReport:
link: Context,
) -> PsiResultReport:
"""
Run PSI with v2 API.
Check PsiConfig at https://www.secretflow.org.cn/docs/psi/latest/en-US/reference/psi_v2_config#psiconfig.
:param config: psi config
:param link: the transport layer
:return: statistical results
Expand All @@ -104,6 +118,26 @@ def psi_v2(
config.SerializeToString(),
link,
)
report = PsiReport()
report = PsiResultReport()
report.ParseFromString(report_str)
return report


def ub_psi(
config: UbPsiConfig,
link: Context = None,
) -> PsiResultReport:
"""
Run PSI with v2 API.
Check UbPsiConfig at https://www.secretflow.org.cn/docs/psi/latest/en-US/reference/psi_v2_config#ubpsiconfig.
:param config: ub psi config
:param link: the transport layer
:return: statistical results
"""
report_str = libpsi.libs.ub_psi(
config.SerializeToString(),
link,
)
report = PsiResultReport()
report.ParseFromString(report_str)
return report

0 comments on commit 42e6278

Please sign in to comment.