Skip to content

Commit

Permalink
Repo sync (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Aug 16, 2024
1 parent 0332ce2 commit c826c48
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
3 changes: 1 addition & 2 deletions libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ struct SortConversion : public OpRewritePattern<SimpleSortOp> {
}

// rewrite all slices
for (const auto &use : uses) {
auto slice = mlir::dyn_cast<SliceOp>(use.getOwner());
for (auto &slice : slices_to_rewrite) {
auto offset = slice.getStartIndices()[sort_dim] - start;
llvm::SmallVector<int64_t> new_start(slice.getStartIndices().begin(),
slice.getStartIndices().end());
Expand Down
31 changes: 16 additions & 15 deletions libspu/mpc/aby3/oram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in,
// generate aeskey for dpf
auto [self_aes_keys, next_aes_keys] = oram::genAesKey(ctx, 1);

auto *octx = new oram::OramContext<el_t>(s);
auto octx = oram::OramContext<el_t>(s);

for (int64_t j = 0; j < 3; j++) {
// in round (rank - 1), as helper
Expand All @@ -64,16 +64,16 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in,
auto target_point = dpf_rank ? target_idxs_[0][0] ^ target_idxs_[0][1]
: target_idxs_[0][0];
// dpf gen
octx->genDpf(ctx, static_cast<oram::DpfGenCtrl>(j), aes_key,
target_point);
octx.genDpf(ctx, static_cast<oram::DpfGenCtrl>(j), aes_key,
target_point);
// B2A
octx->onehotB2A(ctx, static_cast<oram::DpfGenCtrl>(j));
octx.onehotB2A(ctx, static_cast<oram::DpfGenCtrl>(j));
}
}

pforeach(0, s, [&](int64_t k) {
for (int64_t j = 0; j < 2; j++) {
out_[k][j] = octx->dpf_e[j][k];
out_[k][j] = octx.dpf_e[j][k];
}
});
});
Expand Down Expand Up @@ -124,15 +124,16 @@ NdArrayRef OramOneHotAP::proc(KernelEvalContext *ctx, const NdArrayRef &in,
comm->sendAsync<uint128_t>(dst_rank, {aes_key}, "aes_key");
aes_key += comm->recv<uint128_t>(dst_rank, "aes_key")[0];

auto *octx = new oram::OramContext<el_t>(s);
auto octx = oram::OramContext<el_t>(s);

// dpf gen
octx->genDpf(ctx, static_cast<oram::DpfGenCtrl>(1), aes_key,
target_point_2pc_[0]);
octx.genDpf(ctx, static_cast<oram::DpfGenCtrl>(1), aes_key,
target_point_2pc_[0]);
// B2A
octx->onehotB2A(ctx, static_cast<oram::DpfGenCtrl>(1));
octx.onehotB2A(ctx, static_cast<oram::DpfGenCtrl>(1));

int64_t j = comm->getRank() == 0 ? 1 : 0;
pforeach(0, s, [&](int64_t k) { out_[k] = octx->dpf_e[j][k]; });
pforeach(0, s, [&](int64_t k) { out_[k] = octx.dpf_e[j][k]; });
}
});

Expand Down Expand Up @@ -480,20 +481,20 @@ void OramContext<T>::genDpf(KernelEvalContext *ctx, DpfGenCtrl ctrl,
uint128_t aes_key, uint128_t target_point) {
auto *comm = ctx->getState<Communicator>();

auto *odpf = new OramDpf(dpf_size_, yacl::crypto::SecureRandU128(), aes_key,
static_cast<uint128_t>(target_point));
odpf->gen(ctx, ctrl);
auto odpf = OramDpf(dpf_size_, yacl::crypto::SecureRandU128(), aes_key,
static_cast<uint128_t>(target_point));
odpf.gen(ctx, ctrl);

auto dpf_rank = comm->getRank() == static_cast<size_t>(ctrl);
int64_t dpf_idx = dpf_rank ? 0 : 1;
T neg_flag = dpf_rank ? -1 : 1;

// cast e and v to T type and convert v to arith
// leave convert e outside
std::transform(odpf->final_e.begin(), odpf->final_e.begin() + dpf_size_,
std::transform(odpf.final_e.begin(), odpf.final_e.begin() + dpf_size_,
dpf_e[dpf_idx].begin(),
[&](uint8_t x) { return neg_flag * static_cast<T>(x); });
std::transform(odpf->final_v.begin(), odpf->final_v.begin() + dpf_size_,
std::transform(odpf.final_v.begin(), odpf.final_v.begin() + dpf_size_,
convert_help_v[dpf_idx].begin(),
[&](uint128_t x) { return neg_flag * static_cast<T>(x); });
};
Expand Down

0 comments on commit c826c48

Please sign in to comment.