diff --git a/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc b/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc index 139f2650..8d10e7a5 100644 --- a/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc +++ b/libspu/dialect/pphlo/transforms/partial_sort_to_topk.cc @@ -172,8 +172,7 @@ struct SortConversion : public OpRewritePattern { } // rewrite all slices - for (const auto &use : uses) { - auto slice = mlir::dyn_cast(use.getOwner()); + for (auto &slice : slices_to_rewrite) { auto offset = slice.getStartIndices()[sort_dim] - start; llvm::SmallVector new_start(slice.getStartIndices().begin(), slice.getStartIndices().end()); diff --git a/libspu/mpc/aby3/oram.cc b/libspu/mpc/aby3/oram.cc index 395a05ba..c668c745 100644 --- a/libspu/mpc/aby3/oram.cc +++ b/libspu/mpc/aby3/oram.cc @@ -48,7 +48,7 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, // generate aeskey for dpf auto [self_aes_keys, next_aes_keys] = oram::genAesKey(ctx, 1); - auto *octx = new oram::OramContext(s); + auto octx = oram::OramContext(s); for (int64_t j = 0; j < 3; j++) { // in round (rank - 1), as helper @@ -64,16 +64,16 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, auto target_point = dpf_rank ? target_idxs_[0][0] ^ target_idxs_[0][1] : target_idxs_[0][0]; // dpf gen - octx->genDpf(ctx, static_cast(j), aes_key, - target_point); + octx.genDpf(ctx, static_cast(j), aes_key, + target_point); // B2A - octx->onehotB2A(ctx, static_cast(j)); + octx.onehotB2A(ctx, static_cast(j)); } } pforeach(0, s, [&](int64_t k) { for (int64_t j = 0; j < 2; j++) { - out_[k][j] = octx->dpf_e[j][k]; + out_[k][j] = octx.dpf_e[j][k]; } }); }); @@ -124,15 +124,16 @@ NdArrayRef OramOneHotAP::proc(KernelEvalContext *ctx, const NdArrayRef &in, comm->sendAsync(dst_rank, {aes_key}, "aes_key"); aes_key += comm->recv(dst_rank, "aes_key")[0]; - auto *octx = new oram::OramContext(s); + auto octx = oram::OramContext(s); + // dpf gen - octx->genDpf(ctx, static_cast(1), aes_key, - target_point_2pc_[0]); + octx.genDpf(ctx, static_cast(1), aes_key, + target_point_2pc_[0]); // B2A - octx->onehotB2A(ctx, static_cast(1)); + octx.onehotB2A(ctx, static_cast(1)); int64_t j = comm->getRank() == 0 ? 1 : 0; - pforeach(0, s, [&](int64_t k) { out_[k] = octx->dpf_e[j][k]; }); + pforeach(0, s, [&](int64_t k) { out_[k] = octx.dpf_e[j][k]; }); } }); @@ -480,9 +481,9 @@ void OramContext::genDpf(KernelEvalContext *ctx, DpfGenCtrl ctrl, uint128_t aes_key, uint128_t target_point) { auto *comm = ctx->getState(); - auto *odpf = new OramDpf(dpf_size_, yacl::crypto::SecureRandU128(), aes_key, - static_cast(target_point)); - odpf->gen(ctx, ctrl); + auto odpf = OramDpf(dpf_size_, yacl::crypto::SecureRandU128(), aes_key, + static_cast(target_point)); + odpf.gen(ctx, ctrl); auto dpf_rank = comm->getRank() == static_cast(ctrl); int64_t dpf_idx = dpf_rank ? 0 : 1; @@ -490,10 +491,10 @@ void OramContext::genDpf(KernelEvalContext *ctx, DpfGenCtrl ctrl, // cast e and v to T type and convert v to arith // leave convert e outside - std::transform(odpf->final_e.begin(), odpf->final_e.begin() + dpf_size_, + std::transform(odpf.final_e.begin(), odpf.final_e.begin() + dpf_size_, dpf_e[dpf_idx].begin(), [&](uint8_t x) { return neg_flag * static_cast(x); }); - std::transform(odpf->final_v.begin(), odpf->final_v.begin() + dpf_size_, + std::transform(odpf.final_v.begin(), odpf.final_v.begin() + dpf_size_, convert_help_v[dpf_idx].begin(), [&](uint128_t x) { return neg_flag * static_cast(x); }); };