diff --git a/src/processing/plugins/heu/generate_key b/src/processing/plugins/heu/generate_key new file mode 100755 index 000000000000..dace948f0241 Binary files /dev/null and b/src/processing/plugins/heu/generate_key differ diff --git a/src/processing/plugins/heu/generate_key.cc b/src/processing/plugins/heu/generate_key.cc new file mode 100644 index 000000000000..fd5b4ee64d88 --- /dev/null +++ b/src/processing/plugins/heu/generate_key.cc @@ -0,0 +1,45 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gflags/gflags.h" +#include "heu/library/phe/phe.h" + +int GenerateFile(std::string_view file_name, std::string_view buf) { + int fd = open(file_name.data(), O_CREAT | O_TRUNC | O_WRONLY, 0664); + YACL_ENFORCE(fd != -1, "errno {}, {}", errno, strerror(errno)); + + auto ret = write(fd, buf.data(), buf.size()); + YACL_ENFORCE(ret != -1, "errno {}, {}", errno, strerror(errno)); + close(fd); + return 0; +} + +DEFINE_string(schema, "ou", "Schema"); +DEFINE_int32(key_size, 2048, "Key size of phe schema."); + +int main(int argc, char **argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + fmt::print("schema: {}, key_size: {}\n", FLAGS_schema, FLAGS_key_size); + auto schema_type = heu::lib::phe::ParseSchemaType(FLAGS_schema); + auto he_kit = + std::make_unique(schema_type, FLAGS_key_size); + auto pk = he_kit->GetPublicKey()->Serialize(); + auto sk = he_kit->GetSecretKey()->Serialize(); + GenerateFile("public-key", pk); + GenerateFile("secret-key", sk); + fmt::print("generate key files done\n"); + return 0; +} diff --git a/src/processing/plugins/heu/heu_processor.cc b/src/processing/plugins/heu/heu_processor.cc new file mode 100644 index 000000000000..b0297355b0c8 --- /dev/null +++ b/src/processing/plugins/heu/heu_processor.cc @@ -0,0 +1,227 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "processing/heu_processor.h" + +#include + +#include "absl/strings/numbers.h" +#include "heu/library/phe/encoding/batch_encoder.h" + +namespace processing { + +namespace { + +using heu::lib::numpy::DestinationHeKit; +using heu::lib::numpy::HeKit; + +using heu::lib::numpy::CMatrix; +using heu::lib::numpy::PMatrix; + +using heu::lib::phe::BatchEncoder; +using heu::lib::phe::Ciphertext; +using heu::lib::phe::Plaintext; + +using std::string; + +const char kPublicKeyPath[] = "PUBLIC_KEY_PATH"; +const char kSecretKeyPath[] = "SECRET_KEY_PATH"; +const char pk_file[] = "public-key"; +const char sk_file[] = "secret-key"; +const int64_t kDefaultScale = 1e6; + +yacl::Buffer ReadFile(std::string_view file_name) { + int fd = open(file_name.data(), O_RDONLY); + YACL_ENFORCE(fd != -1, "errno {}, {}", errno, strerror(errno)); + + yacl::Buffer buf; + const int cnt = 100; + buf.reserve(cnt); + ssize_t num_read; + while ((num_read = read(fd, buf.data() + buf.size(), cnt)) > 0) { + YACL_ENFORCE(num_read != -1, "errno {}, {}", errno, strerror(errno)); + buf.resize(buf.size() + num_read); + buf.reserve(buf.size() + cnt); + } + + close(fd); + return buf; +} + +yacl::Buffer GetPublicKey(const std::map ¶ms) { + auto it = params.find(kPublicKeyPath); + if (it == params.end()) { + return ReadFile(pk_file); + } else { + return ReadFile(absl::StrCat(it->second, "/", pk_file)); + } +} + +yacl::Buffer GetSecretKey(const std::map ¶ms) { + auto it = params.find(kSecretKeyPath); + if (it == params.end()) { + return ReadFile(sk_file); + } else { + return ReadFile(absl::StrCat(it->second, "/", sk_file)); + } +} + +int64_t GetScale(const std::map ¶ms) { + auto it = params.find("scale"); + if (it == params.end()) { + return kDefaultScale; + } + + int64_t scale; + YACL_ENFORCE(absl::SimpleAtoi(it->second, &scale)); + YACL_ENFORCE(scale > 0); + return scale; +} + +} // namespace + +void HeuProcessor::Initialize(bool active, std::map params) { + active_ = active; + + auto pk_buffer = GetPublicKey(params); + if (active_) { + auto sk_buffer = GetSecretKey(params); + he_kit_ = + std::make_unique(heu::lib::phe::HeKit(pk_buffer, sk_buffer)); + scale_ = GetScale(params); + } else { + dest_he_kit_ = std::make_unique( + heu::lib::phe::DestinationHeKit(pk_buffer)); + } +} + +void HeuProcessor::Shutdown() { + this->cuts_.clear(); + this->slots_.clear(); + + he_kit_ = nullptr; + dest_he_kit_ = nullptr; + gh_ = nullptr; + scale_ = 0; +} + +void HeuProcessor::FreeBuffer(void *buffer) { free(buffer); } + +void *HeuProcessor::ProcessGHPairs(size_t *size, + const std::vector &pairs) { + YACL_ENFORCE(active_, "only active party allowed to call this function"); + YACL_ENFORCE(he_kit_, "he_kit equals to nullptr"); + YACL_ENFORCE(scale_ > 0, "scale not set"); + + auto encoder = he_kit_->GetEncoder(scale_); + PMatrix gh(pairs.size() / 2); + + gh.ForEach([&](int64_t row, int64_t, Plaintext *pt) { + *pt = encoder.Encode(pairs[2 * row], pairs[2 * row + 1]); + }); + + auto encryptor = he_kit_->GetEncryptor(); + gh_ = std::make_unique(encryptor->Encrypt(gh)); + auto buf = gh_->Serialize(); + *size = buf.size(); + return buf.release(); +} + +void *HeuProcessor::HandleGHPairs(size_t *size, void *buffer, size_t buf_size) { + *size = buf_size; + gh_ = std::make_unique( + CMatrix::LoadFrom(yacl::ByteContainerView(buffer, buf_size))); + + return buffer; // TODO: directly return buffer? +} + +void HeuProcessor::InitAggregationContext(const std::vector &cuts, + const std::vector &slots) { + this->cuts_ = cuts; + if (this->slots_.empty()) { + this->slots_ = slots; + } +} + +void *HeuProcessor::ProcessAggregation(size_t *size, + std::map> nodes) { + YACL_ENFORCE(dest_he_kit_, "dest_he_kit equals to nullptr"); + YACL_ENFORCE(gh_, "GH ciphertext matrix not set"); + + auto evaluator = dest_he_kit_->GetEvaluator(); + auto encryptor = dest_he_kit_->GetEncryptor(); + int total_bin_size = cuts_.back(); + auto feature_num = cuts_.size() - 1; + + CMatrix histograms(nodes.size(), total_bin_size); + auto zero = encryptor->EncryptZero(); + histograms.ForEach([&](int64_t, int64_t, Ciphertext *pt) { *pt = zero; }); + + int histo_i = 0; + for (const auto &node : nodes) { + const auto &rows = node.second; + for (int row_id : rows) { + yacl::parallel_for(0, feature_num, 1, [&](int64_t beg, int64_t end) { + for (int64_t f = beg; f < end; ++f) { + int slot = slots_[f + feature_num * row_id]; + if ((slot < 0) || (slot >= total_bin_size)) { + continue; + } + const auto &gh = (*gh_)(row_id); + evaluator->AddInplace(&histograms(histo_i, slot), gh); + } + }); + } + ++histo_i; + } + + auto buf = histograms.Serialize(); + *size = buf.size(); + return buf.release(); +} + +std::vector HeuProcessor::HandleAggregation(void *buffer, + size_t buf_size) { + YACL_ENFORCE(active_, "only active party allowed to call this function"); + YACL_ENFORCE(he_kit_, "he_kit equals to nullptr"); + YACL_ENFORCE(scale_ > 0, "scale not set"); + + auto decryptor = he_kit_->GetDecryptor(); + auto encoder = he_kit_->GetEncoder(scale_); + size_t offset = 0; + std::vector result; + + while (offset != buf_size) { + auto histogram = CMatrix::LoadFrom( + yacl::ByteContainerView(buffer, buf_size), + heu::lib::numpy::MatrixSerializeFormat::Best, &offset); + auto plaintexts = decryptor->Decrypt(histogram); + for (int i = 0; i < plaintexts.rows(); ++i) { + for (int j = 0; j < plaintexts.cols(); ++j) { + result.push_back(encoder.Decode(plaintexts(i, j))); + result.push_back(encoder.Decode(plaintexts(i, j))); + } + } + } + + return result; +} + +extern "C" { +Processor *LoadProcessor(const char *) { + return new processing::HeuProcessor(); // TODO: on heap? +} +} + +} // namespace processing diff --git a/src/processing/plugins/heu/heu_processor.h b/src/processing/plugins/heu/heu_processor.h new file mode 100644 index 000000000000..025d17fb066a --- /dev/null +++ b/src/processing/plugins/heu/heu_processor.h @@ -0,0 +1,62 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "heu/library/numpy/numpy.h" + +#include "processing/processor.h" + +namespace processing { + +class HeuProcessor : public Processor { + public: + void Initialize(bool active, + std::map params) override; + + void Shutdown() override; + + void FreeBuffer(void *buffer) override; + + void *ProcessGHPairs(size_t *size, const std::vector &pairs) override; + + void *HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override; + + void InitAggregationContext(const std::vector &cuts, + const std::vector &slots) override; + + void *ProcessAggregation(size_t *size, + std::map> nodes) override; + + std::vector HandleAggregation(void *buffer, size_t buf_size) override; + + void *ProcessHistograms(size_t *, const std::vector &) override { + YACL_THROW("not implemented"); + } + + std::vector HandleHistograms(void *, size_t) override { + YACL_THROW("not implemented"); + } + + private: + bool active_ = false; + int64_t scale_ = 0; + std::vector cuts_; + std::vector slots_; + std::unique_ptr gh_ = nullptr; + std::unique_ptr he_kit_ = nullptr; + std::unique_ptr dest_he_kit_ = nullptr; +}; + +} // namespace processing diff --git a/src/processing/plugins/heu/libproc_heu.so b/src/processing/plugins/heu/libproc_heu.so new file mode 100755 index 000000000000..65777849b69c Binary files /dev/null and b/src/processing/plugins/heu/libproc_heu.so differ diff --git a/src/processing/plugins/heu/processor_test.cc b/src/processing/plugins/heu/processor_test.cc new file mode 100644 index 000000000000..545c83ee6f0d --- /dev/null +++ b/src/processing/plugins/heu/processor_test.cc @@ -0,0 +1,99 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "processor.h" + +#include "gtest/gtest.h" + +namespace processing::test { + +class ProcessorTest : public testing::Test { + public: + void SetUp() override { + std::map params = { + {kLibraryPath, "processing/plugins"}}; + auto loader = processing::ProcessorLoader(params); + + active_processor_ = loader.Load("heu"); + active_processor_->Initialize(true, + {{"PUBLIC_KEY_PATH", "processing/plugins"}, + {"SECRET_KEY_PATH", "processing/plugins"}, + {"scale", "1000000"}}); + + passive_processor_ = loader.Load("heu"); + passive_processor_->Initialize(false, + {{"PUBLIC_KEY_PATH", "processing/plugins"}}); + } + + void TearDown() override { + active_processor_->Shutdown(); + active_processor_ = nullptr; // TODO: free? + + passive_processor_->Shutdown(); + passive_processor_ = nullptr; // TODO: free? + } + + protected: + processing::Processor *active_processor_ = nullptr; + processing::Processor *passive_processor_ = nullptr; + + // clang-format off + // Test data, 4 Rows, 2 Features + std::vector gh_pairs_ = { + 1.1, 2.1, + 3.1, 4.1, + 5.1, 6.1, + 7.1, 8.1 + }; // 4 Rows, 8 GH Pairs + std::vector cuts_ = {0, 4, 10}; // 2 features, one has 4 bins, another 6 + std::vector slots_ = { + 0, 4, + 1, 9, + 3, 7, + 0, 4 + }; + + std::vector node0_ = {0, 2}; + std::vector node1_ = {1, 3}; + + std::map> nodes_ = {{0, node0_}, + {1, node1_}}; + // clang-format on +}; + +TEST_F(ProcessorTest, TestAggregation) { + size_t buf_size; + auto *buffer = active_processor_->ProcessGHPairs(&buf_size, gh_pairs_); + passive_processor_->HandleGHPairs(&buf_size, buffer, buf_size); + active_processor_->FreeBuffer(buffer); + + passive_processor_->InitAggregationContext(cuts_, slots_); + buffer = passive_processor_->ProcessAggregation(&buf_size, nodes_); + auto histograms = active_processor_->HandleAggregation(buffer, buf_size); + passive_processor_->FreeBuffer(buffer); + + std::vector expected_result = { + 1.1, 2.1, 0, 0, 0, 0, 5.1, 6.1, 1.1, 2.1, 0, 0, 0, 0, + 5.1, 6.1, 0, 0, 0, 0, 7.1, 8.1, 3.1, 4.1, 0, 0, 0, 0, + 7.1, 8.1, 0, 0, 0, 0, 0, 0, 0, 0, 3.1, 4.1}; + + ASSERT_EQ(histograms.size(), expected_result.size()) + << "Histograms have different sizes"; + + for (size_t i = 0; i < histograms.size(); ++i) { + ASSERT_NEAR(histograms[i], expected_result[i], 1e-6); + } +} + +} // namespace processing::test