Skip to content

Commit

Permalink
Repo sync (#586)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
anakinxc authored Mar 4, 2024
1 parent dd968fa commit cb69a32
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 129 deletions.
6 changes: 3 additions & 3 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def _libpsi():
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.3.0.dev240222.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.3.0.dev240304.tar.gz",
],
strip_prefix = "psi-0.3.0.dev240222",
sha256 = "a7319040510a1581741f05ac4b31e3d887ba8ba4766154736f96d76970d00de5",
strip_prefix = "psi-0.3.0.dev240304",
sha256 = "6e56dceaffbe67f7d17fbb32a5486ec31c6f17156aadb9ac57f47e4c7fe6b384",
)

def _rules_proto_grpc():
Expand Down
1 change: 1 addition & 0 deletions libspu/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ spu_cc_library(
deps = [
"//libspu:spu_cc_proto",
"//libspu/core:prelude",
"@yacl//yacl/utils:parallel",
],
)

Expand Down
5 changes: 5 additions & 0 deletions libspu/core/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "libspu/core/config.h"

#include "prelude.h"
#include "yacl/utils/parallel.h"

namespace spu {
namespace {
Expand Down Expand Up @@ -43,6 +44,10 @@ void populateRuntimeConfig(RuntimeConfig& cfg) {
SPU_ENFORCE(cfg.protocol() != ProtocolKind::PROT_INVALID);
SPU_ENFORCE(cfg.field() != FieldType::FT_INVALID);

if (cfg.max_concurrency() == 0) {
cfg.set_max_concurrency(yacl::get_num_threads());
}

//
if (cfg.fxp_fraction_bits() == 0) {
cfg.set_fxp_fraction_bits(defaultFxpBits(cfg.field()));
Expand Down
24 changes: 23 additions & 1 deletion libspu/core/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

#include "libspu/core/context.h"

#include "yacl/link/algorithm/allgather.h"
#include "yacl/utils/parallel.h"

#include "libspu/core/trace.h"

namespace spu {
Expand All @@ -35,7 +38,26 @@ SPUContext::SPUContext(const RuntimeConfig& config,
const std::shared_ptr<yacl::link::Context>& lctx)
: config_(config),
prot_(std::make_unique<Object>(genRootObjectId(lctx))),
lctx_(lctx) {}
lctx_(lctx),
max_cluster_level_concurrency_(yacl::get_num_threads()) {
// Limit number of threads
if (config.max_concurrency() > 0) {
yacl::set_num_threads(config.max_concurrency());
max_cluster_level_concurrency_ = std::min<int32_t>(
max_cluster_level_concurrency_, config.max_concurrency());
}

if (lctx_) {
auto other_max = yacl::link::AllGather(
lctx, {&max_cluster_level_concurrency_, sizeof(int32_t)}, "num_cores");

// Comupte min
for (const auto& o : other_max) {
max_cluster_level_concurrency_ = std::min<int32_t>(
max_cluster_level_concurrency_, o.data<int32_t>()[0]);
}
}
}

std::unique_ptr<SPUContext> SPUContext::fork() const {
std::shared_ptr<yacl::link::Context> new_lctx =
Expand Down
9 changes: 9 additions & 0 deletions libspu/core/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class SPUContext final {
// TODO(jint): do we really need a link here? how about a FHE context.
std::shared_ptr<yacl::link::Context> lctx_;

// Min number of cores in SPU cluster
int32_t max_cluster_level_concurrency_;

public:
explicit SPUContext(const RuntimeConfig& config,
const std::shared_ptr<yacl::link::Context>& lctx);
Expand Down Expand Up @@ -81,6 +84,12 @@ class SPUContext final {
StateT* getState() {
return prot_->template getState<StateT>();
}

// If any task assumes same level of parallelism across all instances,
// this is the max number of tasks to launch at the same time.
int32_t getClusterLevelMaxConcurrency() const {
return max_cluster_level_concurrency_;
}
};

class KernelEvalContext final {
Expand Down
3 changes: 3 additions & 0 deletions libspu/spu.proto
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ message RuntimeConfig {
// 0(default) indicates implementation defined.
int64 fxp_fraction_bits = 3;

// Max number of cores
int32 max_concurrency = 4;

///////////////////////////////////////
// Advanced
///////////////////////////////////////
Expand Down
4 changes: 2 additions & 2 deletions spu/libpsi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ void BindLibs(py::module& m) {
psi::BucketPsiConfig config;
YACL_ENFORCE(config.ParseFromString(config_pb));

psi::BucketPsi psi(config, lctx, ic_mode);
auto r = psi.Run(std::move(progress_callbacks), callbacks_interval_ms);
auto r = psi::RunLegacyPsi(config, lctx, std::move(progress_callbacks),
callbacks_interval_ms, ic_mode);
return r.SerializeAsString();
},
py::arg("link_context"), py::arg("psi_config"),
Expand Down
2 changes: 2 additions & 0 deletions spu/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ py_test(
name = "link_test",
srcs = ["link_test.py"],
deps = [
":utils",
"//spu:api",
],
)
Expand Down Expand Up @@ -308,6 +309,7 @@ py_test(
"exclusive-if-local",
],
deps = [
":utils",
"//spu/utils:distributed",
],
)
Expand Down
21 changes: 7 additions & 14 deletions spu/tests/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,13 @@

import spu.utils.distributed as ppd
from spu import spu_pb2


def unused_tcp_port() -> int:
"""Return an unused port"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("localhost", 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return cast(int, sock.getsockname()[1])
from spu.tests.utils import get_free_port


TEST_NODES_DEF = {
"node:0": f"127.0.0.1:{unused_tcp_port()}",
"node:1": f"127.0.0.1:{unused_tcp_port()}",
"node:2": f"127.0.0.1:{unused_tcp_port()}",
"node:0": f"127.0.0.1:{get_free_port()}",
"node:1": f"127.0.0.1:{get_free_port()}",
"node:2": f"127.0.0.1:{get_free_port()}",
}


Expand All @@ -50,9 +43,9 @@ def unused_tcp_port() -> int:
"config": {
"node_ids": ["node:0", "node:1", "node:2"],
"spu_internal_addrs": [
f"127.0.0.1:{unused_tcp_port()}",
f"127.0.0.1:{unused_tcp_port()}",
f"127.0.0.1:{unused_tcp_port()}",
f"127.0.0.1:{get_free_port()}",
f"127.0.0.1:{get_free_port()}",
f"127.0.0.1:{get_free_port()}",
],
"runtime_config": {
"protocol": "ABY3",
Expand Down
23 changes: 9 additions & 14 deletions spu/tests/link_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,14 @@

import spu.libspu.link as link
from socket import socket


def _rand_port():
with socket() as s:
s.bind(("localhost", 0))
return s.getsockname()[1]
from spu.tests.utils import get_free_port


class UnitTests(unittest.TestCase):
def test_link_brpc(self):
desc = link.Desc()
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")

def proc(rank):
data = "hello" if rank == 0 else "world"
Expand Down Expand Up @@ -90,8 +85,8 @@ def thread(rank):

def test_link_send_recv(self):
desc = link.Desc()
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")

def proc(rank):
lctx = link.create_brpc(desc, rank)
Expand All @@ -116,8 +111,8 @@ def proc(rank):

def test_link_send_async(self):
desc = link.Desc()
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")

def proc(rank):
lctx = link.create_brpc(desc, rank)
Expand All @@ -140,8 +135,8 @@ def proc(rank):

def test_link_next_rank(self):
desc = link.Desc()
desc.add_party("alice", f"127.0.0.1:{_rand_port()}")
desc.add_party("bob", f"127.0.0.1:{_rand_port()}")
desc.add_party("alice", f"127.0.0.1:{get_free_port()}")
desc.add_party("bob", f"127.0.0.1:{get_free_port()}")

def proc(rank):
lctx = link.create_brpc(desc, rank)
Expand Down
68 changes: 39 additions & 29 deletions spu/tests/pir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@

import spu.libspu.link as link
import spu.psi as psi
from spu.tests.utils import create_clean_folder, create_link_desc, wc_count
from spu.tests.utils import create_link_desc, wc_count
from tempfile import TemporaryDirectory


class UnitTests(unittest.TestCase):
def setUp(self) -> None:
self.tempdir_ = TemporaryDirectory()
return super().setUp()

def tearDown(self) -> None:
self.tempdir_.cleanup()
return super().tearDown()

def test_pir(self):
# setup stage

server_setup_config = '''
{
server_setup_config = f'''
{{
"mode": "MODE_SERVER_SETUP",
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
"pir_server_config": {
"pir_server_config": {{
"input_path": "spu/tests/data/alice.csv",
"setup_path": "/tmp/spu_test_pir_pir_server_setup",
"setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup",
"key_columns": [
"id"
],
Expand All @@ -42,56 +50,56 @@ def test_pir(self):
],
"label_max_len": 288,
"bucket_size": 1000000,
"apsi_server_config": {
"oprf_key_path": "/tmp/spu_test_pir_server_secret_key.bin",
"apsi_server_config": {{
"oprf_key_path": "{self.tempdir_.name}/spu_test_pir_server_secret_key.bin",
"num_per_query": 1,
"compressed": false
}
}
}
}}
}}
}}
'''

with open("/tmp/spu_test_pir_server_secret_key.bin", 'wb') as f:
with open(
f"{self.tempdir_.name}/spu_test_pir_server_secret_key.bin", 'wb'
) as f:
f.write(
bytes.fromhex(
"000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000"
)
)

create_clean_folder("/tmp/spu_test_pir_pir_server_setup")

psi.pir(json_format.ParseDict(json.loads(server_setup_config), psi.PirConfig()))

server_online_config = '''
{
server_online_config = f'''
{{
"mode": "MODE_SERVER_ONLINE",
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
"pir_server_config": {
"setup_path": "/tmp/spu_test_pir_pir_server_setup"
}
}
"pir_server_config": {{
"setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup"
}}
}}
'''

client_online_config = '''
{
client_online_config = f'''
{{
"mode": "MODE_CLIENT",
"pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI",
"pir_client_config": {
"input_path": "/tmp/spu_test_pir_pir_client.csv",
"pir_client_config": {{
"input_path": "{self.tempdir_.name}/spu_test_pir_pir_client.csv",
"key_columns": [
"id"
],
"output_path": "/tmp/spu_test_pir_pir_output.csv"
}
}
"output_path": "{self.tempdir_.name}/spu_test_pir_pir_output.csv"
}}
}}
'''

pir_client_input_content = '''id
user808
xxx
'''

with open("/tmp/spu_test_pir_pir_client.csv", 'w') as f:
with open(f"{self.tempdir_.name}/spu_test_pir_pir_client.csv", 'w') as f:
f.write(pir_client_input_content)

configs = [
Expand All @@ -118,7 +126,9 @@ def wrap(rank, link_desc, configs):
self.assertEqual(job.exitcode, 0)

# including title, actual matched item cnt is 1.
self.assertEqual(wc_count("/tmp/spu_test_pir_pir_output.csv"), 2)
self.assertEqual(
wc_count(f"{self.tempdir_.name}/spu_test_pir_pir_output.csv"), 2
)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit cb69a32

Please sign in to comment.